CrossValidationError.h
Go to the documentation of this file.
1//===========================================================================
2/*!
3 *
4 *
5 * \brief cross-validation error for selection of hyper-parameters
6
7
8 *
9 *
10 * \author T. Glasmachers, O. Krause
11 * \date 2007-2012
12 *
13 *
14 * \par Copyright 1995-2017 Shark Development Team
15 *
16 * <BR><HR>
17 * This file is part of Shark.
18 * <https://shark-ml.github.io/Shark/>
19 *
20 * Shark is free software: you can redistribute it and/or modify
21 * it under the terms of the GNU Lesser General Public License as published
22 * by the Free Software Foundation, either version 3 of the License, or
23 * (at your option) any later version.
24 *
25 * Shark is distributed in the hope that it will be useful,
26 * but WITHOUT ANY WARRANTY; without even the implied warranty of
27 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
28 * GNU Lesser General Public License for more details.
29 *
30 * You should have received a copy of the GNU Lesser General Public License
31 * along with Shark. If not, see <http://www.gnu.org/licenses/>.
32 *
33 */
34//===========================================================================
35
36#ifndef SHARK_OBJECTIVEFUNCTIONS_CROSSVALIDATIONERROR_H
37#define SHARK_OBJECTIVEFUNCTIONS_CROSSVALIDATIONERROR_H
38
44
45namespace shark {
46
47
48///
49/// \brief Cross-validation error for selection of hyper-parameters.
50///
51/// \par
52/// The cross-validation error is useful for evaluating
53/// how well a model performs on a problem. It is regularly
54/// used for model selection.
55///
56/// \par
57/// In Shark, the cross-validation procedure is abstracted
58/// as follows:
59/// First, the given point is written into an IParameterizable
60/// object (such as a regularizer or a trainer). Then a model
61/// is trained with a trainer with the given settings on a
62/// number of folds and evaluated on the corresponding validation
63/// sets with a cost function. The average cost function value
64/// over all folds is returned.
65///
66/// \par
67/// Thus, the cross-validation procedure requires a "meta"
68/// IParameterizable object, a model, a trainer, a data set,
69/// and a cost function.
70/// \ingroup objfunctions
71template<class ModelTypeT, class LabelTypeT = typename ModelTypeT::OutputType>
72class CrossValidationError : public AbstractObjectiveFunction< RealVector, double >
73{
74public:
75 typedef typename ModelTypeT::InputType InputType;
76 typedef typename ModelTypeT::OutputType OutputType;
77 typedef LabelTypeT LabelType;
80 typedef ModelTypeT ModelType;
83private:
85
86
87 FoldsType m_folds;
88 IParameterizable<>* mep_meta;
89 ModelType* mep_model;
90 TrainerType* mep_trainer;
91 CostType* mep_cost;
92
93public:
94
96 FoldsType const& dataFolds,
98 ModelType* model,
99 TrainerType* trainer,
100 CostType* cost)
101 : m_folds(dataFolds)
102 , mep_meta(meta)
103 , mep_model(model)
104 , mep_trainer(trainer)
105 , mep_cost(cost)
106 { }
107
108 /// \brief From INameable: return the class name.
109 std::string name() const
110 {
111 return "CrossValidationError<"
112 + mep_model->name() + ","
113 + mep_trainer->name() + ","
114 + mep_cost->name() + ">";
115 }
116
117 std::size_t numberOfVariables()const{
118 return mep_meta->numberOfParameters();
119 }
120
121 /// Evaluate the cross-validation error:
122 /// train sub-models, evaluate objective,
123 /// return the average.
124 double eval(RealVector const& parameters) const {
125 this->m_evaluationCounter++;
126 mep_meta->setParameterVector(parameters);
127
128 double ret = 0.0;
129 for (size_t setID=0; setID != m_folds.size(); ++setID) {
130 DatasetType train = m_folds.training(setID);
131 DatasetType validation = m_folds.validation(setID);
132 mep_trainer->train(*mep_model, train);
133 Data<OutputType> output = (*mep_model)(validation.inputs());
134 ret += mep_cost->eval(validation.labels(), output);
135 }
136 return ret / m_folds.size();
137 }
138};
139
140
141}
142#endif