ContrastiveDivergence.h
Go to the documentation of this file.
1/*!
2 *
3 *
4 * \brief -
5 *
6 * \author -
7 * \date -
8 *
9 *
10 * \par Copyright 1995-2017 Shark Development Team
11 *
12 * <BR><HR>
13 * This file is part of Shark.
14 * <https://shark-ml.github.io/Shark/>
15 *
16 * Shark is free software: you can redistribute it and/or modify
17 * it under the terms of the GNU Lesser General Public License as published
18 * by the Free Software Foundation, either version 3 of the License, or
19 * (at your option) any later version.
20 *
21 * Shark is distributed in the hope that it will be useful,
22 * but WITHOUT ANY WARRANTY; without even the implied warranty of
23 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
24 * GNU Lesser General Public License for more details.
25 *
26 * You should have received a copy of the GNU Lesser General Public License
27 * along with Shark. If not, see <http://www.gnu.org/licenses/>.
28 *
29 */
30#ifndef SHARK_UNSUPERVISED_RBM_GRADIENTAPPROXIMATIONS_CONTRASTIVEDIVERGENCE_H
31#define SHARK_UNSUPERVISED_RBM_GRADIENTAPPROXIMATIONS_CONTRASTIVEDIVERGENCE_H
32
35
36namespace shark{
37
38/// \brief Implements k-step Contrastive Divergence described by Hinton et al. (2006).
39///
40/// k-step Contrastive Divergence approximates the gradient by initializing a Gibbs
41/// chain with a training example and run it for k steps.
42/// The sample gained after k steps than samples is than used to approximate the mean of the RBM distribution in the gradient.
43template<class Operator>
45public:
46 typedef typename Operator::RBM RBM;
47
48 /// \brief The constructor
49 ///
50 ///@param rbm pointer to the RBM which shell be trained
52 : mpe_rbm(rbm),m_operator(rbm)
53 , m_k(1), m_numBatches(0),m_regularizer(0){
54 SHARK_ASSERT(rbm != NULL);
55
59 };
60
61 /// \brief From INameable: return the class name.
62 std::string name() const
63 { return "ContrastiveDivergence"; }
64
65 /// \brief Sets the training batch.
66 ///
67 /// @param data the batch of training data
69 m_data = data;
70 }
71
72 /// \brief Sets the value of k- the number of steps of the Gibbs Chain
73 ///
74 /// @param k the number of steps
75 void setK(unsigned int k){
76 m_k = k;
77 }
78
80 return mpe_rbm->parameterVector();
81 }
82
83 /// \brief Returns the number of variables of the RBM.
84 ///
85 /// @return the number of variables of the RBM
86 std::size_t numberOfVariables()const{
87 return mpe_rbm->numberOfParameters();
88 }
89
90 /// \brief Returns the number of batches of the dataset that are used in every iteration.
91 ///
92 /// If it is less than all batches, the batches are chosen at random. if it is 0, all batches are used
93 std::size_t numBatches()const{
94 return m_numBatches;
95 }
96
97 /// \brief Returns a reference to the number of batches of the dataset that are used in every iteration.
98 ///
99 /// If it is less than all batches, the batches are chosen at random.if it is 0, all batches are used.
100 std::size_t& numBatches(){
101 return m_numBatches;
102 }
103
104 void setRegularizer(double factor, SingleObjectiveFunction* regularizer){
105 m_regularizer = regularizer;
106 m_regularizationStrength = factor;
107 }
108
109 /// \brief Gives the CD-k approximation of the log-likelihood gradient.
110 ///
111 /// @param parameter the actual parameters of the RBM
112 /// @param derivative holds later the CD-k approximation of the log-likelihood gradient
113 double evalDerivative( SearchPointType const & parameter, FirstOrderDerivative & derivative ) const{
114 mpe_rbm->setParameterVector(parameter);
115 derivative.resize(mpe_rbm->numberOfParameters());
116 derivative.clear();
117
118 std::size_t batchesForTraining = m_numBatches > 0? m_numBatches: m_data.numberOfBatches();
119 std::size_t elements = 0;
120 //get the batches for this iteration
121 std::vector<std::size_t> batchIds(m_data.numberOfBatches());
122 {
123 for(std::size_t i = 0; i != m_data.numberOfBatches(); ++i){
124 batchIds[i] = i;
125 }
126 std::shuffle(batchIds.begin(),batchIds.end(),mpe_rbm->rng());
127 for(std::size_t i = 0; i != batchesForTraining; ++i){
128 elements += m_data.batch(batchIds[i]).size1();
129 }
130 }
131
132 std::size_t threads = std::min<std::size_t>(batchesForTraining,SHARK_NUM_THREADS);
133 std::size_t numBatches = batchesForTraining/threads;
134
135
136 SHARK_PARALLEL_FOR(int t = 0; t < (int)threads; ++t){
137 typename RBM::GradientType empiricalAverage(mpe_rbm);
138 typename RBM::GradientType modelAverage(mpe_rbm);
139
140 std::size_t threadElements = 0;
141
142 std::size_t batchStart = t*numBatches;
143 std::size_t batchEnd = (t== (int)threads-1)? batchesForTraining : batchStart+numBatches;
144 for(std::size_t i = batchStart; i != batchEnd; ++i){
145 RealMatrix const& batch = m_data.batch(batchIds[i]);
146 threadElements += batch.size1();
147
148 //create the batches for evaluation
149 typename Operator::HiddenSampleBatch hiddenBatch(batch.size1(),mpe_rbm->numberOfHN());
150 typename Operator::VisibleSampleBatch visibleBatch(batch.size1(),mpe_rbm->numberOfVN());
151
152 visibleBatch.state = batch;
153 m_operator.precomputeHidden(hiddenBatch,visibleBatch,blas::repeat(1.0,batch.size1()));
154 m_operator.sampleHidden(hiddenBatch);
155 empiricalAverage.addVH(hiddenBatch,visibleBatch);
156
157 for(std::size_t step = 0; step != m_k; ++step){
158 m_operator.precomputeVisible(hiddenBatch, visibleBatch,blas::repeat(1.0,batch.size1()));
159 m_operator.sampleVisible(visibleBatch);
160 m_operator.precomputeHidden(hiddenBatch, visibleBatch,blas::repeat(1.0,batch.size1()));
161 if( step != m_k-1){
162 m_operator.sampleHidden(hiddenBatch);
163 }
164 }
165 modelAverage.addVH(hiddenBatch,visibleBatch);
166 }
168 double weight = threadElements/double(elements);
169 noalias(derivative) += weight*(modelAverage.result() - empiricalAverage.result());
170 }
171
172 }
173
174 if(m_regularizer){
175 FirstOrderDerivative regularizerDerivative;
176 m_regularizer->evalDerivative(parameter,regularizerDerivative);
177 noalias(derivative) += m_regularizationStrength*regularizerDerivative;
178 }
179
180 return std::numeric_limits<double>::quiet_NaN();
181 }
182
183private:
185 RBM* mpe_rbm;
186 Operator m_operator;
187 unsigned int m_k;
188 std::size_t m_numBatches;///< number of batches used in every iteration. 0 means all.
189
190 SingleObjectiveFunction* m_regularizer;
191 double m_regularizationStrength;
192};
193
194}
195
196#endif