Shark machine learning library
Installation
Tutorials
Benchmarks
Documentation
Quick references
Class list
Global functions
examples
Benchmark
shark
logistic_regression_SAG.cpp
Go to the documentation of this file.
1
#include <
shark/Data/SparseData.h
>
2
#include <
shark/ObjectiveFunctions/Loss/CrossEntropy.h
>
3
#include <
shark/Algorithms/Trainers/LinearSAGTrainer.h
>
4
5
#include <
shark/Core/Timer.h
>
6
#include <iostream>
7
using namespace
shark
;
8
using namespace
std;
9
10
11
template
<
class
InputType>
12
void
run
(
LabeledData<InputType,unsigned int>
const
& data,
double
alpha,
unsigned
int
epochs){
13
CrossEntropy<unsigned int, RealVector>
loss;
14
LinearClassifier<InputType>
model;
15
16
17
LinearSAGTrainer<InputType,unsigned int>
trainer(&loss,alpha);
18
trainer.
setEpochs
(epochs);
19
20
Timer
time;
21
trainer.
train
(model, data);
22
double
time_taken = time.
stop
();
23
24
cout <<
"Cross-Entropy: "
<< loss(data.
labels
(),model.
decisionFunction
()(data.
inputs
()))<<std::endl;
25
cout <<
"Time:\n"
<< time_taken << endl;
26
}
27
int
main
(
int
argc,
char
**argv) {
28
ClassificationDataset
data_dense;
29
importSparseData
(data_dense,
"mnist"
,0,8192);
30
data_dense =
transformLabels
(data_dense, [](
unsigned
int
y){
return
y%2;});
31
LabeledData<CompressedRealVector,unsigned int>
data_sparse;
32
importSparseData
(data_sparse,
"rcv1_train.binary"
,0,8192);
33
34
double
alpha = 0.1;
35
run
(data_dense, alpha, 200);
36
run
(data_sparse, alpha, 2000);
37
38
}