trmm.hpp
Go to the documentation of this file.
1//===========================================================================
2/*!
3 *
4 *
5 * \brief -
6 *
7 * \author O. Krause
8 * \date 2010
9 *
10 *
11 * \par Copyright 1995-2015 Shark Development Team
12 *
13 * <BR><HR>
14 * This file is part of Shark.
15 * <http://image.diku.dk/shark/>
16 *
17 * Shark is free software: you can redistribute it and/or modify
18 * it under the terms of the GNU Lesser General Public License as published
19 * by the Free Software Foundation, either version 3 of the License, or
20 * (at your option) any later version.
21 *
22 * Shark is distributed in the hope that it will be useful,
23 * but WITHOUT ANY WARRANTY; without even the implied warranty of
24 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
25 * GNU Lesser General Public License for more details.
26 *
27 * You should have received a copy of the GNU Lesser General Public License
28 * along with Shark. If not, see <http://www.gnu.org/licenses/>.
29 *
30 */
31//===========================================================================
32#ifndef REMORA_KERNELS_CBLAS_TRMM_HPP
33#define REMORA_KERNELS_CBLAS_TRMM_HPP
34
35#include "cblas_inc.hpp"
36#include <type_traits>
37
38namespace remora{namespace bindings {
39
40inline void trmm(
41 CBLAS_ORDER const order,
42 CBLAS_SIDE const side,
43 CBLAS_UPLO const uplo,
44 CBLAS_TRANSPOSE const transA,
45 CBLAS_DIAG const unit,
46 int const M,
47 int const N,
48 float const *A, int const lda,
49 float* B, int const incB
50) {
51 cblas_strmm(order, side, uplo, transA, unit, M, N,
52 1.0,
53 A, lda,
54 B, incB
55 );
56}
57
58inline void trmm(
59 CBLAS_ORDER const order,
60 CBLAS_SIDE const side,
61 CBLAS_UPLO const uplo,
62 CBLAS_TRANSPOSE const transA,
63 CBLAS_DIAG const unit,
64 int const M,
65 int const N,
66 double const *A, int const lda,
67 double* B, int const incB
68) {
69 cblas_dtrmm(order, side, uplo, transA, unit, M, N,
70 1.0,
71 A, lda,
72 B, incB
73 );
74}
75
76
77inline void trmm(
78 CBLAS_ORDER const order,
79 CBLAS_SIDE const side,
80 CBLAS_UPLO const uplo,
81 CBLAS_TRANSPOSE const transA,
82 CBLAS_DIAG const unit,
83 int const M,
84 int const N,
85 std::complex<float> const *A, int const lda,
86 std::complex<float>* B, int const incB
87) {
88 std::complex<float> alpha = 1.0;
89 cblas_ctrmm(order, side, uplo, transA, unit, M, N,
90 reinterpret_cast<cblas_float_complex_type const *>(&alpha),
91 reinterpret_cast<cblas_float_complex_type const *>(A), lda,
92 reinterpret_cast<cblas_float_complex_type *>(B), incB
93 );
94}
95
96inline void trmm(
97 CBLAS_ORDER const order,
98 CBLAS_SIDE const side,
99 CBLAS_UPLO const uplo,
100 CBLAS_TRANSPOSE const transA,
101 CBLAS_DIAG const unit,
102 int const M,
103 int const N,
104 std::complex<double> const *A, int const lda,
105 std::complex<double>* B, int const incB
106) {
107 std::complex<double> alpha = 1.0;
108 cblas_ztrmm(order, side, uplo, transA, unit, M, N,
109 reinterpret_cast<cblas_double_complex_type const *>(&alpha),
110 reinterpret_cast<cblas_double_complex_type const *>(A), lda,
111 reinterpret_cast<cblas_double_complex_type *>(B), incB
112 );
113}
114
115template <bool upper, bool unit, typename MatA, typename MatB>
116void trmm(
117 matrix_expression<MatA, cpu_tag> const& A,
118 matrix_expression<MatB, cpu_tag>& B,
119 std::true_type
120){
121 REMORA_SIZE_CHECK(A().size1() == A().size2());
122 REMORA_SIZE_CHECK(A().size2() == B().size1());
123 std::size_t n = A().size1();
124 std::size_t m = B().size2();
125 CBLAS_DIAG cblasUnit = unit?CblasUnit:CblasNonUnit;
126 CBLAS_UPLO cblasUplo = upper?CblasUpper:CblasLower;
127 CBLAS_ORDER stor_ord= (CBLAS_ORDER)storage_order<typename MatA::orientation>::value;
128 CBLAS_TRANSPOSE trans=CblasNoTrans;
129
130 //special case: MatA and MatB do not have same storage order. in this case compute as
131 //AB->B^TA^T where transpose of B is done implicitely by exchanging storage order
132 CBLAS_ORDER stor_ordB= (CBLAS_ORDER)storage_order<typename MatB::orientation>::value;
133 if(stor_ord != stor_ordB){
134 trans = CblasTrans;
135 cblasUplo= upper?CblasLower:CblasUpper;
136 }
137
138 auto storageA = A().raw_storage();
139 auto storageB = B().raw_storage();
140 trmm(stor_ordB, CblasLeft, cblasUplo, trans, cblasUnit,
141 (int)n, int(m),
142 storageA.values,
143 storageA.leading_dimension,
144 storageB.values,
145 storageB.leading_dimension
146 );
147}
148
149
150template<class M1, class M2>
151struct has_optimized_trmm: std::integral_constant<bool,
152 allowed_cblas_type<typename M1::value_type>::type::value
153 && std::is_same<typename M1::value_type, typename M2::value_type>::value
154 && std::is_base_of<dense_tag, typename M1::storage_type::storage_tag>::value
155 && std::is_base_of<dense_tag, typename M2::storage_type::storage_tag>::value
156>{};
157
158}}
159#endif