Shark machine learning library
Installation
Tutorials
Benchmarks
Documentation
Quick references
Class list
Global functions
examples
Supervised
LDATutorial.cpp
Go to the documentation of this file.
1
//===========================================================================
2
/*!
3
*
4
*
5
* \brief Linear Discriminant Analysis Tutorial Sample Code
6
*
7
*
8
*
9
* \author C. Igel
10
* \date 2011
11
*
12
*
13
* \par Copyright 1995-2017 Shark Development Team
14
*
15
* <BR><HR>
16
* This file is part of Shark.
17
* <https://shark-ml.github.io/Shark/>
18
*
19
* Shark is free software: you can redistribute it and/or modify
20
* it under the terms of the GNU Lesser General Public License as published
21
* by the Free Software Foundation, either version 3 of the License, or
22
* (at your option) any later version.
23
*
24
* Shark is distributed in the hope that it will be useful,
25
* but WITHOUT ANY WARRANTY; without even the implied warranty of
26
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27
* GNU Lesser General Public License for more details.
28
*
29
* You should have received a copy of the GNU Lesser General Public License
30
* along with Shark. If not, see <http://www.gnu.org/licenses/>.
31
*
32
*/
33
//===========================================================================
34
35
#include <
shark/Data/Csv.h
>
36
#include <
shark/ObjectiveFunctions/Loss/ZeroOneLoss.h
>
37
38
#include <
shark/Algorithms/Trainers/LDA.h
>
39
40
#include <iostream>
41
42
using namespace
shark
;
43
using namespace
std;
44
45
int
main
(
int
argc,
char
**argv) {
46
if
(argc < 2) {
47
cerr <<
"usage: "
<< argv[0] <<
" (filename)"
<< endl;
48
exit(EXIT_FAILURE);
49
}
50
// read data
51
ClassificationDataset
data;
52
try
{
53
importCSV
(data, argv[1],
LAST_COLUMN
,
' '
);
54
}
55
catch
(...) {
56
cerr <<
"unable to read data from file "
<< argv[1] << endl;
57
exit(EXIT_FAILURE);
58
}
59
60
cout <<
"overall number of data points: "
<< data.
numberOfElements
() <<
" "
61
<<
"number of classes: "
<<
numberOfClasses
(data) <<
" "
62
<<
"input dimension: "
<<
inputDimension
(data) << endl;
63
64
// split data into training and test set
65
ClassificationDataset
dataTest =
splitAtElement
(data, .5 * data.
numberOfElements
() );
66
cout <<
"training data points: "
<< data.
numberOfElements
() << endl;
67
cout <<
"test data points: "
<< dataTest.
numberOfElements
() << endl;
68
69
// define learning algorithm
70
LDA
ldaTrainer;
71
72
// define linear model
73
LinearClassifier<>
lda;
74
75
// train model
76
ldaTrainer.
train
(lda, data);
77
78
// evaluate classifier
79
Data<unsigned int>
prediction;
80
ZeroOneLoss<unsigned int>
loss;
81
82
prediction = lda(data.
inputs
());
83
cout <<
"LDA on training set accuracy: "
<< 1. - loss(data.
labels
(), prediction) << endl;
84
prediction = lda(dataTest.
inputs
());
85
cout <<
"LDA on test set accuracy: "
<< 1. - loss(dataTest.
labels
(), prediction) << endl;
86
}