MarkovChain.h
Go to the documentation of this file.
1/*!
2 *
3 *
4 * \brief -
5 *
6 * \author -
7 * \date -
8 *
9 *
10 * \par Copyright 1995-2017 Shark Development Team
11 *
12 * <BR><HR>
13 * This file is part of Shark.
14 * <https://shark-ml.github.io/Shark/>
15 *
16 * Shark is free software: you can redistribute it and/or modify
17 * it under the terms of the GNU Lesser General Public License as published
18 * by the Free Software Foundation, either version 3 of the License, or
19 * (at your option) any later version.
20 *
21 * Shark is distributed in the hope that it will be useful,
22 * but WITHOUT ANY WARRANTY; without even the implied warranty of
23 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
24 * GNU Lesser General Public License for more details.
25 *
26 * You should have received a copy of the GNU Lesser General Public License
27 * along with Shark. If not, see <http://www.gnu.org/licenses/>.
28 *
29 */
30#ifndef SHARK_UNSUPERVISED_RBM_SAMPLING_MARKOVCHAIN_H
31#define SHARK_UNSUPERVISED_RBM_SAMPLING_MARKOVCHAIN_H
32
33#include <shark/Data/Dataset.h>
34#include <shark/Core/Random.h>
36#include "Impl/SampleTypes.h"
37namespace shark{
38
39/// \brief A single Markov chain.
40///
41/// You can run the Markov chain for some sampling steps by applying a transition operator.
42template<class Operator>
44private:
45 typedef typename Operator::HiddenSample HiddenSample;
46 typedef typename Operator::VisibleSample VisibleSample;
47public:
48
49 ///\brief The MarkovChain can be used to compute several samples at once.
50 static const bool computesBatch = true;
51
52 ///\brief The type of the RBM the operator is working with.
53 typedef typename Operator::RBM RBM;
54 ///\brief A batch of samples containing hidden and visible samples as well as the energies.
56
57 ///\brief Mutable reference to an element of the batch.
58 typedef typename SampleBatch::reference reference;
59
60 ///\brief Immutable reference to an element of the batch.
61 typedef typename SampleBatch::const_reference const_reference;
62private:
63 ///\brief The batch of samples containing the state of the visible and the hidden units.
64 SampleBatch m_samples;
65 ///\brief The transition operator.
66 Operator m_operator;
67public:
68
69 /// \brief Constructor.
70 MarkovChain(RBM* rbm):m_operator(rbm){}
71
72
73 /// \brief Sets the number of parallel samples to be evaluated
74 void setBatchSize(std::size_t batchSize){
75 std::size_t visibles=m_operator.rbm()->numberOfVN();
76 std::size_t hiddens=m_operator.rbm()->numberOfHN();
77 m_samples=SampleBatch(batchSize,visibles,hiddens);
78 }
79 std::size_t batchSize(){
80 return m_samples.size();
81 }
82
83 /// \brief Initializes with data points drawn uniform from the set.
84 ///
85 /// @param dataSet the data set
86 void initializeChain(Data<RealVector> const& dataSet){
87 std::size_t visibles=m_operator.rbm()->numberOfVN();
88 RealMatrix sampleData(m_samples.size(),visibles);
89
90 for(std::size_t i = 0; i != m_samples.size(); ++i){
91 noalias(row(sampleData,i)) = dataSet.element(random::discrete(m_operator.rbm()->rng(),std::size_t(0),dataSet.numberOfElements()-1));
92 }
93 initializeChain(sampleData);
94 }
95
96 /// \brief Initializes with data points from a batch of points
97 ///
98 /// @param sampleData Data set
99 void initializeChain(RealMatrix const& sampleData){
100 m_operator.createSample(m_samples.hidden,m_samples.visible,sampleData);
101 }
102
103 /// \brief Runs the chain for a given number of steps.
104 ///
105 /// @param numberOfSteps the number of steps
106 void step(unsigned int numberOfSteps){
107 m_operator.stepVH(m_samples.hidden,m_samples.visible,numberOfSteps,blas::repeat(1.0,batchSize()));
108 }
109
110 /// \brief Returns the current sample of the Markov chain.
112 return const_reference(m_samples,0);
113 }
114
115 /// \brief Returns the current batch of samples of the Markov chain.
116 SampleBatch const& samples()const{
117 return m_samples;
118 }
119
120 /// \brief Returns the current batch of samples of the Markov chain.
122 return m_samples;
123 }
124
125 /// \brief Returns the transition operator of the Markov chain.
126 Operator const& transitionOperator()const{
127 return m_operator;
128 }
129
130 /// \brief Returns the transition operator of the Markov chain.
132 return m_operator;
133 }
134};
135
136}
137#endif