VariationalAutoencoderError.h
Go to the documentation of this file.
1/*!
2 *
3 *
4 * \brief Variational-autoencoder error function
5 *
6 *
7 *
8 * \author O.Krause
9 * \date 2017
10 *
11 *
12 * \par Copyright 1995-2017 Shark Development Team
13 *
14 * <BR><HR>
15 * This file is part of Shark.
16 * <https://shark-ml.github.io/Shark/>
17 *
18 * Shark is free software: you can redistribute it and/or modify
19 * it under the terms of the GNU Lesser General Public License as published
20 * by the Free Software Foundation, either version 3 of the License, or
21 * (at your option) any later version.
22 *
23 * Shark is distributed in the hope that it will be useful,
24 * but WITHOUT ANY WARRANTY; without even the implied warranty of
25 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
26 * GNU Lesser General Public License for more details.
27 *
28 * You should have received a copy of the GNU Lesser General Public License
29 * along with Shark. If not, see <http://www.gnu.org/licenses/>.
30 *
31 */
32#ifndef SHARK_OBJECTIVEFUNCTIONS_NEGATIVE_LOG_LIKELIHOOD_H
33#define SHARK_OBJECTIVEFUNCTIONS_NEGATIVE_LOG_LIKELIHOOD_H
34
37#include <shark/Core/Random.h>
38
39namespace shark{
40
41/// \brief Computes the variational autoencoder error function
42///
43/// We want to optimize a model \f$ p(x) = \int p(x|z) p(z) dz \f$ where we choose p(z) as a multivariate normal distribution
44/// and p(x|z) is an arbitrary model, e.g. a deep neural entwork. The naive solution is sampling from p(z) and then compute the sample
45/// average. This will fail when p(z|x) is a very localized distribution and we might need many samples from p(z) to find a sample which is likely under
46/// p(z|x). p(z|x) is assumed to be intractable to compute, so we introduce a second model q(z|x), modeling p(z|x) and we want to train
47/// it such that it learns the unknown p(z|x). For this a variational lower bound on the likelihood is used and we maximize
48/// \f[ log p(x) \leq E_{q(z|x)}[\log p(x|z)] - KL[q(z|x) || p(z)] \f]
49/// The first term explains the meaning of variational autoencoder: we first sample z given x using the encoder model q and then decode
50/// z to obtain an estimate for x. The only difference to normal autoencoders is that we now have a probabilistic z. The second term ensures that
51/// q is learning p(z|x), assuming that we have enough modeling capacity to actually learn it.
52/// See https://arxiv.org/abs/1606.05908 for more background.
53///
54/// Implementation notice: we assume q(z|x) to be a set of independent gaussian distributions parameterized as
55/// \f$ q(z| mu(x), \log \sigma^2(x)) \f$.
56/// The provided encoder model q must therefore have twice as many outputs as the decvoder has inputs as
57/// the second half of outputs is interpreted as the log of the variance. So if z should be a 100 dimensional variable, q must have 200
58/// outputs. The outputs and loss function used for the encoder p is arbitrary, but a SquaredLoss will work well, however also other losses
59/// like pixel probabilities can be used.
60/// \ingroup objfunctions
61
62template<class SearchPointType>
63class VariationalAutoencoderError : public AbstractObjectiveFunction<SearchPointType, double>
64{
65private:
66 typedef typename SearchPointType::device_type device_type;
67 typedef typename SearchPointType::value_type value_type;
68 typedef blas::matrix<value_type, blas::row_major, device_type> MatrixType;
69public:
72
74 DatasetType const& data,
75 ModelType* encoder,
76 ModelType* decoder,
78 double lambda = 1.0
79 ):mep_decoder(decoder), mep_encoder(encoder), mep_loss(visible_loss), m_data(data), m_lambda(lambda){
80 if(mep_decoder->hasFirstParameterDerivative() && mep_encoder->hasFirstParameterDerivative())
81 this->m_features |= this->HAS_FIRST_DERIVATIVE;
83 this->m_features |= this->IS_NOISY;
84 }
85
86 /// \brief From INameable: return the class name.
87 std::string name() const
88 { return "VariationalAutoencoderError"; }
89
91 return mep_decoder->parameterVector() | mep_encoder->parameterVector();
92 }
93
94 std::size_t numberOfVariables()const{
95 return mep_decoder->numberOfParameters() + mep_encoder->numberOfParameters();
96 }
97
98 MatrixType sampleZ(SearchPointType const& parameters, MatrixType const& batch) const{
99 mep_decoder->setParameterVector(subrange(parameters,0,mep_decoder->numberOfParameters()));
100 mep_encoder->setParameterVector(subrange(parameters,mep_decoder->numberOfParameters(), numberOfVariables()));
101
102 MatrixType hiddenResponse = (*mep_encoder)(batch);
103 auto const& mu = columns(hiddenResponse,0,hiddenResponse.size2()/2);
104 auto const& log_var = columns(hiddenResponse,hiddenResponse.size2()/2, hiddenResponse.size2());
105 //sample random point from distribution
106 MatrixType epsilon = blas::normal(*this->mep_rng,mu.size1(), mu.size2(), value_type(0.0), value_type(1.0), device_type());
107 return mu + exp(0.5*log_var) * epsilon;
108 }
109
110 double eval(SearchPointType const& parameters) const{
111 SIZE_CHECK(parameters.size() == numberOfVariables());
112 this->m_evaluationCounter++;
113 mep_decoder->setParameterVector(subrange(parameters,0,mep_decoder->numberOfParameters()));
114 mep_encoder->setParameterVector(subrange(parameters,mep_decoder->numberOfParameters(), numberOfVariables()));
115
116 auto const& batch = m_data.batch(random::discrete(*this->mep_rng, std::size_t(0), m_data.numberOfBatches() -1));
117 MatrixType hiddenResponse = (*mep_encoder)(batch);
118 auto const& mu = columns(hiddenResponse,0,hiddenResponse.size2()/2);
119 auto const& log_var = columns(hiddenResponse,hiddenResponse.size2()/2, hiddenResponse.size2());
120 //compute kulback leibler divergence term
121 double klError = 0.5 * (sum(exp(log_var)) + sum(sqr(mu)) - mu.size1() * mu.size2() - sum(log_var));
122 //sample random point from distribution
123 MatrixType epsilon = blas::normal(*this->mep_rng,mu.size1(), mu.size2(), value_type(0.0), value_type(1.0), device_type());
124 MatrixType z = mu + exp(0.5*log_var) * epsilon;
125 //reconstruct and compute reconstruction error
126 MatrixType reconstruction = (*mep_decoder)(z);
127 return (m_lambda * (*mep_loss)(batch, reconstruction) + klError) / batch.size1();
128 }
129
130
132 SearchPointType const& parameters,
133 SearchPointType & derivative
134 ) const{
135 SIZE_CHECK(parameters.size() == numberOfVariables());
136 this->m_evaluationCounter++;
137 mep_decoder->setParameterVector(subrange(parameters,0,mep_decoder->numberOfParameters()));
138 mep_encoder->setParameterVector(subrange(parameters,mep_decoder->numberOfParameters(), numberOfVariables()));
139
140 boost::shared_ptr<State> stateEncoder = mep_encoder->createState();
141 boost::shared_ptr<State> stateDecoder = mep_decoder->createState();
142 auto const& batch = m_data.batch(random::discrete(*this->mep_rng, std::size_t(0), m_data.numberOfBatches() -1));
143 MatrixType hiddenResponse;
144 mep_encoder->eval(batch,hiddenResponse,*stateEncoder);
145 auto const& mu = columns(hiddenResponse,0,hiddenResponse.size2()/2);
146 auto const& log_var = columns(hiddenResponse,hiddenResponse.size2()/2, hiddenResponse.size2());
147 //compute kulback leibler divergence term
148 double klError = 0.5 * (sum(exp(log_var)) + sum(sqr(mu)) - mu.size1() * mu.size2() - sum(log_var));
149 MatrixType klDerivative = mu | (0.5 * exp(log_var) - 0.5);
150 MatrixType epsilon = blas::normal(*this->mep_rng,mu.size1(), mu.size2(), value_type(0.0), value_type(1.0), device_type());
151 MatrixType z = mu + exp(0.5*log_var) * epsilon;
152 MatrixType reconstructions;
153 mep_decoder->eval(z,reconstructions, *stateDecoder);
154
155
156 //compute loss derivative
157 MatrixType lossDerivative;
158 double recError = m_lambda * mep_loss->evalDerivative(batch,reconstructions,lossDerivative);
159 lossDerivative *= m_lambda;
160 //backpropagate error from the reconstruction loss to the Decoder
161 SearchPointType derivativeDecoder;
162 MatrixType backpropDecoder;
163 mep_decoder->weightedDerivatives(z,reconstructions, lossDerivative,*stateDecoder, derivativeDecoder, backpropDecoder);
164
165 //compute coefficients of the backprop from mep_decoder and the KL-term
166 MatrixType backprop=(backpropDecoder | (backpropDecoder * 0.5*(z - mu))) + klDerivative;
167 SearchPointType derivativeEncoder;
168 mep_encoder->weightedParameterDerivative(batch,hiddenResponse, backprop,*stateEncoder, derivativeEncoder);
169
170 derivative.resize(numberOfVariables());
171 noalias(derivative) = derivativeDecoder|derivativeEncoder;
172 derivative /= batch.size1();
173 return (recError + klError) / batch.size1();
174 }
175
176private:
177 ModelType* mep_decoder;
178 ModelType* mep_encoder;
181 double m_lambda;
182};
183
184}
185#endif