102 MatrixType hiddenResponse = (*mep_encoder)(batch);
103 auto const& mu = columns(hiddenResponse,0,hiddenResponse.size2()/2);
104 auto const& log_var = columns(hiddenResponse,hiddenResponse.size2()/2, hiddenResponse.size2());
106 MatrixType epsilon = blas::normal(*this->
mep_rng,mu.size1(), mu.size2(), value_type(0.0), value_type(1.0), device_type());
107 return mu + exp(0.5*log_var) * epsilon;
117 MatrixType hiddenResponse = (*mep_encoder)(batch);
118 auto const& mu = columns(hiddenResponse,0,hiddenResponse.size2()/2);
119 auto const& log_var = columns(hiddenResponse,hiddenResponse.size2()/2, hiddenResponse.size2());
121 double klError = 0.5 * (sum(exp(log_var)) + sum(
sqr(mu)) - mu.size1() * mu.size2() - sum(log_var));
123 MatrixType epsilon = blas::normal(*this->
mep_rng,mu.size1(), mu.size2(), value_type(0.0), value_type(1.0), device_type());
124 MatrixType z = mu + exp(0.5*log_var) * epsilon;
126 MatrixType reconstruction = (*mep_decoder)(z);
127 return (m_lambda * (*mep_loss)(batch, reconstruction) + klError) / batch.size1();
140 boost::shared_ptr<State> stateEncoder = mep_encoder->
createState();
141 boost::shared_ptr<State> stateDecoder = mep_decoder->
createState();
143 MatrixType hiddenResponse;
144 mep_encoder->
eval(batch,hiddenResponse,*stateEncoder);
145 auto const& mu = columns(hiddenResponse,0,hiddenResponse.size2()/2);
146 auto const& log_var = columns(hiddenResponse,hiddenResponse.size2()/2, hiddenResponse.size2());
148 double klError = 0.5 * (sum(exp(log_var)) + sum(
sqr(mu)) - mu.size1() * mu.size2() - sum(log_var));
149 MatrixType klDerivative = mu | (0.5 * exp(log_var) - 0.5);
150 MatrixType epsilon = blas::normal(*this->
mep_rng,mu.size1(), mu.size2(), value_type(0.0), value_type(1.0), device_type());
151 MatrixType z = mu + exp(0.5*log_var) * epsilon;
152 MatrixType reconstructions;
153 mep_decoder->
eval(z,reconstructions, *stateDecoder);
157 MatrixType lossDerivative;
158 double recError = m_lambda * mep_loss->
evalDerivative(batch,reconstructions,lossDerivative);
159 lossDerivative *= m_lambda;
162 MatrixType backpropDecoder;
163 mep_decoder->
weightedDerivatives(z,reconstructions, lossDerivative,*stateDecoder, derivativeDecoder, backpropDecoder);
166 MatrixType backprop=(backpropDecoder | (backpropDecoder * 0.5*(z - mu))) + klDerivative;
171 noalias(derivative) = derivativeDecoder|derivativeEncoder;
172 derivative /= batch.size1();
173 return (recError + klError) / batch.size1();