31#ifndef REMORA_KERNELS_CLBLAST_CONV2D_HPP
32#define REMORA_KERNELS_CLBLAST_CONV2D_HPP
34#include "../../expression_types.hpp"
35#include "../../detail/traits.hpp"
37namespace remora{
namespace bindings {
39template<
class E1,
class E2,
class M>
41 matrix_expression<E1, gpu_tag>
const& images,
42 vector_expression<E2, gpu_tag>
const& filter,
43 matrix_expression<M, gpu_tag>& outputs,
44 std::size_t num_channels,
45 std::size_t num_filters,
46 std::size_t image_height,
47 std::size_t image_width,
48 std::size_t filter_height,
49 std::size_t filter_width,
50 std::size_t padding_height,
51 std::size_t padding_width
53 static_assert(std::is_same<typename E1::orientation, row_major>::value,
"[conv2d] Column major not implemented");
54 static_assert(std::is_same<typename E1::storage_type::storage_tag, continuous_dense_tag>::value,
"[conv2d] Subranges not implemented");
55 static_assert(std::is_same<typename M::orientation, row_major>::value,
"[conv2d] Column major not implemented");
57 static_assert(std::is_same<typename E1::value_type, typename E2::value_type>::value,
"[conv2d] Arguments do not have same value type");
58 static_assert(std::is_same<typename E1::value_type, typename M::value_type>::value,
"[conv2d] Arguments do not have same value type");
60 static_assert(std::is_base_of<dense_tag, typename E1::evaluation_category::tag>::value,
"[conv2d] images is not dense");
61 static_assert(std::is_base_of<continuous_dense_tag, typename E2::storage_type::storage_tag>::value,
"[conv2d] filter does not have continuous dense storage layout");
62 static_assert(std::is_base_of<continuous_dense_tag, typename M::storage_type::storage_tag>::value,
"[conv2d] outputs does not have dense storage layout");
65 typedef typename E1::value_type value_type;
68 auto const& images_eval = eval_expression(images);
98 std::size_t output_height = (image_height - filter_height +1 + padding_height);
99 std::size_t output_width = (image_width - filter_width +1 + padding_width);
100 std::size_t filter_size = filter_width * filter_height * num_channels;
101 std::size_t num_images = images().size1();
103 REMORA_SIZE_CHECK(outputs().size1() == images().size1());
104 REMORA_SIZE_CHECK(outputs().size2() == num_filters * output_width * output_height);
105 REMORA_SIZE_CHECK(images().size2() == num_channels * image_width * image_height);
106 REMORA_SIZE_CHECK(filter().size() == num_filters * filter_size);
132 auto storage_images = images_eval.raw_storage();
133 auto storage_filter = filter().raw_storage();
134 auto storage_outputs = outputs().raw_storage();
137 std::size_t num_multiplications = output_width * output_height;
138 std::vector<value_type> alphas(num_multiplications,value_type(1));
139 std::vector<value_type>
const& betas = alphas;
141 std::vector<std::size_t> outputs_offsets(num_multiplications);
142 std::vector<std::size_t> im_offsets(num_multiplications,0);
143 std::vector<std::size_t> filter_offsets(num_multiplications,0);
144 for(std::size_t i = 0; i != output_height; ++i){
145 for(std::size_t j = 0; j != output_width; ++j){
146 std::size_t index = i * output_width + j;
147 outputs_offsets[index] = index * num_filters;
151 for(std::size_t k = 0; k != filter_height; ++k){
153 for(std::size_t i = 0; i != output_height; ++i){
154 for(std::size_t j = 0; j != output_width; ++j){
155 std::size_t index = i * output_width + j;
156 im_offsets[index] = ((i+k) * image_width + j) * num_channels;
160 using namespace clblast;
161 cl_event*
event =
nullptr;
162 auto status = GemmBatched<value_type>(
163 Layout::kRowMajor, Transpose::kNo, Transpose::kYes,
164 num_images, num_filters, num_channels * filter_width,
166 storage_images.buffer.get(), im_offsets.data(), storage_images.leading_dimension,
167 storage_filter.buffer.get(), filter_offsets.data(), filter_size,
169 storage_outputs.buffer.get(), outputs_offsets.data(), storage_outputs.leading_dimension,
171 &outputs().queue().get(), event
174 assert(status == StatusCode::kSuccess);
175 if(k < filter_height -1){
176 for(
auto& offset: filter_offsets){
177 offset += num_channels * filter_width;