EnergyStoringTemperedMarkovChain.h
Go to the documentation of this file.
1/*!
2 *
3 *
4 * \brief -
5 *
6 * \author -
7 * \date -
8 *
9 *
10 * \par Copyright 1995-2017 Shark Development Team
11 *
12 * <BR><HR>
13 * This file is part of Shark.
14 * <https://shark-ml.github.io/Shark/>
15 *
16 * Shark is free software: you can redistribute it and/or modify
17 * it under the terms of the GNU Lesser General Public License as published
18 * by the Free Software Foundation, either version 3 of the License, or
19 * (at your option) any later version.
20 *
21 * Shark is distributed in the hope that it will be useful,
22 * but WITHOUT ANY WARRANTY; without even the implied warranty of
23 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
24 * GNU Lesser General Public License for more details.
25 *
26 * You should have received a copy of the GNU Lesser General Public License
27 * along with Shark. If not, see <http://www.gnu.org/licenses/>.
28 *
29 */
30#ifndef SHARK_UNSUPERVISED_RBM_SAMPLING_ESTEMPEREDMARKOVCHAIN_H
31#define SHARK_UNSUPERVISED_RBM_SAMPLING_ESTEMPEREDMARKOVCHAIN_H
32
33
35namespace shark{
36
37
38///\brief Implements parallel tempering but also stores additional statistics on the energy differences
39template<class Operator>
41private:
42 typedef typename Operator::HiddenSample HiddenSample;
43 typedef typename Operator::VisibleSample VisibleSample;
44public:
45
46 ///\brief The MarkovChain can't be used to compute several samples at once.
47 ///
48 /// The tempered markov chain ues it's batch capabilities allready to compute the samples for all temperatures
49 /// At the same time. Also it is much more powerfull when all samples are drawn one after another for a higher mixing rate.
50 static const bool computesBatch = false;
51
52 ///\brief The type of the RBM the operator is working with.
53 typedef typename Operator::RBM RBM;
54
55 ///\brief A batch of samples containing hidden and visible samples as well as the energies.
57
58 ///\brief Mutable reference to an element of the batch.
59 typedef typename SampleBatch::reference reference;
60
61 ///\brief Immutable reference to an element of the batch.
62 typedef typename SampleBatch::const_reference const_reference;
63
64private:
65
67
68 bool m_storeEnergyDifferences;
69 bool m_integrateEnergyDifferences;
70 std::vector<RealVector> m_energyDiffUp;
71 std::vector<RealVector> m_energyDiffDown;
72
73public:
75 bool integrateEnergyDifferences = true
76 ):m_chain(rbm)
77 , m_integrateEnergyDifferences(integrateEnergyDifferences)
78 , m_storeEnergyDifferences(true){}
79
80 const Operator& transitionOperator()const{
81 return m_chain.transitionOperator();
82 }
83 Operator& transitionOperator(){
84 return m_chain.transitionOperator();
85 }
86
87 void setNumberOfTemperatures(std::size_t temperatures){
88 m_chain.setNumberOfTemperatures(temperatures);
89 }
90 void setUniformTemperatureSpacing(std::size_t temperatures){
91 m_chain.setUniformTemperatureSpacing(temperatures);
92 }
93
94 /// \brief Returns the number Of temperatures.
95 std::size_t numberOfTemperatures()const{
96 return m_chain.numberOfTemperatures();
97 }
98
99 void setBatchSize(std::size_t batchSize){
100 SHARK_RUNTIME_CHECK(batchSize == 1, "Markov chain can only compute batches of size 1.");
101 }
102 std::size_t batchSize(){
103 return 1;
104 }
105
106 void setBeta(std::size_t i, double beta){
107 m_chain.setBeta(i,beta);
108 }
109
110 double beta(std::size_t i)const{
111 return m_chain.beta(i);
112 }
113
114 RealVector const& beta()const{
115 return m_chain.beta();
116 }
117
118 ///\brief Returns the current state of the chain for beta = 1.
120 return m_chain.sample();
121 }
122 ///\brief Returns the current state of the chain for all beta values.
123 SampleBatch const& samples()const{
124 return m_chain.samples();
125 }
126
127 /// \brief Returns the current batch of samples of the Markov chain.
129 return m_chain.samples();
130 }
131
132 ///\brief Initializes the markov chain using samples drawn uniformly from the set.
133 ///
134 /// @param dataSet the data set
135 void initializeChain(Data<RealVector> const& dataSet){
136 m_chain.initializeChain(dataSet);
137 }
138
139 /// \brief Initializes with data points from a batch of points
140 ///
141 /// @param sampleData the data set
142 void initializeChain(RealMatrix const& sampleData){
143 m_chain.initializeChain(sampleData);
144 }
145 //updates the chain using the current sample
146 void step(unsigned int k){
147 m_chain.step(k);
148
149 if(!storeEnergyDifferences()) return;
150
151 typename RBM::EnergyType energy = transitionOperator().rbm()->energy();
152 std::size_t numChains = beta().size();
153 //create diff beta vectors
154 RealVector betaUp(numChains);
155 RealVector betaDown(numChains);
156 betaUp(0) = 1.0;
157 betaDown(numChains-1) = 0.0;
158 for(std::size_t i = 0; i != numChains-1; ++i){
159 betaDown(i) = beta()(i+1);
160 betaUp(i+1) = beta()(i);
161 }
162
163 RealVector energyDiffUp(numChains);
164 RealVector energyDiffDown(numChains);
165 if(!m_integrateEnergyDifferences){
166 noalias(energyDiffUp) = samples().energy*(betaUp-beta());
167 noalias(energyDiffDown) = samples().energy*(betaDown-beta());
168 }
169 else{
170 //calculate the first term: -E(state,beta) thats the same for both matrices
171 energy.inputVisible(samples().visible.input, samples().hidden.state);
172 noalias(energyDiffDown) = energy.logUnnormalizedProbabilityHidden(
173 samples().hidden.state,
174 samples().visible.input,
175 beta()
176 );
177 noalias(energyDiffUp) = energyDiffDown;
178
179 //now add the new term
180 noalias(energyDiffUp) -= energy.logUnnormalizedProbabilityHidden(
181 samples().hidden.state,
182 samples().visible.input,
183 betaUp
184 );
185 noalias(energyDiffDown) -= energy.logUnnormalizedProbabilityHidden(
186 samples().hidden.state,
187 samples().visible.input,
188 betaDown
189 );
190 }
191 m_energyDiffUp.push_back(energyDiffUp);
192 m_energyDiffDown.push_back(energyDiffDown);
193 }
194
195 RealMatrix getUpDifferences()const{
196 RealMatrix diffUp(beta().size(),m_energyDiffUp.size());
197 for(std::size_t i = 0; i != m_energyDiffUp.size(); ++i){
198 noalias(column(diffUp,i)) = m_energyDiffUp[i];
199 }
200 return diffUp;
201 }
202 RealMatrix getDownDifferences()const{
203 RealMatrix diffDown(beta().size(),m_energyDiffDown.size());
204 for(std::size_t i = 0; i != m_energyDiffDown.size(); ++i){
205 noalias(column(diffDown,i)) = m_energyDiffDown[i];
206 }
207 return diffDown;
208 }
209
211 m_energyDiffUp.clear();
212 m_energyDiffDown.clear();
213 }
214
216 return m_storeEnergyDifferences;
217 }
218
219 //is called after the weights of the rbm got updated.
220 //this allows the chains to store intermediate results
221 void update(){
222 m_chain.update();
223 }
224};
225
226}
227#endif