Shark machine learning library
Installation
Tutorials
Benchmarks
Documentation
Quick references
Class list
Global functions
examples
Supervised
RFTutorial.cpp
Go to the documentation of this file.
1
//===========================================================================
2
/*!
3
*
4
*
5
* \brief Linear Regression Tutorial Sample Code
6
*
7
* This file is part of the "Random Forest" tutorial.
8
* It requires some toy sample data that comes with the library.
9
*
10
*
11
*
12
* \author K. N. Hansen
13
* \date 2012
14
*
15
*
16
* \par Copyright 1995-2017 Shark Development Team
17
*
18
* <BR><HR>
19
* This file is part of Shark.
20
* <https://shark-ml.github.io/Shark/>
21
*
22
* Shark is free software: you can redistribute it and/or modify
23
* it under the terms of the GNU Lesser General Public License as published
24
* by the Free Software Foundation, either version 3 of the License, or
25
* (at your option) any later version.
26
*
27
* Shark is distributed in the hope that it will be useful,
28
* but WITHOUT ANY WARRANTY; without even the implied warranty of
29
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
30
* GNU Lesser General Public License for more details.
31
*
32
* You should have received a copy of the GNU Lesser General Public License
33
* along with Shark. If not, see <http://www.gnu.org/licenses/>.
34
*
35
*/
36
//===========================================================================
37
38
#include <
shark/Data/Csv.h
>
//importing the file
39
#include <
shark/Algorithms/Trainers/RFTrainer.h
>
//the random forest trainer
40
#include <
shark/ObjectiveFunctions/Loss/ZeroOneLoss.h
>
//zero one loss for evaluation
41
42
#include <iostream>
43
44
using namespace
std;
45
using namespace
shark
;
46
47
48
int
main
() {
49
50
//*****************LOAD AND PREPARE DATA***********************//
51
//Read Sample data set C.csv
52
53
ClassificationDataset
data;
54
importCSV
(data,
"data/C.csv"
,
LAST_COLUMN
,
' '
);
55
56
//Split the dataset into a training and a test dataset
57
ClassificationDataset
dataTest =
splitAtElement
(data,311);
58
59
cout <<
"Training set - number of data points: "
<< data.
numberOfElements
()
60
<<
" number of classes: "
<<
numberOfClasses
(data)
61
<<
" input dimension: "
<<
inputDimension
(data) << endl;
62
63
cout <<
"Test set - number of data points: "
<< dataTest.
numberOfElements
()
64
<<
" number of classes: "
<<
numberOfClasses
(dataTest)
65
<<
" input dimension: "
<<
inputDimension
(dataTest) << endl;
66
67
//Generate a random forest
68
RFTrainer<unsigned int>
trainer;
69
RFClassifier<unsigned int>
model;
70
trainer.train(model, data);
71
72
// evaluate Random Forest classifier
73
ZeroOneLoss<>
loss;
74
auto
prediction = model(data.
inputs
());
75
cout <<
"Random Forest on training set accuracy: "
<< 1. - loss.
eval
(data.
labels
(), prediction) << endl;
76
77
prediction = model(dataTest.
inputs
());
78
cout <<
"Random Forest on test set accuracy: "
<< 1. - loss.
eval
(dataTest.
labels
(), prediction) << endl;
79
}