trsv.hpp
Go to the documentation of this file.
1/*!
2 *
3 *
4 * \brief -
5 *
6 * \author O. Krause
7 * \date 2011
8 *
9 *
10 * \par Copyright 1995-2015 Shark Development Team
11 *
12 * <BR><HR>
13 * This file is part of Shark.
14 * <http://image.diku.dk/shark/>
15 *
16 * Shark is free software: you can redistribute it and/or modify
17 * it under the terms of the GNU Lesser General Public License as published
18 * by the Free Software Foundation, either version 3 of the License, or
19 * (at your option) any later version.
20 *
21 * Shark is distributed in the hope that it will be useful,
22 * but WITHOUT ANY WARRANTY; without even the implied warranty of
23 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
24 * GNU Lesser General Public License for more details.
25 *
26 * You should have received a copy of the GNU Lesser General Public License
27 * along with Shark. If not, see <http://www.gnu.org/licenses/>.
28 *
29 */
30
31#ifndef REMORA_KERNELS_CBLAS_TRSV_HPP
32#define REMORA_KERNELS_CBLAS_TRSV_HPP
33
34#include "cblas_inc.hpp"
35#include <type_traits>
36
37///solves systems of triangular matrices
38
39namespace remora {namespace bindings {
40inline void trsv(
41 CBLAS_ORDER order, CBLAS_UPLO uplo,
42 CBLAS_TRANSPOSE transA, CBLAS_DIAG unit,
43 int n,
44 float const *A, int lda, float *b, int strideX
45){
46 cblas_strsv(order, uplo, transA, unit,n, A, lda, b, strideX);
47}
48
49inline void trsv(
50 CBLAS_ORDER order, CBLAS_UPLO uplo,
51 CBLAS_TRANSPOSE transA, CBLAS_DIAG unit,
52 int n,
53 double const *A, int lda, double *b, int strideX
54){
55 cblas_dtrsv(order, uplo, transA, unit,n, A, lda, b, strideX);
56}
57
58inline void trsv(
59 CBLAS_ORDER order, CBLAS_UPLO uplo,
60 CBLAS_TRANSPOSE transA, CBLAS_DIAG unit,
61 int n,
62 std::complex<float> const *A, int lda, std::complex<float> *b, int strideX
63){
64 cblas_ctrsv(order, uplo, transA, unit,n,
65 reinterpret_cast<cblas_float_complex_type const *>(A), lda,
66 reinterpret_cast<cblas_float_complex_type *>(b), strideX);
67}
68inline void trsv(
69 CBLAS_ORDER order, CBLAS_UPLO uplo,
70 CBLAS_TRANSPOSE transA, CBLAS_DIAG unit,
71 int n,
72 std::complex<double> const *A, int lda, std::complex<double> *b, int strideX
73){
74 cblas_ztrsv(order, uplo, transA, unit,n,
75 reinterpret_cast<cblas_double_complex_type const *>(A), lda,
76 reinterpret_cast<cblas_double_complex_type *>(b), strideX);
77}
78
79// trsv(): solves A system of linear equations A * x = b
80// when A is A triangular matrix.
81template <class Triangular,typename MatA, typename V>
82void trsv_impl(
83 matrix_expression<MatA, cpu_tag> const &A,
84 vector_expression<V, cpu_tag> &b,
85 std::true_type, left
86){
87 REMORA_SIZE_CHECK(A().size1() == A().size2());
88 REMORA_SIZE_CHECK(A().size1()== b().size());
89 CBLAS_DIAG cblasUnit = Triangular::is_unit?CblasUnit:CblasNonUnit;
90 CBLAS_ORDER const storOrd= (CBLAS_ORDER)storage_order<typename MatA::orientation>::value;
91 CBLAS_UPLO uplo = Triangular::is_upper?CblasUpper:CblasLower;
92
93
94 int const n = A().size1();
95 auto storageA = A().raw_storage();
96 auto storageb = b().raw_storage();
97 trsv(storOrd, uplo, CblasNoTrans,cblasUnit, n,
98 storageA.values,
99 storageA.leading_dimension,
100 storageb.values,
101 storageb.stride
102 );
103}
104
105//right is mapped onto left via transposing A
106template <class Triangular,typename MatA, typename V>
107void trsv_impl(
108 matrix_expression<MatA, cpu_tag> const &A,
109 vector_expression<V, cpu_tag> &b,
110 std::true_type, right
111){
112 trsv_impl<typename Triangular::transposed_orientation>(trans(A), b, std::true_type(), left());
113}
114
115//dispatcher
116
117template <class Triangular, class Side,typename MatA, typename V>
118void trsv(
119 matrix_expression<MatA, cpu_tag> const& A,
120 vector_expression<V, cpu_tag> & b,
121 std::true_type//optimized
122){
123 trsv_impl<Triangular>(A,b,std::true_type(), Side());
124}
125
126template<class M, class V>
127struct has_optimized_trsv: std::integral_constant<bool,
128 allowed_cblas_type<typename M::value_type>::type::value
129 && std::is_same<typename M::value_type, typename V::value_type>::value
130 && std::is_base_of<dense_tag, typename M::storage_type::storage_tag>::value
131 && std::is_base_of<dense_tag, typename V::storage_type::storage_tag>::value
132>{};
133
134}}
135#endif