32#ifndef REMORA_KERNELS_GPU_GEMM_HPP
33#define REMORA_KERNELS_GPU_GEMM_HPP
35#include "../../expression_types.hpp"
36#include "../../detail/traits.hpp"
37#include <boost/compute/functional/operator.hpp>
39namespace remora{
namespace kernels{
42template <
typename MatA,
typename MatB,
typename MatC>
44 matrix_expression<MatA, gpu_tag>
const& A_unreg,
45 matrix_expression<MatB, gpu_tag>
const& B_unreg,
46 matrix_expression<MatC, gpu_tag>& C_unreg,
47 typename MatC::value_type
const& alpha
49 REMORA_SIZE_CHECK(A_unreg().size1() == C_unreg().size1());
50 REMORA_SIZE_CHECK(B_unreg().size2() == C_unreg().size2());
51 REMORA_SIZE_CHECK(A_unreg().size2()== B_unreg().size1());
62 std::size_t BLOCK_SIZE = 4;
63 std::size_t TILE_SIZE = 32;
64 std::size_t NUM_WORKERS = TILE_SIZE / BLOCK_SIZE;
66 char const* options =
"-DTILE_SIZE=32ul -DBLOCK_SIZE=4ul -DTILE_SIZE_K=16ul";
67 typedef typename MatC::value_type value_type;
69 gpu::detail::meta_kernel k(
"blas_gemm");
70 std::size_t M_index = k.add_arg<std::size_t>(
"M");
71 std::size_t N_index = k.add_arg<std::size_t>(
"N");
72 std::size_t K_index = k.add_arg<std::size_t>(
"K");
73 std::size_t alpha_index = k.add_arg<value_type>(
"alpha");
74 auto A = k.register_args(to_functor(A_unreg));
75 auto B = k.register_args(to_functor(B_unreg));
76 auto C = k.register_args(to_functor(C_unreg));
81 k <<
"__local " <<k.decl<value_type>(
"Asub")<<
"[TILE_SIZE_K][TILE_SIZE+2];\n";
82 k <<
"__local " <<k.decl<value_type>(
"Bsub")<<
"[TILE_SIZE_K][TILE_SIZE+2];\n";
83 k <<
" const ulong numWorkers = get_local_size(0);\n";
91 k << k.decl<value_type>(
"acc") <<
"[BLOCK_SIZE][BLOCK_SIZE];\n";
92 k <<
"for (ulong wm=0; wm<BLOCK_SIZE; wm++){\n";
93 k <<
" for (ulong wn=0; wn<BLOCK_SIZE; wn++){\n";
94 k <<
" acc[wm][wn] = 0.0f;\n";
100 k <<
"ulong numTiles = (K+TILE_SIZE_K-1)/TILE_SIZE_K;\n";
101 k <<
"for (ulong t=0; t<numTiles; t++){\n";
104 k <<
" const ulong curTileK = min(TILE_SIZE_K, K - t*TILE_SIZE_K);\n";
107 k <<
" for(ulong i = get_local_id(0); i < TILE_SIZE; i += numWorkers){\n";
108 k <<
" for(ulong k = get_local_id(1); k < curTileK; k += numWorkers){\n";
109 k <<
" ulong ktile = t * TILE_SIZE_K + k;\n";
110 k <<
" Asub[k][i] ="<< A(k.expr<cl_ulong>(
"min(M-1,TILE_SIZE * get_group_id(0)+i)"),k.expr<cl_ulong>(
"ktile"))<<
";\n";
111 k <<
" Bsub[k][i] ="<< B(k.expr<cl_ulong>(
"ktile"),k.expr<cl_ulong>(
"min(N-1,TILE_SIZE * get_group_id(1)+i)"))<<
";\n";
116 k <<
" barrier(CLK_LOCAL_MEM_FENCE);\n";
120 k <<
" for (ulong k=0; k<curTileK; k++){\n";
122 k << k.decl<value_type>(
"Breg")<<
"[BLOCK_SIZE];\n";
123 k <<
" for (ulong wn=0; wn<BLOCK_SIZE; wn++){\n";
124 k <<
" Breg[wn] = Bsub[k][get_local_id(1) + wn * numWorkers];\n";
128 k <<
" for (ulong wm = 0; wm<BLOCK_SIZE; wm++){\n";
129 k << k.decl<value_type>(
"Areg") <<
"= Asub[k][get_local_id(0) + wm * numWorkers];\n";
130 k <<
" for (ulong wn=0; wn<BLOCK_SIZE; wn++){\n";
131 k <<
" acc[wm][wn] += Areg * Breg[wn];\n";
137 k <<
" barrier(CLK_LOCAL_MEM_FENCE);\n";
141 k <<
"const ulong maxCi = min(TILE_SIZE, M - get_group_id(0) * TILE_SIZE);\n";
142 k <<
"const ulong maxCj = min(TILE_SIZE, N - get_group_id(1) * TILE_SIZE);\n";
143 k <<
"const ulong offTileCi = TILE_SIZE * get_group_id(0);\n";
144 k <<
"const ulong offTileCj = TILE_SIZE * get_group_id(1);\n";
145 k <<
"ulong wm = 0;\n";
146 k <<
"for (ulong i = get_local_id(0); i < maxCi; i += numWorkers, wm++){\n";
147 k <<
" ulong wn = 0;\n";
148 k <<
" for (ulong j =get_local_id(1); j < maxCj; j += numWorkers, wn++){\n";
149 k << C(k.expr<cl_ulong>(
"(offTileCi + i)"), k.expr<cl_ulong>(
"(offTileCj + j)")) <<
"+= alpha * acc[wm][wn];\n";
153 boost::compute::kernel kernel = k.compile(C_unreg().queue().get_context(), options);
156 kernel.set_arg(M_index, C_unreg().size1());
157 kernel.set_arg(N_index, C_unreg().size2());
158 kernel.set_arg(K_index, A_unreg().size2());
159 kernel.set_arg(alpha_index, alpha);
161 std::size_t global_work_size[2] = {
162 (C_unreg().size1()+TILE_SIZE-1)/ TILE_SIZE * NUM_WORKERS,
163 (C_unreg().size2()+TILE_SIZE-1)/ TILE_SIZE * NUM_WORKERS
165 std::size_t local_work_size[2] = {NUM_WORKERS, NUM_WORKERS};
166 C_unreg().queue().enqueue_nd_range_kernel(kernel, 2,
nullptr, global_work_size, local_work_size);