OptimizationTrainer.h
Go to the documentation of this file.
1//===========================================================================
2/*!
3 *
4 *
5 * \brief Model training by means of a general purpose optimization procedure.
6 *
7 *
8 *
9 * \author T. Glasmachers
10 * \date 2011-2012
11 *
12 *
13 * \par Copyright 1995-2017 Shark Development Team
14 *
15 * <BR><HR>
16 * This file is part of Shark.
17 * <https://shark-ml.github.io/Shark/>
18 *
19 * Shark is free software: you can redistribute it and/or modify
20 * it under the terms of the GNU Lesser General Public License as published
21 * by the Free Software Foundation, either version 3 of the License, or
22 * (at your option) any later version.
23 *
24 * Shark is distributed in the hope that it will be useful,
25 * but WITHOUT ANY WARRANTY; without even the implied warranty of
26 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27 * GNU Lesser General Public License for more details.
28 *
29 * You should have received a copy of the GNU Lesser General Public License
30 * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31 *
32 */
33//===========================================================================
34
35#ifndef SHARK_ALGORITHMS_TRAINERS_OPTIMIZATIONTRAINER_H
36#define SHARK_ALGORITHMS_TRAINERS_OPTIMIZATIONTRAINER_H
37
44
45namespace shark {
46
47
48///
49/// \brief Wrapper for training schemes based on (iterative) optimization.
50///
51/// \par
52/// The OptimizationTrainer class is designed to allow for
53/// model training via iterative minimization of a
54/// loss function, such as in neural network
55/// "backpropagation" training.
56/// \ingroup supervised_trainer
57template <class Model, class LabelTypeT = typename Model::OutputType>
58class OptimizationTrainer : public AbstractTrainer<Model,LabelTypeT>
59{
61
62public:
64 typedef typename Model::OutputType OutputType;
66 typedef Model ModelType;
67 typedef typename ModelType::ParameterVectorType ParameterVectorType;
68
72
74 LossType* loss,
75 OptimizerType* optimizer,
76 StoppingCriterionType* stoppingCriterion)
77 : mep_loss(loss), mep_optimizer(optimizer), mep_stoppingCriterion(stoppingCriterion)
78 {
79 SHARK_RUNTIME_CHECK(loss != nullptr, "Loss function must not be NULL");
80 SHARK_RUNTIME_CHECK(optimizer != nullptr, "optimizer must not be NULL");
81 SHARK_RUNTIME_CHECK(stoppingCriterion != nullptr, "Stopping Criterion must not be NULL");
82 }
83
84 /// \brief From INameable: return the class name.
85 std::string name() const
86 {
87 return "OptimizationTrainer<"
88 + mep_loss->name() + ","
89 + mep_optimizer->name() + ">";
90 }
91
92 void train(ModelType& model, LabeledData<InputType, LabelType> const& dataset) {
93 ErrorFunction<ParameterVectorType> error(dataset, &model, mep_loss);
94 error.init();
95 mep_optimizer->init(error);
97 do {
98 mep_optimizer->step(error);
99 }
101 model.setParameterVector(mep_optimizer->solution().point);
102 }
103
104 void read( InArchive & archive )
105 {}
106
107 void write( OutArchive & archive ) const
108 {}
109
110protected:
114};
115
116
117}
118#endif