Shark machine learning library
Installation
Tutorials
Benchmarks
Documentation
Quick references
Class list
Global functions
include
shark
Models
AbstractModel.h
Go to the documentation of this file.
1
//===========================================================================
2
/*!
3
* \author T.Glasmachers, O. Krause
4
* \date 2010
5
* \file
6
*
7
* \par Copyright 1995-2017 Shark Development Team
8
*
9
* <BR><HR>
10
* This file is part of Shark.
11
* <https://shark-ml.github.io/Shark/>
12
*
13
* Shark is free software: you can redistribute it and/or modify
14
* it under the terms of the GNU Lesser General Public License as published
15
* by the Free Software Foundation, either version 3 of the License, or
16
* (at your option) any later version.
17
*
18
* Shark is distributed in the hope that it will be useful,
19
* but WITHOUT ANY WARRANTY; without even the implied warranty of
20
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
21
* GNU Lesser General Public License for more details.
22
*
23
* You should have received a copy of the GNU Lesser General Public License
24
* along with Shark. If not, see <http://www.gnu.org/licenses/>.
25
*
26
*/
27
//===========================================================================
28
29
#ifndef SHARK_MODELS_ABSTRACTMODEL_H
30
#define SHARK_MODELS_ABSTRACTMODEL_H
31
32
/// \defgroup models Models
33
///
34
/// \brief Model classes for statistical prediction.
35
///
36
/// Models in shark define the classes that can perform statistical predictions on supplied input data.
37
/// Models can have different types of inputs and outputs so that they can be sued for classification and regression tasks.
38
#include <
shark/Core/Flags.h
>
39
#include <
shark/Core/IParameterizable.h
>
40
#include <
shark/Core/INameable.h
>
41
#include <
shark/Core/State.h
>
42
#include <
shark/Core/Shape.h
>
43
#include <
shark/Core/Random.h
>
44
#include<
shark/Data/Dataset.h
>
45
46
namespace
shark
{
47
48
///\brief Base class for all Models.
49
///
50
/// \par
51
/// A model is one of the three fundaments of supervised learning: model, error measure
52
/// and an optimization algorithm.
53
/// It is a concept of a function which performs a mapping \f$ x \rightarrow f_w(x)\f$.
54
/// In contrast to an error function it has two sets of parameters:
55
/// The first is the current point to map \f$x\f$, the others are the internal model parameters \f$w\f$
56
/// which define the mapping.
57
/// Often a model is used to find an optimal mapping for a problem, for example a function which
58
/// best fits the points of a given dataset. Therefore, AbstractModel does not only offer
59
/// the mapping itself, but also a set of special derivatives with respect to \f$ x \f$ and \f$ w \f$.
60
/// Most of the time, only the derivative with respect to \f$ w \f$ is needed, but in some special problems,
61
/// like finding optimal stimuli or stacking models, also the input derivative is needed.
62
///
63
///\par Models are optimized for batch processing. This means, that instead of only one data point at a time, it can
64
/// evaluate a big set of inputs at the same time, using optimized routines for this task.
65
///
66
/// \par
67
/// The derivatives are weighted, which means that the derivatives of every single output are added together
68
/// weighted by coefficients (see #weightedParameterDerivative). This is an optimization for the chain rule
69
/// which is very efficient to calculate most of the time.
70
///
71
/// \par
72
/// It is allowed to store intermediate values during #eval and use them to speed up calculation of
73
/// derivatives. Therefore it must be guaranteed that eval() is called before calculating derivatives.
74
/// This is no restriction, since typical error measures need the mapping itself and not only the derivative.
75
///
76
/// \par
77
/// Models have names and can be serialised and have parameters. The type of the parameter vector
78
/// can be set as third argument. By default, this is RealVector.
79
/// \ingroup models
80
template
<
class
InputTypeT,
class
OutputTypeT,
class
ParameterVectorType=RealVector>
81
class
AbstractModel
:
public
IParameterizable
<ParameterVectorType>,
public
INameable
,
public
ISerializable
82
{
83
public
:
84
/// \brief Defines the input type of the model.
85
typedef
InputTypeT
InputType
;
86
/// \brief Defines the output type of the model.
87
typedef
OutputTypeT
OutputType
;
88
/// \brief Defines the output type of the model compatible with standard functors
89
typedef
OutputType
result_type;
90
91
///\brief Defines the BaseType used by the model (this type). Useful for creating derived models
92
typedef
AbstractModel<InputTypeT,OutputTypeT,ParameterVectorType>
ModelBaseType
;
93
94
/// \brief defines the batch type of the input type.
95
///
96
/// This could for example be std::vector<InputType> but for example for RealVector it could be RealMatrix
97
typedef
typename
Batch<InputType>::type
BatchInputType
;
98
/// \brief defines the batch type of the output type
99
typedef
typename
Batch<OutputType>::type
BatchOutputType
;
100
101
102
AbstractModel
() { }
103
104
virtual
~AbstractModel
() { }
105
106
enum
Feature
{
107
HAS_FIRST_PARAMETER_DERIVATIVE
= 1,
108
HAS_FIRST_INPUT_DERIVATIVE
= 4,
109
};
110
SHARK_FEATURE_INTERFACE
;
111
112
/// \brief Returns true when the first parameter derivative is implemented.
113
bool
hasFirstParameterDerivative
()
const
{
114
return
m_features
&
HAS_FIRST_PARAMETER_DERIVATIVE
;
115
}
116
/// \brief Returns true when the first input derivative is implemented.
117
bool
hasFirstInputDerivative
()
const
{
118
return
m_features
&
HAS_FIRST_INPUT_DERIVATIVE
;
119
}
120
121
///\brief Returns the expected shape of the input.
122
virtual
Shape
inputShape
()
const
= 0;
123
///\brief Returns the shape of the output.
124
virtual
Shape
outputShape
()
const
= 0;
125
126
///\brief Creates an internal state of the model.
127
///
128
///The state is needed when the derivatives are to be
129
///calculated. Eval can store a state which is then reused to speed up
130
///the calculations of the derivatives. This also allows eval to be
131
///evaluated in parallel!
132
virtual
boost::shared_ptr<State>
createState
()
const
133
{
134
if
(
hasFirstParameterDerivative
()
135
||
hasFirstInputDerivative
())
136
{
137
throw
SHARKEXCEPTION
(
"[AbstractModel::createState] createState must be overridden by models with derivatives"
);
138
}
139
return
boost::shared_ptr<State>(
new
EmptyState
());
140
}
141
142
/// \brief From ISerializable, reads a model from an archive.
143
virtual
void
read
(
InArchive
& archive ){
144
m_features
.
read
(archive);
145
ParameterVectorType
p;
146
archive & p;
147
this->
setParameterVector
(p);
148
}
149
150
/// \brief writes a model to an archive
151
///
152
/// the default implementation just saves the parameters, not the structure!
153
virtual
void
write
(
OutArchive
& archive )
const
{
154
m_features
.
write
(archive);
155
ParameterVectorType
p = this->
parameterVector
();
156
archive & p;
157
}
158
159
/// \brief Standard interface for evaluating the response of the model to a batch of patterns.
160
///
161
/// \param patterns the inputs of the model
162
/// \param outputs the predictions or response of the model to every pattern
163
virtual
void
eval
(
BatchInputType
const
& patterns,
BatchOutputType
& outputs)
const
{
164
boost::shared_ptr<State> state =
createState
();
165
eval
(patterns,outputs,*state);
166
}
167
168
/// \brief Standard interface for evaluating the response of the model to a batch of patterns.
169
///
170
/// \param patterns the inputs of the model
171
/// \param outputs the predictions or response of the model to every pattern
172
/// \param state intermediate results stored by eval which can be reused for derivative computation.
173
virtual
void
eval
(
BatchInputType
const
& patterns,
BatchOutputType
& outputs,
State
& state)
const
= 0;
174
175
/// \brief Standard interface for evaluating the response of the model to a single pattern.
176
///
177
/// \param pattern the input of the model
178
/// \param output the prediction or response of the model to the pattern
179
virtual
void
eval
(
InputType
const
& pattern,
OutputType
& output)
const
{
180
BatchInputType
patternBatch=
Batch<InputType>::createBatch
(pattern);
181
getBatchElement
(patternBatch,0) = pattern;
182
BatchOutputType
outputBatch;
183
eval
(patternBatch,outputBatch);
184
output =
getBatchElement
(outputBatch,0);
185
}
186
187
/// \brief Model evaluation as an operator for a whole dataset. This is a convenience function
188
///
189
/// \param patterns the input of the model
190
/// \returns the responses of the model
191
Data<OutputType>
operator()
(
Data<InputType>
const
& patterns)
const
{
192
return
transform
(patterns,*
this
);
193
}
194
195
/// \brief Model evaluation as an operator for a single pattern. This is a convenience function
196
///
197
/// \param pattern the input of the model
198
/// \returns the response of the model
199
OutputType
operator()
(
InputType
const
& pattern)
const
{
200
OutputType
output;
201
eval
(pattern,output);
202
return
output;
203
}
204
205
/// \brief Model evaluation as an operator for a single pattern. This is a convenience function
206
///
207
/// \param patterns the input of the model
208
/// \returns the response of the model
209
BatchOutputType
operator()
(
BatchInputType
const
& patterns)
const
{
210
BatchOutputType
output;
211
eval
(patterns,output);
212
return
output;
213
}
214
215
/// \brief calculates the weighted sum of derivatives w.r.t the parameters.
216
///
217
/// \param pattern the patterns to evaluate
218
/// \param outputs the target outputs
219
/// \param coefficients the coefficients which are used to calculate the weighted sum for every pattern
220
/// \param state intermediate results stored by eval to speed up calculations of the derivatives
221
/// \param derivative the calculated derivative as sum over all derivates of all patterns
222
virtual
void
weightedParameterDerivative
(
223
BatchInputType
const
& pattern,
224
BatchOutputType
const
& outputs,
225
BatchOutputType
const
& coefficients,
226
State
const
& state,
227
ParameterVectorType
& derivative
228
)
const
{
229
SHARK_FEATURE_EXCEPTION
(
HAS_FIRST_PARAMETER_DERIVATIVE
);
230
}
231
232
///\brief calculates the weighted sum of derivatives w.r.t the inputs
233
///
234
/// \param pattern the patterns to evaluate
235
/// \param outputs the target outputs
236
/// \param coefficients the coefficients which are used to calculate the weighted sum for every pattern
237
/// \param state intermediate results stored by eval to sped up calculations of the derivatives
238
/// \param derivative the calculated derivative for every pattern
239
virtual
void
weightedInputDerivative
(
240
BatchInputType
const
& pattern,
241
BatchOutputType
const
& outputs,
242
BatchOutputType
const
& coefficients,
243
State
const
& state,
244
BatchInputType
& derivative
245
)
const
{
246
SHARK_FEATURE_EXCEPTION
(
HAS_FIRST_INPUT_DERIVATIVE
);
247
}
248
249
///\brief calculates weighted input and parameter derivative at the same time
250
///
251
/// Sometimes, both derivatives are needed at the same time. But sometimes, when calculating the
252
/// weighted parameter derivative, the input derivative can be calculated for free. This is for example true for
253
/// the feed-forward neural networks. However, there exists the obvious default implementation that just calculates
254
/// the derivatives one after another.
255
/// \param patterns the patterns to evaluate
256
/// \param outputs the target outputs
257
/// \param coefficients the coefficients which are used to calculate the weighted sum
258
/// \param state intermediate results stored by eval to sped up calculations of the derivatives
259
/// \param parameterDerivative the calculated parameter derivative as sum over all derivates of all patterns
260
/// \param inputDerivative the calculated derivative for every pattern
261
virtual
void
weightedDerivatives
(
262
BatchInputType
const
& patterns,
263
BatchOutputType
const
& outputs,
264
BatchOutputType
const
& coefficients,
265
State
const
& state,
266
ParameterVectorType
& parameterDerivative,
267
BatchInputType
& inputDerivative
268
)
const
{
269
weightedParameterDerivative
(patterns, outputs, coefficients,state,parameterDerivative);
270
weightedInputDerivative
(patterns, outputs, coefficients,state,inputDerivative);
271
}
272
};
273
274
275
/**
276
* \ingroup shark_globals
277
*
278
* @{
279
*/
280
281
/// \brief Initialize model parameters normally distributed.
282
///
283
/// \param model: model to be initialized
284
/// \param s: variance of mean-free normal distribution
285
template
<
class
InputType,
class
OutputType,
class
ParameterVectorType>
286
void
initRandomNormal
(
AbstractModel<InputType, OutputType, ParameterVectorType>
& model,
double
s){
287
typedef
typename
ParameterVectorType::value_type Float;
288
typedef
typename
ParameterVectorType::device_type Device;
289
auto
weights = blas::normal(
random::globalRng
, model.
numberOfParameters
(), Float(0), Float(s), Device() );
290
model.
setParameterVector
(weights);
291
}
292
293
294
/// \brief Initialize model parameters uniformly at random.
295
///
296
/// \param model model to be initialized
297
/// \param lower lower bound of initialization interval
298
/// \param upper upper bound of initialization interval
299
template
<
class
InputType,
class
OutputType,
class
ParameterVectorType>
300
void
initRandomUniform
(
AbstractModel<InputType, OutputType, ParameterVectorType>
& model,
double
lower,
double
upper){
301
typedef
typename
ParameterVectorType::value_type Float;
302
typedef
typename
ParameterVectorType::device_type Device;
303
auto
weights = blas::uniform(
random::globalRng
, model.
numberOfParameters
(), Float(lower), Float(upper), Device() );
304
model.
setParameterVector
(weights);
305
}
306
307
/** @}*/
308
309
namespace
detail{
310
//Required for correct shape infering of transform
311
template
<
class
I,
class
O,
class
V>
312
struct
InferShape<AbstractModel<I,O,V> >{
313
static
Shape infer(AbstractModel<I,O,V>
const
& f){
return
f.outputShape();}
314
};
315
316
}
317
318
}
319
320
321
#endif