Shark machine learning library
Installation
Tutorials
Benchmarks
Documentation
Quick references
Class list
Global functions
include
shark
ObjectiveFunctions
AbstractCost.h
Go to the documentation of this file.
1
//===========================================================================
2
/*!
3
*
4
*
5
* \brief cost function for quantitative judgement of deviations of predictions from target values
6
* \file
7
*
8
*
9
* \author T. Glasmachers
10
* \date 2011
11
* \file
12
*
13
* \par Copyright 1995-2017 Shark Development Team
14
*
15
* <BR><HR>
16
* This file is part of Shark.
17
* <https://shark-ml.github.io/Shark/>
18
*
19
* Shark is free software: you can redistribute it and/or modify
20
* it under the terms of the GNU Lesser General Public License as published
21
* by the Free Software Foundation, either version 3 of the License, or
22
* (at your option) any later version.
23
*
24
* Shark is distributed in the hope that it will be useful,
25
* but WITHOUT ANY WARRANTY; without even the implied warranty of
26
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27
* GNU Lesser General Public License for more details.
28
*
29
* You should have received a copy of the GNU Lesser General Public License
30
* along with Shark. If not, see <http://www.gnu.org/licenses/>.
31
*
32
*/
33
34
#ifndef SHARK_OBJECTIVEFUNCTIONS_ABSTRACTCOST_H
35
#define SHARK_OBJECTIVEFUNCTIONS_ABSTRACTCOST_H
36
37
38
#include <
shark/LinAlg/Base.h
>
39
#include <
shark/Core/INameable.h
>
40
#include <
shark/Core/Flags.h
>
41
#include <
shark/Data/Dataset.h
>
42
43
namespace
shark
{
44
45
/// \defgroup costfunctions Cost functions
46
/// \brief Defines cost functions used for optimization
47
///
48
/// Unlike \ref lossfunctions, cost functions are defined on a set of points.
49
50
51
/// \brief Cost function interface
52
///
53
/// \par
54
/// In Shark a cost function encodes the severity of a deviation
55
/// of predictions from targets. This concept is more general than
56
/// that or a loss function, because it does not necessarily amount
57
/// to (uniformly) averaging a loss function over samples.
58
/// In general, the loss depends on the true (training) label and
59
/// the prediction in a not necessarily symmetric way. Also, in
60
/// the most general case predictions can be in a different format
61
/// than labels. E.g., the model prediction could be a probability
62
/// distribution, while the label is a single value.
63
///
64
/// \par
65
/// The concept of an AbstractCost function is different from that
66
/// encoded by the ErrorFunction class. A cost function compares
67
/// model predictions to labels. It does not know about the model
68
/// making the predictions, and thus it can not handle LabeledData
69
/// directly. However, it is one of the components necessary to
70
/// process LabeledData in an ErrorFunction.
71
/// \ingroup costfunctions
72
template
<
class
LabelT,
class
OutputT = LabelT>
73
class
AbstractCost
:
public
INameable
74
{
75
public
:
76
typedef
OutputT
OutputType
;
77
typedef
LabelT
LabelType
;
78
typedef
typename
Batch<OutputType>::type
BatchOutputType
;
79
typedef
typename
Batch<LabelType>::type
BatchLabelType
;
80
81
virtual
~AbstractCost
()
82
{ }
83
84
/// list of features a cost function can have
85
enum
Feature
{
86
HAS_FIRST_DERIVATIVE
= 1,
87
HAS_SECOND_DERIVATIVE
= 2,
88
IS_LOSS_FUNCTION
= 4,
89
};
90
91
SHARK_FEATURE_INTERFACE
;
92
93
/// returns true when the first parameter derivative is implemented
94
bool
hasFirstDerivative
()
const
{
95
return
m_features
&
HAS_FIRST_DERIVATIVE
;
96
}
97
//~ /// returns true when the second parameter derivative is implemented
98
//~ bool hasSecondDerivative() const{
99
//~ return m_features & HAS_SECOND_DERIVATIVE;
100
//~ }
101
102
/// returns true when the cost function is in fact a loss function
103
bool
isLossFunction
()
const
{
104
return
m_features
&
IS_LOSS_FUNCTION
;
105
}
106
107
/// Evaluates the cost of predictions, given targets.
108
/// \param targets target values
109
/// \param predictions predictions, typically made by a model
110
virtual
double
eval
(
Data<LabelType>
const
& targets,
Data<OutputType>
const
& predictions)
const
= 0;
111
112
/// Evaluates the cost of predictions, given targets.
113
/// \param targets Targets of the predictions
114
/// \param predictions Predictions to be compared with the targets
115
/// \return The costs (e.g., error, regret)
116
double
operator ()
(
Data<LabelType>
const
& targets,
Data<OutputType>
const
& predictions)
const
117
{
return
eval
(targets, predictions); }
118
};
119
120
121
}
122
#endif