35#ifndef SHARK_MODELS_KERNELS_MODEL_KERNEL_H
36#define SHARK_MODELS_KERNELS_MODEL_KERNEL_H
43#include <boost/scoped_ptr.hpp>
48template<
class InputType,
class IntermediateType>
49class ModelKernelImpl :
public AbstractKernelFunction<InputType>
52 typedef AbstractKernelFunction<InputType> base_type;
57 typedef AbstractKernelFunction<IntermediateType> Kernel;
58 typedef AbstractModel<InputType,IntermediateType> Model;
60 struct InternalState:
public State{
61 boost::shared_ptr<State> kernelStateX1X2;
62 boost::shared_ptr<State> kernelStateX2X1;
63 boost::shared_ptr<State> modelStateX1;
64 boost::shared_ptr<State> modelStateX2;
70 ModelKernelImpl(Kernel* kernel, Model* model):mpe_kernel(kernel),mpe_model(model){
71 if(kernel->hasFirstParameterDerivative()
72 && kernel->hasFirstInputDerivative()
73 && model->hasFirstParameterDerivative())
78 std::string
name()
const
79 {
return "ModelKernel"; }
82 return mpe_kernel->numberOfParameters() + mpe_model->numberOfParameters();
85 return mpe_kernel->parameterVector() | mpe_model->parameterVector();
89 std::size_t kParams =mpe_kernel->numberOfParameters();
90 mpe_kernel->setParameterVector(subrange(newParameters,0,kParams));
91 mpe_model->setParameterVector(subrange(newParameters,kParams,newParameters.size()));
95 InternalState* s =
new InternalState();
96 boost::shared_ptr<State> sharedState(s);
97 s->kernelStateX1X2 = mpe_kernel->createState();
98 s->kernelStateX2X1 = mpe_kernel->createState();
99 s->modelStateX1 = mpe_model->createState();
100 s->modelStateX2 = mpe_model->createState();
104 double eval(ConstInputReference x1, ConstInputReference x2)
const{
105 auto mx1 = (*mpe_model)(x1);
106 auto mx2= (*mpe_model)(x2);
107 return mpe_kernel->eval(mx1,mx2);
110 void eval(ConstBatchInputReference x1, ConstBatchInputReference x2, RealMatrix& result, State& state)
const{
111 InternalState& s=state.toState<InternalState>();
112 mpe_model->eval(x1,s.intermediateX1,*s.modelStateX1);
113 mpe_model->eval(x2,s.intermediateX2,*s.modelStateX2);
114 mpe_kernel->eval(s.intermediateX2,s.intermediateX1,result,*s.kernelStateX2X1);
115 mpe_kernel->eval(s.intermediateX1,s.intermediateX2,result,*s.kernelStateX1X2);
119 void eval(ConstBatchInputReference batchX1, ConstBatchInputReference batchX2, RealMatrix& result)
const{
120 auto mx1 = (*mpe_model)(batchX1);
121 auto mx2 = (*mpe_model)(batchX2);
122 return mpe_kernel->eval(mx1,mx2,result);
126 ConstBatchInputReference batchX1,
127 ConstBatchInputReference batchX2,
128 RealMatrix
const& coefficients,
133 InternalState
const& s=state.toState<InternalState>();
136 RealVector kernelGrad;
137 mpe_kernel->weightedParameterDerivative(
138 s.intermediateX1,s.intermediateX2,
139 coefficients,*s.kernelStateX1X2,kernelGrad
143 mpe_kernel->weightedInputDerivative(
144 s.intermediateX1,s.intermediateX2,
145 coefficients,*s.kernelStateX1X2,inputDerivativeX1
147 mpe_kernel->weightedInputDerivative(
148 s.intermediateX2,s.intermediateX1,
149 trans(coefficients),*s.kernelStateX2X1,inputDerivativeX2
153 RealVector modelGradX1,modelGradX2;
154 mpe_model->weightedParameterDerivative(batchX1,s.intermediateX1, inputDerivativeX1,*s.modelStateX1,modelGradX1);
155 mpe_model->weightedParameterDerivative(batchX2,s.intermediateX2, inputDerivativeX2,*s.modelStateX2,modelGradX2);
156 noalias(gradient) = kernelGrad | (modelGradX1+modelGradX2);
160 SHARK_RUNTIME_CHECK(mpe_kernel,
"The kernel function is NULL, kernel needs to be constructed prior to read in");
161 SHARK_RUNTIME_CHECK(mpe_model,
"The model is NULL, model needs to be constructed prior to read in");
197template<
class InputType=RealVector>
206 template<
class IntermediateType>
210 ):m_wrapper(new detail::ModelKernelImpl<
InputType,IntermediateType>(kernel,model)){
213 if(m_wrapper->hasFirstParameterDerivative())
219 {
return "ModelKernel"; }
223 return m_wrapper->numberOfParameters();
227 return m_wrapper->parameterVector();
231 m_wrapper->setParameterVector(newParameters);
236 return m_wrapper->createState();
241 return m_wrapper->eval(x1,x2);
246 return m_wrapper->eval(batchX1,batchX2,result,state);
250 m_wrapper->eval(batchX1,batchX2,result);
260 RealMatrix
const& coefficients,
264 m_wrapper->weightedParameterDerivative(batchX1,batchX2,coefficients,state,gradient);
276 boost::scoped_ptr<AbstractKernelFunction<InputType> > m_wrapper;