dense_gemm.hpp
Go to the documentation of this file.
1//===========================================================================
2/*!
3 *
4 *
5 * \brief cblas binding for dense gemm
6 *
7 * \author O. Krause
8 * \date 2016
9 *
10 *
11 * \par Copyright 1995-2015 Shark Development Team
12 *
13 * <BR><HR>
14 * This file is part of Shark.
15 * <http://image.diku.dk/shark/>
16 *
17 * Shark is free software: you can redistribute it and/or modify
18 * it under the terms of the GNU Lesser General Public License as published
19 * by the Free Software Foundation, either version 3 of the License, or
20 * (at your option) any later version.
21 *
22 * Shark is distributed in the hope that it will be useful,
23 * but WITHOUT ANY WARRANTY; without even the implied warranty of
24 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
25 * GNU Lesser General Public License for more details.
26 *
27 * You should have received a copy of the GNU Lesser General Public License
28 * along with Shark. If not, see <http://www.gnu.org/licenses/>.
29 *
30 */
31//===========================================================================
32#ifndef REMORA_KERNELS_CBLAS_DENSE_GEMM_HPP
33#define REMORA_KERNELS_CBLAS_DENSE_GEMM_HPP
34
35#include "cblas_inc.hpp"
36#include "../../proxy_expressions.hpp"
37#include "../../assignment.hpp"
38#include "../../dense.hpp"
39#include "../default/simd.hpp"
40#include <type_traits>
41namespace remora{ namespace bindings {
42
43inline void dense_gemm(
44 CBLAS_ORDER const Order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
45 int M, int N, int K,
46 float alpha, float const *A, int lda,
47 float const *B, int ldb,
48 float beta, float *C, int ldc
49){
50 cblas_sgemm(
51 Order, TransA, TransB,
52 M, N, K,
53 alpha, A, lda,
54 B, ldb,
55 beta, C, ldc
56 );
57}
58
59inline void dense_gemm(
60 CBLAS_ORDER const Order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
61 int M, int N, int K,
62 double alpha, double const *A, int lda,
63 double const *B, int ldb,
64 double beta, double *C, int ldc
65){
66 cblas_dgemm(
67 Order, TransA, TransB,
68 M, N, K,
69 alpha,
70 A, lda,
71 B, ldb,
72 beta,
73 C, ldc
74 );
75}
76
77inline void dense_gemm(
78 CBLAS_ORDER const Order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
79 int M, int N, int K,
80 float alpha,
81 std::complex<float> const *A, int lda,
82 std::complex<float> const *B, int ldb,
83 float beta,
84 std::complex<float>* C, int ldc
85) {
86 std::complex<float> alphaArg(alpha,0);
87 std::complex<float> betaArg(beta,0);
88 cblas_cgemm(
89 Order, TransA, TransB,
90 M, N, K,
91 reinterpret_cast<cblas_float_complex_type const *>(&alphaArg),
92 reinterpret_cast<cblas_float_complex_type const *>(A), lda,
93 reinterpret_cast<cblas_float_complex_type const *>(B), ldb,
94 reinterpret_cast<cblas_float_complex_type const *>(&betaArg),
95 reinterpret_cast<cblas_float_complex_type *>(C), ldc
96 );
97}
98
99inline void dense_gemm(
100 CBLAS_ORDER const Order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
101 int M, int N, int K,
102 double alpha,
103 std::complex<double> const *A, int lda,
104 std::complex<double> const *B, int ldb,
105 double beta,
106 std::complex<double>* C, int ldc
107) {
108 std::complex<double> alphaArg(alpha,0);
109 std::complex<double> betaArg(beta,0);
110 cblas_zgemm(
111 Order, TransA, TransB,
112 M, N, K,
113 reinterpret_cast<cblas_double_complex_type const *>(&alphaArg),
114 reinterpret_cast<cblas_double_complex_type const *>(A), lda,
115 reinterpret_cast<cblas_double_complex_type const *>(B), ldb,
116 reinterpret_cast<cblas_double_complex_type const *>(&betaArg),
117 reinterpret_cast<cblas_double_complex_type *>(C), ldc
118 );
119}
120
121//optimized cblas version
122template <typename MatA, typename MatB, typename MatC>
123void dense_gemm(
124 matrix_expression<MatA, cpu_tag> const& A,
125 matrix_expression<MatB, cpu_tag> const& B,
126 matrix_expression<MatC, cpu_tag>& C,
127 typename MatC::value_type alpha,
128 std::true_type
129){
130 static_assert(std::is_same<typename MatC::orientation,row_major>::value,"C must be row major");
131
132 CBLAS_TRANSPOSE transA = std::is_same<typename MatA::orientation,typename MatC::orientation>::value?CblasNoTrans:CblasTrans;
133 CBLAS_TRANSPOSE transB = std::is_same<typename MatB::orientation,typename MatC::orientation>::value?CblasNoTrans:CblasTrans;
134 std::size_t m = C().size1();
135 std::size_t n = C().size2();
136 std::size_t k = A().size2();
137 CBLAS_ORDER stor_ord = (CBLAS_ORDER) storage_order<typename MatC::orientation >::value;
138
139 auto storageA = A().raw_storage();
140 auto storageB = B().raw_storage();
141 auto storageC = C().raw_storage();
142 dense_gemm(stor_ord, transA, transB, (int)m, (int)n, (int)k, alpha,
143 storageA.values,
144 storageA.leading_dimension,
145 storageB.values,
146 storageB.leading_dimension,
147 typename MatC::value_type(1),
148 storageC.values,
149 storageC.leading_dimension
150 );
151}
152
153template <typename MatA, typename MatB, typename MatC>
154void dense_gemm(
155 matrix_expression<MatA, cpu_tag> const& A,
156 matrix_expression<MatB, cpu_tag> const& B,
157 matrix_expression<MatC, cpu_tag>& C,
158 typename MatC::value_type alpha,
159 std::false_type
160){
161 typedef typename MatC::value_type value_type;
162 std::size_t const tile_size = 512;
163 static const std::size_t align = 64;
164 std::size_t size1 = C().size1();
165 std::size_t size2 = C().size2();
166 std::size_t num_blocks = (A().size2()+tile_size-1)/tile_size;
167 boost::alignment::aligned_allocator<value_type,align> allocator;
168 value_type* A_pointer = allocator.allocate(size1 * tile_size);
169 value_type* B_pointer = allocator.allocate(size2 * tile_size);
170 for(std::size_t k = 0; k != num_blocks; ++k){
171 std::size_t start_k = k * tile_size;
172 std::size_t current_size = std::min(tile_size,A().size2() - start_k);
173 auto A_block = adapt_matrix(size1, current_size, A_pointer);
174 auto B_block = adapt_matrix(current_size, size2, B_pointer);
175 noalias(A_block) = subrange(A, 0, size1, start_k, start_k + current_size);
176 noalias(B_block) = subrange(B, start_k, start_k + current_size, 0, size2);
177 dense_gemm(A_block, B_block, C, alpha, std::true_type());
178 }
179 allocator.deallocate(A_pointer, size1 * tile_size);
180 allocator.deallocate(B_pointer, size1 * tile_size);
181}
182
183
184template<class M1, class M2, class M3>
185struct has_optimized_gemm: std::integral_constant<bool,
186 allowed_cblas_type<typename M1::value_type>::type::value
187 && std::is_same<typename M1::value_type, typename M2::value_type>::value
188 && std::is_same<typename M1::value_type, typename M3::value_type>::value
189 && std::is_base_of<dense_tag, typename M1::storage_type::storage_tag>::value
190 && std::is_base_of<dense_tag, typename M2::storage_type::storage_tag>::value
191 && std::is_base_of<dense_tag, typename M3::storage_type::storage_tag>::value
192>{};
193
194template <typename MatA, typename MatB, typename MatC>
195void dense_gemm(
196 matrix_expression<MatA, cpu_tag> const& A,
197 matrix_expression<MatB, cpu_tag> const& B,
198 matrix_expression<MatC, cpu_tag>& C,
199 typename MatC::value_type alpha
200){
201 REMORA_SIZE_CHECK(A().size1() == C().size1());
202 REMORA_SIZE_CHECK(B().size2() == C().size2());
203 REMORA_SIZE_CHECK(A().size2()== B().size1());
204 dense_gemm(
205 A,B,C,alpha,
206 typename has_optimized_gemm<MatA,MatB,MatC>::type()
207 );
208}
209
210}}
211
212#endif