logistic_regression_SAG.cpp
Go to the documentation of this file.
4
5#include <shark/Core/Timer.h>
6#include <iostream>
7using namespace shark;
8using namespace std;
9
10
11template<class InputType>
12void run(LabeledData<InputType,unsigned int> const& data, double alpha, unsigned int epochs){
15
16
18 trainer.setEpochs(epochs);
19
20 Timer time;
21 trainer.train(model, data);
22 double time_taken = time.stop();
23
24 cout << "Cross-Entropy: " << loss(data.labels(),model.decisionFunction()(data.inputs()))<<std::endl;
25 cout << "Time:\n" << time_taken << endl;
26}
27int main(int argc, char **argv) {
28 ClassificationDataset data_dense;
29 importSparseData(data_dense, "mnist",0,8192);
30 data_dense = transformLabels(data_dense, [](unsigned int y){ return y%2;});
32 importSparseData(data_sparse, "rcv1_train.binary",0,8192);
33
34 double alpha = 0.1;
35 run(data_dense, alpha, 200);
36 run(data_sparse, alpha, 2000);
37
38}