RBM.h
Go to the documentation of this file.
1/*!
2 *
3 *
4 * \brief -
5 *
6 * \author -
7 * \date -
8 *
9 *
10 * \par Copyright 1995-2017 Shark Development Team
11 *
12 * <BR><HR>
13 * This file is part of Shark.
14 * <https://shark-ml.github.io/Shark/>
15 *
16 * Shark is free software: you can redistribute it and/or modify
17 * it under the terms of the GNU Lesser General Public License as published
18 * by the Free Software Foundation, either version 3 of the License, or
19 * (at your option) any later version.
20 *
21 * Shark is distributed in the hope that it will be useful,
22 * but WITHOUT ANY WARRANTY; without even the implied warranty of
23 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
24 * GNU Lesser General Public License for more details.
25 *
26 * You should have received a copy of the GNU Lesser General Public License
27 * along with Shark. If not, see <http://www.gnu.org/licenses/>.
28 *
29 */
30#ifndef SHARK_UNSUPERVISED_RBM_RBM_H
31#define SHARK_UNSUPERVISED_RBM_RBM_H
32
35#include <shark/Unsupervised/RBM/Impl/AverageEnergyGradient.h>
36
37#include <sstream>
38#include <boost/serialization/string.hpp>
39namespace shark{
40
41///\brief stub for the RBM class. at the moment it is just a holder of the parameter set and the Energy.
42template<class VisibleLayerT,class HiddenLayerT, class randomT>
43class RBM : public AbstractModel<RealVector, RealVector>{
44private:
46public:
47 typedef HiddenLayerT HiddenType; ///< type of the hidden layer
48 typedef VisibleLayerT VisibleType; ///< type of the visible layer
49 typedef randomT randomType;
50 typedef Energy<RBM<VisibleType,HiddenType,randomT> > EnergyType;///< Type of the energy function
51 typedef detail::AverageEnergyGradient<RBM> GradientType;///< Type of the gradient calculator
52
55
56private:
57 /// \brief The weight matrix connecting hidden and visible layer.
58 RealMatrix m_weightMatrix;
59
60 ///The layer of hidden Neurons
61 HiddenType m_hiddenNeurons;
62
63 ///The Layer of visible Neurons
64 VisibleType m_visibleNeurons;
65
66 randomType* mpe_rng;
67 bool m_forward;
68 bool m_evalMean;
69
70 ///\brief Evaluates the input by propagating the visible input to the hidden neurons.
71 ///
72 ///@param patterns batch of states of visible units
73 ///@param outputs batch of (expected) states of hidden units
74 void evalForward(BatchInputType const& state,BatchOutputType& output)const{
75 std::size_t batchSize=state.size1();
76 typename HiddenType::StatisticsBatch statisticsBatch(batchSize,numberOfHN());
77 RealMatrix inputBatch(batchSize,numberOfHN());
78 output.resize(state.size1(),numberOfHN());
79
80 energy().inputHidden(inputBatch,state);
81 hiddenNeurons().sufficientStatistics(inputBatch,statisticsBatch,blas::repeat(1.0,batchSize));
82
83 if(m_evalMean){
84 noalias(output) = hiddenNeurons().mean(statisticsBatch);
85 }
86 else{
87 hiddenNeurons().sample(statisticsBatch,output,0.0,*mpe_rng);
88 }
89 }
90
91 ///\brief Evaluates the input by propagating the hidden input to the visible neurons.
92 ///
93 ///@param patterns batch of states of hidden units
94 ///@param outputs batch of (expected) states of visible units
95 void evalBackward(BatchInputType const& state,BatchOutputType& output)const{
96 std::size_t batchSize = state.size1();
97 typename VisibleType::StatisticsBatch statisticsBatch(batchSize,numberOfVN());
98 RealMatrix inputBatch(batchSize,numberOfVN());
99 output.resize(batchSize,numberOfVN());
100
101 energy().inputVisible(inputBatch,state);
102 visibleNeurons().sufficientStatistics(inputBatch,statisticsBatch,blas::repeat(1.0,batchSize));
103
104 if(m_evalMean){
105 noalias(output) = visibleNeurons().mean(statisticsBatch);
106 }
107 else{
108 visibleNeurons().sample(statisticsBatch,output,0.0,*mpe_rng);
109 }
110 }
111public:
112 RBM(randomType& rng):mpe_rng(&rng),m_forward(true),m_evalMean(true)
113 { }
114
115 /// \brief From INameable: return the class name.
116 std::string name() const
117 { return "RBM"; }
118
119 ///\brief Returns the total number of parameters of the model.
120 std::size_t numberOfParameters()const {
121 std::size_t parameters = numberOfVN()*numberOfHN();
122 parameters += m_hiddenNeurons.numberOfParameters();
123 parameters += m_visibleNeurons.numberOfParameters();
124 return parameters;
125 }
126
127 ///\brief Returns the parameters of the Model as parameter vector.
128 RealVector parameterVector () const {
129 return to_vector(m_weightMatrix)
130 | m_hiddenNeurons.parameterVector()
131 | m_visibleNeurons.parameterVector();
132 };
133
134 ///\brief Sets the parameters of the model.
135 ///
136 /// @param newParameters vector of parameters
137 void setParameterVector(const RealVector& newParameters) {
138 std::size_t endW = numberOfVN()*numberOfHN();
139 std::size_t endH = endW + m_hiddenNeurons.numberOfParameters();
140 std::size_t endV = endH + m_visibleNeurons.numberOfParameters();
141 noalias(to_vector(m_weightMatrix)) = subrange(newParameters,0,endW);
142 m_hiddenNeurons.setParameterVector(subrange(newParameters,endW,endH));
143 m_visibleNeurons.setParameterVector(subrange(newParameters,endH,endV));
144 }
145
146 ///\brief Creates the structure of the RBM.
147 ///
148 ///@param hiddenNeurons number of hidden neurons.
149 ///@param visibleNeurons number of visible neurons.
150 void setStructure(std::size_t visibleNeurons,std::size_t hiddenNeurons){
151 m_weightMatrix.resize(hiddenNeurons,visibleNeurons);
152 m_weightMatrix.clear();
153
154 m_hiddenNeurons.resize(hiddenNeurons);
155 m_visibleNeurons.resize(visibleNeurons);
156 }
157
158 ///\brief Returns the layer of hidden neurons.
160 return m_hiddenNeurons;
161 }
162 ///\brief Returns the layer of hidden neurons.
164 return m_hiddenNeurons;
165 }
166 ///\brief Returns the layer of visible neurons.
168 return m_visibleNeurons;
169 }
170 ///\brief Returns the layer of visible neurons.
172 return m_visibleNeurons;
173 }
174
175 ///\brief Returns the weight matrix connecting the layers.
176 RealMatrix& weightMatrix(){
177 return m_weightMatrix;
178 }
179 ///\brief Returns the weight matrix connecting the layers.
180 RealMatrix const& weightMatrix()const{
181 return m_weightMatrix;
182 }
183
184 ///\brief Returns the energy function of the RBM.
186 return EnergyType(*this);
187 }
188
189 ///\brief Returns the random number generator associated with this RBM.
191 return *mpe_rng;
192 }
193
194 ///\brief Sets the type of evaluation, eval will perform.
195 ///
196 ///Eval performs its operation based on the state of this function.
197 ///There are two ways to pass data through an rbm: either forward, setting the states of the
198 ///visible neurons and sample the hidden states or backwards, where the state of the hidden is fixed and the visible
199 ///are sampled.
200 ///Instead of the state of the hidden/visible, one often wants the mean of the state \f$ E_{p(h|v)}\left(h\right)\f$.
201 ///By default, the RBM uses the forward evaluation and returns the mean of the state
202 ///
203 ///@param forward whether the forward view should be used false=backwards
204 ///@param evalMean whether the mean state should be returned. false=a sample is returned
205 void evaluationType(bool forward,bool evalMean){
206 m_forward = forward;
207 m_evalMean = evalMean;
208 }
209
211 if(m_forward){
212 return numberOfHN();
213 }else{
214 return numberOfVN();
215 }
216 }
217
219 if(m_forward){
220 return numberOfVN();
221 }else{
222 return numberOfHN();
223 }
224 }
225
226 boost::shared_ptr<State> createState()const{
227 return boost::shared_ptr<State>(new EmptyState());
228 }
229
230 ///\brief Passes information through/samples from an RBM in a forward or backward way.
231 ///
232 ///Eval performs its operation based on the given evaluation type.
233 ///There are two ways to pass data through an RBM: either forward, setting the states of the
234 ///visible neurons and sample the hidden states or backwards, where the state of the hidden is fixed and the visible
235 ///are sampled.
236 ///Instead of the state of the hidden/visible, one often wants the mean of the state \f$ E_{p(h|v)}\left(h\right)\f$.
237 ///By default, the RBM uses the forward evaluation and returns the mean of the state,
238 ///but other evaluation modes can be set by evaluationType().
239 ///
240 ///@param patterns the batch of (visible or hidden) inputs
241 ///@param outputs the batch of (visible or hidden) outputs
242 void eval(BatchInputType const& patterns,BatchOutputType& outputs)const{
243 if(m_forward){
244 evalForward(patterns,outputs);
245 }
246 else{
247 evalBackward(patterns,outputs);
248 }
249 }
250
251
252 void eval(BatchInputType const& patterns, BatchOutputType& outputs, State& state)const{
253 eval(patterns,outputs);
254 }
255
256 ///\brief Calculates the input of the hidden neurons given the state of the visible in a batch-vise fassion.
257 ///
258 ///@param inputs the batch of vectors the input of the hidden neurons is stored in
259 ///@param visibleStates the batch of states of the visible neurons
260 void inputHidden(RealMatrix& inputs, RealMatrix const& visibleStates)const{
261 SIZE_CHECK(visibleStates.size1() == inputs.size1());
262 SIZE_CHECK(inputs.size2() == m_hiddenNeurons.size());
263 SIZE_CHECK( visibleStates.size2() == m_visibleNeurons.size());
264
265 noalias(inputs) = prod(m_visibleNeurons.phi(visibleStates),trans(m_weightMatrix));
266 }
267
268
269 ///\brief Calculates the input of the visible neurons given the state of the hidden.
270 ///
271 ///@param inputs the vector the input of the visible neurons is stored in
272 ///@param hiddenStates the state of the hidden neurons
273 void inputVisible(RealMatrix& inputs, RealMatrix const& hiddenStates)const{
274 SIZE_CHECK(hiddenStates.size1() == inputs.size1());
275 SIZE_CHECK(inputs.size2() == m_visibleNeurons.size());
276
277 noalias(inputs) = prod(m_hiddenNeurons.phi(hiddenStates),m_weightMatrix);
278 }
279
280 using base_type::eval;
281
282
283 ///\brief Returns the number of hidden Neurons.
284 std::size_t numberOfHN()const{
285 return m_hiddenNeurons.size();
286 }
287 ///\brief Returns the number of visible Neurons.
288 std::size_t numberOfVN()const{
289 return m_visibleNeurons.size();
290 }
291
292 /// \brief Reads the network from an archive.
293 void read(InArchive& archive){
294 archive >> m_weightMatrix;
295 archive >> m_hiddenNeurons;
296 archive >> m_visibleNeurons;
297
298 //serialization of the rng is a bit...complex
299 //let's hope that we can remove this hack one time. But we really can't ignore the state of the rng.
300 std::string str;
301 archive>> str;
302 std::stringstream stream(str);
303 stream>> *mpe_rng;
304 }
305
306 /// \brief Writes the network to an archive.
307 void write(OutArchive& archive) const{
308 archive << m_weightMatrix;
309 archive << m_hiddenNeurons;
310 archive << m_visibleNeurons;
311
312 std::stringstream stream;
313 stream <<*mpe_rng;
314 std::string str = stream.str();
315 archive <<str;
316 }
317
318};
319
320}
321
322#endif