SteepestDescent.h
Go to the documentation of this file.
1//===========================================================================
2/*!
3 *
4 *
5 * \brief SteepestDescent
6 *
7 *
8 *
9 * \author O. Krause
10 * \date 2010
11 *
12 *
13 * \par Copyright 1995-2017 Shark Development Team
14 *
15 * <BR><HR>
16 * This file is part of Shark.
17 * <https://shark-ml.github.io/Shark/>
18 *
19 * Shark is free software: you can redistribute it and/or modify
20 * it under the terms of the GNU Lesser General Public License as published
21 * by the Free Software Foundation, either version 3 of the License, or
22 * (at your option) any later version.
23 *
24 * Shark is distributed in the hope that it will be useful,
25 * but WITHOUT ANY WARRANTY; without even the implied warranty of
26 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27 * GNU Lesser General Public License for more details.
28 *
29 * You should have received a copy of the GNU Lesser General Public License
30 * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31 *
32 */
33//===========================================================================
34#ifndef SHARK_ML_OPTIMIZER_STEEPESTDESCENT_H
35#define SHARK_ML_OPTIMIZER_STEEPESTDESCENT_H
36
38
39namespace shark{
40
41///@brief Standard steepest descent.
42/// \ingroup gradientopt
43template<class SearchPointType = RealVector>
45{
46public:
50
51 m_learningRate = 0.1;
52 m_momentum = 0.0;
53 }
54
55 /// \brief From INameable: return the class name.
56 std::string name() const
57 { return "SteepestDescent"; }
58
59 void init(ObjectiveFunctionType const& objectiveFunction, SearchPointType const& startingPoint) {
60 this->checkFeatures(objectiveFunction);
61 SHARK_RUNTIME_CHECK(startingPoint.size() == objectiveFunction.numberOfVariables(), "Initial starting point and dimensionality of function do not agree");
62
63 m_path.resize(startingPoint.size());
64 m_path.clear();
65 this->m_best.point = startingPoint;
66 this->m_best.value = objectiveFunction.evalDerivative(this->m_best.point,m_derivative);
67 }
69
70 /*!
71 * \brief get learning rate
72 */
73 double learningRate() const {
74 return m_learningRate;
75 }
76
77 /*!
78 * \brief set learning rate
79 */
81 m_learningRate = learningRate;
82 }
83
84 /*!
85 * \brief get momentum parameter
86 */
87 double momentum() const {
88 return m_momentum;
89 }
90
91 /*!
92 * \brief set momentum parameter
93 */
94 void setMomentum(double momentum) {
95 m_momentum = momentum;
96 }
97 /*!
98 * \brief updates searchdirection and then does simple gradient descent
99 */
100 void step(ObjectiveFunctionType const& objectiveFunction) {
101 m_path = -m_learningRate * m_derivative + m_momentum * m_path;
102 this->m_best.point+=m_path;
103 this->m_best.value = objectiveFunction.evalDerivative(this->m_best.point,m_derivative);
104 }
105 virtual void read( InArchive & archive )
106 {
107 archive>>m_path;
108 archive>>m_learningRate;
109 archive>>m_momentum;
110 }
111
112 virtual void write( OutArchive & archive ) const
113 {
114 archive<<m_path;
115 archive<<m_learningRate;
116 archive<<m_momentum;
117 }
118
119private:
120 SearchPointType m_path;
121 SearchPointType m_derivative;
122 double m_learningRate;
123 double m_momentum;
124};
125
126}
127#endif
128