32#ifndef REMORA_KERNELS_CLBLAS_TRMM_HPP 
   33#define REMORA_KERNELS_CLBLAS_TRMM_HPP 
   35#include "../../expression_types.hpp" 
   36#include "../../detail/traits.hpp" 
   37#include <boost/compute/functional/operator.hpp>  
   40namespace remora{
namespace bindings {
 
   43    boost::compute::kernel kernel;
 
   45    std::size_t start_index;
 
   46    std::size_t end_index;
 
   47    std::size_t unit_index;
 
   48    std::size_t upper_index;
 
   51template<
class MatA, 
class MatB>
 
   52trmm_kernel createTRMMBlockKernel(
 
   53    matrix_expression<MatA, gpu_tag> 
const& A_unreg,
 
   54    matrix_expression<MatB, gpu_tag>& B_unreg,
 
   57    typedef typename MatA::value_type value_typeA;
 
   58    typedef typename MatB::value_type value_typeB;
 
   59    boost::compute::multiplies<value_typeB> prod;
 
   61    gpu::detail::meta_kernel k(
"blas_trmm");
 
   62    std::size_t K_index = k.add_arg<std::size_t>(
"K");
 
   63    std::size_t start_index = k.add_arg<std::size_t>(
"start");
 
   64    std::size_t end_index = k.add_arg<std::size_t>(
"end");
 
   65    std::size_t unit_index = k.add_arg<std::size_t>(
"unit");
 
   66    std::size_t upper_index = k.add_arg<std::size_t>(
"upper");
 
   67    auto A = k.register_args(to_functor(A_unreg));
 
   68    auto B = k.register_args(to_functor(B_unreg));
 
   72    k << 
"__local " <<k.decl<value_typeA>(
"Asub")<< 
"[TILE_SIZE][TILE_SIZE+2];\n";
 
   73    k << 
"__local " <<k.decl<value_typeB>(
"Bsub")<< 
"[TILE_SIZE_K][TILE_SIZE+2];\n";
 
   74    k << 
"__local " <<k.decl<value_typeB>(
"BResult")<< 
"[TILE_SIZE_K][TILE_SIZE+2];\n";
 
   75    k << 
"const ulong numWorkers = get_local_size(0);\n";
 
   78    k << 
"const ulong curTileA =  end-start;\n";
 
   79    k << 
"for(ulong i = get_local_id(0); i < curTileA; i += numWorkers){\n";
 
   80    k << 
"  for(ulong j = get_local_id(1); j < curTileA; j += numWorkers){\n";
 
   81    k << 
"      Asub[i][j] ="<< A(k.expr<cl_ulong>(
"(i+start)"),k.expr<cl_ulong>(
"(j+start)"))<<
";\n";
 
   86    k << 
"const ulong t = get_group_id(1);\n";
 
   87    k << 
"const ulong curTileK =  min(TILE_SIZE_K, K - t*TILE_SIZE_K);\n";
 
   89    k << 
"for(ulong i = get_local_id(0); i < curTileA; i += numWorkers){\n";
 
   90    k << 
"  for(ulong k = get_local_id(1); k < curTileK; k += numWorkers){\n";
 
   91    k << 
"      Bsub[k][i] ="<< B(k.expr<cl_ulong>(
"(i+start)"),k.expr<cl_ulong>(
"(t * TILE_SIZE_K+k)"))<<
";\n";
 
   95    k << 
"barrier(CLK_LOCAL_MEM_FENCE);\n";
 
  100    k << 
"if(!upper){\n";
 
  101    k << 
"  for(ulong i = get_local_id(0); i < curTileA; i += numWorkers){\n";
 
  102    k << 
"      for(ulong k = get_local_id(1); k < curTileK; k += numWorkers){\n";
 
  103    k << 
"          BResult[k][i] = Bsub[k][i];\n";
 
  104    k << 
"          if(!unit){BResult[k][i] *= Asub[i][i];}\n";
 
  105    k << 
"          for(ulong j = 0; j < i; ++j){\n";
 
  106    k << 
"              BResult[k][i] +="<< prod(k.expr<value_typeB>(
"Bsub[k][j]"), k.expr<value_typeA>(
"Asub[i][j]"))<<
";\n";
 
  112    k << 
"  for(ulong i = get_local_id(0); i < curTileA; i += numWorkers){\n";
 
  113    k << 
"      for(ulong k = get_local_id(1); k < curTileK; k += numWorkers){\n";
 
  114    k << 
"          BResult[k][i] = Bsub[k][i];\n";
 
  115    k << 
"          if(!unit){BResult[k][i] *= Asub[i][i];}\n";
 
  116    k << 
"          for(ulong j = i+1; j < curTileA; ++j){\n";
 
  117    k << 
"              BResult[k][i] +="<< prod(k.expr<value_typeB>(
"Bsub[k][j]"), k.expr<value_typeA>(
"Asub[i][j]"))<<
";\n";
 
  123    k << 
"barrier(CLK_LOCAL_MEM_FENCE);\n";
 
  125    k << 
"for(ulong i = get_local_id(0); i < curTileA; i += numWorkers){\n";
 
  126    k << 
"  for(ulong k = get_local_id(1); k < curTileK; k += numWorkers){\n";
 
  127    k << B(k.expr<cl_ulong>(
"(start+i)"),k.expr<cl_ulong>(
"(t * TILE_SIZE_K+k)"))<<
" =  BResult[k][i];\n";
 
  131    boost::compute::kernel kernel = k.compile(B_unreg().queue().get_context(), options);
 
  132    return {kernel,K_index,start_index,end_index,unit_index,upper_index};
 
  135template <
typename MatA, 
typename MatB, 
typename Triangular>
 
  137    matrix_expression<MatA, gpu_tag> 
const& Afull, 
 
  138    matrix_expression<MatB, gpu_tag> & Bfull,
 
  142    std::size_t tileSizeA,
 
  143    std::size_t tileSizeB,
 
  144    std::size_t numWorkers,
 
  147    std::size_t size = end-start;
 
  150    if(size <= tileSizeA){
 
  152        kernel.kernel.set_arg(kernel.K_index, Bfull().size2());
 
  153        kernel.kernel.set_arg(kernel.start_index, start);
 
  154        kernel.kernel.set_arg(kernel.end_index, end);
 
  155        kernel.kernel.set_arg(kernel.unit_index, (std::size_t)Triangular::is_unit);
 
  156        kernel.kernel.set_arg(kernel.upper_index, (std::size_t)Triangular::is_upper);
 
  158        std::size_t global_work_size[2] = {
 
  160            (Bfull().size2()+tileSizeB-1)/ tileSizeB * numWorkers
 
  162        std::size_t local_work_size[2] = {numWorkers, numWorkers};
 
  163        Bfull().queue().enqueue_nd_range_kernel(kernel.kernel, 2,
nullptr, global_work_size, local_work_size);
 
  167    std::size_t split = (size+tileSizeA-1)/tileSizeA/2*tileSizeA;
 
  168    auto Aul = subrange(Afull,start,start+split,start,start+split);
 
  169    auto BFront  = subrange(Bfull,start,start+split,0,Bfull().size2());
 
  170    auto Bback =subrange(Bfull,start+split,end,0,Bfull().size2());
 
  173    if(Triangular::is_upper){ 
 
  174        auto Aur = subrange(Afull,start,start+split,start+split,end);
 
  175        trmm_recursive(Afull, Bfull, kernel, start, start+split, tileSizeA, tileSizeB, numWorkers, t);
 
  176        kernels::gemm(Aur, Bback, BFront, 1.0);
 
  177        trmm_recursive(Afull, Bfull, kernel, start+split, end, tileSizeA, tileSizeB, numWorkers, t);
 
  179        auto All = subrange(Afull,start+split,end,start,start+split);
 
  180        trmm_recursive(Afull, Bfull, kernel, start+split, end, tileSizeA, tileSizeB, numWorkers, t);
 
  181        kernels::gemm(All, BFront, Bback, 1.0);
 
  182        trmm_recursive(Afull, Bfull, kernel, start, start+split, tileSizeA, tileSizeB, numWorkers, t);
 
  189template <
bool Upper,
bool Unit,
typename MatA, 
typename MatB>
 
  191    matrix_expression<MatA, gpu_tag> 
const& A, 
 
  192    matrix_expression<MatB, gpu_tag>& B
 
  194    REMORA_SIZE_CHECK(A().size1() == A().size2());
 
  195    REMORA_SIZE_CHECK(A().size2() == B().size1());
 
  197    std::size_t 
const TileSizeA = 32;
 
  198    std::size_t 
const TileSizeB = 32;
 
  199    std::size_t 
const numWorkers = 8; 
 
  200    char const* options =
"-DTILE_SIZE=32ul -DTILE_SIZE_K=32ul";
 
  201    auto kernel = bindings::createTRMMBlockKernel(A,B,options);
 
  203    bindings::trmm_recursive(A,B,kernel,0,A().size1(), TileSizeA, TileSizeB, numWorkers, triangular_tag<Upper,Unit>());