35#ifndef SHARK_MODELS_KERNELS_SUBRANGE_KERNEL_H
36#define SHARK_MODELS_KERNELS_SUBRANGE_KERNEL_H
45template<
class InputType>
46class SubrangeKernelWrapper :
public AbstractKernelFunction<InputType>{
48 typedef AbstractKernelFunction<InputType> base_type;
54 SubrangeKernelWrapper(AbstractKernelFunction<InputType>* kernel,std::size_t start, std::size_t end)
55 :m_kernel(kernel),m_start(start),m_end(end){
56 if(kernel->hasFirstParameterDerivative())
58 if(kernel->hasFirstInputDerivative())
63 std::string
name()
const
64 {
return "SubrangeKernelWrapper"; }
67 return m_kernel->parameterVector();
71 m_kernel->setParameterVector(newParameters);
75 return m_kernel->numberOfParameters();
80 return m_kernel->createState();
83 double eval(ConstInputReference x1, ConstInputReference x2)
const{
84 return m_kernel->eval(blas::subrange(x1,m_start,m_end),blas::subrange(x2,m_start,m_end));
87 void eval(ConstBatchInputReference batchX1, ConstBatchInputReference batchX2, RealMatrix& result, State& state)
const{
88 m_kernel->eval(columns(batchX1,m_start,m_end),columns(batchX2,m_start,m_end),result,state);
91 void eval(ConstBatchInputReference batchX1, ConstBatchInputReference batchX2, RealMatrix& result)
const{
92 m_kernel->eval(columns(batchX1,m_start,m_end),columns(batchX2,m_start,m_end),result);
96 ConstBatchInputReference batchX1,
97 ConstBatchInputReference batchX2,
98 RealMatrix
const& coefficients,
102 m_kernel->weightedParameterDerivative(
103 columns(batchX1,m_start,m_end),
104 columns(batchX2,m_start,m_end),
111 ConstBatchInputReference batchX1,
112 ConstBatchInputReference batchX2,
113 RealMatrix
const& coefficientsX2,
115 BatchInputType& gradient
117 BatchInputType temp(gradient.size1(),m_end-m_start);
118 m_kernel->weightedInputDerivative(
119 columns(batchX1,m_start,m_end),
120 columns(batchX2,m_start,m_end),
125 ensure_size(gradient,batchX1.size1(),batchX2.size2());
127 noalias(columns(gradient,m_start,m_end)) = temp;
137 AbstractKernelFunction<InputType>* m_kernel;
142template<
class InputType>
143class SubrangeKernelBase
147 template<
class Kernels,
class Ranges>
148 SubrangeKernelBase(Kernels
const& kernels, Ranges
const& ranges){
150 for(std::size_t i = 0; i != kernels.size(); ++i){
151 m_kernelWrappers.push_back(
152 SubrangeKernelWrapper<InputType>(kernels[i],ranges[i].first,ranges[i].second)
157 std::vector<AbstractKernelFunction<InputType>* > makeKernelVector(){
158 std::vector<AbstractKernelFunction<InputType>* > kernels(m_kernelWrappers.size());
159 for(std::size_t i = 0; i != m_kernelWrappers.size(); ++i)
160 kernels[i] = & m_kernelWrappers[i];
164 std::vector<SubrangeKernelWrapper <InputType> > m_kernelWrappers;
189template<
class InputType,
class InnerKernel=WeightedSumKernel<InputType> >
191:
private detail::SubrangeKernelBase<InputType>
195 typedef detail::SubrangeKernelBase<InputType> base_type1;
200 {
return "SubrangeKernel"; }
202 template<
class Kernels,
class Ranges>
204 : base_type1(kernels,ranges)
205 , InnerKernel(base_type1::makeKernelVector()){}