McSvm.cpp
Go to the documentation of this file.
1#include <cstdio>
2#include <tuple>
3
4#include <shark/LinAlg/Base.h>
5#include <shark/Core/Random.h>
11
12
13using namespace shark;
14
15
16// data generating distribution for our toy
17// multi-category classification problem
18/// @cond EXAMPLE_SYMBOLS
19class Problem : public LabeledDataDistribution<RealVector, unsigned int>
20{
21public:
22 void draw(RealVector& input, unsigned int& label)const
23 {
25 input.resize(1);
26 input(0) = random::gauss(random::globalRng) + 3.0 * label;
27 }
28};
29/// @endcond
30
31int main()
32{
33 // experiment settings
34 unsigned int ell = 30;
35 unsigned int tests = 100;
36 double C = 10.0;
37 double gamma = 0.5;
38
39 // generate a very simple dataset with a little noise
40 Problem problem;
41 ClassificationDataset training = problem.generateDataset(ell);
42 ClassificationDataset test = problem.generateDataset(tests);
43
44 // kernel function
45 GaussianRbfKernel<> kernel(gamma);
46
47 // SVM kernel classifiers
49
50 // loss measuring classification errors
52
53 // There are 9 trainers for multi-class SVMs in Shark which can train with or without bias:
54 std::tuple<std::string,McSvm,bool> machines[18] ={
55 std::make_tuple("OVA", McSvm::OVA,false),
56 std::make_tuple("CS", McSvm::CS,false),
57 std::make_tuple("WW",McSvm::WW,false),
58 std::make_tuple("LLW",McSvm::LLW,false),
59 std::make_tuple("ADM",McSvm::ADM,false),
60 std::make_tuple("ATS",McSvm::ATS,false),
61 std::make_tuple("ATM",McSvm::ATM,false),
62 std::make_tuple("MMR",McSvm::MMR,false),
63 std::make_tuple("ReinforcedSvm",McSvm::ReinforcedSvm,false),
64 std::make_tuple("OVA", McSvm::OVA,true),
65 std::make_tuple("CS", McSvm::CS,true),
66 std::make_tuple("WW",McSvm::WW,true),
67 std::make_tuple("LLW",McSvm::LLW,true),
68 std::make_tuple("ADM",McSvm::ADM,true),
69 std::make_tuple("ATS",McSvm::ATS,true),
70 std::make_tuple("ATM",McSvm::ATM,true),
71 std::make_tuple("MMR",McSvm::MMR,true),
72 std::make_tuple("ReinforcedSvm",McSvm::ReinforcedSvm,true)
73 };
74
75 std::printf("SHARK multi-class SVM example - training 18 machines:\n");
76 for (int i=0; i<18; i++)
77 {
78 CSvmTrainer<RealVector> trainer(&kernel, C, std::get<2>(machines[i]));
79 trainer.setMcSvmType(std::get<1>(machines[i]));
80 trainer.train(svm, training);
81 Data<unsigned int> output = svm(training.inputs());
82 double train_error = loss.eval(training.labels(), output);
83 output = svm(test.inputs());
84 double test_error = loss.eval(test.labels(), output);
85
86 std::cout<<std::get<0>(machines[i])<<(trainer.trainOffset()? " w bias ":" w/o bias");
87 std::cout<<"\ttraining error="<<train_error;
88 std::cout<<"\ttest error="<<test_error;
89 std::cout<<"\titerations="<<trainer.solutionProperties().iterations;
90 std::cout<<"\ttime="<<trainer.solutionProperties().seconds<<std::endl;
91
92
93 }
94}