trmm.hpp
Go to the documentation of this file.
1//===========================================================================
2/*!
3 *
4 *
5 * \brief -
6 *
7 * \author O. Krause
8 * \date 2016
9 *
10 *
11 * \par Copyright 1995-2015 Shark Development Team
12 *
13 * <BR><HR>
14 * This file is part of Shark.
15 * <http://image.diku.dk/shark/>
16 *
17 * Shark is free software: you can redistribute it and/or modify
18 * it under the terms of the GNU Lesser General Public License as published
19 * by the Free Software Foundation, either version 3 of the License, or
20 * (at your option) any later version.
21 *
22 * Shark is distributed in the hope that it will be useful,
23 * but WITHOUT ANY WARRANTY; without even the implied warranty of
24 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
25 * GNU Lesser General Public License for more details.
26 *
27 * You should have received a copy of the GNU Lesser General Public License
28 * along with Shark. If not, see <http://www.gnu.org/licenses/>.
29 *
30 */
31//===========================================================================
32#ifndef REMORA_KERNELS_CLBLAS_TRMM_HPP
33#define REMORA_KERNELS_CLBLAS_TRMM_HPP
34
35#include "../../expression_types.hpp"
36#include "../../detail/traits.hpp"
37#include <boost/compute/functional/operator.hpp> //for multiplies
38#include "../gemm.hpp"
39
40namespace remora{namespace bindings {
41
42struct trmm_kernel{
43 boost::compute::kernel kernel;
44 std::size_t K_index;
45 std::size_t start_index;
46 std::size_t end_index;
47 std::size_t unit_index;
48 std::size_t upper_index;
49};
50//Lower triangular - matrix(row-major)
51template<class MatA, class MatB>
52trmm_kernel createTRMMBlockKernel(
53 matrix_expression<MatA, gpu_tag> const& A_unreg,
54 matrix_expression<MatB, gpu_tag>& B_unreg,
55 char const* options
56){
57 typedef typename MatA::value_type value_typeA;
58 typedef typename MatB::value_type value_typeB;
59 boost::compute::multiplies<value_typeB> prod;
60
61 gpu::detail::meta_kernel k("blas_trmm");
62 std::size_t K_index = k.add_arg<std::size_t>("K");//number of columns in B
63 std::size_t start_index = k.add_arg<std::size_t>("start");//start of block of A
64 std::size_t end_index = k.add_arg<std::size_t>("end");//end of Block of A
65 std::size_t unit_index = k.add_arg<std::size_t>("unit");//whether A is unit triangular
66 std::size_t upper_index = k.add_arg<std::size_t>("upper");//whether A is unit triangular
67 auto A = k.register_args(to_functor(A_unreg));
68 auto B = k.register_args(to_functor(B_unreg));
69 // Local memory to fit a tile of A and B
70 // we store B as column major in local memory
71 // we also allocate memory to store results of B
72 k << "__local " <<k.decl<value_typeA>("Asub")<< "[TILE_SIZE][TILE_SIZE+2];\n";//+2 to avoid bank conflicts
73 k << "__local " <<k.decl<value_typeB>("Bsub")<< "[TILE_SIZE_K][TILE_SIZE+2];\n";//+2 to avoid bank conflicts
74 k << "__local " <<k.decl<value_typeB>("BResult")<< "[TILE_SIZE_K][TILE_SIZE+2];\n";//+2 to avoid bank conflicts
75 k << "const ulong numWorkers = get_local_size(0);\n";
76
77 // Load tile of A into local memory
78 k << "const ulong curTileA = end-start;\n";
79 k << "for(ulong i = get_local_id(0); i < curTileA; i += numWorkers){\n";
80 k << " for(ulong j = get_local_id(1); j < curTileA; j += numWorkers){\n";
81 k << " Asub[i][j] ="<< A(k.expr<cl_ulong>("(i+start)"),k.expr<cl_ulong>("(j+start)"))<<";\n";
82 k << " }\n";
83 k << "}\n";
84
85 //ensure we are not reading out of bounds
86 k << "const ulong t = get_group_id(1);\n";
87 k << "const ulong curTileK = min(TILE_SIZE_K, K - t*TILE_SIZE_K);\n";
88 // Load Tile of B into local memory, store columns of B as rows
89 k << "for(ulong i = get_local_id(0); i < curTileA; i += numWorkers){\n";
90 k << " for(ulong k = get_local_id(1); k < curTileK; k += numWorkers){\n";
91 k << " Bsub[k][i] ="<< B(k.expr<cl_ulong>("(i+start)"),k.expr<cl_ulong>("(t * TILE_SIZE_K+k)"))<<";\n";
92 k << " }\n";
93 k << "}\n";
94 // Synchronise to make sure the tile is loaded
95 k << "barrier(CLK_LOCAL_MEM_FENCE);\n";
96
97 // Loop over the values of a single tile
98 // by computing outer products ulongo the local accumulation registers acc
99 //lower-case
100 k << "if(!upper){\n";
101 k << " for(ulong i = get_local_id(0); i < curTileA; i += numWorkers){\n";
102 k << " for(ulong k = get_local_id(1); k < curTileK; k += numWorkers){\n";
103 k << " BResult[k][i] = Bsub[k][i];\n";
104 k << " if(!unit){BResult[k][i] *= Asub[i][i];}\n";
105 k << " for(ulong j = 0; j < i; ++j){\n";
106 k << " BResult[k][i] +="<< prod(k.expr<value_typeB>("Bsub[k][j]"), k.expr<value_typeA>("Asub[i][j]"))<<";\n";
107 k << " }\n";
108 k << " }\n";
109 k << " }\n";
110 k << "}else{\n";
111 //upper case
112 k << " for(ulong i = get_local_id(0); i < curTileA; i += numWorkers){\n";
113 k << " for(ulong k = get_local_id(1); k < curTileK; k += numWorkers){\n";
114 k << " BResult[k][i] = Bsub[k][i];\n";
115 k << " if(!unit){BResult[k][i] *= Asub[i][i];}\n";
116 k << " for(ulong j = i+1; j < curTileA; ++j){\n";
117 k << " BResult[k][i] +="<< prod(k.expr<value_typeB>("Bsub[k][j]"), k.expr<value_typeA>("Asub[i][j]"))<<";\n";
118 k << " }\n";
119 k << " }\n";
120 k << " }\n";
121 k << "}\n";
122 // Synchronise before loading the next tile
123 k << "barrier(CLK_LOCAL_MEM_FENCE);\n";
124 // Store the final results back in B
125 k << "for(ulong i = get_local_id(0); i < curTileA; i += numWorkers){\n";
126 k << " for(ulong k = get_local_id(1); k < curTileK; k += numWorkers){\n";
127 k << B(k.expr<cl_ulong>("(start+i)"),k.expr<cl_ulong>("(t * TILE_SIZE_K+k)"))<<" = BResult[k][i];\n";
128 k << " }\n";
129 k << "}\n";
130
131 boost::compute::kernel kernel = k.compile(B_unreg().queue().get_context(), options);
132 return {kernel,K_index,start_index,end_index,unit_index,upper_index};
133}
134
135template <typename MatA, typename MatB, typename Triangular>
136void trmm_recursive(
137 matrix_expression<MatA, gpu_tag> const& Afull,
138 matrix_expression<MatB, gpu_tag> & Bfull,
139 trmm_kernel& kernel,
140 std::size_t start,
141 std::size_t end,
142 std::size_t tileSizeA,
143 std::size_t tileSizeB,
144 std::size_t numWorkers,
145 Triangular t
146){
147 std::size_t size = end-start;
148
149 //if the matrix is small enough, call the computation kernel directly for the block
150 if(size <= tileSizeA){
151 //enqueue kernel with kernel args
152 kernel.kernel.set_arg(kernel.K_index, Bfull().size2());
153 kernel.kernel.set_arg(kernel.start_index, start);
154 kernel.kernel.set_arg(kernel.end_index, end);
155 kernel.kernel.set_arg(kernel.unit_index, (std::size_t)Triangular::is_unit);
156 kernel.kernel.set_arg(kernel.upper_index, (std::size_t)Triangular::is_upper);
157
158 std::size_t global_work_size[2] = {
159 numWorkers,
160 (Bfull().size2()+tileSizeB-1)/ tileSizeB * numWorkers
161 };
162 std::size_t local_work_size[2] = {numWorkers, numWorkers};
163 Bfull().queue().enqueue_nd_range_kernel(kernel.kernel, 2,nullptr, global_work_size, local_work_size);
164 return;
165 }
166 //otherwise run the kernel recursively
167 std::size_t split = (size+tileSizeA-1)/tileSizeA/2*tileSizeA;//split at the next multiple of the TileSize
168 auto Aul = subrange(Afull,start,start+split,start,start+split);
169 auto BFront = subrange(Bfull,start,start+split,0,Bfull().size2());
170 auto Bback =subrange(Bfull,start+split,end,0,Bfull().size2());
171
172
173 if(Triangular::is_upper){ //Upper triangular case
174 auto Aur = subrange(Afull,start,start+split,start+split,end);
175 trmm_recursive(Afull, Bfull, kernel, start, start+split, tileSizeA, tileSizeB, numWorkers, t);
176 kernels::gemm(Aur, Bback, BFront, 1.0);
177 trmm_recursive(Afull, Bfull, kernel, start+split, end, tileSizeA, tileSizeB, numWorkers, t);
178 }else{// Lower triangular caste
179 auto All = subrange(Afull,start+split,end,start,start+split);
180 trmm_recursive(Afull, Bfull, kernel, start+split, end, tileSizeA, tileSizeB, numWorkers, t);
181 kernels::gemm(All, BFront, Bback, 1.0);
182 trmm_recursive(Afull, Bfull, kernel, start, start+split, tileSizeA, tileSizeB, numWorkers, t);
183 }
184
185}
186}
187namespace kernels{
188//main kernel runs the kernel above recursively and calls gemv
189template <bool Upper,bool Unit,typename MatA, typename MatB>
190void trmm(
191 matrix_expression<MatA, gpu_tag> const& A,
192 matrix_expression<MatB, gpu_tag>& B
193){
194 REMORA_SIZE_CHECK(A().size1() == A().size2());
195 REMORA_SIZE_CHECK(A().size2() == B().size1());
196
197 std::size_t const TileSizeA = 32;//size of the diagonal blocks where the single kernel runs
198 std::size_t const TileSizeB = 32;// size of the blocks B is partitioned into along the number of columns
199 std::size_t const numWorkers = 8; //number of workers in two dimensions (e.g. 8x8=64)
200 char const* options ="-DTILE_SIZE=32ul -DTILE_SIZE_K=32ul";
201 auto kernel = bindings::createTRMMBlockKernel(A,B,options);
202
203 bindings::trmm_recursive(A,B,kernel,0,A().size1(), TileSizeA, TileSizeB, numWorkers, triangular_tag<Upper,Unit>());
204
205}
206
207}}
208#endif