LDATutorial.cpp
Go to the documentation of this file.
1//===========================================================================
2/*!
3 *
4 *
5 * \brief Linear Discriminant Analysis Tutorial Sample Code
6 *
7 *
8 *
9 * \author C. Igel
10 * \date 2011
11 *
12 *
13 * \par Copyright 1995-2017 Shark Development Team
14 *
15 * <BR><HR>
16 * This file is part of Shark.
17 * <https://shark-ml.github.io/Shark/>
18 *
19 * Shark is free software: you can redistribute it and/or modify
20 * it under the terms of the GNU Lesser General Public License as published
21 * by the Free Software Foundation, either version 3 of the License, or
22 * (at your option) any later version.
23 *
24 * Shark is distributed in the hope that it will be useful,
25 * but WITHOUT ANY WARRANTY; without even the implied warranty of
26 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27 * GNU Lesser General Public License for more details.
28 *
29 * You should have received a copy of the GNU Lesser General Public License
30 * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31 *
32 */
33//===========================================================================
34
35#include <shark/Data/Csv.h>
37
39
40#include <iostream>
41
42using namespace shark;
43using namespace std;
44
45int main(int argc, char **argv) {
46 if(argc < 2) {
47 cerr << "usage: " << argv[0] << " (filename)" << endl;
48 exit(EXIT_FAILURE);
49 }
50 // read data
52 try {
53 importCSV(data, argv[1], LAST_COLUMN, ' ');
54 }
55 catch (...) {
56 cerr << "unable to read data from file " << argv[1] << endl;
57 exit(EXIT_FAILURE);
58 }
59
60 cout << "overall number of data points: " << data.numberOfElements() << " "
61 << "number of classes: " << numberOfClasses(data) << " "
62 << "input dimension: " << inputDimension(data) << endl;
63
64 // split data into training and test set
65 ClassificationDataset dataTest = splitAtElement(data, .5 * data.numberOfElements() );
66 cout << "training data points: " << data.numberOfElements() << endl;
67 cout << "test data points: " << dataTest.numberOfElements() << endl;
68
69 // define learning algorithm
70 LDA ldaTrainer;
71
72 // define linear model
74
75 // train model
76 ldaTrainer.train(lda, data);
77
78 // evaluate classifier
79 Data<unsigned int> prediction;
81
82 prediction = lda(data.inputs());
83 cout << "LDA on training set accuracy: " << 1. - loss(data.labels(), prediction) << endl;
84 prediction = lda(dataTest.inputs());
85 cout << "LDA on test set accuracy: " << 1. - loss(dataTest.labels(), prediction) << endl;
86}