AbstractModel.h
Go to the documentation of this file.
1//===========================================================================
2/*!
3 * \author T.Glasmachers, O. Krause
4 * \date 2010
5 * \file
6 *
7 * \par Copyright 1995-2017 Shark Development Team
8 *
9 * <BR><HR>
10 * This file is part of Shark.
11 * <https://shark-ml.github.io/Shark/>
12 *
13 * Shark is free software: you can redistribute it and/or modify
14 * it under the terms of the GNU Lesser General Public License as published
15 * by the Free Software Foundation, either version 3 of the License, or
16 * (at your option) any later version.
17 *
18 * Shark is distributed in the hope that it will be useful,
19 * but WITHOUT ANY WARRANTY; without even the implied warranty of
20 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
21 * GNU Lesser General Public License for more details.
22 *
23 * You should have received a copy of the GNU Lesser General Public License
24 * along with Shark. If not, see <http://www.gnu.org/licenses/>.
25 *
26 */
27//===========================================================================
28
29#ifndef SHARK_MODELS_ABSTRACTMODEL_H
30#define SHARK_MODELS_ABSTRACTMODEL_H
31
32/// \defgroup models Models
33///
34/// \brief Model classes for statistical prediction.
35///
36/// Models in shark define the classes that can perform statistical predictions on supplied input data.
37/// Models can have different types of inputs and outputs so that they can be sued for classification and regression tasks.
38#include <shark/Core/Flags.h>
41#include <shark/Core/State.h>
42#include <shark/Core/Shape.h>
43#include <shark/Core/Random.h>
45
46namespace shark {
47
48///\brief Base class for all Models.
49///
50/// \par
51/// A model is one of the three fundaments of supervised learning: model, error measure
52/// and an optimization algorithm.
53/// It is a concept of a function which performs a mapping \f$ x \rightarrow f_w(x)\f$.
54/// In contrast to an error function it has two sets of parameters:
55/// The first is the current point to map \f$x\f$, the others are the internal model parameters \f$w\f$
56/// which define the mapping.
57/// Often a model is used to find an optimal mapping for a problem, for example a function which
58/// best fits the points of a given dataset. Therefore, AbstractModel does not only offer
59/// the mapping itself, but also a set of special derivatives with respect to \f$ x \f$ and \f$ w \f$.
60/// Most of the time, only the derivative with respect to \f$ w \f$ is needed, but in some special problems,
61/// like finding optimal stimuli or stacking models, also the input derivative is needed.
62///
63///\par Models are optimized for batch processing. This means, that instead of only one data point at a time, it can
64/// evaluate a big set of inputs at the same time, using optimized routines for this task.
65///
66/// \par
67/// The derivatives are weighted, which means that the derivatives of every single output are added together
68/// weighted by coefficients (see #weightedParameterDerivative). This is an optimization for the chain rule
69/// which is very efficient to calculate most of the time.
70///
71/// \par
72/// It is allowed to store intermediate values during #eval and use them to speed up calculation of
73/// derivatives. Therefore it must be guaranteed that eval() is called before calculating derivatives.
74/// This is no restriction, since typical error measures need the mapping itself and not only the derivative.
75///
76/// \par
77/// Models have names and can be serialised and have parameters. The type of the parameter vector
78/// can be set as third argument. By default, this is RealVector.
79/// \ingroup models
80template<class InputTypeT, class OutputTypeT, class ParameterVectorType=RealVector>
81class AbstractModel : public IParameterizable<ParameterVectorType>, public INameable, public ISerializable
82{
83public:
84 /// \brief Defines the input type of the model.
85 typedef InputTypeT InputType;
86 /// \brief Defines the output type of the model.
87 typedef OutputTypeT OutputType;
88 /// \brief Defines the output type of the model compatible with standard functors
89 typedef OutputType result_type;
90
91 ///\brief Defines the BaseType used by the model (this type). Useful for creating derived models
93
94 /// \brief defines the batch type of the input type.
95 ///
96 /// This could for example be std::vector<InputType> but for example for RealVector it could be RealMatrix
98 /// \brief defines the batch type of the output type
100
101
103
104 virtual ~AbstractModel() { }
105
111
112 /// \brief Returns true when the first parameter derivative is implemented.
116 /// \brief Returns true when the first input derivative is implemented.
120
121 ///\brief Returns the expected shape of the input.
122 virtual Shape inputShape() const = 0;
123 ///\brief Returns the shape of the output.
124 virtual Shape outputShape() const = 0;
125
126 ///\brief Creates an internal state of the model.
127 ///
128 ///The state is needed when the derivatives are to be
129 ///calculated. Eval can store a state which is then reused to speed up
130 ///the calculations of the derivatives. This also allows eval to be
131 ///evaluated in parallel!
132 virtual boost::shared_ptr<State> createState() const
133 {
136 {
137 throw SHARKEXCEPTION("[AbstractModel::createState] createState must be overridden by models with derivatives");
138 }
139 return boost::shared_ptr<State>(new EmptyState());
140 }
141
142 /// \brief From ISerializable, reads a model from an archive.
143 virtual void read( InArchive & archive ){
144 m_features.read(archive);
146 archive & p;
147 this->setParameterVector(p);
148 }
149
150 /// \brief writes a model to an archive
151 ///
152 /// the default implementation just saves the parameters, not the structure!
153 virtual void write( OutArchive & archive ) const{
154 m_features.write(archive);
156 archive & p;
157 }
158
159 /// \brief Standard interface for evaluating the response of the model to a batch of patterns.
160 ///
161 /// \param patterns the inputs of the model
162 /// \param outputs the predictions or response of the model to every pattern
163 virtual void eval(BatchInputType const & patterns, BatchOutputType& outputs) const{
164 boost::shared_ptr<State> state = createState();
165 eval(patterns,outputs,*state);
166 }
167
168 /// \brief Standard interface for evaluating the response of the model to a batch of patterns.
169 ///
170 /// \param patterns the inputs of the model
171 /// \param outputs the predictions or response of the model to every pattern
172 /// \param state intermediate results stored by eval which can be reused for derivative computation.
173 virtual void eval(BatchInputType const & patterns, BatchOutputType& outputs, State& state) const = 0;
174
175 /// \brief Standard interface for evaluating the response of the model to a single pattern.
176 ///
177 /// \param pattern the input of the model
178 /// \param output the prediction or response of the model to the pattern
179 virtual void eval(InputType const & pattern, OutputType& output)const{
180 BatchInputType patternBatch=Batch<InputType>::createBatch(pattern);
181 getBatchElement(patternBatch,0) = pattern;
182 BatchOutputType outputBatch;
183 eval(patternBatch,outputBatch);
184 output = getBatchElement(outputBatch,0);
185 }
186
187 /// \brief Model evaluation as an operator for a whole dataset. This is a convenience function
188 ///
189 /// \param patterns the input of the model
190 /// \returns the responses of the model
192 return transform(patterns,*this);
193 }
194
195 /// \brief Model evaluation as an operator for a single pattern. This is a convenience function
196 ///
197 /// \param pattern the input of the model
198 /// \returns the response of the model
199 OutputType operator()(InputType const & pattern)const{
200 OutputType output;
201 eval(pattern,output);
202 return output;
203 }
204
205 /// \brief Model evaluation as an operator for a single pattern. This is a convenience function
206 ///
207 /// \param patterns the input of the model
208 /// \returns the response of the model
210 BatchOutputType output;
211 eval(patterns,output);
212 return output;
213 }
214
215 /// \brief calculates the weighted sum of derivatives w.r.t the parameters.
216 ///
217 /// \param pattern the patterns to evaluate
218 /// \param outputs the target outputs
219 /// \param coefficients the coefficients which are used to calculate the weighted sum for every pattern
220 /// \param state intermediate results stored by eval to speed up calculations of the derivatives
221 /// \param derivative the calculated derivative as sum over all derivates of all patterns
223 BatchInputType const & pattern,
224 BatchOutputType const& outputs,
225 BatchOutputType const & coefficients,
226 State const& state,
227 ParameterVectorType& derivative
228 )const{
230 }
231
232 ///\brief calculates the weighted sum of derivatives w.r.t the inputs
233 ///
234 /// \param pattern the patterns to evaluate
235 /// \param outputs the target outputs
236 /// \param coefficients the coefficients which are used to calculate the weighted sum for every pattern
237 /// \param state intermediate results stored by eval to sped up calculations of the derivatives
238 /// \param derivative the calculated derivative for every pattern
240 BatchInputType const & pattern,
241 BatchOutputType const& outputs,
242 BatchOutputType const & coefficients,
243 State const& state,
244 BatchInputType& derivative
245 )const{
247 }
248
249 ///\brief calculates weighted input and parameter derivative at the same time
250 ///
251 /// Sometimes, both derivatives are needed at the same time. But sometimes, when calculating the
252 /// weighted parameter derivative, the input derivative can be calculated for free. This is for example true for
253 /// the feed-forward neural networks. However, there exists the obvious default implementation that just calculates
254 /// the derivatives one after another.
255 /// \param patterns the patterns to evaluate
256 /// \param outputs the target outputs
257 /// \param coefficients the coefficients which are used to calculate the weighted sum
258 /// \param state intermediate results stored by eval to sped up calculations of the derivatives
259 /// \param parameterDerivative the calculated parameter derivative as sum over all derivates of all patterns
260 /// \param inputDerivative the calculated derivative for every pattern
262 BatchInputType const & patterns,
263 BatchOutputType const& outputs,
264 BatchOutputType const & coefficients,
265 State const& state,
266 ParameterVectorType& parameterDerivative,
267 BatchInputType& inputDerivative
268 )const{
269 weightedParameterDerivative(patterns, outputs, coefficients,state,parameterDerivative);
270 weightedInputDerivative(patterns, outputs, coefficients,state,inputDerivative);
271 }
272};
273
274
275/**
276 * \ingroup shark_globals
277 *
278 * @{
279 */
280
281/// \brief Initialize model parameters normally distributed.
282///
283/// \param model: model to be initialized
284/// \param s: variance of mean-free normal distribution
285template <class InputType, class OutputType, class ParameterVectorType>
287 typedef typename ParameterVectorType::value_type Float;
288 typedef typename ParameterVectorType::device_type Device;
289 auto weights = blas::normal(random::globalRng, model.numberOfParameters(), Float(0), Float(s), Device() );
290 model.setParameterVector(weights);
291}
292
293
294/// \brief Initialize model parameters uniformly at random.
295///
296/// \param model model to be initialized
297/// \param lower lower bound of initialization interval
298/// \param upper upper bound of initialization interval
299template <class InputType, class OutputType, class ParameterVectorType>
301 typedef typename ParameterVectorType::value_type Float;
302 typedef typename ParameterVectorType::device_type Device;
303 auto weights = blas::uniform(random::globalRng, model.numberOfParameters(), Float(lower), Float(upper), Device() );
304 model.setParameterVector(weights);
305}
306
307/** @}*/
308
309namespace detail{
310//Required for correct shape infering of transform
311template<class I, class O, class V>
312struct InferShape<AbstractModel<I,O,V> >{
313 static Shape infer(AbstractModel<I,O,V> const& f){return f.outputShape();}
314};
315
316}
317
318}
319
320
321#endif