Shark machine learning library
Installation
Tutorials
Benchmarks
Documentation
Quick references
Class list
Global functions
include
shark
ObjectiveFunctions
ErrorFunction.h
Go to the documentation of this file.
1
/*!
2
*
3
*
4
* \brief error function for supervised learning
5
*
6
*
7
*
8
* \author T.Voss, T. Glasmachers, O.Krause
9
* \date 2010-2011
10
*
11
*
12
* \par Copyright 1995-2017 Shark Development Team
13
*
14
* <BR><HR>
15
* This file is part of Shark.
16
* <https://shark-ml.github.io/Shark/>
17
*
18
* Shark is free software: you can redistribute it and/or modify
19
* it under the terms of the GNU Lesser General Public License as published
20
* by the Free Software Foundation, either version 3 of the License, or
21
* (at your option) any later version.
22
*
23
* Shark is distributed in the hope that it will be useful,
24
* but WITHOUT ANY WARRANTY; without even the implied warranty of
25
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
26
* GNU Lesser General Public License for more details.
27
*
28
* You should have received a copy of the GNU Lesser General Public License
29
* along with Shark. If not, see <http://www.gnu.org/licenses/>.
30
*
31
*/
32
#ifndef SHARK_OBJECTIVEFUNCTIONS_ERRORFUNCTION_H
33
#define SHARK_OBJECTIVEFUNCTIONS_ERRORFUNCTION_H
34
35
36
#include <
shark/Models/AbstractModel.h
>
37
#include <
shark/ObjectiveFunctions/Loss/AbstractLoss.h
>
38
#include <
shark/ObjectiveFunctions/AbstractObjectiveFunction.h
>
39
#include <
shark/Data/Dataset.h
>
40
#include <
shark/Data/WeightedDataset.h
>
41
#include "Impl/ErrorFunction.inl"
42
43
#include <boost/scoped_ptr.hpp>
44
45
namespace
shark
{
46
47
///
48
/// \brief Objective function for supervised learning
49
///
50
/// \par
51
/// An ErrorFunction object is an objective function for
52
/// learning the parameters of a model from data by means
53
/// of minimization of a cost function. The value of the
54
/// objective function is the cost of the model predictions
55
/// on the training data, given the targets.
56
/// \par
57
/// It supports mini-batch learning using an optional fourth argument to
58
/// The constructor. With mini-batch learning enabled, each iteration a random
59
/// batch is taken from the dataset. Thus the size of the minibatch is the size of the batches in
60
/// the datasets. Normalization ensures that batches of different sizes have approximately the same
61
/// magnitude of error and derivative.
62
///
63
///\par
64
/// It automatically infers the input und label type from the given dataset and the output type
65
/// of the model in the constructor and ensures that Model and loss match. Thus the user does
66
/// not need to provide the types as template parameters.
67
/// \ingroup objfunctions
68
template
<
class
SearchPo
int
Type = RealVector>
69
class
ErrorFunction
:
public
AbstractObjectiveFunction
<SearchPointType, double>
70
{
71
private
:
72
typedef
AbstractObjectiveFunction<SearchPointType, double>
FunctionType
;
73
public
:
74
typedef
typename
FunctionType::ResultType
ResultType
;
75
typedef
typename
FunctionType::FirstOrderDerivative
FirstOrderDerivative
;
76
77
template
<
class
InputType,
class
LabelType,
class
OutputType>
78
ErrorFunction
(
79
LabeledData<InputType, LabelType>
const
& dataset,
80
AbstractModel<InputType,OutputType, SearchPointType>
* model,
81
AbstractLoss<LabelType, OutputType>
* loss,
82
bool
useMiniBatches =
false
83
){
84
m_regularizer =
nullptr
;
85
mp_wrapper.reset(
new
detail::ErrorFunctionImpl<InputType,LabelType,OutputType, SearchPointType>(dataset,model,loss, useMiniBatches));
86
87
this
->
m_features
= mp_wrapper ->
features
();
88
}
89
template
<
class
InputType,
class
LabelType,
class
OutputType>
90
ErrorFunction
(
91
WeightedLabeledData<InputType, LabelType>
const
& dataset,
92
AbstractModel<InputType,OutputType, SearchPointType>
* model,
93
AbstractLoss<LabelType, OutputType>
* loss
94
){
95
m_regularizer =
nullptr
;
96
mp_wrapper.reset(
new
detail::WeightedErrorFunctionImpl<InputType,LabelType,OutputType, SearchPointType>(dataset,model,loss));
97
this
->
m_features
= mp_wrapper ->
features
();
98
}
99
ErrorFunction
(
ErrorFunction
const
& op)
100
:mp_wrapper(op.mp_wrapper->clone()){
101
this
->
m_features
= mp_wrapper ->
features
();
102
}
103
ErrorFunction
&
operator=
(
ErrorFunction
const
& op){
104
ErrorFunction
copy(op);
105
swap
(copy.mp_wrapper,mp_wrapper);
106
swap
(copy.
m_features
, this->m_features);
107
return
*
this
;
108
}
109
110
std::string
name
()
const
111
{
return
"ErrorFunction"
; }
112
113
void
setRegularizer
(
double
factor,
FunctionType
* regularizer){
114
m_regularizer = regularizer;
115
m_regularizationStrength = factor;
116
}
117
118
SearchPointType
proposeStartingPoint
()
const
{
119
return
mp_wrapper ->
proposeStartingPoint
();
120
}
121
std::size_t
numberOfVariables
()
const
{
122
return
mp_wrapper ->
numberOfVariables
();
123
}
124
125
void
init
(){
126
mp_wrapper->setRng(this->
mep_rng
);
127
mp_wrapper->
init
();
128
}
129
130
double
eval
(
SearchPointType
const
& input)
const
{
131
++this->
m_evaluationCounter
;
132
double
value = mp_wrapper ->
eval
(input);
133
if
(m_regularizer)
134
value += m_regularizationStrength * m_regularizer->
eval
(input);
135
return
value;
136
}
137
ResultType
evalDerivative
(
SearchPointType
const
& input,
FirstOrderDerivative
& derivative )
const
{
138
++this->
m_evaluationCounter
;
139
double
value = mp_wrapper ->
evalDerivative
(input,derivative);
140
if
(m_regularizer){
141
FirstOrderDerivative
regularizerDerivative;
142
value += m_regularizationStrength * m_regularizer->
evalDerivative
(input,regularizerDerivative);
143
noalias(derivative) += m_regularizationStrength*regularizerDerivative;
144
}
145
return
value;
146
}
147
private
:
148
boost::scoped_ptr<detail::FunctionWrapperBase<SearchPointType> > mp_wrapper;
149
FunctionType* m_regularizer;
150
double
m_regularizationStrength;
151
};
152
153
}
154
155
#endif