Ensemble.h
Go to the documentation of this file.
1//===========================================================================
2/*!
3 *
4 *
5 * \brief Implements the Ensemble Model that can be used to merge predictions from weighted models
6 *
7 * \author O. Krause
8 * \date 2018
9 *
10 *
11 * \par Copyright 1995-2017 Shark Development Team
12 *
13 * <BR><HR>
14 * This file is part of Shark.
15 * <https://shark-ml.github.io/Shark/>
16 *
17 * Shark is free software: you can redistribute it and/or modify
18 * it under the terms of the GNU Lesser General Public License as published
19 * by the Free Software Foundation, either version 3 of the License, or
20 * (at your option) any later version.
21 *
22 * Shark is distributed in the hope that it will be useful,
23 * but WITHOUT ANY WARRANTY; without even the implied warranty of
24 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
25 * GNU Lesser General Public License for more details.
26 *
27 * You should have received a copy of the GNU Lesser General Public License
28 * along with Shark. If not, see <http://www.gnu.org/licenses/>.
29 *
30 */
31//===========================================================================
32
33#ifndef SHARK_MODELS_ENSEMBLE_H
34#define SHARK_MODELS_ENSEMBLE_H
35
38#include <type_traits>
39namespace shark {
40
41namespace detail{
42template<class BaseModelType, class VectorType>
43class EnsembleImpl: public AbstractModel<
44 typename std::remove_pointer<BaseModelType>::type::InputType,
45 VectorType,
46 typename std::remove_pointer<BaseModelType>::type::ParameterVectorType
47>{
48public:
49 typedef typename std::remove_pointer<BaseModelType>::type ModelType;
50private:
51 typedef AbstractModel<typename ModelType::InputType, VectorType, typename ModelType::ParameterVectorType> Base;
52
53 // the following functions are returning a reference to the model
54 // independent of whether a pointer to the model or the model itself
55 // is stored.
56 ModelType& derefIfPtr(ModelType& model)const{
57 return model;
58 }
59 ModelType const& derefIfPtr(ModelType const& model)const{
60 return model;
61 }
62 ModelType& derefIfPtr(ModelType* model)const{
63 return *model;
64 }
65
66
67 //implements the pooling operation which creates a vector from the model responses to the given patterns
68 template<class T> struct tag{};
69
70 template<class InputBatch, class T, class Device>
71 void pool(InputBatch const& patterns, blas::matrix<T, blas::row_major, Device>& outputs, tag<blas::vector<T, Device> >)const{
72 for(std::size_t i = 0; i != numberOfModels(); i++){
73 noalias(outputs) += weight(i) * model(i)(patterns);
74 }
75 outputs /= sumOfWeights();
76 }
77 template<class InputBatch, class OutputBatch>
78 void pool(InputBatch const& patterns, OutputBatch& outputs, tag<unsigned int>)const{
79 blas::vector<unsigned int> responses;
80 for(std::size_t i = 0; i != numberOfModels(); ++i){
81 model(i).eval(patterns, responses);
82 for(std::size_t p = 0; p != patterns.size1(); ++p){
83 outputs(p,responses(p)) += weight(i);
84 }
85 }
86 outputs /= sumOfWeights();
87 }
88
89 std::vector<BaseModelType> m_models;
90 RealVector m_weights;
91public:
92 typedef typename Base::BatchInputType BatchInputType;
93 typedef typename Base::BatchOutputType BatchOutputType;
95
96 ParameterVectorType parameterVector() const {
97 return {};
98 }
99 void setParameterVector(ParameterVectorType const& param) {
100 SHARK_ASSERT(param.size() == 0);
101 }
102
103 void addModel(BaseModelType const& model, double weight = 1.0){
104 SHARK_RUNTIME_CHECK(weight > 0, "Weights must be positive");
105 m_models.push_back(model);
106 m_weights.push_back(weight);
107 }
108
109 /// \brief Removes all models from the ensemble
110 void clearModels(){
111 m_models.clear();
112 m_weights.clear();
113 }
114
115 ModelType& model(std::size_t index){
116 return derefIfPtr(m_models[index]);
117 }
118
119 ModelType const& model(std::size_t index)const{
120 return derefIfPtr(m_models[index]);
121 }
122
123 /// \brief Returns the weight of the i-th model.
124 double const& weight(std::size_t i)const{
125 return m_weights[i];
126 }
127
128 /// \brief Returns the weight of the i-th model.
129 double& weight(std::size_t i){
130 return m_weights[i];
131 }
132
133 /// \brief Returns the total sum of weights used for averaging
134 double sumOfWeights() const{
135 return sum(m_weights);
136 }
137
138 /// \brief Returns the number of models.
139 std::size_t numberOfModels()const{
140 return m_models.size();
141 }
142
143 ///\brief Returns the expected shape of the input
144 Shape inputShape() const{
145 return m_models.empty() ? Shape(): model(0).inputShape();
146 }
147 ///\brief Returns the shape of the output
148 Shape outputShape() const{
149 return m_models.empty() ? Shape(): model(0).outputShape();
150 }
151
152 using Base::eval;
153 void eval(BatchInputType const& patterns, BatchOutputType& outputs)const{
154 outputs.resize(patterns.size1(), outputShape().numElements());
155 outputs.clear();
156 pool(patterns,outputs, tag<typename ModelType::OutputType>());
157 }
158 void eval(BatchInputType const& patterns, BatchOutputType& outputs, State&)const{
159 eval(patterns,outputs);
160 }
161
162 void read(InArchive& archive){
163 std::size_t numModels;
164 archive >> numModels;
165 m_models.resize(numModels);
166 for(std::size_t i = 0; i != numModels; ++i){
167 archive >> model(i);
168 }
169 archive >> m_weights;
170 }
171 void write(OutArchive& archive)const{
172 std::size_t numModels = m_models.size();
173 archive << numModels;
174 for(std::size_t i = 0; i != numModels; ++i){
175 archive << model(i);
176 }
177 archive << m_weights;
178 }
179};
180
181//the following creates an ensemble base depending on whether the ensemble should be a classifier or not.
182
183template<class ModelType, class OutputType>
184struct EnsembleBase : public detail::EnsembleImpl<ModelType, OutputType>{
185private:
186 typedef typename std::remove_pointer<ModelType>::type::OutputType ModelOutputType;
187protected:
188 detail::EnsembleImpl<ModelType, OutputType>& impl(){ return *this;};
189 detail::EnsembleImpl<ModelType, OutputType> const& impl() const{ return *this;};
190};
191
192//if the output type is unsigned int, this is a classifier
193template<class BaseModelType>
194struct EnsembleBase<BaseModelType, unsigned int>
195: public Classifier<detail::EnsembleImpl<BaseModelType, typename std::remove_pointer<BaseModelType>::type::ParameterVectorType> >{
196private:
197 typedef typename std::remove_pointer<BaseModelType>::type::ParameterVectorType PoolingVectorType;
198protected:
199 detail::EnsembleImpl<BaseModelType, PoolingVectorType>& impl()
200 { return this->decisionFunction();}
201 detail::EnsembleImpl<BaseModelType, PoolingVectorType> const& impl() const
202 { return this->decisionFunction();}
203};
204
205//if the OutputType is void, this is treated as choosing it as the OutputType of the model
206template<class ModelType>
207struct EnsembleBase<ModelType, void>
208: public EnsembleBase<ModelType, typename std::remove_pointer<ModelType>::type::OutputType>{};
209}
210
211/// \brief Represents en weighted ensemble of models.
212///
213/// In an ensemble, each model computes a response for an input independently. The responses are then pooled
214/// to form a single label. The hope is that models in an ensemble do not produce the same type of errors
215/// and thus the averaged response is more reliable. An example for this is AdaBoost, where a series
216/// of weak models is trained and weighted to create one final prediction.
217///
218/// There are two orthogonal aspects to consider in the Ensemble. The pooling function, which is chosen
219/// based on the output type of the ensemble models, and the mapping of the output of the pooling function
220/// to the model output.
221///
222/// If the ensemble consists of models returning vectors, pooling is implemented
223/// using weighted averaging. If the models return class labels, those are first transformed
224/// into a one-hot encoding before averaging. Thus the output can be interpreted
225/// as the probability of a class label when picking a member of the emsemble randomly with probability
226/// proportional to its weights.
227///
228/// The final mapping to the output is based on the OutputType template parameter, which by default
229/// is the same as the output type of the model. If it is unsigned int, the Ensemble is treated as Classifier
230/// with decision function being the result of the pooling function (i.e. the class with maximum response in
231/// the weighted average is chosen). In this case, Essemble is derived from Classifier<>.
232/// Otherwise the weighted average is returned.
233///
234/// Note that there is a decision in algorihm design tot ake for classifiers:
235/// We can either let each member of the Ensemble predict
236/// a class-label and then pool the labels as described above, or we can create an ensemble of
237/// decision functions and weight them into one decision function to produce the class-label.
238/// Those approaches will lead to different results. For example if the underlying models
239/// produce class probabilities, the class with the largest average probability
240/// might not be the same as the class with most votes from the individual models.
241///
242/// Models are added using addModel.
243/// The ModelType is allowed to be either a concrete model like LinearModel<>, in which
244/// case a copy of each added model is stored. If the ModelType is a pointer, for example
245/// AbstractModel<...>*, only pointers are stored and all added models
246/// must outlive the lifetime of the ensemble. This also entails differences in serialization.
247/// In the first case, the model can be serialized completely without any setup. In the second
248/// case before deserializing, the models must be constructed and added.
249///
250/// \ingroup models
251template<class ModelType, class OutputType = void>
252class Ensemble: public detail::EnsembleBase<ModelType, OutputType>{
253public:
254 std::string name() const
255 { return "Ensemble"; }
256
257 /// \brief Adds a new model to the ensemble.
258 ///
259 /// \param model the new model
260 /// \param weight weight of the model. must be > 0
261 void addModel(ModelType const& model, double weight = 1.0){
262 this->impl().addModel(model,weight);
263 }
264
265 /// \brief Removes all models from the ensemble
267 this->impl().clearModels();
268 }
269
270 /// \brief Returns the number of models.
271 std::size_t numberOfModels()const{
272 return this->impl().numberOfModels();
273 }
274
275 /// \brief Returns a reference to the i-th model.
276 ///
277 /// \param i model index.
278 typename std::remove_pointer<ModelType>::type& model(std::size_t i){
279 return this->impl().model(i);
280 }
281 /// \brief Returns a const reference to the i-th model.
282 ///
283 /// \param i model index.
284 typename std::remove_pointer<ModelType>::type const& model(std::size_t i)const{
285 return this->impl().model(i);
286 }
287
288 /// \brief Returns the weight of the i-th model.
289 ///
290 /// \param i model index.
291 double const& weight(std::size_t i)const{
292 return this->impl().weight(i);
293 }
294
295 /// \brief Returns the weight of the i-th model.
296 ///
297 /// \param i model index.
298 double& weight(std::size_t i){
299 return this->impl().weight(i);
300 }
301
302 /// \brief Returns the total sum of weights used for averaging
303 double sumOfWeights() const{
304 return this->impl().sumOfWeights();
305 }
306
307};
308
309}
310#endif