32#ifndef REMORA_KERNELS_CLBLAST_TRSM_HPP
33#define REMORA_KERNELS_CLBLAST_TRSM_HPP
35#include "../../expression_types.hpp"
36#include "../../detail/traits.hpp"
37#include "../../proxy_expressions.hpp"
39namespace remora{
namespace kernels{
42template <
class Triangular,
typename MatA,
typename MatB>
44 matrix_expression<MatA, gpu_tag>
const& A,
45 matrix_expression<MatB, gpu_tag>& B,
49 REMORA_SIZE_CHECK(A().size1() == A().size2());
50 REMORA_SIZE_CHECK(A().size2() == B().size1());
52 static_assert(std::is_same<typename MatA::value_type, typename MatB::value_type>::value,
"[trsm] Arguments do not have same element type");
53 static_assert(std::is_same<typename MatA::evaluation_category::tag, dense_tag>::value,
"[trsm] A is not dense");
54 static_assert(std::is_base_of<dense_tag, typename MatB::storage_type::storage_tag>::value,
"[trsm] B does not have dense storage layout");
57 auto const& Aeval = eval_expression(A);
59 using namespace clblast;
62 auto transA = std::is_same<typename MatA::orientation,typename MatB::orientation>::value? Transpose::kNo : Transpose::kYes;
63 auto layout = std::is_same<typename MatB::orientation::orientation, row_major>::value? Layout::kRowMajor : Layout::kColMajor;
64 auto diagonal = Triangular::is_unit? Diagonal::kUnit : Diagonal::kNonUnit;
65 auto triangular = Triangular::is_upper? Triangle::kUpper : Triangle::kLower;
66 if(transA == Transpose::kYes){
67 triangular = Triangular::is_upper? Triangle::kLower : Triangle::kUpper;
69 std::size_t m = B().size1();
70 std::size_t n = B().size2();
73 auto storageA = Aeval.raw_storage();
74 auto storageB = B().raw_storage();
76 cl_event*
event =
nullptr;
77 auto code = Trsm(layout, Side::kLeft, triangular, transA, diagonal,
78 m, n,
typename MatB::value_type(1),
79 storageA.buffer.get(), storageA.offset, storageA.leading_dimension,
80 storageB.buffer.get(), storageB.offset, storageB.leading_dimension,
81 &B().queue().get(), event
83 assert(code == StatusCode::kSuccess);
87template <
class Triangular,
typename MatA,
typename MatB>
89 matrix_expression<MatA, gpu_tag>
const& A,
90 matrix_expression<MatB, gpu_tag>& B,
94 auto transB = trans(B);
95 trsm_impl(trans(A), transB,
typename Triangular::transposed_orientation(), left());
99template <
class Triangular,
class S
ide,
typename MatA,
typename MatB>
101 matrix_expression<MatA, gpu_tag>
const& A,
102 matrix_expression<MatB, gpu_tag>& B
104 trsm_impl(A,B,Triangular(),Side());