81 RFTrainer(
bool computeFeatureImportances =
false,
bool computeOOBerror =
false){
82 m_computeFeatureImportances = computeFeatureImportances;
83 m_computeOOBerror = computeOOBerror;
85 m_min_samples_leaf = 1;
86 m_min_split = 2 * m_min_samples_leaf;
88 m_min_impurity_split = 1e-10;
136 CART::TreeBuilder<unsigned int,CART::ClassificationCriterion> builder;
137 builder.m_min_samples_leaf = m_min_samples_leaf;
138 builder.m_min_split = m_min_split;
139 builder.m_max_depth = m_max_depth;
140 builder.m_min_impurity_split = m_min_impurity_split;
141 builder.m_epsilon = m_epsilon;
142 builder.m_max_features = m_max_features? m_max_features: std::sqrt(
inputDimension(dataset));
145 blas::matrix<double, blas::column_major> data_train = createBatch<RealVector>(dataset.
inputs().
elements().begin(),dataset.
inputs().
elements().end());
147 auto weights_train = createBatch<double>(dataset.weights().elements().begin(),dataset.weights().elements().end());
150 std::vector<unsigned int> seeds(m_numTrees);
151 for (
auto& seed: seeds) {
155 std::vector<std::vector<std::size_t> > complements;
159 random::rng_type rng(seeds[t]);
162 CART::Bootstrap<blas::matrix<double, blas::column_major>, UIntVector>
bootstrap(rng, data_train,labels_train, weights_train);
163 auto const& tree = builder.buildTree(rng,
bootstrap);
167 complements.push_back(std::move(
bootstrap.complement));
171 if(m_computeOOBerror)
174 if(m_computeFeatureImportances)