Shark machine learning library
Installation
Tutorials
Benchmarks
Documentation
Quick references
Class list
Global functions
include
shark
ObjectiveFunctions
Loss
AbstractLoss.h
Go to the documentation of this file.
1
//===========================================================================
2
/*!
3
*
4
*
5
* \brief super class of all loss functions
6
*
7
*
8
*
9
* \author T. Glasmachers
10
* \date 2010-2011
11
* \file
12
*
13
* \par Copyright 1995-2017 Shark Development Team
14
*
15
* <BR><HR>
16
* This file is part of Shark.
17
* <https://shark-ml.github.io/Shark/>
18
*
19
* Shark is free software: you can redistribute it and/or modify
20
* it under the terms of the GNU Lesser General Public License as published
21
* by the Free Software Foundation, either version 3 of the License, or
22
* (at your option) any later version.
23
*
24
* Shark is distributed in the hope that it will be useful,
25
* but WITHOUT ANY WARRANTY; without even the implied warranty of
26
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27
* GNU Lesser General Public License for more details.
28
*
29
* You should have received a copy of the GNU Lesser General Public License
30
* along with Shark. If not, see <http://www.gnu.org/licenses/>.
31
*
32
*/
33
34
#ifndef SHARK_OBJECTIVEFUNCTIONS_LOSS_ABSTRACTLOSS_H
35
#define SHARK_OBJECTIVEFUNCTIONS_LOSS_ABSTRACTLOSS_H
36
37
#include <
shark/ObjectiveFunctions/AbstractCost.h
>
38
#include <
shark/LinAlg/Base.h
>
39
#include <
shark/Core/Traits/ProxyReferenceTraits.h
>
40
namespace
shark
{
41
42
/// \defgroup lossfunctions Loss Functions
43
/// \brief Loss functions define loss values between a model prediction and a given label.
44
45
/// \brief Loss function interface
46
///
47
/// \par
48
/// In statistics and machine learning, a loss function encodes
49
/// the severity of getting a label wrong. This is am important
50
/// special case of a cost function (see AbstractCost), where
51
/// the cost is computed as the average loss over a set, also
52
/// known as (empirical) risk.
53
///
54
/// \par
55
/// It is generally agreed that loss values are non-negative,
56
/// and that the loss of correct prediction is zero. This rule
57
/// is not formally checked, but instead left to the various
58
/// sub-classes.
59
///
60
/// \ingroup lossfunctions
61
template
<
class
LabelT,
class
OutputT = LabelT>
62
class
AbstractLoss
:
public
AbstractCost
<LabelT, OutputT>
63
{
64
public
:
65
typedef
AbstractCost<LabelT, OutputT>
base_type
;
66
typedef
OutputT
OutputType
;
67
typedef
LabelT
LabelType
;
68
typedef
RealMatrix
MatrixType
;
69
70
typedef
typename
Batch<OutputType>::type
BatchOutputType
;
71
typedef
typename
Batch<LabelType>::type
BatchLabelType
;
72
73
/// \brief Const references to LabelType
74
typedef
typename
ConstProxyReference<LabelType const>::type
ConstLabelReference
;
75
/// \brief Const references to OutputType
76
typedef
typename
ConstProxyReference<OutputType const>::type
ConstOutputReference
;
77
78
AbstractLoss
(){
79
this->
m_features
|=
base_type::IS_LOSS_FUNCTION
;
80
}
81
82
/// \brief evaluate the loss for a batch of targets and a prediction
83
///
84
/// \param target target values
85
/// \param prediction predictions, typically made by a model
86
virtual
double
eval
( BatchLabelType
const
& target,
BatchOutputType
const
& prediction)
const
= 0;
87
88
/// \brief evaluate the loss for a target and a prediction
89
///
90
/// \param target target value
91
/// \param prediction prediction, typically made by a model
92
virtual
double
eval
(
ConstLabelReference
target,
ConstOutputReference
prediction)
const
{
93
BatchLabelType
labelBatch =
Batch<LabelType>::createBatch
(target,1);
94
getBatchElement
(labelBatch,0)=target;
95
BatchOutputType
predictionBatch =
Batch<OutputType>::createBatch
(prediction,1);
96
getBatchElement
(predictionBatch,0)=prediction;
97
return
eval
(labelBatch,predictionBatch);
98
}
99
100
/// \brief evaluate the loss and its derivative for a target and a prediction
101
///
102
/// \param target target value
103
/// \param prediction prediction, typically made by a model
104
/// \param gradient the gradient of the loss function with respect to the prediction
105
virtual
double
evalDerivative
(
ConstLabelReference
target,
ConstOutputReference
prediction,
OutputType
& gradient)
const
{
106
BatchLabelType
labelBatch =
Batch<LabelType>::createBatch
(target,1);
107
getBatchElement
(labelBatch, 0) = target;
108
BatchOutputType
predictionBatch =
Batch<OutputType>::createBatch
(prediction, 1);
109
getBatchElement
(predictionBatch, 0) = prediction;
110
BatchOutputType
gradientBatch =
Batch<OutputType>::createBatch
(gradient, 1);
111
double
ret =
evalDerivative
(labelBatch, predictionBatch, gradientBatch);
112
gradient =
getBatchElement
(gradientBatch, 0);
113
return
ret;
114
}
115
116
/// \brief evaluate the loss and its first and second derivative for a target and a prediction
117
///
118
/// \param target target value
119
/// \param prediction prediction, typically made by a model
120
/// \param gradient the gradient of the loss function with respect to the prediction
121
/// \param hessian the hessian of the loss function with respect to the prediction
122
virtual
double
evalDerivative
(
123
ConstLabelReference
target,
ConstOutputReference
prediction,
124
OutputType
& gradient,
MatrixType
& hessian
125
)
const
{
126
SHARK_FEATURE_EXCEPTION_DERIVED
(
HAS_SECOND_DERIVATIVE
);
127
return
0.0;
// dead code, prevent warning
128
}
129
130
/// \brief evaluate the loss and the derivative w.r.t. the prediction
131
///
132
/// \par
133
/// The default implementations throws an exception.
134
/// If you overwrite this method, don't forget to set
135
/// the flag HAS_FIRST_DERIVATIVE.
136
/// \param target target value
137
/// \param prediction prediction, typically made by a model
138
/// \param gradient the gradient of the loss function with respect to the prediction
139
virtual
double
evalDerivative
(
BatchLabelType
const
& target,
BatchOutputType
const
& prediction,
BatchOutputType
& gradient)
const
140
{
141
SHARK_FEATURE_EXCEPTION_DERIVED
(
HAS_FIRST_DERIVATIVE
);
142
return
0.0;
// dead code, prevent warning
143
}
144
145
/// from AbstractCost
146
///
147
/// \param targets target values
148
/// \param predictions predictions, typically made by a model
149
double
eval
(
Data<LabelType>
const
& targets,
Data<OutputType>
const
& predictions)
const
{
150
SIZE_CHECK
(predictions.
numberOfElements
() == targets.
numberOfElements
());
151
SIZE_CHECK
(predictions.
numberOfBatches
() == targets.
numberOfBatches
());
152
int
numBatches = (int) targets.
numberOfBatches
();
153
double
error = 0;
154
SHARK_PARALLEL_FOR
(
int
i = 0; i < numBatches; ++i){
155
double
batchError=
eval
(targets.
batch
(i),predictions.
batch
(i));
156
SHARK_CRITICAL_REGION
{
157
error+=batchError;
158
}
159
}
160
return
error / targets.
numberOfElements
();
161
}
162
163
/// \brief evaluate the loss for a target and a prediction
164
///
165
/// \par
166
/// convenience operator
167
///
168
/// \param target target value
169
/// \param prediction prediction, typically made by a model
170
double
operator ()
(
LabelType
const
& target,
OutputType
const
& prediction)
const
171
{
return
eval
(target, prediction); }
172
173
double
operator ()
(
BatchLabelType
const
& target,
BatchOutputType
const
& prediction)
const
174
{
return
eval
(target, prediction); }
175
176
using
base_type::operator();
177
};
178
179
180
}
181
#endif