trsm.hpp
Go to the documentation of this file.
1/*!
2 *
3 *
4 * \brief -
5 *
6 * \author O. Krause
7 * \date 2011
8 *
9 *
10 * \par Copyright 1995-2015 Shark Development Team
11 *
12 * <BR><HR>
13 * This file is part of Shark.
14 * <http://image.diku.dk/shark/>
15 *
16 * Shark is free software: you can redistribute it and/or modify
17 * it under the terms of the GNU Lesser General Public License as published
18 * by the Free Software Foundation, either version 3 of the License, or
19 * (at your option) any later version.
20 *
21 * Shark is distributed in the hope that it will be useful,
22 * but WITHOUT ANY WARRANTY; without even the implied warranty of
23 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
24 * GNU Lesser General Public License for more details.
25 *
26 * You should have received a copy of the GNU Lesser General Public License
27 * along with Shark. If not, see <http://www.gnu.org/licenses/>.
28 *
29 */
30
31#ifndef REMORA_KERNELS_CBLAS_TRSM_HPP
32#define REMORA_KERNELS_CBLAS_TRSM_HPP
33
34#include "cblas_inc.hpp"
35#include "../../proxy_expressions.hpp"
36#include <type_traits>
37///solves systems of triangular matrices
38
39namespace remora{namespace bindings {
40inline void trsm(
41 CBLAS_ORDER order, CBLAS_UPLO uplo,CBLAS_TRANSPOSE transA,
42 CBLAS_SIDE side, CBLAS_DIAG unit,
43 int n, int nRHS,
44 float const *A, int lda, float *B, int ldb
45) {
46 cblas_strsm(order, side, uplo, transA, unit,n, nRHS, 1.0, A, lda, B, ldb);
47}
48
49inline void trsm(
50 CBLAS_ORDER order, CBLAS_UPLO uplo,CBLAS_TRANSPOSE transA,
51 CBLAS_SIDE side, CBLAS_DIAG unit,
52 int n, int nRHS,
53 double const *A, int lda, double *B, int ldb
54) {
55 cblas_dtrsm(order, side, uplo, transA, unit,n, nRHS, 1.0, A, lda, B, ldb);
56}
57
58inline void trsm(
59 CBLAS_ORDER order, CBLAS_UPLO uplo,CBLAS_TRANSPOSE transA,
60 CBLAS_SIDE side, CBLAS_DIAG unit,
61 int n, int nRHS,
62 std::complex<float> const *A, int lda, std::complex<float> *B, int ldb
63) {
64 std::complex<float> alpha(1.0,0);
65 cblas_ctrsm(order, side, uplo, transA, unit,n, nRHS,
66 reinterpret_cast<cblas_float_complex_type const *>(&alpha),
67 reinterpret_cast<cblas_float_complex_type const *>(A), lda,
68 reinterpret_cast<cblas_float_complex_type *>(B), ldb);
69}
70inline void trsm(
71 CBLAS_ORDER order, CBLAS_UPLO uplo,CBLAS_TRANSPOSE transA,
72 CBLAS_SIDE side, CBLAS_DIAG unit,
73 int n, int nRHS,
74 std::complex<double> const *A, int lda, std::complex<double> *B, int ldb
75) {
76 std::complex<double> alpha(1.0,0);
77 cblas_ztrsm(order, side, uplo, transA, unit,n, nRHS,
78 reinterpret_cast<cblas_double_complex_type const *>(&alpha),
79 reinterpret_cast<cblas_double_complex_type const *>(A), lda,
80 reinterpret_cast<cblas_double_complex_type *>(B), ldb);
81}
82
83// trsm(): solves A system of linear equations A * X = B
84// when A is a triangular matrix
85template <class Triangular, typename MatA, typename MatB>
86void trsm_impl(
87 matrix_expression<MatA, cpu_tag> const &A,
88 matrix_expression<MatB, cpu_tag> &B,
89 std::true_type, left
90){
91 REMORA_SIZE_CHECK(A().size1() == A().size2());
92 REMORA_SIZE_CHECK(A().size1() == B().size1());
93
94 //orientation is defined by the second argument
95 CBLAS_ORDER const storOrd = (CBLAS_ORDER)storage_order<typename MatB::orientation>::value;
96 //if orientations do not match, wecan interpret this as transposing A
97 bool transposeA = !std::is_same<typename MatA::orientation,typename MatB::orientation>::value;
98
99 CBLAS_DIAG cblasUnit = Triangular::is_unit?CblasUnit:CblasNonUnit;
100 CBLAS_UPLO cblasUplo = (Triangular::is_upper != transposeA)?CblasUpper:CblasLower;
101 CBLAS_TRANSPOSE transA = transposeA?CblasTrans:CblasNoTrans;
102
103 int m = B().size1();
104 int nrhs = B().size2();
105 auto storageA = A().raw_storage();
106 auto storageB = B().raw_storage();
107 trsm(storOrd, cblasUplo, transA, CblasLeft,cblasUnit, m, nrhs,
108 storageA.values,
109 storageA.leading_dimension,
110 storageB.values,
111 storageB.leading_dimension
112 );
113}
114
115template <class Triangular, typename MatA, typename MatB>
116void trsm_impl(
117 matrix_expression<MatA, cpu_tag> const &A,
118 matrix_expression<MatB, cpu_tag> &B,
119 std::true_type, right
120){
121 auto transB = trans(B);
122 trsm_impl<typename Triangular::transposed_orientation>(trans(A), transB, std::true_type(), left());
123}
124
125template <class Triangular, class Side, typename MatA, typename MatB>
126void trsm(
127 matrix_expression<MatA, cpu_tag> const &A,
128 matrix_expression<MatB, cpu_tag> &B,
129 std::true_type
130){
131 trsm_impl<Triangular>(A,B, std::true_type(), Side());
132}
133
134template<class M1, class M2>
135struct has_optimized_trsm: std::integral_constant<bool,
136 allowed_cblas_type<typename M1::value_type>::type::value
137 && std::is_same<typename M1::value_type, typename M2::value_type>::value
138 && std::is_base_of<dense_tag, typename M1::storage_type::storage_tag>::value
139 && std::is_base_of<dense_tag, typename M2::storage_type::storage_tag>::value
140>{};
141
142}}
143#endif