Shark machine learning library
Installation
Tutorials
Benchmarks
Documentation
Quick references
Class list
Global functions
include
shark
Core
Shape.h
Go to the documentation of this file.
1
//===========================================================================
2
/*!
3
*
4
*
5
* \brief Class Describing the Shape of an Input
6
*
7
*
8
*
9
* \author O. Krause
10
* \date 2017
11
*
12
*
13
* \par Copyright 1995-2017 Shark Development Team
14
*
15
* <BR><HR>
16
* This file is part of Shark.
17
* <https://shark-ml.github.io/Shark/>
18
*
19
* Shark is free software: you can redistribute it and/or modify
20
* it under the terms of the GNU Lesser General Public License as published
21
* by the Free Software Foundation, either version 3 of the License, or
22
* (at your option) any later version.
23
*
24
* Shark is distributed in the hope that it will be useful,
25
* but WITHOUT ANY WARRANTY; without even the implied warranty of
26
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27
* GNU Lesser General Public License for more details.
28
*
29
* You should have received a copy of the GNU Lesser General Public License
30
* along with Shark. If not, see <http://www.gnu.org/licenses/>.
31
*
32
*/
33
//===========================================================================
34
35
#ifndef SHARK_CORE_SHAPE_H
36
#define SHARK_CORE_SHAPE_H
37
38
#include <vector>
39
#include <initializer_list>
40
#include <ostream>
41
#include <boost/serialization/vector.hpp>
42
namespace
shark
{
43
44
45
/// \brief Represents the Shape of an input or output
46
///
47
/// Mostly used for vector data, the Shape describes
48
/// The expected structure of a model.
49
/// A N-D shape with shape variables (n1,n2,..nN)
50
/// expects an input of size n1*n2*...*nN which is then interpreted as tensor
51
/// with the dimensionalities n1 x n2 x ... x nN.
52
/// A batch of inputs is then treated as each element having this shape, so
53
/// the batch size is not a part of the shape.
54
///
55
/// The standard shape
56
/// is the 1-D shape just describing that a model interprets every
57
/// input as 1-D input.
58
/// A 0-D shape describes the inputs of a model where the input can not be
59
/// described by a shape, for example a class label or other scalar values are 0d shapes.
60
/// A 3-D shape could describe an image patch with rows x columns x channels.
61
///
62
/// Shapes can be flattened, this way a 3-D image patch can also be treated as a simple
63
/// vector input.
64
///
65
/// Shark currently does not enforce Shapes, it only checks that input data is compatible
66
/// to a shape, i.e. a vector has the right number of dimensions.
67
class
Shape
{
68
public
:
69
Shape
(): m_numElements(1){}
70
Shape
(std::size_t
size
): m_dims(1,
size
), m_numElements(
size
){}
71
Shape
(std::initializer_list<std::size_t> dims): m_dims(dims){
72
m_numElements = 1;
73
for
(
auto
dim: m_dims){
74
m_numElements *= dim;
75
}
76
}
77
std::size_t
size
()
const
{
78
return
m_dims.size();
79
}
80
std::size_t
operator[]
(std::size_t i)
const
{
81
return
m_dims[i];
82
}
83
std::size_t
numElements
()
const
{
84
return
m_numElements;
85
}
86
87
///\brief Returns a 1-D shape with the same number of elements
88
Shape
flatten
()
const
{
89
return
Shape
({m_numElements});
90
}
91
92
//stride of elements in memory when increasing dimension dim by 1
93
//assuming the underlying memory is contiguous
94
std::size_t
stride
(std::size_t dim)
const
{
95
std::size_t val = 1;
96
if
(
size
() == 0)
return
val;
97
for
(std::size_t i =
size
() - 1; i != dim; --i){
98
val *= m_dims[i];
99
}
100
return
val;
101
}
102
103
template
<
class
Archive>
104
void
serialize
(Archive & archive,
unsigned
int
version){
105
archive & m_dims;
106
archive & m_numElements;
107
}
108
private
:
109
std::vector<std::size_t> m_dims;
110
std::size_t m_numElements;
111
};
112
113
inline
bool
operator ==
(
Shape
const
& shape1,
Shape
const
& shape2){
114
if
(shape1.
size
() != shape2.
size
())
115
return
false
;
116
for
(std::size_t i = 0; i != shape1.
size
(); ++i){
117
if
(shape1[i] != shape2[i]){
118
return
false
;
119
}
120
}
121
return
true
;
122
}
123
124
inline
bool
operator !=
(
Shape
const
& shape1,
Shape
const
& shape2){
125
return
! (shape1 == shape2);
126
}
127
128
template
<
class
E,
class
T>
129
std::basic_ostream<E, T> &
operator <<
(std::basic_ostream<E, T> &os, Shape
const
& shape) {
130
os<<
'('
;
131
for
(std::size_t i = 0; i != shape.size(); ++i){
132
os<<shape[i];
133
if
(i != shape.size() -1)
134
os<<
", "
;
135
}
136
os<<
')'
;
137
return
os;
138
}
139
140
}
141
#endif