Shape.h
Go to the documentation of this file.
1//===========================================================================
2/*!
3 *
4 *
5 * \brief Class Describing the Shape of an Input
6 *
7 *
8 *
9 * \author O. Krause
10 * \date 2017
11 *
12 *
13 * \par Copyright 1995-2017 Shark Development Team
14 *
15 * <BR><HR>
16 * This file is part of Shark.
17 * <https://shark-ml.github.io/Shark/>
18 *
19 * Shark is free software: you can redistribute it and/or modify
20 * it under the terms of the GNU Lesser General Public License as published
21 * by the Free Software Foundation, either version 3 of the License, or
22 * (at your option) any later version.
23 *
24 * Shark is distributed in the hope that it will be useful,
25 * but WITHOUT ANY WARRANTY; without even the implied warranty of
26 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27 * GNU Lesser General Public License for more details.
28 *
29 * You should have received a copy of the GNU Lesser General Public License
30 * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31 *
32 */
33//===========================================================================
34
35#ifndef SHARK_CORE_SHAPE_H
36#define SHARK_CORE_SHAPE_H
37
38#include <vector>
39#include <initializer_list>
40#include <ostream>
41#include <boost/serialization/vector.hpp>
42namespace shark{
43
44
45/// \brief Represents the Shape of an input or output
46///
47/// Mostly used for vector data, the Shape describes
48/// The expected structure of a model.
49/// A N-D shape with shape variables (n1,n2,..nN)
50/// expects an input of size n1*n2*...*nN which is then interpreted as tensor
51/// with the dimensionalities n1 x n2 x ... x nN.
52/// A batch of inputs is then treated as each element having this shape, so
53/// the batch size is not a part of the shape.
54///
55/// The standard shape
56/// is the 1-D shape just describing that a model interprets every
57/// input as 1-D input.
58/// A 0-D shape describes the inputs of a model where the input can not be
59/// described by a shape, for example a class label or other scalar values are 0d shapes.
60/// A 3-D shape could describe an image patch with rows x columns x channels.
61///
62/// Shapes can be flattened, this way a 3-D image patch can also be treated as a simple
63/// vector input.
64///
65/// Shark currently does not enforce Shapes, it only checks that input data is compatible
66/// to a shape, i.e. a vector has the right number of dimensions.
67class Shape{
68public:
69 Shape(): m_numElements(1){}
70 Shape(std::size_t size): m_dims(1,size), m_numElements(size){}
71 Shape(std::initializer_list<std::size_t> dims): m_dims(dims){
72 m_numElements = 1;
73 for(auto dim: m_dims){
74 m_numElements *= dim;
75 }
76 }
77 std::size_t size()const{
78 return m_dims.size();
79 }
80 std::size_t operator[](std::size_t i) const{
81 return m_dims[i];
82 }
83 std::size_t numElements()const{
84 return m_numElements;
85 }
86
87 ///\brief Returns a 1-D shape with the same number of elements
88 Shape flatten() const{
89 return Shape({m_numElements});
90 }
91
92 //stride of elements in memory when increasing dimension dim by 1
93 //assuming the underlying memory is contiguous
94 std::size_t stride(std::size_t dim) const{
95 std::size_t val = 1;
96 if(size() == 0) return val;
97 for(std::size_t i = size() - 1; i != dim; --i){
98 val *= m_dims[i];
99 }
100 return val;
101 }
102
103 template<class Archive>
104 void serialize(Archive & archive,unsigned int version){
105 archive & m_dims;
106 archive & m_numElements;
107 }
108private:
109 std::vector<std::size_t> m_dims;
110 std::size_t m_numElements;
111};
112
113inline bool operator == (Shape const& shape1, Shape const& shape2){
114 if(shape1.size() != shape2.size())
115 return false;
116 for(std::size_t i = 0; i != shape1.size(); ++i){
117 if(shape1[i] != shape2[i]){
118 return false;
119 }
120 }
121 return true;
122}
123
124inline bool operator != (Shape const& shape1, Shape const& shape2){
125 return ! (shape1 == shape2);
126}
127
128template<class E, class T>
129std::basic_ostream<E, T> &operator << (std::basic_ostream<E, T> &os, Shape const& shape) {
130 os<<'(';
131 for(std::size_t i = 0; i != shape.size(); ++i){
132 os<<shape[i];
133 if(i != shape.size() -1)
134 os<<", ";
135 }
136 os<<')';
137 return os;
138}
139
140}
141#endif