32#ifndef SHARK_DATA_BATCHINTERFACE_H
33#define SHARK_DATA_BATCHINTERFACE_H
38#include <boost/utility/enable_if.hpp>
39#include <boost/mpl/if.hpp>
47template<
class BatchType>
50 typedef BatchType type;
52 typedef typename type::reference reference;
54 typedef typename type::const_reference const_reference;
58 typedef typename type::value_type value_type;
61 typedef typename type::iterator iterator;
63 typedef typename type::const_iterator const_iterator;
67 static type
createBatch(value_type
const& input, std::size_t size = 1){
68 return type(size,input);
72 template<
class Iterator>
73 static type createBatchFromRange(Iterator
const& begin, Iterator
const& end){
75 typename type::reference c=batch[0];
77 std::copy(begin,end,batch.begin());
82 static void resize(T& batch, std::size_t
batchSize, std::size_t elements){
94 static std::size_t size(T
const& batch){
return batch.size();}
97 static typename T::reference get(T& batch, std::size_t i){
101 static typename T::const_reference get(T
const& batch, std::size_t i){
105 static typename T::iterator begin(T& batch){
106 return batch.begin();
109 static typename T::const_iterator begin(T
const& batch){
110 return batch.begin();
113 static typename T::iterator end(T& batch){
117 static typename T::const_iterator end(T
const& batch){
125template<
class Matrix>
126class MatrixRowReference:
public blas::detail::matrix_row_optimizer<
127 typename blas::closure<Matrix>::type
130 typedef typename blas::detail::matrix_row_optimizer<
131 typename blas::closure<Matrix>::type
134 typedef typename blas::vector_temporary<Matrix>::type Vector;
136 MatrixRowReference( Matrix& matrix, std::size_t i)
137 :row_type(row(matrix,i)){}
138 MatrixRowReference(row_type
const& matrixrow)
139 :row_type(matrixrow){}
142 MatrixRowReference(MatrixRowReference<M2>
const& matrixrow)
143 :row_type(matrixrow){}
146 const MatrixRowReference& operator=(
const T& argument){
147 static_cast<row_type&
>(*this)=argument;
152 return Vector(*
this);
166template<
class Matrix>
169 typedef typename blas::matrix_temporary<Matrix>::type type;
172 typedef typename blas::vector_temporary<Matrix>::type value_type;
176 typedef detail::MatrixRowReference<Matrix> reference;
178 typedef detail::MatrixRowReference<const Matrix> const_reference;
182 typedef ProxyIterator<Matrix, value_type, reference > iterator;
184 typedef ProxyIterator<const Matrix, value_type, const_reference > const_iterator;
187 template<
class Element>
188 static type
createBatch(Element
const& input, std::size_t size = 1){
189 return type(size,input.size());
192 template<
class Iterator>
193 static type createBatchFromRange(Iterator
const& pos, Iterator
const& end){
194 type batch(end - pos,pos->size());
195 std::copy(pos,end,begin(batch));
200 static void resize(Matrix& batch, std::size_t
batchSize, std::size_t elements){
204 static std::size_t size(Matrix
const& batch){
return batch.size1();}
205 static reference get( Matrix& batch, std::size_t i){
206 return reference(batch,i);
208 static const_reference get( Matrix
const& batch, std::size_t i){
209 return const_reference(batch,i);
212 static iterator begin(Matrix& batch){
213 return iterator(batch,0);
215 static const_iterator begin(Matrix
const& batch){
216 return const_iterator(batch,0);
219 static iterator end(Matrix& batch){
220 return iterator(batch,batch.size1());
222 static const_iterator end(Matrix
const& batch){
223 return const_iterator(batch,batch.size1());
240:
public std::conditional<
241 std::is_arithmetic<T>::value,
242 detail::SimpleBatch<blas::vector<T> >,
243 detail::SimpleBatch<std::vector<T> >
247template<
class T,
class Device>
248struct Batch<blas::vector<T, Device> >:
public detail::VectorBatch<blas::matrix<T, blas::row_major, Device> >{};
254 typedef shark::blas::compressed_matrix<T> type;
257 typedef shark::blas::compressed_vector<T> value_type;
272 template<
class Element>
273 static type
createBatch(Element
const& input, std::size_t size = 1){
274 return type(size,input.size());
277 template<
class Iterator>
280 std::size_t nonzeros = 0;
281 for(Iterator pos = start; pos != end; ++pos){
282 nonzeros += pos->nnz();
285 std::size_t size = end - start;
286 type batch(size,start->size(),nonzeros);
288 Iterator pos = start;
289 for(std::size_t i = 0; i != size; ++i, ++pos){
290 auto row_start = batch.major_end(i);
291 for(
auto elem_pos = pos->begin(); elem_pos != pos->end(); ++elem_pos){
292 row_start = batch.set_element(row_start, elem_pos.index(), *elem_pos);
304 static std::size_t
size(type
const& batch){
return batch.size1();}
320 return iterator(batch,batch.size1());
328struct Batch<detail::MatrixRowReference<M> >
329:
public Batch<typename detail::MatrixRowReference<M>::Vector>{};
332template<
class BatchType>
337template<
class T,
class Device>
345template<
class T,
class Tag,
class Device>
346struct BatchTraits<blas::dense_matrix_adaptor<T, blas::row_major, Tag, Device> >{
347 typedef detail::VectorBatch<blas::dense_matrix_adaptor<T, blas::row_major, Tag, Device> > type;
352struct batch_to_element{
356struct batch_to_element<T&>{
358 typedef typename BatchTraits<T>::type::value_type type;
361struct batch_to_element<T const&>{
363 typedef typename BatchTraits<T>::type::value_type type;
367struct batch_to_reference{
368 typedef typename BatchTraits<T>::type::reference type;
371struct batch_to_reference<T&>{
372 typedef typename BatchTraits<T>::type::reference type;
375struct batch_to_reference<T const&>{
376 typedef typename BatchTraits<T>::type::const_reference type;
380struct element_to_batch{
381 typedef typename Batch<T>::type type;
384struct element_to_batch<T&>{
385 typedef typename Batch<T>::type& type;
388struct element_to_batch<T const&>{
389 typedef typename Batch<T>::type
const& type;
392struct element_to_batch<detail::MatrixRowReference<M> >{
393 typedef typename Batch<typename detail::MatrixRowReference<M>::Vector>::type& type;
396struct element_to_batch<detail::MatrixRowReference<M const> >{
397 typedef typename Batch<typename detail::MatrixRowReference<M>::Vector>::type
const& type;
403template<
class T,
class Range>
414template<
class T,
class Iterator>
419template<
class BatchT>
424template<
class BatchT>
429template<
class BatchT>
434template<
class BatchT>
439template<
class BatchT>
444template<
class BatchT>
449template<
class BatchT>