Shark machine learning library
Installation
Tutorials
Benchmarks
Documentation
Quick references
Class list
Global functions
examples
Benchmark
shark
logistic_regression_LBFGS.cpp
Go to the documentation of this file.
1
#include <
shark/Data/SparseData.h
>
2
#include <
shark/ObjectiveFunctions/Loss/CrossEntropy.h
>
3
#include <
shark/ObjectiveFunctions/Regularizer.h
>
4
5
#include <
shark/Algorithms/GradientDescent/LBFGS.h
>
6
#include <
shark/ObjectiveFunctions/ErrorFunction.h
>
7
#include <
shark/Models/LinearModel.h
>
8
9
#include <
shark/Core/Timer.h
>
10
#include <iostream>
11
using namespace
shark
;
12
using namespace
std;
13
14
int
main
(
int
argc,
char
**argv) {
15
ClassificationDataset
data;
16
importSparseData
(data,
"mnist"
,0,8192);
17
double
alpha = 0.1;
18
CrossEntropy<unsigned int, RealVector>
loss;
19
LinearClassifier<>
model;
20
21
//Setting up the problem
22
model.
decisionFunction
().setStructure(
inputDimension
(data),
numberOfClasses
(data),
true
);
23
TwoNormRegularizer
regularizer;
24
ErrorFunction
error(data,&model.
decisionFunction
(),&loss);
25
error.
setRegularizer
(alpha,®ularizer);
26
27
//solving
28
Timer
time;
29
LBFGS
optimizer;
30
optimizer.
init
(error);
31
while
(error.
evaluationCounter
()<200){
32
optimizer.
step
(error);
33
}
34
model.
setParameterVector
(optimizer.
solution
().point);
35
double
time_taken = time.
stop
();
36
37
cout <<
"Cross-Entropy: "
<< loss(data.
labels
(),model.
decisionFunction
()(data.
inputs
()))<<std::endl;
38
cout <<
"Time:\n"
<< time_taken << endl;
39
}