Shark machine learning library
Installation
Tutorials
Benchmarks
Documentation
Quick references
Class list
Global functions
examples
Supervised
KNNCrossValidationTutorial.cpp
Go to the documentation of this file.
1
//===========================================================================
2
/*!
3
*
4
*
5
* \brief Nearest Neighbor Tutorial Sample Code
6
*
7
*
8
*
9
*
10
* \author C. Igel
11
* \date 2011
12
*
13
*
14
* \par Copyright 1995-2017 Shark Development Team
15
*
16
* <BR><HR>
17
* This file is part of Shark.
18
* <https://shark-ml.github.io/Shark/>
19
*
20
* Shark is free software: you can redistribute it and/or modify
21
* it under the terms of the GNU Lesser General Public License as published
22
* by the Free Software Foundation, either version 3 of the License, or
23
* (at your option) any later version.
24
*
25
* Shark is distributed in the hope that it will be useful,
26
* but WITHOUT ANY WARRANTY; without even the implied warranty of
27
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
28
* GNU Lesser General Public License for more details.
29
*
30
* You should have received a copy of the GNU Lesser General Public License
31
* along with Shark. If not, see <http://www.gnu.org/licenses/>.
32
*
33
*/
34
//===========================================================================
35
36
#include <
shark/Data/Csv.h
>
37
#include <
shark/Models/NearestNeighborModel.h
>
38
#include <
shark/Algorithms/NearestNeighbors/SimpleNearestNeighbors.h
>
39
#include <
shark/Models/Kernels/LinearKernel.h
>
40
#include <
shark/ObjectiveFunctions/Loss/ZeroOneLoss.h
>
41
#include <
shark/Data/CVDatasetTools.h
>
42
#include <iostream>
43
44
using namespace
shark
;
45
using namespace
std;
46
47
int
main
(
int
argc,
char
**argv) {
48
if
(argc < 2) {
49
cerr <<
"usage: "
<< argv[0] <<
" (filename)"
<< endl;
50
exit(EXIT_FAILURE);
51
}
52
// read data
53
ClassificationDataset
data;
54
try
{
55
importCSV
(data, argv[1],
LAST_COLUMN
,
' '
);
56
}
57
catch
(...) {
58
cerr <<
"unable to read data from file "
<< argv[1] << endl;
59
exit(EXIT_FAILURE);
60
}
61
62
cout <<
"number of data points: "
<< data.
numberOfElements
()
63
<<
" number of classes: "
<<
numberOfClasses
(data)
64
<<
" input dimension: "
<<
inputDimension
(data) << endl;
65
66
// split data into training and test set
67
ClassificationDataset
dataTest =
splitAtElement
(data, .5 * data.
numberOfElements
());
68
cout <<
"training data points: "
<< data.
numberOfElements
() << endl;
69
cout <<
"test data points: "
<< dataTest.
numberOfElements
() << endl;
70
71
//create 10 CV-Folds
72
const
unsigned
int
NFolds= 10;
73
CVFolds<ClassificationDataset>
folds =
createCVSameSizeBalanced
(data, NFolds);
74
75
//we have 5 different values of k to test
76
unsigned
int
k[]={1,3,5,7,9};
77
unsigned
int
numParameters = 5;
78
79
ZeroOneLoss<unsigned int>
loss;
//loss for evaluation
80
LinearKernel<>
metric;
//linear distance measure
81
82
//find best #-neighbors using CV
83
unsigned
int
best_k = 0;
84
double
best_error = 2;
//maximum 0-1loss is 1
85
//for every parameter....
86
for
(std::size_t p = 0; p != numParameters; ++p){
87
double
error = 0;
88
//calculate CV-error
89
for
(std::size_t i = 0; i != NFolds; ++i){
90
SimpleNearestNeighbors<RealVector, unsigned int>
algorithm(folds.
training
(i), &metric);
91
NearestNeighborModel<RealVector, unsigned int>
KNN(&algorithm, k[p]);
92
error += loss(folds.
validation
(i).labels(),KNN(folds.
validation
(i).inputs()));
93
}
94
error /=NFolds;
95
//print cv-error for current parameter
96
std::cout<<k[p]<<
" "
<<error<<std::endl;
97
//if the error is better, we keep it.
98
if
(error < best_error){
99
best_k = k[p];
100
best_error = error;
101
}
102
}
103
//evaluate the best paramter found on test set using the full training set
104
SimpleNearestNeighbors<RealVector, unsigned int>
algorithm(data, &metric);
105
NearestNeighborModel<RealVector, unsigned int>
KNN(&algorithm, best_k);
106
std::cout<<
"NearestNeighbors: "
<< loss(dataTest.
labels
(),KNN(dataTest.
inputs
()))<<
'\n'
;
107
std::cout<<
"K: "
<<best_k<<std::endl;
108
}