MultiTaskSvm.cpp
Go to the documentation of this file.
1
7
8using namespace shark;
9using namespace std;
10
11
12// RealVector input with task index
14
15
16// Multi-task problem with up to three tasks.
17class MultiTaskProblem : public LabeledDataDistribution<InputType, unsigned int>
18{
19public:
20 MultiTaskProblem()
21 {
22 m_task[0] = true;
23 m_task[1] = true;
24 m_task[2] = true;
25 }
26
27 void setTasks(bool task0, bool task1, bool task2)
28 {
29 m_task[0] = task0;
30 m_task[1] = task1;
31 m_task[2] = task2;
32 }
33
34 void draw(InputType& input, unsigned int& label) const
35 {
36 size_t taskindex = 0;
37 do {
38 taskindex = random::uni(random::globalRng, 0, 2);
39 } while (! m_task[taskindex]);
41 double x2 = 3.0 * random::gauss(random::globalRng);
42 unsigned int y = (x1 > 0.0) ? 1 : 0;
43 double alpha = 0.05 * M_PI * taskindex;
44 input.input.resize(2);
45 input.input(0) = cos(alpha) * x1 - sin(alpha) * x2;
46 input.input(1) = sin(alpha) * x1 + cos(alpha) * x2;
47 input.task = taskindex;
48 label = y;
49 }
50
51protected:
52 bool m_task[3];
53};
54
55
56int main(int argc, char** argv)
57{
58 // experiment settings
59 unsigned int ell_train = 1000; // number of training data point from tasks 0 and 1
60 unsigned int ell_test = 1000; // number of test data points from task 2
61 double C = 1.0; // regularization parameter
62 double gamma = 0.5; // kernel bandwidth parameter
63
64 // generate data
65 MultiTaskProblem problem;
66 problem.setTasks(true, true, false);
67 LabeledData<InputType, unsigned int> training = problem.generateDataset(ell_train);
68 problem.setTasks(false, false, true);
69 LabeledData<InputType, unsigned int> test = problem.generateDataset(ell_test);
70
71 // merge all inputs into a single data object
72 Data<InputType> data(ell_train + ell_test);
73 for (size_t i=0; i<ell_train; i++)
74 data.element(i) = training.inputs().element(i);
75 for (size_t i=0; i<ell_test; i++)
76 data.element(ell_train + i) = test.inputs().element(i);
77
78 // create kernel objects
79 GaussianRbfKernel<RealVector> inputKernel(gamma); // Gaussian kernel on inputs
80 GaussianTaskKernel<RealVector> taskKernel( // task similarity kernel
81 data, // all inputs with task indices, no labels
82 3, // total number of tasks
83 inputKernel, // base kernel for input similarity
84 gamma); // bandwidth for task similarity kernel
85 MultiTaskKernel<RealVector> multiTaskKernel(&inputKernel, &taskKernel);
86
87 // train the SVM
89 CSvmTrainer<InputType> trainer(&multiTaskKernel, C,false);
90 cout << "training ..." << endl;
91 trainer.train(ke, training);
92 cout << "done." << endl;
93
95 Data<RealVector> output;
96
97 // evaluate training performance
98 double trainError = loss.eval(training.labels(), ke(training.inputs()));
99 cout << "training error:\t" << trainError << endl;
100
101 // evaluate its transfer performance
102 double testError = loss.eval(test.labels(), ke(test.inputs()));
103 cout << "test error:\t" << testError << endl;
104}