32#ifndef REMORA_KERNELS_CLBLAST_GEMM_HPP
33#define REMORA_KERNELS_CLBLAST_GEMM_HPP
35#include "../../expression_types.hpp"
36#include "../../detail/traits.hpp"
38namespace remora{
namespace kernels{
41template <
typename MatA,
typename MatB,
typename MatC>
43 matrix_expression<MatA, gpu_tag>
const& A,
44 matrix_expression<MatB, gpu_tag>
const& B,
45 matrix_expression<MatC, gpu_tag>& C,
46 typename MatC::value_type
const& alpha
48 REMORA_SIZE_CHECK(A().size1() == C().size1());
49 REMORA_SIZE_CHECK(B().size2() == C().size2());
50 REMORA_SIZE_CHECK(A().size2()== B().size1());
52 static_assert(std::is_same<typename MatA::value_type, typename MatC::value_type>::value,
"[gemm] Arguments do not have same element type");
53 static_assert(std::is_same<typename MatA::value_type, typename MatB::value_type>::value,
"[gemm] Arguments do not have same element type");
54 static_assert(std::is_same<typename MatA::evaluation_category::tag, dense_tag>::value,
"[gemm] A is not dense");
55 static_assert(std::is_same<typename MatB::evaluation_category::tag, dense_tag>::value,
"[gemm] B is not dense");
56 static_assert(std::is_base_of<dense_tag, typename MatC::storage_type::storage_tag>::value,
"[gemm] C does not have dense storage layout");
59 auto const& Aeval = eval_expression(A);
60 auto const& Beval = eval_expression(B);
62 using namespace clblast;
65 auto transA = std::is_same<typename MatA::orientation,typename MatC::orientation>::value? Transpose::kNo : Transpose::kYes;
66 auto transB = std::is_same<typename MatB::orientation,typename MatC::orientation>::value? Transpose::kNo : Transpose::kYes;
67 auto layout = std::is_same<typename MatC::orientation::orientation, row_major>::value? Layout::kRowMajor: Layout::kColMajor;
68 std::size_t m = C().size1();
69 std::size_t n = C().size2();
70 std::size_t k = A().size2();
73 auto storageA = Aeval.raw_storage();
74 auto storageB = Beval.raw_storage();
75 auto storageC = C().raw_storage();
78 cl_event*
event =
nullptr;
79 auto code = Gemm(layout, transA, transB,
82 storageA.buffer.get(), storageA.offset, storageA.leading_dimension,
83 storageB.buffer.get(), storageB.offset, storageB.leading_dimension,
84 typename MatC::value_type(1),
85 storageC.buffer.get(), storageC.offset, storageC.leading_dimension,
86 &C().queue().get(), event
89 assert(code == StatusCode::kSuccess);