Shark machine learning library
Installation
Tutorials
Benchmarks
Documentation
Quick references
Class list
Global functions
examples
Supervised
CSvmGridSearchTutorial.cpp
Go to the documentation of this file.
1
#include <
shark/Models/Kernels/GaussianRbfKernel.h
>
2
#include <
shark/ObjectiveFunctions/Loss/ZeroOneLoss.h
>
3
#include <
shark/Algorithms/Trainers/CSvmTrainer.h
>
4
#include <
shark/Data/DataDistribution.h
>
5
6
#include <
shark/ObjectiveFunctions/CrossValidationError.h
>
7
#include <
shark/Algorithms/DirectSearch/GridSearch.h
>
8
#include <
shark/Algorithms/JaakkolaHeuristic.h
>
9
10
using namespace
shark
;
11
using namespace
std;
12
13
14
int
main
()
15
{
16
// problem definition
17
Chessboard
prob;
18
ClassificationDataset
dataTrain = prob.
generateDataset
(200);
19
ClassificationDataset
dataTest= prob.
generateDataset
(10000);
20
21
// SVM setup
22
GaussianRbfKernel<>
kernel(0.5,
true
);
//unconstrained?
23
KernelClassifier<RealVector>
svm;
24
bool
offset =
true
;
25
bool
unconstrained =
true
;
26
CSvmTrainer<RealVector>
trainer(&kernel, 1.0, offset,unconstrained);
27
28
// cross-validation error
29
const
unsigned
int
K = 5;
// number of folds
30
ZeroOneLoss<unsigned int>
loss;
31
CVFolds<ClassificationDataset>
folds =
createCVSameSizeBalanced
(dataTrain, K);
32
CrossValidationError<KernelClassifier<RealVector>
,
unsigned
int
> cvError(
33
folds, &trainer, &svm, &trainer, &loss
34
);
35
36
37
// find best parameters
38
39
// use Jaakkola's heuristic as a starting point for the grid-search
40
JaakkolaHeuristic
ja(dataTrain);
41
double
ljg = log(ja.
gamma
());
42
cout <<
"Tommi Jaakkola says gamma = "
<< ja.
gamma
() <<
" and ln(gamma) = "
<< ljg << endl;
43
44
GridSearch
grid;
45
vector<double> min(2);
46
vector<double> max(2);
47
vector<size_t> sections(2);
48
// kernel parameter gamma
49
min[0] = ljg-4.; max[0] = ljg+4; sections[0] = 9;
50
// regularization parameter C
51
min[1] = 0.0; max[1] = 10.0; sections[1] = 11;
52
grid.
configure
(min, max, sections);
53
grid.
step
(cvError);
54
55
// train model on the full dataset
56
trainer.
setParameterVector
(grid.
solution
().point);
57
trainer.
train
(svm, dataTrain);
58
cout <<
"grid.solution().point "
<< grid.
solution
().point << endl;
59
cout <<
"C =\t"
<< trainer.
C
() << endl;
60
cout <<
"gamma =\t"
<< kernel.
gamma
() << endl;
61
62
// evaluate
63
Data<unsigned int>
output = svm(dataTrain.
inputs
());
64
double
train_error = loss.
eval
(dataTrain.
labels
(), output);
65
cout <<
"training error:\t"
<< train_error << endl;
66
output = svm(dataTest.
inputs
());
67
double
test_error = loss.
eval
(dataTest.
labels
(), output);
68
cout <<
"test error: \t"
<< test_error << endl;
69
}