KMeansTutorial.cpp
Go to the documentation of this file.
1//===========================================================================
2/*!
3 *
4 *
5 * \brief k-means Clustering Tutorial Sample Code, requires the data
6 * set faithful.csv
7 *
8 *
9 *
10 * \author C. Igel
11 * \date 2011
12 *
13 *
14 * \par Copyright 1995-2017 Shark Development Team
15 *
16 * <BR><HR>
17 * This file is part of Shark.
18 * <https://shark-ml.github.io/Shark/>
19 *
20 * Shark is free software: you can redistribute it and/or modify
21 * it under the terms of the GNU Lesser General Public License as published
22 * by the Free Software Foundation, either version 3 of the License, or
23 * (at your option) any later version.
24 *
25 * Shark is distributed in the hope that it will be useful,
26 * but WITHOUT ANY WARRANTY; without even the implied warranty of
27 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
28 * GNU Lesser General Public License for more details.
29 *
30 * You should have received a copy of the GNU Lesser General Public License
31 * along with Shark. If not, see <http://www.gnu.org/licenses/>.
32 *
33 */
34//===========================================================================
35
36#include <shark/Data/Csv.h> //load the csv file
38
39#include <shark/Algorithms/KMeans.h> //k-means algorithm
40#include <shark/Models/Clustering/HardClusteringModel.h>//model performing hard clustering of points
41
42#include <iostream>
43
44using namespace shark;
45using namespace std;
46
47int main(int argc, char **argv) {
48 if(argc < 2) {
49 cerr << "usage: " << argv[0] << " (filename)" << endl;
50 exit(EXIT_FAILURE);
51 }
52 // read data
54 try {
55 importCSV(data, argv[1], ' ');
56 }
57 catch (...) {
58 cerr << "unable to read data from file " << argv[1] << endl;
59 exit(EXIT_FAILURE);
60 }
61 std::size_t elements = data.numberOfElements();
62
63 // write statistics of input data
64 cout << "number of data points: " << elements << " dimensions: " << dataDimension(data) << endl;
65
66 // normalize data
67 Normalizer<> normalizer;
68 NormalizeComponentsUnitVariance<> normalizingTrainer(true);//zero mean
69 normalizingTrainer.train(normalizer, data);
70 data = normalizer(data);
71
72 // compute centroids using k-means clustering
73 Centroids centroids;
74 size_t iterations = kMeans(data, 2, centroids);
75 // report number of iterations by the clustering algorithm
76 cout << "iterations: " << iterations << endl;
77
78 // write cluster centers/centroids
79 Data<RealVector> const& c = centroids.centroids();
80 cout<<c<<std::endl;
81
82 // cluster data
83 HardClusteringModel<RealVector> model(&centroids);
84 Data<unsigned> clusters = model(data);
85
86 // write results to files
87 ofstream c1("cl1.csv");
88 ofstream c2("cl2.csv");
89 ofstream cc("clc.csv");
90 for(std::size_t i=0; i != elements; i++) {
91 if(clusters.element(i))
92 c1 << data.element(i)(0) << " " << data.element(i)(1) << endl;
93 else
94 c2 << data.element(i)(0) << " " << data.element(i)(1) << endl;
95 }
96 cc << c.element(0)(0) << " " << c.element(0)(1) << endl;
97 cc << c.element(1)(0) << " " << c.element(1)(1) << endl;
98}