Shark machine learning library
Installation
Tutorials
Benchmarks
Documentation
Quick references
Class list
Global functions
include
shark
Models
PoolingLayer.h
Go to the documentation of this file.
1
/*!
2
*
3
*
4
* \brief Creates pooling layers
5
*
6
* \author O.Krause
7
* \date 2018
8
*
9
*
10
* \par Copyright 1995-2017 Shark Development Team
11
*
12
* <BR><HR>
13
* This file is part of Shark.
14
* <https://shark-ml.github.io/Shark/>
15
*
16
* Shark is free software: you can redistribute it and/or modify
17
* it under the terms of the GNU Lesser General Public License as published
18
* by the Free Software Foundation, either version 3 of the License, or
19
* (at your option) any later version.
20
*
21
* Shark is distributed in the hope that it will be useful,
22
* but WITHOUT ANY WARRANTY; without even the implied warranty of
23
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
24
* GNU Lesser General Public License for more details.
25
*
26
* You should have received a copy of the GNU Lesser General Public License
27
* along with Shark. If not, see <http://www.gnu.org/licenses/>.
28
*
29
*/
30
#ifndef SHARK_MODELS_POOLING_LAYER_H
31
#define SHARK_MODELS_POOLING_LAYER_H
32
33
#include <
shark/LinAlg/Base.h
>
34
#include <
shark/Models/AbstractModel.h
>
35
#include <
shark/Core/Images/Padding.h
>
36
#include <
shark/Core/Images/CPU/Pooling.h
>
37
#ifdef SHARK_USE_OPENCL
38
#include <
shark/Core/Images/OpenCL/Pooling.h
>
39
#endif
40
41
namespace
shark
{
42
43
enum class
Pooling
{
44
Maximum
45
};
46
47
/// \brief Performs Pooling operations for a given input image.
48
///
49
/// Pooling partitions the input images in rectangular regions, typically 2x2 and computes
50
/// a statistic over the data. This could for example be the maximum or average of values. This is
51
/// done channel-by-channel. The output is a smaller image where each pixel includes
52
/// for each channel the computed statistic. Therefore, if the patch is 2x2 the output image will have half width and height.
53
///
54
/// \ingroup models
55
template
<
class
VectorType = RealVector>
56
class
PoolingLayer
:
public
AbstractModel
<VectorType, VectorType, VectorType>{
57
private
:
58
typedef
AbstractModel<VectorType,VectorType, VectorType>
base_type
;
59
typedef
typename
VectorType::value_type value_type;
60
public
:
61
typedef
typename
base_type::BatchInputType
BatchInputType
;
62
typedef
typename
base_type::BatchOutputType
BatchOutputType
;
63
typedef
typename
base_type::ParameterVectorType
ParameterVectorType
;
64
65
PoolingLayer
(){
66
base_type::m_features
|=
base_type::HAS_FIRST_PARAMETER_DERIVATIVE
;
67
base_type::m_features
|=
base_type::HAS_FIRST_INPUT_DERIVATIVE
;
68
}
69
70
PoolingLayer
(
Shape
const
&
inputShape
,
Shape
const
& patchShape,
Pooling
pooling =
Pooling::Maximum
,
Padding
padding =
Padding::Valid
){
71
base_type::m_features
|=
base_type::HAS_FIRST_PARAMETER_DERIVATIVE
;
72
base_type::m_features
|=
base_type::HAS_FIRST_INPUT_DERIVATIVE
;
73
setStructure
(
inputShape
, patchShape, pooling, padding);
74
}
75
76
/// \brief From INameable: return the class name.
77
std::string
name
()
const
78
{
return
"NeuronLayer"
; }
79
80
Shape
inputShape
()
const
{
81
return
m_inputShape;
82
}
83
84
Shape
outputShape
()
const
{
85
return
m_outputShape;
86
}
87
88
/// obtain the parameter vector
89
ParameterVectorType
parameterVector
()
const
{
90
return
ParameterVectorType
();
91
}
92
93
/// overwrite the parameter vector
94
void
setParameterVector
(
ParameterVectorType
const
& newParameters){
95
SIZE_CHECK
(newParameters.size() == 0);
96
}
97
98
/// returns the number of parameters
99
size_t
numberOfParameters
()
const
{
100
return
0;
101
}
102
103
boost::shared_ptr<State>
createState
()
const
{
104
return
boost::shared_ptr<State>(
new
EmptyState
());
105
}
106
107
///\brief Configures the model.
108
///
109
/// \arg inputShape Shape of the image imHeight x imWidth x channel
110
/// \arg outputShape Shape of the resized output imHeight x imWidth
111
/// \arg type Type of interpolation to perform, default is Spline-Interpolation
112
void
setStructure
(
113
Shape
const
&
inputShape
,
Shape
const
& patchShape,
Pooling
type =
Pooling::Maximum
,
Padding
padding =
Padding::Valid
114
){
115
SHARK_RUNTIME_CHECK
( padding ==
Padding::Valid
,
"Padding not implemented"
);
116
m_inputShape =
inputShape
;
117
m_patch = patchShape;
118
m_padding = padding;
119
m_type = type;
120
if
(m_padding ==
Padding::Valid
)
121
m_outputShape = {m_inputShape[0]/m_patch[0], m_inputShape[1]/m_patch[1], m_inputShape[2]};
122
else
123
m_outputShape = {
124
(m_inputShape[0] + m_patch[0] - 1)/m_patch[0],
125
(m_inputShape[1] + m_patch[1] - 1)/m_patch[1],
126
m_inputShape[2]
127
};
128
}
129
130
using
base_type::eval
;
131
132
void
eval
(
BatchInputType
const
& inputs,
BatchOutputType
& outputs,
State
& state)
const
{
133
SIZE_CHECK
(inputs.size2() == m_inputShape.
numElements
());
134
outputs.resize(inputs.size1(),m_outputShape.
numElements
());
135
switch
(m_type){
136
case
Pooling::Maximum
:
137
image::maxPooling<value_type>(inputs, m_inputShape, m_patch, outputs);
138
break
;
139
}
140
}
141
142
///\brief Calculates the first derivative w.r.t the parameters and summing them up over all inputs of the last computed batch
143
void
weightedParameterDerivative
(
144
BatchInputType
const
& inputs,
145
BatchOutputType
const
& outputs,
146
BatchOutputType
const
& coefficients,
147
State
const
& state,
148
ParameterVectorType
& gradient
149
)
const
{
150
SIZE_CHECK
(coefficients.size1()==outputs.size1());
151
SIZE_CHECK
(coefficients.size2()==outputs.size2());
152
gradient.resize(0);
153
}
154
///\brief Calculates the first derivative w.r.t the inputs and summs them up over all inputs of the last computed batch
155
void
weightedInputDerivative
(
156
BatchInputType
const
& inputs,
157
BatchOutputType
const
& outputs,
158
BatchOutputType
const
& coefficients,
159
State
const
& state,
160
BatchInputType
& derivative
161
)
const
{
162
SIZE_CHECK
(coefficients.size1() == outputs.size1());
163
SIZE_CHECK
(coefficients.size2() == outputs.size2());
164
derivative.resize(inputs.size1(),inputs.size2());
165
switch
(m_type){
166
case
Pooling::Maximum
:
167
image::maxPoolingDerivative<value_type>(inputs, coefficients, m_inputShape, m_patch, derivative);
168
break
;
169
}
170
}
171
172
/// From ISerializable
173
void
read
(
InArchive
& archive){
174
archive >> m_inputShape;
175
archive >> m_outputShape;
176
archive >> m_patch;
177
archive >> (
int
&)m_padding;
178
archive >> (
int
&)m_type;
179
}
180
/// From ISerializable
181
void
write
(
OutArchive
& archive)
const
{
182
archive << m_inputShape;
183
archive << m_outputShape;
184
archive << m_patch;
185
archive << (
int
&)m_padding;
186
archive << (
int
&)m_type;
187
}
188
private
:
189
Shape
m_inputShape;
190
Shape
m_outputShape;
191
Shape
m_patch;
192
Padding
m_padding;
193
Pooling
m_type;
194
};
195
196
197
}
198
199
#endif