33#ifndef SHARK_MODELS_ENSEMBLE_H
34#define SHARK_MODELS_ENSEMBLE_H
42template<
class BaseModelType,
class VectorType>
43class EnsembleImpl:
public AbstractModel<
44 typename std::remove_pointer<BaseModelType>::type::InputType,
46 typename std::remove_pointer<BaseModelType>::type::ParameterVectorType
49 typedef typename std::remove_pointer<BaseModelType>::type ModelType;
51 typedef AbstractModel<typename ModelType::InputType, VectorType, typename ModelType::ParameterVectorType> Base;
56 ModelType& derefIfPtr(ModelType& model)
const{
59 ModelType
const& derefIfPtr(ModelType
const& model)
const{
62 ModelType& derefIfPtr(ModelType* model)
const{
68 template<
class T>
struct tag{};
70 template<
class InputBatch,
class T,
class Device>
71 void pool(InputBatch
const& patterns, blas::matrix<T, blas::row_major, Device>& outputs,
tag<blas::vector<T, Device> >)
const{
72 for(std::size_t i = 0; i != numberOfModels(); i++){
73 noalias(outputs) += weight(i) * model(i)(patterns);
77 template<
class InputBatch,
class OutputBatch>
78 void pool(InputBatch
const& patterns, OutputBatch& outputs, tag<unsigned int>)
const{
79 blas::vector<unsigned int> responses;
80 for(std::size_t i = 0; i != numberOfModels(); ++i){
81 model(i).eval(patterns, responses);
82 for(std::size_t p = 0; p != patterns.size1(); ++p){
83 outputs(p,responses(p)) += weight(i);
89 std::vector<BaseModelType> m_models;
103 void addModel(BaseModelType
const& model,
double weight = 1.0){
105 m_models.push_back(model);
106 m_weights.push_back(weight);
115 ModelType& model(std::size_t index){
116 return derefIfPtr(m_models[index]);
119 ModelType
const& model(std::size_t index)
const{
120 return derefIfPtr(m_models[index]);
124 double const& weight(std::size_t i)
const{
129 double& weight(std::size_t i){
135 return sum(m_weights);
139 std::size_t numberOfModels()
const{
140 return m_models.size();
144 Shape inputShape()
const{
145 return m_models.empty() ? Shape(): model(0).inputShape();
148 Shape outputShape()
const{
149 return m_models.empty() ? Shape(): model(0).outputShape();
153 void eval(BatchInputType
const& patterns, BatchOutputType& outputs)
const{
154 outputs.resize(patterns.size1(), outputShape().numElements());
156 pool(patterns,outputs, tag<typename ModelType::OutputType>());
158 void eval(BatchInputType
const& patterns, BatchOutputType& outputs, State&)
const{
159 eval(patterns,outputs);
163 std::size_t numModels;
164 archive >> numModels;
165 m_models.resize(numModels);
166 for(std::size_t i = 0; i != numModels; ++i){
169 archive >> m_weights;
172 std::size_t numModels = m_models.size();
173 archive << numModels;
174 for(std::size_t i = 0; i != numModels; ++i){
177 archive << m_weights;
183template<
class ModelType,
class OutputType>
184struct EnsembleBase :
public detail::EnsembleImpl<ModelType, OutputType>{
186 typedef typename std::remove_pointer<ModelType>::type::OutputType ModelOutputType;
188 detail::EnsembleImpl<ModelType, OutputType>& impl(){
return *
this;};
189 detail::EnsembleImpl<ModelType, OutputType>
const& impl()
const{
return *
this;};
193template<
class BaseModelType>
194struct EnsembleBase<BaseModelType, unsigned int>
195:
public Classifier<detail::EnsembleImpl<BaseModelType, typename std::remove_pointer<BaseModelType>::type::ParameterVectorType> >{
197 typedef typename std::remove_pointer<BaseModelType>::type::ParameterVectorType PoolingVectorType;
199 detail::EnsembleImpl<BaseModelType, PoolingVectorType>& impl()
200 {
return this->decisionFunction();}
201 detail::EnsembleImpl<BaseModelType, PoolingVectorType>
const& impl()
const
202 {
return this->decisionFunction();}
206template<
class ModelType>
207struct EnsembleBase<ModelType, void>
208:
public EnsembleBase<ModelType, typename std::remove_pointer<ModelType>::type::OutputType>{};
251template<
class ModelType,
class OutputType =
void>
252class Ensemble:
public detail::EnsembleBase<ModelType, OutputType>{
255 {
return "Ensemble"; }
267 this->impl().clearModels();
272 return this->impl().numberOfModels();
278 typename std::remove_pointer<ModelType>::type&
model(std::size_t i){
279 return this->impl().model(i);
284 typename std::remove_pointer<ModelType>::type
const&
model(std::size_t i)
const{
285 return this->impl().model(i);
291 double const&
weight(std::size_t i)
const{
292 return this->impl().weight(i);
299 return this->impl().weight(i);
304 return this->impl().sumOfWeights();