Shark machine learning library
Installation
Tutorials
Benchmarks
Documentation
Quick references
Class list
Global functions
include
shark
Algorithms
GradientDescent
SteepestDescent.h
Go to the documentation of this file.
1
//===========================================================================
2
/*!
3
*
4
*
5
* \brief SteepestDescent
6
*
7
*
8
*
9
* \author O. Krause
10
* \date 2010
11
*
12
*
13
* \par Copyright 1995-2017 Shark Development Team
14
*
15
* <BR><HR>
16
* This file is part of Shark.
17
* <https://shark-ml.github.io/Shark/>
18
*
19
* Shark is free software: you can redistribute it and/or modify
20
* it under the terms of the GNU Lesser General Public License as published
21
* by the Free Software Foundation, either version 3 of the License, or
22
* (at your option) any later version.
23
*
24
* Shark is distributed in the hope that it will be useful,
25
* but WITHOUT ANY WARRANTY; without even the implied warranty of
26
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27
* GNU Lesser General Public License for more details.
28
*
29
* You should have received a copy of the GNU Lesser General Public License
30
* along with Shark. If not, see <http://www.gnu.org/licenses/>.
31
*
32
*/
33
//===========================================================================
34
#ifndef SHARK_ML_OPTIMIZER_STEEPESTDESCENT_H
35
#define SHARK_ML_OPTIMIZER_STEEPESTDESCENT_H
36
37
#include <
shark/Algorithms/AbstractSingleObjectiveOptimizer.h
>
38
39
namespace
shark
{
40
41
///@brief Standard steepest descent.
42
/// \ingroup gradientopt
43
template
<
class
SearchPo
int
Type = RealVector>
44
class
SteepestDescent
:
public
AbstractSingleObjectiveOptimizer
<SearchPointType>
45
{
46
public
:
47
typedef
AbstractObjectiveFunction<SearchPointType,double>
ObjectiveFunctionType
;
48
SteepestDescent
() {
49
this->
m_features
|= this->
REQUIRES_FIRST_DERIVATIVE
;
50
51
m_learningRate = 0.1;
52
m_momentum = 0.0;
53
}
54
55
/// \brief From INameable: return the class name.
56
std::string
name
()
const
57
{
return
"SteepestDescent"
; }
58
59
void
init
(
ObjectiveFunctionType
const
& objectiveFunction,
SearchPointType
const
& startingPoint) {
60
this->
checkFeatures
(objectiveFunction);
61
SHARK_RUNTIME_CHECK
(startingPoint.size() == objectiveFunction.
numberOfVariables
(),
"Initial starting point and dimensionality of function do not agree"
);
62
63
m_path.resize(startingPoint.size());
64
m_path.clear();
65
this->
m_best
.
point
= startingPoint;
66
this->
m_best
.
value
= objectiveFunction.
evalDerivative
(this->
m_best
.
point
,m_derivative);
67
}
68
using
AbstractSingleObjectiveOptimizer
<
SearchPointType
>
::init
;
69
70
/*!
71
* \brief get learning rate
72
*/
73
double
learningRate
()
const
{
74
return
m_learningRate;
75
}
76
77
/*!
78
* \brief set learning rate
79
*/
80
void
setLearningRate
(
double
learningRate
) {
81
m_learningRate =
learningRate
;
82
}
83
84
/*!
85
* \brief get momentum parameter
86
*/
87
double
momentum
()
const
{
88
return
m_momentum;
89
}
90
91
/*!
92
* \brief set momentum parameter
93
*/
94
void
setMomentum
(
double
momentum
) {
95
m_momentum =
momentum
;
96
}
97
/*!
98
* \brief updates searchdirection and then does simple gradient descent
99
*/
100
void
step
(
ObjectiveFunctionType
const
& objectiveFunction) {
101
m_path = -m_learningRate * m_derivative + m_momentum * m_path;
102
this->
m_best
.
point
+=m_path;
103
this->
m_best
.
value
= objectiveFunction.
evalDerivative
(this->
m_best
.
point
,m_derivative);
104
}
105
virtual
void
read
(
InArchive
& archive )
106
{
107
archive>>m_path;
108
archive>>m_learningRate;
109
archive>>m_momentum;
110
}
111
112
virtual
void
write
(
OutArchive
& archive )
const
113
{
114
archive<<m_path;
115
archive<<m_learningRate;
116
archive<<m_momentum;
117
}
118
119
private
:
120
SearchPointType
m_path;
121
SearchPointType
m_derivative;
122
double
m_learningRate;
123
double
m_momentum;
124
};
125
126
}
127
#endif
128