RFClassifier.h
Go to the documentation of this file.
1//===========================================================================
2/*!
3 *
4 *
5 * \brief Random Forest Classifier.
6 *
7 *
8 *
9 * \author K. N. Hansen, O.Krause, J. Kremer
10 * \date 2011-2012
11 *
12 *
13 * \par Copyright 1995-2017 Shark Development Team
14 *
15 * <BR><HR>
16 * This file is part of Shark.
17 * <https://shark-ml.github.io/Shark/>
18 *
19 * Shark is free software: you can redistribute it and/or modify
20 * it under the terms of the GNU Lesser General Public License as published
21 * by the Free Software Foundation, either version 3 of the License, or
22 * (at your option) any later version.
23 *
24 * Shark is distributed in the hope that it will be useful,
25 * but WITHOUT ANY WARRANTY; without even the implied warranty of
26 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27 * GNU Lesser General Public License for more details.
28 *
29 * You should have received a copy of the GNU Lesser General Public License
30 * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31 *
32 */
33//===========================================================================
34
35#ifndef SHARK_MODELS_TREES_RFCLASSIFIER_H
36#define SHARK_MODELS_TREES_RFCLASSIFIER_H
37
42#include <shark/Data/DataView.h>
43
44namespace shark {
45///
46/// \brief Random Forest Classifier.
47///
48/// \par
49/// The Random Forest Classifier predicts a class label
50/// using the Random Forest algorithm as described in<br/>
51/// Random Forests. Leo Breiman. Machine Learning, 1(45), pages 5-32. Springer, 2001.<br/>
52///
53/// \par
54/// It is an ensemble learner that uses multiple decision trees built
55/// using the CART methodology. The trees are created using bagging
56/// which allows the use the out-of-bag error estimates for an approximately
57/// unbiased estimate of the test-error as well as unbiased feature-importance
58/// estimates using feature permutation.
59/// \ingroup models
60template<class LabelType>
61class RFClassifier : public Ensemble<CARTree<LabelType> >{
62private:
63 //OOB-Error for regression
64 template<class VectorType>
65 double doComputeOOBerror(
66 UIntMatrix const& oobPoints, LabeledData<VectorType, VectorType> const& data
67 ){
68 double OOBerror = 0;
69 //aquire votes for every element
71 VectorType input(inputDimension(data));
72 std::size_t elem = 0;
73 for(auto const& point: data.elements()){
74 noalias(input) = point.input;
75 mean.clear();
76 double oobWeightSum = 0;
77 for(std::size_t m = 0; m != this->numberOfModels();++m){
78 if(oobPoints(m,elem)){
79 oobWeightSum += this->weight(m);
80 noalias(mean) += this->weight(m) * this->model(m)(input);
81 }
82 }
83 mean /= oobWeightSum;
84 OOBerror += 0.5 * norm_sqr(point.label - mean);
85 ++elem;
86 }
87 OOBerror /= data.numberOfElements();
88 return OOBerror;
89 }
90
91 //OOB-Error for Classification
92 template<class VectorType>
93 double doComputeOOBerror(
94 UIntMatrix const& oobPoints, LabeledData<VectorType, unsigned int> const& data
95 ){
96 double OOBerror = 0;
97 //aquire votes for every element
98 RealVector votes(numberOfClasses(data));
99 RealVector input(inputDimension(data));
100 std::size_t elem = 0;
101 for(auto const& point: data.elements()){
102 noalias(input) = point.input;
103 votes.clear();
104 for(std::size_t m = 0; m != this->numberOfModels();++m){
105 if(oobPoints(m,elem)){
106 unsigned int label = this->model(m)(input);
107 votes(label) += this->weight(m);
108 }
109 }
110 OOBerror += (arg_max(votes) != point.label);
111 ++elem;
112 }
113 OOBerror /= data.numberOfElements();
114 return OOBerror;
115 }
116
117 //loss for regression
118 double loss(RealMatrix const& labels, RealMatrix const& predictions) const{
120 return loss.eval(labels, predictions);
121 }
122 //loss for classification
123 double loss(UIntVector const& labels, UIntVector const& predictions) const{
125 return loss.eval(labels, predictions);
126 }
127
128public:
129
130 /// \brief From INameable: return the class name.
131 std::string name() const
132 { return "RFClassifier"; }
133
134
135 /// \brief Returns the computed out-of-bag-error of the forest
136 double OOBerror() const {
137 return m_OOBerror;
138 }
139
140 /// \brief Returns the computed feature importances of the forest
141 RealVector const& featureImportances()const{
142 return m_featureImportances;
143 }
144
145 /// \brief Counts how often attributes are used
146 UIntVector countAttributes() const {
147 std::size_t n = this->numberOfModels();
148 if(!n) return UIntVector();
149 UIntVector r = this->model(0).countAttributes();
150 for(std::size_t i=1; i< n; i++ ) {
151 noalias(r) += this->model(i).countAttributes();
152 }
153 return r;
154 }
155
156 /// Compute oob error, given an oob dataset
157 void computeOOBerror(std::vector<std::vector<std::size_t> > const& oobIndices, LabeledData<RealVector, LabelType> const& data){
158 UIntMatrix oobMatrix(oobIndices.size(), data.numberOfElements(),0);
159 for(std::size_t i = 0; i != oobMatrix.size1(); ++i){
160 for(auto index: oobIndices[i])
161 oobMatrix(i,index) = 1;
162 }
163 m_OOBerror = this->doComputeOOBerror(oobMatrix,data);
164 }
165
166 /// Compute feature importances, given an oob dataset
167 ///
168 /// For each tree, extracts the out-of-bag-samples indicated by oobIndices. The feature importance is defined
169 /// as the average change of loss (Squared loss or accuracy depending on label type) when the feature is permuted across the oob samples of a tree.
170 void computeFeatureImportances(std::vector<std::vector<std::size_t> > const& oobIndices, LabeledData<RealVector, LabelType> const& data, random::rng_type& rng){
171 std::size_t inputs = inputDimension(data);
172 m_featureImportances.resize(inputs);
174
175 for(std::size_t m = 0; m != this->numberOfModels();++m){
176 auto batch = subBatch(view, oobIndices[m]);
177 double errorBefore = this->loss(batch.label,this->model(m)(batch.input));
178 for(std::size_t i=0; i!=inputs;++i) {
179 RealVector vOld= column(batch.input,i);
180 RealVector v = vOld;
181 std::shuffle(v.begin(), v.end(), rng);
182 noalias(column(batch.input,i)) = v;
183 double errorAfter = this->loss(batch.label,this->model(m)(batch.input));
184 noalias(column(batch.input,i)) = vOld;
185 m_featureImportances(i) += this->weight(m) * (errorAfter - errorBefore) / batch.size();
186 }
187 }
188 m_featureImportances /= this->sumOfWeights();
189 }
190
191private:
192 double m_OOBerror; ///< oob error for the forest
193 RealVector m_featureImportances; ///< feature importances for the forest
194
195};
196
197
198}
199#endif