32#ifndef REMORA_KERNELS_CLBLAS_SYRK_HPP
33#define REMORA_KERNELS_CLBLAS_SYRK_HPP
35#include "../../expression_types.hpp"
36#include "../../detail/traits.hpp"
37#include <boost/compute/functional/operator.hpp>
39namespace remora{
namespace kernels{
42template <
bool Upper,
typename MatA,
typename MatC>
44 matrix_expression<MatA, gpu_tag>
const& A_unreg,
45 matrix_expression<MatC, gpu_tag>& C_unreg,
46 typename MatC::value_type
const& alpha
48 REMORA_SIZE_CHECK(A_unreg().size1() == C_unreg().size1());
49 REMORA_SIZE_CHECK(C_unreg().size1()== C_unreg().size2());
60 std::size_t BLOCK_SIZE = 4;
61 std::size_t TILE_SIZE = 32;
62 std::size_t NUM_WORKERS = TILE_SIZE / BLOCK_SIZE;
63 char const* options =
"-DTILE_SIZE=32ul -DBLOCK_SIZE=4ul -DTILE_SIZE_K=16ul";
64 typedef typename MatC::value_type value_type;
66 gpu::detail::meta_kernel k(
"blas_syrk");
67 std::size_t N_index = k.add_arg<std::size_t>(
"N");
68 std::size_t K_index = k.add_arg<std::size_t>(
"K");
69 std::size_t upper_index = k.add_arg<std::size_t>(
"upper");
70 std::size_t alpha_index = k.add_arg<value_type>(
"alpha");
71 auto A = k.register_args(to_functor(A_unreg));
72 auto C = k.register_args(to_functor(C_unreg));
74 k <<
"if((upper && get_group_id(1) < get_group_id(0))) return;\n";
75 k <<
"if((!upper && get_group_id(1) > get_group_id(0))) return;\n";
80 k <<
"__local " <<k.decl<value_type>(
"Asub")<<
"[TILE_SIZE_K][TILE_SIZE+2];\n";
81 k <<
"__local " <<k.decl<value_type>(
"Bsub")<<
"[TILE_SIZE_K][TILE_SIZE+2];\n";
82 k <<
" const ulong numWorkers = get_local_size(0);\n";
90 k << k.decl<value_type>(
"acc") <<
"[BLOCK_SIZE][BLOCK_SIZE];\n";
91 k <<
"for (ulong wm=0; wm<BLOCK_SIZE; wm++){\n";
92 k <<
" for (ulong wn=0; wn<BLOCK_SIZE; wn++){\n";
93 k <<
" acc[wm][wn] = 0.0f;\n";
99 k <<
"ulong numTiles = (K+TILE_SIZE_K-1)/TILE_SIZE_K;\n";
100 k <<
"for (ulong t=0; t<numTiles; t++){\n";
103 k <<
" const ulong curTileK = min(TILE_SIZE_K, K - t*TILE_SIZE_K);\n";
106 k <<
" for(ulong i = get_local_id(0); i < TILE_SIZE; i += numWorkers){\n";
107 k <<
" for(ulong k = get_local_id(1); k < curTileK; k += numWorkers){\n";
108 k <<
" ulong ktile = t * TILE_SIZE_K + k;\n";
109 k <<
" Asub[k][i] ="<< A(k.expr<cl_ulong>(
"min(N-1,TILE_SIZE * get_group_id(0)+i)"), k.expr<cl_ulong>(
"ktile"))<<
";\n";
110 k <<
" Bsub[k][i] ="<< A(k.expr<cl_ulong>(
"min(N-1,TILE_SIZE * get_group_id(1)+i)"), k.expr<cl_ulong>(
"ktile"))<<
";\n";
115 k <<
" barrier(CLK_LOCAL_MEM_FENCE);\n";
119 k <<
" for (ulong k=0; k<curTileK; k++){\n";
121 k << k.decl<value_type>(
"Breg")<<
"[BLOCK_SIZE];\n";
122 k <<
" for (ulong wn=0; wn<BLOCK_SIZE; wn++){\n";
123 k <<
" Breg[wn] = Bsub[k][get_local_id(1) + wn * numWorkers];\n";
127 k <<
" for (ulong wm = 0; wm<BLOCK_SIZE; wm++){\n";
128 k << k.decl<value_type>(
"Areg") <<
"= Asub[k][get_local_id(0) + wm * numWorkers];\n";
129 k <<
" for (ulong wn=0; wn<BLOCK_SIZE; wn++){\n";
130 k <<
" acc[wm][wn] += Areg * Breg[wn];\n";
136 k <<
" barrier(CLK_LOCAL_MEM_FENCE);\n";
140 k <<
"const ulong maxCi = min(TILE_SIZE, N - get_group_id(0) * TILE_SIZE);\n";
141 k <<
"const ulong maxCj = min(TILE_SIZE, N - get_group_id(1) * TILE_SIZE);\n";
142 k <<
"const ulong offTileCi = TILE_SIZE * get_group_id(0);\n";
143 k <<
"const ulong offTileCj = TILE_SIZE * get_group_id(1);\n";
144 k <<
"ulong wm = 0;\n";
145 k <<
"for (ulong i = get_local_id(0); i < maxCi; i += numWorkers, wm++){\n";
146 k <<
" ulong wn = 0;\n";
147 k <<
" for (ulong j =get_local_id(1); j < maxCj; j += numWorkers, wn++){\n";
148 k <<
" if(get_group_id(1) != get_group_id(0) || (upper && j >= i) || (!upper && j <= i) ){\n";
149 k << C(k.expr<cl_ulong>(
"(offTileCi + i)"), k.expr<cl_ulong>(
"(offTileCj + j)")) <<
"+= alpha * acc[wm][wn];\n";
154 boost::compute::kernel kernel = k.compile(C_unreg().queue().get_context(), options);
157 kernel.set_arg(N_index, C_unreg().size1());
158 kernel.set_arg(K_index, A_unreg().size2());
159 kernel.set_arg(alpha_index, alpha);
160 kernel.set_arg(upper_index, (std::size_t)Upper);
162 std::size_t global_work_size[2] = {
163 (C_unreg().size1()+TILE_SIZE-1)/ TILE_SIZE * NUM_WORKERS,
164 (C_unreg().size2()+TILE_SIZE-1)/ TILE_SIZE * NUM_WORKERS
166 std::size_t local_work_size[2] = {NUM_WORKERS, NUM_WORKERS};
167 C_unreg().queue().enqueue_nd_range_kernel(kernel, 2,
nullptr, global_work_size, local_work_size);