ErrorFunction.h
Go to the documentation of this file.
1/*!
2 *
3 *
4 * \brief error function for supervised learning
5 *
6 *
7 *
8 * \author T.Voss, T. Glasmachers, O.Krause
9 * \date 2010-2011
10 *
11 *
12 * \par Copyright 1995-2017 Shark Development Team
13 *
14 * <BR><HR>
15 * This file is part of Shark.
16 * <https://shark-ml.github.io/Shark/>
17 *
18 * Shark is free software: you can redistribute it and/or modify
19 * it under the terms of the GNU Lesser General Public License as published
20 * by the Free Software Foundation, either version 3 of the License, or
21 * (at your option) any later version.
22 *
23 * Shark is distributed in the hope that it will be useful,
24 * but WITHOUT ANY WARRANTY; without even the implied warranty of
25 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
26 * GNU Lesser General Public License for more details.
27 *
28 * You should have received a copy of the GNU Lesser General Public License
29 * along with Shark. If not, see <http://www.gnu.org/licenses/>.
30 *
31 */
32#ifndef SHARK_OBJECTIVEFUNCTIONS_ERRORFUNCTION_H
33#define SHARK_OBJECTIVEFUNCTIONS_ERRORFUNCTION_H
34
35
39#include <shark/Data/Dataset.h>
41#include "Impl/ErrorFunction.inl"
42
43#include <boost/scoped_ptr.hpp>
44
45namespace shark{
46
47///
48/// \brief Objective function for supervised learning
49///
50/// \par
51/// An ErrorFunction object is an objective function for
52/// learning the parameters of a model from data by means
53/// of minimization of a cost function. The value of the
54/// objective function is the cost of the model predictions
55/// on the training data, given the targets.
56/// \par
57/// It supports mini-batch learning using an optional fourth argument to
58/// The constructor. With mini-batch learning enabled, each iteration a random
59/// batch is taken from the dataset. Thus the size of the minibatch is the size of the batches in
60/// the datasets. Normalization ensures that batches of different sizes have approximately the same
61/// magnitude of error and derivative.
62///
63///\par
64/// It automatically infers the input und label type from the given dataset and the output type
65/// of the model in the constructor and ensures that Model and loss match. Thus the user does
66/// not need to provide the types as template parameters.
67/// \ingroup objfunctions
68template<class SearchPointType = RealVector>
69class ErrorFunction : public AbstractObjectiveFunction<SearchPointType, double>
70{
71private:
73public:
76
77 template<class InputType, class LabelType, class OutputType>
82 bool useMiniBatches = false
83 ){
84 m_regularizer = nullptr;
85 mp_wrapper.reset(new detail::ErrorFunctionImpl<InputType,LabelType,OutputType, SearchPointType>(dataset,model,loss, useMiniBatches));
86
87 this -> m_features = mp_wrapper -> features();
88 }
89 template<class InputType, class LabelType, class OutputType>
94 ){
95 m_regularizer = nullptr;
96 mp_wrapper.reset(new detail::WeightedErrorFunctionImpl<InputType,LabelType,OutputType, SearchPointType>(dataset,model,loss));
97 this -> m_features = mp_wrapper -> features();
98 }
100 :mp_wrapper(op.mp_wrapper->clone()){
101 this -> m_features = mp_wrapper -> features();
102 }
104 ErrorFunction copy(op);
105 swap(copy.mp_wrapper,mp_wrapper);
106 swap(copy.m_features, this->m_features);
107 return *this;
108 }
109
110 std::string name() const
111 { return "ErrorFunction"; }
112
113 void setRegularizer(double factor, FunctionType* regularizer){
114 m_regularizer = regularizer;
115 m_regularizationStrength = factor;
116 }
117
119 return mp_wrapper -> proposeStartingPoint();
120 }
121 std::size_t numberOfVariables()const{
122 return mp_wrapper -> numberOfVariables();
123 }
124
125 void init(){
126 mp_wrapper->setRng(this->mep_rng);
127 mp_wrapper-> init();
128 }
129
130 double eval(SearchPointType const& input) const{
131 ++this->m_evaluationCounter;
132 double value = mp_wrapper -> eval(input);
133 if(m_regularizer)
134 value += m_regularizationStrength * m_regularizer->eval(input);
135 return value;
136 }
138 ++this->m_evaluationCounter;
139 double value = mp_wrapper -> evalDerivative(input,derivative);
140 if(m_regularizer){
141 FirstOrderDerivative regularizerDerivative;
142 value += m_regularizationStrength * m_regularizer->evalDerivative(input,regularizerDerivative);
143 noalias(derivative) += m_regularizationStrength*regularizerDerivative;
144 }
145 return value;
146 }
147private:
148 boost::scoped_ptr<detail::FunctionWrapperBase<SearchPointType> > mp_wrapper;
149 FunctionType* m_regularizer;
150 double m_regularizationStrength;
151};
152
153}
154
155#endif