VariationalAutoencoder.cpp
Go to the documentation of this file.
1#include <shark/Data/SparseData.h>//for reading in the images as sparseData/Libsvm format
2#include <shark/Data/Pgm.h>//for printing out reconstructions
3#include <shark/Models/LinearModel.h>//single dense layer
4#include <shark/Models/ConcatenatedModel.h>//for stacking layers
5#include <shark/Algorithms/GradientDescent/Adam.h>// The Adam optimization algorithm
6#include <shark/ObjectiveFunctions/Loss/SquaredLoss.h> //squared loss function (can also be cross-entropy for greyscale images)
7#include <shark/ObjectiveFunctions/VariationalAutoencoderError.h> //variational autoencoder error function
8using namespace shark;
9
10int main(int argc, char **argv)
11{
12 if(argc < 2) {
13 std::cerr << "usage: " << argv[0] << " path/to/mnist_subset.libsvm" << std::endl;
14 return 1;
15 }
16
17 //Step1: load data
19 importSparseData( data, argv[1] , 784 , 50);
20
21 //Step 2: define model
22 //build encoder network
23 //note that the output layer must be linear and must have twice the number of outputs than the decoder inputs
24 //as we have to model mean and variance for each decoder-input.
26 LinearModel<FloatVector, LinearNeuron> encoder2(encoder1.outputShape(),2 * 300, true);
27 auto encoder = encoder1 >> encoder2;
28
29 //build decoder network
30 //MNIST is scaled between 0 and 1 so a sigmoid output makes predicting compeltely black and completely white pixels easier
31 LinearModel<FloatVector, RectifierNeuron> decoder1(300, 500, true);
32 LinearModel<FloatVector, LogisticNeuron> decoder2(decoder1.outputShape(), data.inputShape(), true);
33 auto decoder = decoder1 >> decoder2;
34
36 double lambda = 1.0;
37 VariationalAutoencoderError<FloatVector> error(data.inputs(), &encoder, &decoder,&loss, lambda);
38
39 //Step 4 set up optimizer and run optimization
40 std::size_t iterations = 20000;
41 Adam<FloatVector> optimizer;
42 optimizer.setEta(0.001);
43 initRandomNormal(encoder,0.0001);
44 initRandomNormal(decoder,0.0001);
45 error.init();
46 optimizer.init(error);
47 std::cout<<"Optimizing model "<<std::endl;
48 for(std::size_t i = 0; i <= iterations; ++i){
49 optimizer.step(error);
50 if(i % 100 == 0){
51 //create some reconstructions for evaluation
52 auto const& batch = data.batch(0).input;
53 RealMatrix reconstructed = decoder(error.sampleZ(optimizer.solution().point, batch));
54
55 std::cout<<i<<" "<<optimizer.solution().value<<" "<<loss(batch, reconstructed)/batch.size1()<<std::endl;
56 //store reconstructions
57 exportFiltersToPGMGrid("reconstructed"+std::to_string(i), reconstructed,28,28);
58 }
59 }
60}