CSvmGridSearchTutorial.cpp
Go to the documentation of this file.
5
9
10using namespace shark;
11using namespace std;
12
13
14int main()
15{
16 // problem definition
17 Chessboard prob;
18 ClassificationDataset dataTrain = prob.generateDataset(200);
19 ClassificationDataset dataTest= prob.generateDataset(10000);
20
21 // SVM setup
22 GaussianRbfKernel<> kernel(0.5, true); //unconstrained?
24 bool offset = true;
25 bool unconstrained = true;
26 CSvmTrainer<RealVector> trainer(&kernel, 1.0, offset,unconstrained);
27
28 // cross-validation error
29 const unsigned int K = 5; // number of folds
33 folds, &trainer, &svm, &trainer, &loss
34 );
35
36
37 // find best parameters
38
39 // use Jaakkola's heuristic as a starting point for the grid-search
40 JaakkolaHeuristic ja(dataTrain);
41 double ljg = log(ja.gamma());
42 cout << "Tommi Jaakkola says gamma = " << ja.gamma() << " and ln(gamma) = " << ljg << endl;
43
44 GridSearch grid;
45 vector<double> min(2);
46 vector<double> max(2);
47 vector<size_t> sections(2);
48 // kernel parameter gamma
49 min[0] = ljg-4.; max[0] = ljg+4; sections[0] = 9;
50 // regularization parameter C
51 min[1] = 0.0; max[1] = 10.0; sections[1] = 11;
52 grid.configure(min, max, sections);
53 grid.step(cvError);
54
55 // train model on the full dataset
56 trainer.setParameterVector(grid.solution().point);
57 trainer.train(svm, dataTrain);
58 cout << "grid.solution().point " << grid.solution().point << endl;
59 cout << "C =\t" << trainer.C() << endl;
60 cout << "gamma =\t" << kernel.gamma() << endl;
61
62 // evaluate
63 Data<unsigned int> output = svm(dataTrain.inputs());
64 double train_error = loss.eval(dataTrain.labels(), output);
65 cout << "training error:\t" << train_error << endl;
66 output = svm(dataTest.inputs());
67 double test_error = loss.eval(dataTest.labels(), output);
68 cout << "test error: \t" << test_error << endl;
69}