gemm.hpp
Go to the documentation of this file.
1/*!
2 *
3 *
4 * \brief -
5 *
6 * \author O. Krause
7 * \date 2010
8 *
9 *
10 * \par Copyright 1995-2015 Shark Development Team
11 *
12 * <BR><HR>
13 * This file is part of Shark.
14 * <http://image.diku.dk/shark/>
15 *
16 * Shark is free software: you can redistribute it and/or modify
17 * it under the terms of the GNU Lesser General Public License as published
18 * by the Free Software Foundation, either version 3 of the License, or
19 * (at your option) any later version.
20 *
21 * Shark is distributed in the hope that it will be useful,
22 * but WITHOUT ANY WARRANTY; without even the implied warranty of
23 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
24 * GNU Lesser General Public License for more details.
25 *
26 * You should have received a copy of the GNU Lesser General Public License
27 * along with Shark. If not, see <http://www.gnu.org/licenses/>.
28 *
29 */
30
31#ifndef REMORA_KERNELS_DEFAULT_GEMM_HPP
32#define REMORA_KERNELS_DEFAULT_GEMM_HPP
33
34#include "../gemv.hpp"//for dispatching to gemv
35#include "../vector_assign.hpp" //assignment of vectors
36#include "../../dense.hpp"//sparse gemm needs temporary vector
37#include "../../proxy_expressions.hpp"//matrix row,column,transpose,range
38#include <type_traits> //std::false_type marker for unoptimized, std::common_type
39
40namespace remora{namespace bindings {
41
42
43// Dense-Sparse gemm
44template <class E1, class E2, class M, class Orientation>
45void gemm(
46 matrix_expression<E1, cpu_tag> const& e1,
47 matrix_expression<E2, cpu_tag> const& e2,
48 matrix_expression<M, cpu_tag>& m,
49 typename M::value_type alpha,
50 row_major, row_major, Orientation,
51 dense_tag, sparse_tag
52){
53 for (std::size_t i = 0; i != e1().size1(); ++i) {
54 auto row_m = row(m,i);
55 kernels::gemv(trans(e2),row(e1,i),row_m,alpha);
56 }
57}
58
59template <class E1, class E2, class M>
60void gemm(
61 matrix_expression<E1, cpu_tag> const& e1,
62 matrix_expression<E2, cpu_tag> const& e2,
63 matrix_expression<M, cpu_tag>& m,
64 typename M::value_type alpha,
65 row_major, column_major, column_major,
66 dense_tag, sparse_tag
67){
68 for (std::size_t j = 0; j != e2().size2(); ++j) {
69 auto column_m = column(m,j);
70 kernels::gemv(e1,column(e2,j),column_m,alpha);
71 }
72}
73
74template <class E1, class E2, class M>
75void gemm(
76 matrix_expression<E1, cpu_tag> const& e1,
77 matrix_expression<E2, cpu_tag> const& e2,
78 matrix_expression<M, cpu_tag>& m,
79 typename M::value_type alpha,
80 row_major, column_major, row_major,
81 dense_tag, sparse_tag
82){
83 typedef typename M::value_type value_type;
84 typedef device_traits<cpu_tag>::multiply_and_add<value_type> MultAdd;
85 for (std::size_t k = 0; k != e1().size2(); ++k) {
86 for(std::size_t i = 0; i != e1().size1(); ++i){
87 auto row_m = row(m,i);
88 kernels::assign(row_m, row(e2,k), MultAdd(alpha * e1()(i,k)));
89 }
90 }
91}
92
93// Sparse-Dense gemm
94template <class E1, class E2, class M, class Orientation>
95void gemm(
96 matrix_expression<E1, cpu_tag> const& e1,
97 matrix_expression<E2, cpu_tag> const& e2,
98 matrix_expression<M, cpu_tag>& m,
99 typename M::value_type alpha,
100 row_major, row_major, Orientation,
101 sparse_tag, dense_tag
102){
103 for (std::size_t i = 0; i != e1().size1(); ++i) {
104 auto row_m = row(m,i);
105 kernels::gemv(trans(e2),row(e1,i),row_m,alpha);
106 }
107}
108
109template <class E1, class E2, class M>
110void gemm(
111 matrix_expression<E1, cpu_tag> const& e1,
112 matrix_expression<E2, cpu_tag> const& e2,
113 matrix_expression<M, cpu_tag>& m,
114 typename M::value_type alpha,
115 row_major, column_major, column_major,
116 sparse_tag, dense_tag
117){
118 for (std::size_t j = 0; j != e2().size2(); ++j) {
119 auto column_m = column(m,j);
120 kernels::gemv(e1,column(e2,j),column_m,alpha);
121 }
122}
123
124template <class E1, class E2, class M>
125void gemm(
126 matrix_expression<E1, cpu_tag> const& e1,
127 matrix_expression<E2, cpu_tag> const& e2,
128 matrix_expression<M, cpu_tag>& m,
129 typename M::value_type alpha,
130 row_major, column_major, row_major,
131 sparse_tag, dense_tag
132){
133 typedef typename M::value_type value_type;
134 typedef device_traits<cpu_tag>::multiply_and_add<value_type> MultAdd;
135 for (std::size_t k = 0; k != e1().size2(); ++k) {
136 auto e1end = e1().major_end(k);
137 for(auto e1pos = e1().major_begin(k); e1pos != e1end; ++e1pos){
138 std::size_t i = e1pos.index();
139 auto row_m = row(m,i);
140 kernels::assign(row_m, row(e2,k), MultAdd(alpha * (*e1pos)));
141 }
142 }
143}
144
145// Sparse-Sparse gemm
146template<class M, class E1, class E2>
147void gemm(
148 matrix_expression<E1, cpu_tag> const& e1,
149 matrix_expression<E2, cpu_tag> const& e2,
150 matrix_expression<M, cpu_tag>& m,
151 typename M::value_type alpha,
152 row_major, row_major, row_major,
153 sparse_tag, sparse_tag
154) {
155 typedef typename M::value_type value_type;
156 value_type zero = value_type();
157 vector<value_type> temporary(e2().size2(), zero);//dense vector for quick random access
158 for (std::size_t i = 0; i != e1().size1(); ++i) {
159 kernels::gemv(trans(e2),row(e1,i),temporary,alpha);
160 auto insert_pos = m().major_begin(i);
161 for (std::size_t j = 0; j != temporary.size(); ++ j) {
162 if (temporary(j) != zero) {
163 //find element with that index
164 auto row_end = m().major_end(i);
165 while(insert_pos != row_end && insert_pos.index() < j)
166 ++insert_pos;
167 //check if element exists
168 if(insert_pos != row_end && insert_pos.index() == j){
169 *insert_pos += temporary(j);
170 }else{//create new element
171 insert_pos = m().set_element(insert_pos,j,temporary(j));
172 }
173 temporary(j) = zero; // delete element
174 }
175 }
176 }
177}
178
179template<class M, class E1, class E2>
180void gemm(
181 matrix_expression<E1, cpu_tag> const& e1,
182 matrix_expression<E2, cpu_tag> const& e2,
183 matrix_expression<M, cpu_tag>& m,
184 typename M::value_type alpha,
185 row_major, row_major, column_major,
186 sparse_tag, sparse_tag
187) {
188 for (std::size_t j = 0; j != e2().size2(); ++j) {
189 auto column_m = column(m,j);
190 kernels::gemv(e1,column(e2,j),column_m,alpha);
191 }
192}
193
194template <class E1, class E2, class M, class Orientation>
195void gemm(
196 matrix_expression<E1, cpu_tag> const& e1,
197 matrix_expression<E2, cpu_tag> const& e2,
198 matrix_expression<M, cpu_tag>& m,
199 typename M::value_type alpha,
200 row_major, column_major, Orientation o,
201 sparse_tag t1, sparse_tag t2
202){
203 //best way to compute this is to transpose e1 in memory. alternative would be
204 // to compute outer products, which is a no-no.
205 typename transposed_matrix_temporary<E1>::type e1_trans(e1);
206 gemm(e1_trans,e2,m,alpha,row_major(),row_major(),o,t1,t2);
207}
208
209}}
210
211#endif