Shark machine learning library
Installation
Tutorials
Benchmarks
Documentation
Quick references
Class list
Global functions
examples
Supervised
KNNTutorial.cpp
Go to the documentation of this file.
1
//===========================================================================
2
/*!
3
*
4
*
5
* \brief Nearest Neighbor Tutorial Sample Code
6
*
7
*
8
*
9
*
10
* \author C. Igel
11
* \date 2011
12
*
13
*
14
* \par Copyright 1995-2017 Shark Development Team
15
*
16
* <BR><HR>
17
* This file is part of Shark.
18
* <https://shark-ml.github.io/Shark/>
19
*
20
* Shark is free software: you can redistribute it and/or modify
21
* it under the terms of the GNU Lesser General Public License as published
22
* by the Free Software Foundation, either version 3 of the License, or
23
* (at your option) any later version.
24
*
25
* Shark is distributed in the hope that it will be useful,
26
* but WITHOUT ANY WARRANTY; without even the implied warranty of
27
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
28
* GNU Lesser General Public License for more details.
29
*
30
* You should have received a copy of the GNU Lesser General Public License
31
* along with Shark. If not, see <http://www.gnu.org/licenses/>.
32
*
33
*/
34
//===========================================================================
35
36
#include <
shark/Data/Csv.h
>
37
#include <
shark/Models/NearestNeighborModel.h
>
38
#include <
shark/Algorithms/NearestNeighbors/TreeNearestNeighbors.h
>
39
#include <
shark/Models/Trees/KDTree.h
>
40
#include <
shark/ObjectiveFunctions/Loss/ZeroOneLoss.h
>
41
#include <
shark/Data/DataView.h
>
42
#include <iostream>
43
44
using namespace
shark
;
45
using namespace
std;
46
47
int
main
(
int
argc,
char
**argv) {
48
if
(argc < 2) {
49
cerr <<
"usage: "
<< argv[0] <<
" (filename)"
<< endl;
50
exit(EXIT_FAILURE);
51
}
52
// read data
53
ClassificationDataset
data;
54
try
{
55
importCSV
(data, argv[1],
LAST_COLUMN
,
' '
);
56
}
57
catch
(...) {
58
cerr <<
"unable to read data from file "
<< argv[1] << endl;
59
exit(EXIT_FAILURE);
60
}
61
62
cout <<
"number of data points: "
<< data.
numberOfElements
()
63
<<
" number of classes: "
<<
numberOfClasses
(data)
64
<<
" input dimension: "
<<
inputDimension
(data) << endl;
65
66
// split data into training and test set
67
ClassificationDataset
dataTest =
splitAtElement
(data,
static_cast<
std::size_t
>
(.5 * data.
numberOfElements
()));
68
cout <<
"training data points: "
<< data.
numberOfElements
() << endl;
69
cout <<
"test data points: "
<< dataTest.
numberOfElements
() << endl;
70
71
//create a binary search tree and initialize the search algorithm - a fast tree search
72
KDTree<RealVector>
tree(data.
inputs
());
73
TreeNearestNeighbors<RealVector,unsigned int>
algorithm(data,&tree);
74
//instantiate the classifier
75
const
unsigned
int
K = 1;
// number of neighbors for kNN
76
NearestNeighborModel<RealVector, unsigned int>
KNN(&algorithm,K);
77
78
// evaluate classifier
79
ZeroOneLoss<unsigned int>
loss;
80
Data<unsigned int>
prediction = KNN(data.
inputs
());
81
cout << K <<
"-KNN on training set accuracy: "
<< 1. - loss.
eval
(data.
labels
(), prediction) << endl;
82
prediction = KNN(dataTest.
inputs
());
83
cout << K <<
"-KNN on test set accuracy: "
<< 1. - loss.
eval
(dataTest.
labels
(), prediction) << endl;
84
}