Shark machine learning library
Installation
Tutorials
Benchmarks
Documentation
Quick references
Class list
Global functions
include
shark
Algorithms
Trainers
OptimizationTrainer.h
Go to the documentation of this file.
1
//===========================================================================
2
/*!
3
*
4
*
5
* \brief Model training by means of a general purpose optimization procedure.
6
*
7
*
8
*
9
* \author T. Glasmachers
10
* \date 2011-2012
11
*
12
*
13
* \par Copyright 1995-2017 Shark Development Team
14
*
15
* <BR><HR>
16
* This file is part of Shark.
17
* <https://shark-ml.github.io/Shark/>
18
*
19
* Shark is free software: you can redistribute it and/or modify
20
* it under the terms of the GNU Lesser General Public License as published
21
* by the Free Software Foundation, either version 3 of the License, or
22
* (at your option) any later version.
23
*
24
* Shark is distributed in the hope that it will be useful,
25
* but WITHOUT ANY WARRANTY; without even the implied warranty of
26
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27
* GNU Lesser General Public License for more details.
28
*
29
* You should have received a copy of the GNU Lesser General Public License
30
* along with Shark. If not, see <http://www.gnu.org/licenses/>.
31
*
32
*/
33
//===========================================================================
34
35
#ifndef SHARK_ALGORITHMS_TRAINERS_OPTIMIZATIONTRAINER_H
36
#define SHARK_ALGORITHMS_TRAINERS_OPTIMIZATIONTRAINER_H
37
38
#include <
shark/Algorithms/AbstractSingleObjectiveOptimizer.h
>
39
#include <
shark/Core/ResultSets.h
>
40
#include <
shark/Models/AbstractModel.h
>
41
#include <
shark/ObjectiveFunctions/ErrorFunction.h
>
42
#include <
shark/Algorithms/Trainers/AbstractTrainer.h
>
43
#include <
shark/Algorithms/StoppingCriteria/AbstractStoppingCriterion.h
>
44
45
namespace
shark
{
46
47
48
///
49
/// \brief Wrapper for training schemes based on (iterative) optimization.
50
///
51
/// \par
52
/// The OptimizationTrainer class is designed to allow for
53
/// model training via iterative minimization of a
54
/// loss function, such as in neural network
55
/// "backpropagation" training.
56
/// \ingroup supervised_trainer
57
template
<
class
Model,
class
LabelTypeT =
typename
Model::OutputType>
58
class
OptimizationTrainer
:
public
AbstractTrainer
<Model,LabelTypeT>
59
{
60
typedef
AbstractTrainer<Model,LabelTypeT>
base_type
;
61
62
public
:
63
typedef
typename
base_type::InputType
InputType
;
64
typedef
typename
Model::OutputType
OutputType
;
65
typedef
typename
base_type::LabelType
LabelType
;
66
typedef
Model
ModelType
;
67
typedef
typename
ModelType::ParameterVectorType
ParameterVectorType
;
68
69
typedef
AbstractSingleObjectiveOptimizer< ParameterVectorType >
OptimizerType
;
70
typedef
AbstractLoss< LabelType, OutputType >
LossType
;
71
typedef
AbstractStoppingCriterion<SingleObjectiveResultSet<ParameterVectorType>
>
StoppingCriterionType
;
72
73
OptimizationTrainer
(
74
LossType
* loss,
75
OptimizerType
* optimizer,
76
StoppingCriterionType
* stoppingCriterion)
77
:
mep_loss
(loss),
mep_optimizer
(optimizer),
mep_stoppingCriterion
(stoppingCriterion)
78
{
79
SHARK_RUNTIME_CHECK
(loss !=
nullptr
,
"Loss function must not be NULL"
);
80
SHARK_RUNTIME_CHECK
(optimizer !=
nullptr
,
"optimizer must not be NULL"
);
81
SHARK_RUNTIME_CHECK
(stoppingCriterion !=
nullptr
,
"Stopping Criterion must not be NULL"
);
82
}
83
84
/// \brief From INameable: return the class name.
85
std::string
name
()
const
86
{
87
return
"OptimizationTrainer<"
88
+
mep_loss
->
name
() +
","
89
+
mep_optimizer
->
name
() +
">"
;
90
}
91
92
void
train
(
ModelType
& model,
LabeledData<InputType, LabelType>
const
& dataset) {
93
ErrorFunction<ParameterVectorType>
error(dataset, &model,
mep_loss
);
94
error.
init
();
95
mep_optimizer
->
init
(error);
96
mep_stoppingCriterion
->
reset
();
97
do
{
98
mep_optimizer
->
step
(error);
99
}
100
while
(!
mep_stoppingCriterion
->
stop
(
mep_optimizer
->
solution
()));
101
model.setParameterVector(
mep_optimizer
->
solution
().
point
);
102
}
103
104
void
read
(
InArchive
& archive )
105
{}
106
107
void
write
(
OutArchive
& archive )
const
108
{}
109
110
protected
:
111
LossType
*
mep_loss
;
112
OptimizerType
*
mep_optimizer
;
113
StoppingCriterionType
*
mep_stoppingCriterion
;
114
};
115
116
117
}
118
#endif