Shark machine learning library
Installation
Tutorials
Benchmarks
Documentation
Quick references
Class list
Global functions
include
shark
Unsupervised
RBM
Sampling
MarkovChain.h
Go to the documentation of this file.
1
/*!
2
*
3
*
4
* \brief -
5
*
6
* \author -
7
* \date -
8
*
9
*
10
* \par Copyright 1995-2017 Shark Development Team
11
*
12
* <BR><HR>
13
* This file is part of Shark.
14
* <https://shark-ml.github.io/Shark/>
15
*
16
* Shark is free software: you can redistribute it and/or modify
17
* it under the terms of the GNU Lesser General Public License as published
18
* by the Free Software Foundation, either version 3 of the License, or
19
* (at your option) any later version.
20
*
21
* Shark is distributed in the hope that it will be useful,
22
* but WITHOUT ANY WARRANTY; without even the implied warranty of
23
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
24
* GNU Lesser General Public License for more details.
25
*
26
* You should have received a copy of the GNU Lesser General Public License
27
* along with Shark. If not, see <http://www.gnu.org/licenses/>.
28
*
29
*/
30
#ifndef SHARK_UNSUPERVISED_RBM_SAMPLING_MARKOVCHAIN_H
31
#define SHARK_UNSUPERVISED_RBM_SAMPLING_MARKOVCHAIN_H
32
33
#include <
shark/Data/Dataset.h
>
34
#include <
shark/Core/Random.h
>
35
#include <
shark/Unsupervised/RBM/Tags.h
>
36
#include "Impl/SampleTypes.h"
37
namespace
shark
{
38
39
/// \brief A single Markov chain.
40
///
41
/// You can run the Markov chain for some sampling steps by applying a transition operator.
42
template
<
class
Operator>
43
class
MarkovChain
{
44
private
:
45
typedef
typename
Operator::HiddenSample HiddenSample;
46
typedef
typename
Operator::VisibleSample VisibleSample;
47
public
:
48
49
///\brief The MarkovChain can be used to compute several samples at once.
50
static
const
bool
computesBatch
=
true
;
51
52
///\brief The type of the RBM the operator is working with.
53
typedef
typename
Operator::RBM
RBM
;
54
///\brief A batch of samples containing hidden and visible samples as well as the energies.
55
typedef
typename
Batch<detail::MarkovChainSample<HiddenSample,VisibleSample>
>::type
SampleBatch
;
56
57
///\brief Mutable reference to an element of the batch.
58
typedef
typename
SampleBatch::reference
reference
;
59
60
///\brief Immutable reference to an element of the batch.
61
typedef
typename
SampleBatch::const_reference
const_reference
;
62
private
:
63
///\brief The batch of samples containing the state of the visible and the hidden units.
64
SampleBatch
m_samples;
65
///\brief The transition operator.
66
Operator m_operator;
67
public
:
68
69
/// \brief Constructor.
70
MarkovChain
(
RBM
* rbm):m_operator(rbm){}
71
72
73
/// \brief Sets the number of parallel samples to be evaluated
74
void
setBatchSize
(std::size_t
batchSize
){
75
std::size_t visibles=m_operator.rbm()->numberOfVN();
76
std::size_t hiddens=m_operator.rbm()->numberOfHN();
77
m_samples=
SampleBatch
(
batchSize
,visibles,hiddens);
78
}
79
std::size_t
batchSize
(){
80
return
m_samples.size();
81
}
82
83
/// \brief Initializes with data points drawn uniform from the set.
84
///
85
/// @param dataSet the data set
86
void
initializeChain
(
Data<RealVector>
const
& dataSet){
87
std::size_t visibles=m_operator.rbm()->numberOfVN();
88
RealMatrix sampleData(m_samples.size(),visibles);
89
90
for
(std::size_t i = 0; i != m_samples.size(); ++i){
91
noalias(row(sampleData,i)) = dataSet.
element
(
random::discrete
(m_operator.rbm()->rng(),std::size_t(0),dataSet.
numberOfElements
()-1));
92
}
93
initializeChain
(sampleData);
94
}
95
96
/// \brief Initializes with data points from a batch of points
97
///
98
/// @param sampleData Data set
99
void
initializeChain
(RealMatrix
const
& sampleData){
100
m_operator.createSample(m_samples.hidden,m_samples.visible,sampleData);
101
}
102
103
/// \brief Runs the chain for a given number of steps.
104
///
105
/// @param numberOfSteps the number of steps
106
void
step
(
unsigned
int
numberOfSteps){
107
m_operator.stepVH(m_samples.hidden,m_samples.visible,numberOfSteps,blas::repeat(1.0,
batchSize
()));
108
}
109
110
/// \brief Returns the current sample of the Markov chain.
111
const_reference
sample
()
const
{
112
return
const_reference
(m_samples,0);
113
}
114
115
/// \brief Returns the current batch of samples of the Markov chain.
116
SampleBatch
const
&
samples
()
const
{
117
return
m_samples;
118
}
119
120
/// \brief Returns the current batch of samples of the Markov chain.
121
SampleBatch
&
samples
(){
122
return
m_samples;
123
}
124
125
/// \brief Returns the transition operator of the Markov chain.
126
Operator
const
&
transitionOperator
()
const
{
127
return
m_operator;
128
}
129
130
/// \brief Returns the transition operator of the Markov chain.
131
Operator&
transitionOperator
(){
132
return
m_operator;
133
}
134
};
135
136
}
137
#endif