Shark machine learning library
Installation
Tutorials
Benchmarks
Documentation
Quick references
Class list
Global functions
examples
Unsupervised
KMeansTutorial.cpp
Go to the documentation of this file.
1
//===========================================================================
2
/*!
3
*
4
*
5
* \brief k-means Clustering Tutorial Sample Code, requires the data
6
* set faithful.csv
7
*
8
*
9
*
10
* \author C. Igel
11
* \date 2011
12
*
13
*
14
* \par Copyright 1995-2017 Shark Development Team
15
*
16
* <BR><HR>
17
* This file is part of Shark.
18
* <https://shark-ml.github.io/Shark/>
19
*
20
* Shark is free software: you can redistribute it and/or modify
21
* it under the terms of the GNU Lesser General Public License as published
22
* by the Free Software Foundation, either version 3 of the License, or
23
* (at your option) any later version.
24
*
25
* Shark is distributed in the hope that it will be useful,
26
* but WITHOUT ANY WARRANTY; without even the implied warranty of
27
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
28
* GNU Lesser General Public License for more details.
29
*
30
* You should have received a copy of the GNU Lesser General Public License
31
* along with Shark. If not, see <http://www.gnu.org/licenses/>.
32
*
33
*/
34
//===========================================================================
35
36
#include <
shark/Data/Csv.h
>
//load the csv file
37
#include <
shark/Algorithms/Trainers/NormalizeComponentsUnitVariance.h
>
//normalize
38
39
#include <
shark/Algorithms/KMeans.h
>
//k-means algorithm
40
#include <
shark/Models/Clustering/HardClusteringModel.h
>
//model performing hard clustering of points
41
42
#include <iostream>
43
44
using namespace
shark
;
45
using namespace
std;
46
47
int
main
(
int
argc,
char
**argv) {
48
if
(argc < 2) {
49
cerr <<
"usage: "
<< argv[0] <<
" (filename)"
<< endl;
50
exit(EXIT_FAILURE);
51
}
52
// read data
53
UnlabeledData<RealVector>
data;
54
try
{
55
importCSV
(data, argv[1],
' '
);
56
}
57
catch
(...) {
58
cerr <<
"unable to read data from file "
<< argv[1] << endl;
59
exit(EXIT_FAILURE);
60
}
61
std::size_t elements = data.
numberOfElements
();
62
63
// write statistics of input data
64
cout <<
"number of data points: "
<< elements <<
" dimensions: "
<<
dataDimension
(data) << endl;
65
66
// normalize data
67
Normalizer<>
normalizer;
68
NormalizeComponentsUnitVariance<>
normalizingTrainer(
true
);
//zero mean
69
normalizingTrainer.
train
(normalizer, data);
70
data = normalizer(data);
71
72
// compute centroids using k-means clustering
73
Centroids
centroids;
74
size_t
iterations =
kMeans
(data, 2, centroids);
75
// report number of iterations by the clustering algorithm
76
cout <<
"iterations: "
<< iterations << endl;
77
78
// write cluster centers/centroids
79
Data<RealVector>
const
& c = centroids.
centroids
();
80
cout<<c<<std::endl;
81
82
// cluster data
83
HardClusteringModel<RealVector>
model(¢roids);
84
Data<unsigned>
clusters = model(data);
85
86
// write results to files
87
ofstream c1(
"cl1.csv"
);
88
ofstream c2(
"cl2.csv"
);
89
ofstream cc(
"clc.csv"
);
90
for
(std::size_t i=0; i != elements; i++) {
91
if
(clusters.
element
(i))
92
c1 << data.
element
(i)(0) <<
" "
<< data.
element
(i)(1) << endl;
93
else
94
c2 << data.
element
(i)(0) <<
" "
<< data.
element
(i)(1) << endl;
95
}
96
cc << c.
element
(0)(0) <<
" "
<< c.
element
(0)(1) << endl;
97
cc << c.
element
(1)(0) <<
" "
<< c.
element
(1)(1) << endl;
98
}