Shark machine learning library
Installation
Tutorials
Benchmarks
Documentation
Quick references
Class list
Global functions
include
shark
ObjectiveFunctions
CrossValidationError.h
Go to the documentation of this file.
1
//===========================================================================
2
/*!
3
*
4
*
5
* \brief cross-validation error for selection of hyper-parameters
6
7
8
*
9
*
10
* \author T. Glasmachers, O. Krause
11
* \date 2007-2012
12
*
13
*
14
* \par Copyright 1995-2017 Shark Development Team
15
*
16
* <BR><HR>
17
* This file is part of Shark.
18
* <https://shark-ml.github.io/Shark/>
19
*
20
* Shark is free software: you can redistribute it and/or modify
21
* it under the terms of the GNU Lesser General Public License as published
22
* by the Free Software Foundation, either version 3 of the License, or
23
* (at your option) any later version.
24
*
25
* Shark is distributed in the hope that it will be useful,
26
* but WITHOUT ANY WARRANTY; without even the implied warranty of
27
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
28
* GNU Lesser General Public License for more details.
29
*
30
* You should have received a copy of the GNU Lesser General Public License
31
* along with Shark. If not, see <http://www.gnu.org/licenses/>.
32
*
33
*/
34
//===========================================================================
35
36
#ifndef SHARK_OBJECTIVEFUNCTIONS_CROSSVALIDATIONERROR_H
37
#define SHARK_OBJECTIVEFUNCTIONS_CROSSVALIDATIONERROR_H
38
39
#include <
shark/ObjectiveFunctions/AbstractObjectiveFunction.h
>
40
#include <
shark/Algorithms/Trainers/AbstractTrainer.h
>
41
#include <
shark/Algorithms/AbstractSingleObjectiveOptimizer.h
>
42
#include <
shark/ObjectiveFunctions/AbstractCost.h
>
43
#include <
shark/Data/CVDatasetTools.h
>
44
45
namespace
shark
{
46
47
48
///
49
/// \brief Cross-validation error for selection of hyper-parameters.
50
///
51
/// \par
52
/// The cross-validation error is useful for evaluating
53
/// how well a model performs on a problem. It is regularly
54
/// used for model selection.
55
///
56
/// \par
57
/// In Shark, the cross-validation procedure is abstracted
58
/// as follows:
59
/// First, the given point is written into an IParameterizable
60
/// object (such as a regularizer or a trainer). Then a model
61
/// is trained with a trainer with the given settings on a
62
/// number of folds and evaluated on the corresponding validation
63
/// sets with a cost function. The average cost function value
64
/// over all folds is returned.
65
///
66
/// \par
67
/// Thus, the cross-validation procedure requires a "meta"
68
/// IParameterizable object, a model, a trainer, a data set,
69
/// and a cost function.
70
/// \ingroup objfunctions
71
template
<
class
ModelTypeT,
class
LabelTypeT =
typename
ModelTypeT::OutputType>
72
class
CrossValidationError
:
public
AbstractObjectiveFunction
< RealVector, double >
73
{
74
public
:
75
typedef
typename
ModelTypeT::InputType
InputType
;
76
typedef
typename
ModelTypeT::OutputType
OutputType
;
77
typedef
LabelTypeT
LabelType
;
78
typedef
LabeledData<InputType, LabelType>
DatasetType
;
79
typedef
CVFolds<DatasetType>
FoldsType
;
80
typedef
ModelTypeT
ModelType
;
81
typedef
AbstractTrainer<ModelType, LabelType>
TrainerType
;
82
typedef
AbstractCost<LabelType, OutputType>
CostType
;
83
private
:
84
typedef
SingleObjectiveFunction
base_type
;
85
86
87
FoldsType
m_folds;
88
IParameterizable<>
* mep_meta;
89
ModelType
* mep_model;
90
TrainerType
* mep_trainer;
91
CostType
* mep_cost;
92
93
public
:
94
95
CrossValidationError
(
96
FoldsType
const
& dataFolds,
97
IParameterizable<>
* meta,
98
ModelType
* model,
99
TrainerType
* trainer,
100
CostType
* cost)
101
: m_folds(dataFolds)
102
, mep_meta(meta)
103
, mep_model(model)
104
, mep_trainer(trainer)
105
, mep_cost(cost)
106
{ }
107
108
/// \brief From INameable: return the class name.
109
std::string
name
()
const
110
{
111
return
"CrossValidationError<"
112
+ mep_model->name() +
","
113
+ mep_trainer->
name
() +
","
114
+ mep_cost->
name
() +
">"
;
115
}
116
117
std::size_t
numberOfVariables
()
const
{
118
return
mep_meta->
numberOfParameters
();
119
}
120
121
/// Evaluate the cross-validation error:
122
/// train sub-models, evaluate objective,
123
/// return the average.
124
double
eval
(RealVector
const
& parameters)
const
{
125
this->
m_evaluationCounter
++;
126
mep_meta->
setParameterVector
(parameters);
127
128
double
ret = 0.0;
129
for
(
size_t
setID=0; setID != m_folds.
size
(); ++setID) {
130
DatasetType
train = m_folds.
training
(setID);
131
DatasetType
validation = m_folds.
validation
(setID);
132
mep_trainer->
train
(*mep_model, train);
133
Data<OutputType>
output = (*mep_model)(validation.
inputs
());
134
ret += mep_cost->
eval
(validation.
labels
(), output);
135
}
136
return
ret / m_folds.
size
();
137
}
138
};
139
140
141
}
142
#endif