Shark machine learning library
Installation
Tutorials
Benchmarks
Documentation
Quick references
Class list
Global functions
include
shark
Statistics
Distributions
MultiVariateNormalDistribution.h
Go to the documentation of this file.
1
/*!
2
*
3
*
4
* \brief Implements a multi-variate normal distribution with zero mean.
5
*
6
*
7
*
8
* \author T.Voss, O.Krause
9
* \date 2016
10
*
11
*
12
* \par Copyright 1995-2017 Shark Development Team
13
*
14
* <BR><HR>
15
* This file is part of Shark.
16
* <https://shark-ml.github.io/Shark/>
17
*
18
* Shark is free software: you can redistribute it and/or modify
19
* it under the terms of the GNU Lesser General Public License as published
20
* by the Free Software Foundation, either version 3 of the License, or
21
* (at your option) any later version.
22
*
23
* Shark is distributed in the hope that it will be useful,
24
* but WITHOUT ANY WARRANTY; without even the implied warranty of
25
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
26
* GNU Lesser General Public License for more details.
27
*
28
* You should have received a copy of the GNU Lesser General Public License
29
* along with Shark. If not, see <http://www.gnu.org/licenses/>.
30
*
31
*/
32
#ifndef SHARK_STATISTICS_MULTIVARIATENORMALDISTRIBUTION_H
33
#define SHARK_STATISTICS_MULTIVARIATENORMALDISTRIBUTION_H
34
35
#include <
shark/LinAlg/Base.h
>
36
#include <
shark/Core/Random.h
>
37
namespace
shark
{
38
39
/// \brief Implements a multi-variate normal distribution with zero mean.
40
class
MultiVariateNormalDistribution
{
41
public
:
42
///\brief Result type of a sampling operation.
43
///
44
/// The first element is the result of sampling this distribution, the
45
/// second element is the original standard-normally distributed vector drawn
46
/// for sampling purposes.
47
typedef
std::pair<RealVector,RealVector> result_type;
48
49
/// \brief Constructor
50
/// \param [in] Sigma covariance matrix
51
MultiVariateNormalDistribution
(RealMatrix
const
& Sigma ) {
52
m_covarianceMatrix = Sigma;
53
update
();
54
}
55
56
/// \brief Constructor
57
MultiVariateNormalDistribution
(){}
58
59
/// \brief Stores/Restores the distribution from the supplied archive.
60
/// \param [in,out] ar The archive to read from/write to.
61
/// \param [in] version Currently unused.
62
template
<
typename
Archive>
63
void
serialize
( Archive & ar,
const
std::size_t version ) {
64
ar & BOOST_SERIALIZATION_NVP( m_covarianceMatrix );
65
ar & BOOST_SERIALIZATION_NVP( m_decomposition );
66
}
67
68
/// \brief Resizes the distribution. Updates both eigenvectors and eigenvalues.
69
/// \param [in] size The new size of the distribution
70
void
resize
( std::size_t size ) {
71
m_covarianceMatrix = blas::identity_matrix<double>( size );
72
update
();
73
}
74
75
/// \brief Accesses the covariance matrix defining the distribution.
76
RealMatrix
const
&
covarianceMatrix
()
const
{
77
return
m_covarianceMatrix;
78
}
79
80
/// \brief Accesses a mutable reference to the covariance matrix
81
/// defining the distribution. Allows for l-value semantics.
82
///
83
/// ATTENTION: If the reference is altered, update needs to be called manually.
84
RealMatrix&
covarianceMatrix
() {
85
return
m_covarianceMatrix;
86
}
87
88
/// \brief Sets the covariance matrix and updates the internal variables. This is expensive
89
void
setCovarianceMatrix
(RealMatrix
const
& matrix){
90
covarianceMatrix
() = matrix;
91
update
();
92
}
93
94
/// \brief Accesses an immutable reference to the eigenvectors of the covariance matrix.
95
RealMatrix
const
&
eigenVectors
()
const
{
96
return
m_decomposition.Q();
97
}
98
99
/// \brief Accesses an immutable reference to the eigenvalues of the covariance matrix.
100
RealVector
const
&
eigenValues
()
const
{
101
return
m_decomposition.D();
102
}
103
104
/// \brief Samples the distribution.
105
/// \param [in] rng Random number generator.
106
template
<
class
randomType>
107
result_type
operator()
(randomType& rng)
const
{
108
RealVector z( m_covarianceMatrix.size1() );
109
110
for
( std::size_t i = 0; i < z.size(); i++ ) {
111
z( i ) =
random::gauss
(rng, 0., 1. );
112
}
113
114
RealVector result = m_decomposition.Q() % to_diagonal(sqrt(max(
eigenValues
(),0))) % z;
115
return
std::make_pair( result, z );
116
}
117
118
/// \brief Calculates the evd of the current covariance matrix.
119
void
update
() {
120
m_decomposition.decompose(m_covarianceMatrix);
121
}
122
123
private
:
124
RealMatrix m_covarianceMatrix;
///< Covariance matrix of the mutation distribution.
125
blas::symm_eigenvalue_decomposition<RealMatrix> m_decomposition;
/// < Eigenvalue decomposition of the covarianceMatrix
126
};
127
128
/// \brief Multivariate normal distribution with zero mean using a cholesky decomposition
129
class
MultiVariateNormalDistributionCholesky
{
130
public
:
131
/// \brief Result type of a sampling operation.
132
///
133
/// The first element is the result of sampling this distribution, the
134
/// second element is the original standard-normally distributed vector drawn
135
/// for sampling purposes.
136
typedef
std::pair<RealVector,RealVector> result_type;
137
138
/// \brief Constructor
139
/// \param [in] covariance Covariance matrix.
140
MultiVariateNormalDistributionCholesky
( RealMatrix
const
&
covariance
){
141
setCovarianceMatrix
(
covariance
);
142
}
143
144
MultiVariateNormalDistributionCholesky
(){}
145
146
/// \brief Stores/Restores the distribution from the supplied archive.
147
///\param [in,out] ar Archive to read from/write to.
148
///\param [in] version Currently unused.
149
template
<
typename
Archive>
150
void
serialize
( Archive & ar,
const
std::size_t version ) {
151
ar & BOOST_SERIALIZATION_NVP( m_cholesky);
152
}
153
154
/// \brief Resizes the distribution. Updates both eigenvectors and eigenvalues.
155
/// \param [in] size The new size of the distribution
156
void
resize
( std::size_t
size
) {
157
m_cholesky = blas::identity_matrix<double>(
size
);
158
}
159
160
/// \brief Returns the size of the created vectors
161
std::size_t
size
()
const
{
162
return
m_cholesky.lower_factor().size1();
163
}
164
165
/// \brief Returns the matrix holding the lower cholesky factor A
166
blas::matrix<double,blas::column_major>
const
&
lowerCholeskyFactor
()
const
{
167
return
m_cholesky.lower_factor();
168
}
169
170
171
/// \brief Sets the new covariance matrix by computing the new cholesky dcomposition
172
void
setCovarianceMatrix
(RealMatrix
const
& matrix){
173
m_cholesky.decompose(matrix);
174
}
175
176
/// \brief Updates the covariance matrix of the distribution to C<- alpha*C+beta * vv^T
177
void
rankOneUpdate
(
double
alpha,
double
beta, RealVector
const
& v){
178
m_cholesky.update(alpha,beta,v);
179
}
180
181
template
<
class
randomType,
class
Vector1,
class
Vector2>
182
void
generate
(randomType& rng, Vector1& y, Vector2& z)
const
{
183
z.resize(
size
());
184
y.resize(
size
());
185
for
( std::size_t i = 0; i !=
size
(); i++ ) {
186
z( i ) =
random::gauss
(rng, 0, 1 );
187
}
188
noalias(y) = blas::triangular_prod<blas::lower>(m_cholesky.lower_factor(),z);
189
}
190
191
/// \brief Samples the distribution.
192
///
193
/// Returns a vector pair (y,z) where y=Lz and, L is the lower cholesky factor and z is a vector
194
/// of normally distributed numbers. Thus y is the real sampled point.
195
///
196
/// \param [in] rng Random number generator.
197
template
<
class
randomType>
198
result_type
operator()
(randomType& rng)
const
{
199
result_type result;
200
RealVector& z = result.second;
201
RealVector& y = result.first;
202
generate
(rng,y,z);
203
return
result;
204
}
205
206
private
:
207
blas::cholesky_decomposition<blas::matrix<double,blas::column_major> > m_cholesky;
///< The lower cholesky factor (actually any is okay as long as it is the left)
208
};
209
210
211
}
212
213
#endif