Shark machine learning library
Installation
Tutorials
Benchmarks
Documentation
Quick references
Class list
Global functions
examples
Benchmark
shark
nearest_neighbours.cpp
Go to the documentation of this file.
1
#include <
shark/Data/SparseData.h
>
2
#include <
shark/ObjectiveFunctions/Loss/ZeroOneLoss.h
>
3
#include <shark/Models/NearestNeighborClassifier.h>
4
#include <
shark/Algorithms/NearestNeighbors/TreeNearestNeighbors.h
>
5
#include <
shark/Algorithms/NearestNeighbors/SimpleNearestNeighbors.h
>
6
#include <
shark/Models/Trees/KDTree.h
>
7
#include <
shark/Models/Kernels/LinearKernel.h
>
8
9
#include <
shark/Core/Timer.h
>
10
#include <iostream>
11
using namespace
shark
;
12
using namespace
std;
13
14
int
main
(
int
argc,
char
**argv) {
15
LabeledData<RealVector,unsigned int>
data;
16
importSparseData
(data,
"cod-rna"
,0,8192);
17
18
LabeledData<RealVector,unsigned int>
mnist;
19
importSparseData
(mnist,
"mnist"
,0,8192);
20
//~ {
21
//~ Timer time;
22
//~ KDTree<RealVector> kdtree(data.inputs());
23
//~ TreeNearestNeighbors<RealVector,unsigned int> algorithmKD(data,&kdtree);
24
//~ NearestNeighborClassifier<RealVector> model(&algorithmKD, 10);
25
//~ ZeroOneLoss<> loss;
26
//~ double error = loss(data.labels(),model(data.inputs()));
27
//~ double time_taken = time.stop();
28
29
//~ cout << "kdtree: "<< time_taken <<" "<< error<<std::endl;
30
//~ }
31
32
{
33
Timer
time;
34
LinearKernel<RealVector>
euclideanKernel;
35
SimpleNearestNeighbors<RealVector,unsigned int>
simpleAlgorithm(data,&euclideanKernel);
36
NearestNeighborClassifier<RealVector> model(&simpleAlgorithm, 10);
37
ZeroOneLoss<>
loss;
38
double
error = loss(data.
labels
(),model(data.
inputs
()));
39
double
time_taken = time.
stop
();
40
41
cout <<
"brute-force: "
<< time_taken <<
" "
<< error<<std::endl;
42
}
43
44
{
45
Timer
time;
46
LinearKernel<RealVector>
euclideanKernel;
47
SimpleNearestNeighbors<RealVector,unsigned int>
simpleAlgorithm(mnist,&euclideanKernel);
48
NearestNeighborClassifier<RealVector> model(&simpleAlgorithm, 10);
49
ZeroOneLoss<>
loss;
50
double
error = loss(mnist.
labels
(),model(mnist.
inputs
()));
51
double
time_taken = time.
stop
();
52
53
cout <<
"brute-force-mnist: "
<< time_taken <<
" "
<< error<<std::endl;
54
}
55
56
}