Adam.h
Go to the documentation of this file.
1//===========================================================================
2/*!
3 *
4 *
5 * \brief Adam
6 *
7 *
8 *
9 * \author O. Krause
10 * \date 2017
11 *
12 *
13 * \par Copyright 1995-2017 Shark Development Team
14 *
15 * <BR><HR>
16 * This file is part of Shark.
17 * <https://shark-ml.github.io/Shark/>
18 *
19 * Shark is free software: you can redistribute it and/or modify
20 * it under the terms of the GNU Lesser General Public License as published
21 * by the Free Software Foundation, either version 3 of the License, or
22 * (at your option) any later version.
23 *
24 * Shark is distributed in the hope that it will be useful,
25 * but WITHOUT ANY WARRANTY; without even the implied warranty of
26 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27 * GNU Lesser General Public License for more details.
28 *
29 * You should have received a copy of the GNU Lesser General Public License
30 * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31 *
32 */
33//===========================================================================
34#ifndef SHARK_ML_OPTIMIZER_ADAM_H
35#define SHARK_ML_OPTIMIZER_ADAM_H
36
38
39namespace shark{
40
41/// \brief Adaptive Moment Estimation Algorithm (ADAM)
42///
43/// Performs SGD by using a long term average of the gradient as well as its second moment to adapt
44/// a step size for each coordinate.
45/// \ingroup gradientopt
46template<class SearchPointType = RealVector>
47class Adam : public AbstractSingleObjectiveOptimizer<SearchPointType >
48{
49public:
51 Adam() {
53
54 m_beta1 = 0.9;
55 m_beta2 = 0.999;
56 m_epsilon = 1.e-8;
57 m_eta = 0.001;
58 }
59
60 /// \brief From INameable: return the class name.
61 std::string name() const
62 { return "Adam"; }
63
64 void init(ObjectiveFunctionType const& objectiveFunction, SearchPointType const& startingPoint) {
65 this-> checkFeatures(objectiveFunction);
66 SHARK_RUNTIME_CHECK(startingPoint.size() == objectiveFunction.numberOfVariables(), "Initial starting point and dimensionality of function do not agree");
67
68 //initialize long term averages
69 m_avgGrad = SearchPointType(startingPoint.size(),0.0);
70 m_secondMoment = SearchPointType(startingPoint.size(),0.0);
71 m_counter = 0;
72
73 //set point to the current starting point
74 this->m_best.point = startingPoint;
75 this->m_best.value = objectiveFunction.evalDerivative(this->m_best.point,m_derivative);
76 }
78
79 /// \brief get learning rate eta
80 double eta() const {
81 return m_eta;
82 }
83
84 /// \brief set learning rate eta
85 void setEta(double eta) {
86 SHARK_RUNTIME_CHECK(eta > 0, "eta must be positive.");
87 m_eta = eta;
88 }
89
90 /// \brief get gradient averaging parameter beta1
91 double beta1() const {
92 return m_beta1;
93 }
94
95 /// \brief set gradient averaging parameter beta1
96 void setBeta1(double beta1) {
97 SHARK_RUNTIME_CHECK(beta1 > 0, "beta1 must be positive.");
98 m_beta1 = beta1;
99 }
100
101 /// \brief get gradient averaging parameter beta2
102 double beta2() const {
103 return m_beta2;
104 }
105
106 /// \brief set gradient averaging parameter beta2
107 void setBeta2(double beta2) {
108 SHARK_RUNTIME_CHECK(beta2 > 0, "beta2 must be positive.");
109 m_beta2 = beta2;
110 }
111
112 /// \brief get minimum noise estimate epsilon
113 double epsilon() const {
114 return m_epsilon;
115 }
116
117 /// \brief set minimum noise estimate epsilon
118 void setEpsilon(double epsilon) {
119 SHARK_RUNTIME_CHECK(epsilon > 0, "epsilon must be positive.");
120 m_epsilon = epsilon;
121 }
122 /// \brief Performs a step of the optimization.
123 ///
124 /// First the current guess for gradient and its second moment are updated using
125 /// \f[ g_t = \beta_1 g_{t-1} + (1-\beta1) \frac{\partial}{\partial x} f(x_{t-1})\f]
126 /// \f[ v_t = \beta_2 v_{t-1} + (1-\beta2) (\frac{\partial}{\partial x} f(x_{t-1}))^2\f]
127 ///
128 /// The step is then performed as
129 /// \f[ x_{t} = x_{t-1} - \eta * g_t *(sqrt(v_t) + \epsilon)^{-1} \f]
130 /// where a slight step correction is used to remove the bias in the first few iterations where the means are close to 0.
131 void step(ObjectiveFunctionType const& objectiveFunction) {
132 //update long term averages of the gradient and its variance
133 noalias(m_avgGrad) = m_beta1 * m_avgGrad + (1-m_beta1) * m_derivative;
134 noalias(m_secondMoment) = m_beta2 * m_secondMoment + (1-m_beta2)* sqr(m_derivative);
135 //for the first few iterations, we need bias correction
136 ++m_counter;
137 double bias1 = 1-std::pow(m_beta1,m_counter);
138 double bias2 = 1-std::pow(m_beta2,m_counter);
139
140 noalias(this->m_best.point) -= (m_eta/bias1) * m_avgGrad/(m_epsilon + sqrt(m_secondMoment/bias2));
141 this->m_best.value = objectiveFunction.evalDerivative(this->m_best.point,m_derivative);
142 }
143 virtual void read( InArchive & archive ){
144 archive>>m_avgGrad;
145 archive>>m_secondMoment;
146 archive>>m_counter;
147 archive>>m_derivative;
148 archive>>this->m_best;
149
150 archive>>m_beta1;
151 archive>>m_beta2;
152 archive>>m_epsilon;
153 archive>>m_eta;
154 }
155
156 virtual void write( OutArchive & archive ) const
157 {
158 archive<<m_avgGrad;
159 archive<<m_secondMoment;
160 archive<<m_counter;
161 archive<<m_derivative;
162 archive<<this->m_best;
163
164 archive<<m_beta1;
165 archive<<m_beta2;
166 archive<<m_epsilon;
167 archive<<m_eta;
168 }
169
170private:
171 SearchPointType m_avgGrad;
172 SearchPointType m_secondMoment;
173 unsigned int m_counter;
174 SearchPointType m_derivative;
175
176 double m_beta1;
177 double m_beta2;
178 double m_epsilon;
179 double m_eta;
180};
181
182}
183#endif
184