RFTrainer.h
Go to the documentation of this file.
1//===========================================================================
2/*!
3 *
4 *
5 * \brief Random Forest Trainer
6 *
7 *
8 *
9 * \author K. N. Hansen, J. Kremer
10 * \date 2011-2012
11 *
12 *
13 * \par Copyright 1995-2017 Shark Development Team
14 *
15 * <BR><HR>
16 * This file is part of Shark.
17 * <https://shark-ml.github.io/Shark/>
18 *
19 * Shark is free software: you can redistribute it and/or modify
20 * it under the terms of the GNU Lesser General Public License as published
21 * by the Free Software Foundation, either version 3 of the License, or
22 * (at your option) any later version.
23 *
24 * Shark is distributed in the hope that it will be useful,
25 * but WITHOUT ANY WARRANTY; without even the implied warranty of
26 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27 * GNU Lesser General Public License for more details.
28 *
29 * You should have received a copy of the GNU Lesser General Public License
30 * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31 *
32 */
33//===========================================================================
34
35
36#ifndef SHARK_ALGORITHMS_TRAINERS_RFTRAINER_H
37#define SHARK_ALGORITHMS_TRAINERS_RFTRAINER_H
38
41#include <shark/Algorithms/Trainers/Impl/CART.h>
42
43#include <vector>
44#include <limits>
45
46namespace shark {
47
48/// \brief Random Forest
49///
50/// Random Forest is an ensemble learner, that builds multiple binary decision trees.
51/// The trees are built using a variant of the CART methodology
52///
53/// Typically 100+ trees are built, and classification/regression is done by combining
54/// the results generated by each tree. Typically the a majority vote is used in the
55/// classification case, and the mean is used in the regression case
56///
57/// Each tree is built based on a random subset of the total dataset. Furthermore
58/// at each split, only a random subset of the attributes are investigated for
59/// the best split
60///
61/// The node impurity is measured by the Gini criteria in the classification
62/// case, and the total sum of squared errors in the regression case
63///
64/// After growing a maximum sized tree, the tree is added to the ensemble
65/// without pruning.
66///
67/// For detailed information about Random Forest, see Random Forest
68/// by L. Breiman et al. 2001.
69/// \ingroup supervised_trainer
70
71
72template<class LabelType>
74
75template<>
76class RFTrainer<unsigned int>
77: public AbstractWeightedTrainer<RFClassifier<unsigned int> >, public IParameterizable<RealVector>
78{
79public:
80 /// Construct and compute feature importances when training or not
81 RFTrainer(bool computeFeatureImportances = false, bool computeOOBerror = false){
82 m_computeFeatureImportances = computeFeatureImportances;
83 m_computeOOBerror = computeOOBerror;
84 m_numTrees = 100;
85 m_min_samples_leaf = 1;
86 m_min_split = 2 * m_min_samples_leaf;
87 m_max_depth = 10000;
88 m_min_impurity_split = 1e-10;
89 m_epsilon = 1e-10;
90 m_max_features = 0;
91 }
92
93 /// \brief From INameable: return the class name.
94 std::string name() const
95 { return "RFTrainer"; }
96
97 /// Set the number of random attributes to investigate at each node.
98 ///
99 /// Defualt is 0 which is translated to sqrt(inputDim(data)) during training
100 void setMTry(std::size_t mtry) { m_max_features = mtry; }
101
102 /// Set the number of trees to grow. (default 100)
103 void setNTrees(std::size_t numTrees) {m_numTrees = numTrees;}
104
105 /// Set Minimum number of samples that is split (default 2)
106 void setMinSplit(std::size_t numSamples) {m_min_split = numSamples;}
107
108 /// Set Maximum depth of the tree (default 10000)
109 void setMaxDepth(std::size_t maxDepth) {m_max_depth = maxDepth;}
110
111 /// Controls when a node is considered pure. If set to 1, a node is pure
112 /// when it only consists of a single node.(default 1)
113 void setNodeSize(std::size_t nodeSize) { m_min_samples_leaf = nodeSize; }
114
115 /// The minimum impurity below which a a node is considere pure (default 1.e-10)
116 void minImpurity(double impurity) {m_min_impurity_split = impurity;}
117
118 /// The minimum dtsnace of features to be considered different (detault 1.e-10)
119 void epsilon(double distance) {m_epsilon = distance;}
120
121 /// Return the parameter vector.
122 RealVector parameterVector() const{return RealVector();}
123
124 /// Set the parameter vector.
125 void setParameterVector(RealVector const& newParameters){
126 SHARK_ASSERT(newParameters.size() == 0);
127 }
128
129
130 /// Train a random forest for classification.
133 model.clearModels();
134
135 //setup treebuilder
136 CART::TreeBuilder<unsigned int,CART::ClassificationCriterion> builder;
137 builder.m_min_samples_leaf = m_min_samples_leaf;
138 builder.m_min_split = m_min_split;
139 builder.m_max_depth = m_max_depth;
140 builder.m_min_impurity_split = m_min_impurity_split;
141 builder.m_epsilon = m_epsilon;
142 builder.m_max_features = m_max_features? m_max_features: std::sqrt(inputDimension(dataset));
143
144 //copy data into single batch for easier lookup
145 blas::matrix<double, blas::column_major> data_train = createBatch<RealVector>(dataset.inputs().elements().begin(),dataset.inputs().elements().end());
146 auto labels_train = createBatch<LabelType>(dataset.labels().elements().begin(),dataset.labels().elements().end());
147 auto weights_train = createBatch<double>(dataset.weights().elements().begin(),dataset.weights().elements().end());
148
149 //Setup seeds for the rng in the different threads
150 std::vector<unsigned int> seeds(m_numTrees);
151 for (auto& seed: seeds) {
152 seed = random::discrete(random::globalRng, 0u,std::numeric_limits<unsigned int>::max());
153 }
154
155 std::vector<std::vector<std::size_t> > complements;
156
157 //Generate trees
158 SHARK_PARALLEL_FOR(int t = 0; t < m_numTrees; ++t){
159 random::rng_type rng(seeds[t]);
160
161 //Setup data for this tree
162 CART::Bootstrap<blas::matrix<double, blas::column_major>, UIntVector> bootstrap(rng, data_train,labels_train, weights_train);
163 auto const& tree = builder.buildTree(rng, bootstrap);
164
166 model.addModel(tree);
167 complements.push_back(std::move(bootstrap.complement));
168 }
169 }
170
171 if(m_computeOOBerror)
172 model.computeOOBerror(complements, dataset.data());
173
174 if(m_computeFeatureImportances)
175 model.computeFeatureImportances(complements,dataset.data(), random::globalRng);
176 }
177
178
179private:
180 bool m_computeFeatureImportances;///< set true if the feature importances should be computed
181 bool m_computeOOBerror;///< set true if OOB error should be computed
182
183 long m_numTrees; ///< number of trees in the forest
184 std::size_t m_max_features;///< number of attributes to randomly test at each inner node
185 std::size_t m_min_samples_leaf; ///< minimum number of samples in a leaf node
186 std::size_t m_min_split; ///< minimum number of samples to be considered a split
187 std::size_t m_max_depth;///< maximum depth of the tree
188 double m_epsilon;///< Minimum difference between two values to be considered different
189 double m_min_impurity_split;///< stops splitting when the impority is below a threshold
190};
191
192
193template<>
194class RFTrainer<RealVector>
195: public AbstractWeightedTrainer<RFClassifier<RealVector> >, public IParameterizable<RealVector>
196{
197public:
198 /// Construct and compute feature importances when training or not
199 RFTrainer(bool computeFeatureImportances = false, bool computeOOBerror = false){
200 m_computeFeatureImportances = computeFeatureImportances;
201 m_computeOOBerror = computeOOBerror;
202 m_numTrees = 100;
203 m_min_samples_leaf = 1;
204 m_min_split = 2 * m_min_samples_leaf;
205 m_max_depth = 10000;
206 m_min_impurity_split = 1e-10;
207 m_epsilon = 1e-10;
208 m_max_features = 0;
209 }
210
211 /// \brief From INameable: return the class name.
212 std::string name() const
213 { return "RFTrainer"; }
214
215 /// Set the number of random attributes to investigate at each node.
216 ///
217 /// Defualt is 0 which is translated to inputDim(data)/3 during training
218 void setMTry(std::size_t mtry) { m_max_features = mtry; }
219
220 /// Set the number of trees to grow. (default 100)
221 void setNTrees(std::size_t numTrees) {m_numTrees = numTrees;}
222
223 /// Set Minimum number of samples that is split (default 10)
224 void setMinSplit(std::size_t numSamples) {m_min_split = numSamples;}
225
226 /// Set Maximum depth of the tree (default 10000)
227 void setMaxDepth(std::size_t maxDepth) {m_max_depth = maxDepth;}
228
229 /// Controls when a node is considered pure. If set to 1, a node is pure
230 /// when it only consists of a single node.(default 5)
231 void setNodeSize(std::size_t nodeSize) { m_min_samples_leaf = nodeSize; }
232
233 /// The minimum impurity below which a a node is considere pure (default 1.e-10)
234 void minImpurity(double impurity) {m_min_impurity_split = impurity;}
235
236 /// The minimum dtsnace of features to be considered different (detault 1.e-10)
237 void epsilon(double distance) {m_epsilon = distance;}
238
239 /// Return the parameter vector.
240 RealVector parameterVector() const{ return RealVector();}
241
242 /// Set the parameter vector.
243 void setParameterVector(RealVector const& newParameters){
244 SHARK_ASSERT(newParameters.size() == 0);
245 }
246
247
248 /// Train a random forest for classification.
250 model.clearModels();
251 //setup treebuilder
252 CART::TreeBuilder<RealVector,CART::MSECriterion> builder;
253 builder.m_min_samples_leaf = m_min_samples_leaf;
254 builder.m_min_split = m_min_split;
255 builder.m_max_depth = m_max_depth;
256 builder.m_min_impurity_split = m_min_impurity_split;
257 builder.m_epsilon = m_epsilon;
258 builder.m_max_features = m_max_features? m_max_features: inputDimension(dataset)/3;
259 //copy data into single batch for easier lookup
260 blas::matrix<double, blas::column_major> data_train = createBatch<RealVector>(dataset.inputs().elements().begin(),dataset.inputs().elements().end());
261 auto labels_train = createBatch<LabelType>(dataset.labels().elements().begin(),dataset.labels().elements().end());
262 auto weights_train = createBatch<double>(dataset.weights().elements().begin(),dataset.weights().elements().end());
263
264 //Setup seeds for the rng in the different threads
265 std::vector<unsigned int> seeds(m_numTrees);
266 for (auto& seed: seeds) {
267 seed = random::discrete(random::globalRng, 0u,std::numeric_limits<unsigned int>::max());
268 }
269
270 std::vector<std::vector<std::size_t> > complements;
271
272 //Generate trees
273 SHARK_PARALLEL_FOR(int t = 0; t < m_numTrees; ++t){
274 random::rng_type rng{seeds[t]};
275
276 //Setup data for this tree
277 CART::Bootstrap<blas::matrix<double, blas::column_major>, RealMatrix> bootstrap(rng, data_train,labels_train, weights_train);
278 auto const& tree = builder.buildTree(rng, bootstrap);
279
281 model.addModel(tree);
282 complements.push_back(std::move(bootstrap.complement));
283 }
284 }
285
286 if(m_computeOOBerror)
287 model.computeOOBerror(complements,dataset.data());
288
289 if(m_computeFeatureImportances)
290 model.computeFeatureImportances(complements,dataset.data(), random::globalRng);
291 }
292
293
294private:
295 bool m_computeFeatureImportances;///< set true if the feature importances should be computed
296 bool m_computeOOBerror;///< set true if OOB error should be computed
297
298 long m_numTrees; ///< number of trees in the forest
299 std::size_t m_max_features;///< number of attributes to randomly test at each inner node
300 std::size_t m_min_samples_leaf; ///< minimum number of samples in a leaf node
301 std::size_t m_min_split; ///< minimum number of samples to be considered a split
302 std::size_t m_max_depth;///< maximum depth of the tree
303 double m_epsilon;///< Minimum difference between two values to be considered different
304 double m_min_impurity_split;///< stops splitting when the impority is below a threshold
305};
306
307
308}
309#endif