McSvmLinear.cpp
Go to the documentation of this file.
1
7
8using namespace shark;
9
10
11double const noise = 1.0;
12typedef CompressedRealVector VectorType;
13typedef CompressedRealMatrix MatrixType;
14
15
16// data generating distribution for our toy
17// multi-category classification problem
18/// @cond EXAMPLE_SYMBOLS
19class Problem : public LabeledDataDistribution<VectorType, unsigned int>
20{
21public:
22 void draw(VectorType& input, unsigned int& label)const
23 {
25 input.resize(1000002);
26 input.set_element(input.end(), 1000000, noise * random::gauss(random::globalRng) + 3.0 * std::cos((double)label));
27 input.set_element(input.end(), 1000001, noise * random::gauss(random::globalRng) + 3.0 * std::sin((double)label));
28 }
29};
30/// @endcond
31
32
33int main(int argc, char** argv)
34{
35 if (argc != 4)
36 {
37 std::cout << "required parameters: ell C epsilon" << std::endl;
38 return 1;
39 }
40
41 // experiment settings
42 unsigned int ell = std::atoi(argv[1]);
43 double C = std::atof(argv[2]);
44 double epsilon = std::atof(argv[3]);
45 unsigned int tests = 10000;
46 std::cout << "ell=" << ell << std::endl;
47 std::cout << "C=" << C << std::endl;
48 std::cout << "epsilon=" << epsilon << std::endl;
49
50 // generate a very simple dataset with a little noise
51 Problem problem;
52 LabeledData<VectorType, unsigned int> training = problem.generateDataset(ell);
53 LabeledData<VectorType, unsigned int> test = problem.generateDataset(tests);
54
55 // define the model
57
58 // train the machine
59 std::cout << "machine training ..." << std::endl;
60 LinearCSvmTrainer<VectorType> trainer(C, epsilon);
61 trainer.setMcSvmType(McSvm::OVA);
62 trainer.train(svm, training);
63 std::cout << "done." << std::endl;
64
65 // loss measuring classification errors
67
68 Data<unsigned int> output = svm(training.inputs());
69 double train_error = loss.eval(training.labels(), output);
70 output = svm(test.inputs());
71 double test_error = loss.eval(test.labels(), output);
72 std::cout <<"training error= "<< train_error <<" test error= "<< test_error<<std::endl;
73}