FFNNBasicTutorial.cpp
Go to the documentation of this file.
1//the model
2#include <shark/Models/LinearModel.h>//single dense layer
3#include <shark/Models/ConcatenatedModel.h>//for stacking layers, provides operator>>
4//training the model
5#include <shark/ObjectiveFunctions/ErrorFunction.h>//error function, allows for minibatch training
6#include <shark/ObjectiveFunctions/Loss/CrossEntropy.h> // loss used for supervised training
7#include <shark/ObjectiveFunctions/Loss/ZeroOneLoss.h> // loss used for evaluation of performance
8#include <shark/Algorithms/GradientDescent/Adam.h> //optimizer: simple gradient descent.
9#include <shark/Data/SparseData.h> //loading the dataset
10using namespace shark;
11
12int main(int argc, char **argv)
13{
14 if(argc < 2) {
15 std::cerr << "usage: " << argv[0] << " path/to/mnist_subset.libsvm" << std::endl;
16 return 1;
17 }
18 std::size_t hidden1 = 200;
19 std::size_t hidden2 = 100;
20 std::size_t iterations = 1000;
21
22 std::size_t batchSize = 256;
24 importSparseData( data, argv[1], 0, batchSize );
25 data.shuffle(); //shuffle data randomly
26 auto test = splitAtElement(data, 70 * data.numberOfElements() / 100);//split a test set
27 std::size_t numClasses = numberOfClasses(data);
28 std::size_t inputDim = inputDimension(data);
29 //We use a dense linear model with rectifier activations
31
32 //build the network
33 DenseLayer layer1(inputDim,hidden1, true);
34 DenseLayer layer2(hidden1,hidden2, true);
35 LinearModel<RealVector> output(hidden2,numClasses, true);
36 auto network = layer1 >> layer2 >> output;
37 //create the supervised problem.
39 ErrorFunction<> error(data, &network, &loss, true);//enable minibatch training
40
41 //optimize the model
42 std::cout<<"training network"<<std::endl;
43 initRandomNormal(network,0.001);
44 Adam<> optimizer;
45 error.init();
46 optimizer.init(error);
47 for(std::size_t i = 0; i != iterations; ++i){
48 optimizer.step(error);
49 std::cout<<i<<" "<<optimizer.solution().value<<std::endl;
50 }
51 network.setParameterVector(optimizer.solution().point);
52
53 //evaluation
55 Data<RealVector> predictionTrain = network(data.inputs());
56 std::cout << "classification error,train: " << loss01.eval(data.labels(), predictionTrain) << std::endl;
57
58 Data<RealVector> prediction = network(test.inputs());
59 std::cout << "classification error,test: " << loss01.eval(test.labels(), prediction) << std::endl;
60
61}
62