31#ifndef REMORA_KERNELS_DEFAULT_TRSM_HPP
32#define REMORA_KERNELS_DEFAULT_TRSM_HPP
34#include "../../expression_types.hpp"
35#include "../../proxy_expressions.hpp"
36#include "../../detail/structure.hpp"
42namespace remora{
namespace bindings {
45template<std::
size_t maxBlockSize1,std::
size_t maxBlockSize2,
bool Unit,
class MatA,
class MatB>
47 matrix_expression<MatA, cpu_tag>
const& A,
48 matrix_expression<MatB, cpu_tag> &B,
52 REMORA_SIZE_CHECK(A().size1() <= maxBlockSize1);
53 typedef typename MatA::value_type value_typeA;
54 typedef typename MatB::value_type value_typeB;
58 std::size_t size = A().size1();
59 value_typeA blockA[maxBlockSize1][maxBlockSize1];
60 for(std::size_t i = 0; i != size; ++i){
61 for(std::size_t j = 0; j <= i; ++j){
62 blockA[i][j] = A()(i,j);
67 value_typeB blockB[maxBlockSize2][maxBlockSize1];
68 std::size_t numBlocks = (B().size2()+maxBlockSize2-1)/maxBlockSize2;
69 for(std::size_t i = 0; i != numBlocks; ++i){
70 std::size_t startB= i*maxBlockSize2;
71 std::size_t curBlockSize2 =std::min(maxBlockSize2, B().size2() - startB);
74 for(std::size_t i = 0; i != size; ++i){
75 for(std::size_t k = 0; k != curBlockSize2; ++k){
76 blockB[k][i] = B()(i,startB+k);
80 for(std::size_t k = 0; k != curBlockSize2; ++k){
81 for (std::size_t i = 0; i != size; ++i) {
82 for (std::size_t j = 0; j != i; ++j) {
83 blockB[k][i] -= blockA[i][j]*blockB[k][j];
86 if(blockA[i][i] == value_typeA())
87 throw std::invalid_argument(
"[TRSM] Matrix is singular!");
88 blockB[k][i] /= blockA[i][i];
93 for(std::size_t i = 0; i != size; ++i){
94 for(std::size_t k = 0; k != curBlockSize2; ++k){
95 B()(i,startB+k) = blockB[k][i];
102template<std::
size_t maxBlockSize1,std::
size_t maxBlockSize2,
bool Unit,
class MatA,
class MatB>
104 matrix_expression<MatA, cpu_tag>
const& A,
105 matrix_expression<MatB, cpu_tag>& B,
109 typedef typename MatA::value_type value_type;
111 std::size_t size = A().size1();
112 value_type blockA[maxBlockSize1][maxBlockSize1];
113 for(std::size_t i = 0; i != size; ++i){
114 for(std::size_t j = 0; j <= i; ++j){
115 blockA[i][j] = A()(i,j);
120 for(std::size_t k = 0; k != B().size2(); ++k){
121 for (std::size_t i = 0; i != size; ++i) {
122 for (std::size_t j = 0; j != i; ++j) {
123 B()(i,k) -= blockA[i][j] * B()(j,k);
126 if(blockA[i][i] == value_type())
127 throw std::invalid_argument(
"[TRSM] Matrix is singular!");
128 B()(i,k) /= blockA[i][i];
136template<std::
size_t maxBlockSize1, std::
size_t maxBlockSize2,
bool Unit,
class MatA,
class MatB>
138 matrix_expression<MatA, cpu_tag>
const& A,
139 matrix_expression<MatB, cpu_tag> &B,
143 REMORA_SIZE_CHECK(A().size1() <= maxBlockSize1);
144 typedef typename MatA::value_type value_typeA;
145 typedef typename MatB::value_type value_typeB;
148 std::size_t size = A().size1();
149 value_typeA blockA[maxBlockSize1][maxBlockSize1];
150 for(std::size_t i = 0; i != size; ++i){
151 for(std::size_t j = i; j != size; ++j){
152 blockA[i][j] = A()(i,j);
156 value_typeB blockB[maxBlockSize2][maxBlockSize1];
157 std::size_t numBlocks = (B().size2()+maxBlockSize2-1)/maxBlockSize2;
158 for(std::size_t i = 0; i != numBlocks; ++i){
159 std::size_t startB= i*maxBlockSize2;
160 std::size_t curBlockSize2 =std::min(maxBlockSize2, B().size2() - startB);
163 for(std::size_t i = 0; i != size; ++i){
164 for(std::size_t k = 0; k != curBlockSize2; ++k){
165 blockB[k][i] = B()(i,startB+k);
169 for(std::size_t k = 0; k != curBlockSize2; ++k){
170 for (std::size_t n = 0; n != size; ++n) {
171 std::size_t i = size-n-1;
172 for (std::size_t j = i+1; j != size; ++j) {
173 blockB[k][i] -= blockA[i][j] * blockB[k][j];
176 if(blockA[i][i] == value_typeA())
177 throw std::invalid_argument(
"[TRSM] Matrix is singular!");
178 blockB[k][i] /= blockA[i][i];
183 for(std::size_t i = 0; i != size; ++i){
184 for(std::size_t j = 0; j != curBlockSize2; ++j){
185 B()(i,startB+j) = blockB[j][i];
192template<std::
size_t maxBlockSize1,std::
size_t maxBlockSize2,
bool Unit,
class MatA,
class MatB>
194 matrix_expression<MatA, cpu_tag>
const& A,
195 matrix_expression<MatB, cpu_tag>& B,
199 typedef typename MatA::value_type value_type;
201 std::size_t size = A().size1();
202 value_type blockA[maxBlockSize1][maxBlockSize1];
203 for(std::size_t i = 0; i != size; ++i){
204 for(std::size_t j = i; j != size; ++j){
205 blockA[i][j] = A()(i,j);
210 for(std::size_t k = 0; k != B().size2(); ++k){
211 for (std::size_t n = 0; n != size; ++n) {
212 std::size_t i = size-n-1;
213 for (std::size_t j = i+1; j != size; ++j) {
214 B()(i,k) -= blockA[i][j] * B()(j,k);
217 if(blockA[i][i] == value_type())
218 throw std::invalid_argument(
"[TRSM] Matrix is singular!");
219 B()(i,k) /= blockA[i][i];
225template <
typename MatA,
typename MatB,
class Triangular>
227 matrix_expression<MatA, cpu_tag>
const& Afull,
228 matrix_expression<MatB, cpu_tag> & Bfull,
234 static std::size_t
const Block_Size = 32;
235 std::size_t num_rhs = Bfull().size2();
236 auto A = subrange(Afull,start,end,start,end);
237 auto B = subrange(Bfull,start,end,0,num_rhs);
239 if(A.size1() <= Block_Size){
240 trsm_block<Block_Size,16,Triangular::is_unit>(A,B,triangular_tag<Triangular::is_upper,false>(),
typename MatB::orientation());
243 std::size_t size = A.size1();
244 std::size_t numBlocks =(A.size1()+Block_Size-1)/Block_Size;
245 std::size_t split = numBlocks/2*Block_Size;
246 auto Bfront = subrange(B,0,split,0,num_rhs);
247 auto Bback = subrange(B,split,size,0,num_rhs);
250 if(Triangular::is_upper){
251 trsm_recursive(Afull, Bfull,start+split,end, t, l);
252 kernels::gemm(subrange(A,0,split,split,size), Bback, Bfront, -1.0);
253 trsm_recursive(Afull, Bfull,start,start+split, t, l);
255 trsm_recursive(Afull, Bfull,start,start+split, t, l);
256 kernels::gemm(subrange(A,split,size,0,split), Bfront, Bback, -1.0);
257 trsm_recursive(Afull, Bfull,start+split,end, t, l);
261template <
typename MatA,
typename MatB,
class Triangular>
263 matrix_expression<MatA, cpu_tag>
const& Afull,
264 matrix_expression<MatB, cpu_tag> & Bfull,
270 auto transA = trans(Afull);
271 auto transB = trans(Bfull);
272 trsm_recursive(transA,transB,start,end,
typename Triangular::transposed_orientation(),left());
276template <
class Triangular,
class S
ide,
typename MatA,
typename MatB>
278 matrix_expression<MatA, cpu_tag>
const& A,
279 matrix_expression<MatB, cpu_tag>& B,
283 bindings::trsm_recursive(A,B,0,A().size1(), Triangular(), Side());