ModelKernel.h
Go to the documentation of this file.
1//===========================================================================
2/*!
3 *
4 *
5 * \brief Kernel on a finite, discrete space.
6 *
7 *
8 *
9 * \author T. Glasmachers
10 * \date 2012
11 *
12 *
13 * \par Copyright 1995-2017 Shark Development Team
14 *
15 * <BR><HR>
16 * This file is part of Shark.
17 * <https://shark-ml.github.io/Shark/>
18 *
19 * Shark is free software: you can redistribute it and/or modify
20 * it under the terms of the GNU Lesser General Public License as published
21 * by the Free Software Foundation, either version 3 of the License, or
22 * (at your option) any later version.
23 *
24 * Shark is distributed in the hope that it will be useful,
25 * but WITHOUT ANY WARRANTY; without even the implied warranty of
26 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27 * GNU Lesser General Public License for more details.
28 *
29 * You should have received a copy of the GNU Lesser General Public License
30 * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31 *
32 */
33//===========================================================================
34
35#ifndef SHARK_MODELS_KERNELS_MODEL_KERNEL_H
36#define SHARK_MODELS_KERNELS_MODEL_KERNEL_H
37
38
40#include <shark/LinAlg/Base.h>
42#include <vector>
43#include <boost/scoped_ptr.hpp>
44
45namespace shark {
46
47namespace detail{
48template<class InputType, class IntermediateType>
49class ModelKernelImpl : public AbstractKernelFunction<InputType>
50{
51private:
52 typedef AbstractKernelFunction<InputType> base_type;
53public:
57 typedef AbstractKernelFunction<IntermediateType> Kernel;
58 typedef AbstractModel<InputType,IntermediateType> Model;
59private:
60 struct InternalState: public State{
61 boost::shared_ptr<State> kernelStateX1X2;
62 boost::shared_ptr<State> kernelStateX2X1;
63 boost::shared_ptr<State> modelStateX1;
64 boost::shared_ptr<State> modelStateX2;
65 typename Model::BatchOutputType intermediateX1;
66 typename Model::BatchOutputType intermediateX2;
67 };
68public:
69
70 ModelKernelImpl(Kernel* kernel, Model* model):mpe_kernel(kernel),mpe_model(model){
71 if(kernel->hasFirstParameterDerivative()
72 && kernel->hasFirstInputDerivative()
73 && model->hasFirstParameterDerivative())
75 }
76
77 /// \brief From INameable: return the class name.
78 std::string name() const
79 { return "ModelKernel"; }
80
81 std::size_t numberOfParameters()const{
82 return mpe_kernel->numberOfParameters() + mpe_model->numberOfParameters();
83 }
84 RealVector parameterVector() const{
85 return mpe_kernel->parameterVector() | mpe_model->parameterVector();
86 }
87 void setParameterVector(RealVector const& newParameters){
88 SIZE_CHECK(newParameters.size() == numberOfParameters());
89 std::size_t kParams =mpe_kernel->numberOfParameters();
90 mpe_kernel->setParameterVector(subrange(newParameters,0,kParams));
91 mpe_model->setParameterVector(subrange(newParameters,kParams,newParameters.size()));
92 }
93
94 boost::shared_ptr<State> createState()const{
95 InternalState* s = new InternalState();
96 boost::shared_ptr<State> sharedState(s);//create now to allow for destructor to be called in case of exception
97 s->kernelStateX1X2 = mpe_kernel->createState();
98 s->kernelStateX2X1 = mpe_kernel->createState();
99 s->modelStateX1 = mpe_model->createState();
100 s->modelStateX2 = mpe_model->createState();
101 return sharedState;
102 }
103
104 double eval(ConstInputReference x1, ConstInputReference x2) const{
105 auto mx1 = (*mpe_model)(x1);
106 auto mx2= (*mpe_model)(x2);
107 return mpe_kernel->eval(mx1,mx2);
108 }
109
110 void eval(ConstBatchInputReference x1, ConstBatchInputReference x2, RealMatrix& result, State& state) const{
111 InternalState& s=state.toState<InternalState>();
112 mpe_model->eval(x1,s.intermediateX1,*s.modelStateX1);
113 mpe_model->eval(x2,s.intermediateX2,*s.modelStateX2);
114 mpe_kernel->eval(s.intermediateX2,s.intermediateX1,result,*s.kernelStateX2X1);
115 mpe_kernel->eval(s.intermediateX1,s.intermediateX2,result,*s.kernelStateX1X2);
116
117 }
118
119 void eval(ConstBatchInputReference batchX1, ConstBatchInputReference batchX2, RealMatrix& result) const{
120 auto mx1 = (*mpe_model)(batchX1);
121 auto mx2 = (*mpe_model)(batchX2);
122 return mpe_kernel->eval(mx1,mx2,result);
123 }
124
126 ConstBatchInputReference batchX1,
127 ConstBatchInputReference batchX2,
128 RealMatrix const& coefficients,
129 State const& state,
130 RealVector& gradient
131 ) const{
132 gradient.resize(numberOfParameters());
133 InternalState const& s=state.toState<InternalState>();
134
135 //compute derivative of the kernel wrt. parameters
136 RealVector kernelGrad;
137 mpe_kernel->weightedParameterDerivative(
138 s.intermediateX1,s.intermediateX2,
139 coefficients,*s.kernelStateX1X2,kernelGrad
140 );
141 //compute derivative of the kernel wrt left and right parameter
142 typename Model::BatchOutputType inputDerivativeX1, inputDerivativeX2;
143 mpe_kernel->weightedInputDerivative(
144 s.intermediateX1,s.intermediateX2,
145 coefficients,*s.kernelStateX1X2,inputDerivativeX1
146 );
147 mpe_kernel->weightedInputDerivative(
148 s.intermediateX2,s.intermediateX1,
149 trans(coefficients),*s.kernelStateX2X1,inputDerivativeX2
150 );
151
152 //compute derivative of model wrt parameters
153 RealVector modelGradX1,modelGradX2;
154 mpe_model->weightedParameterDerivative(batchX1,s.intermediateX1, inputDerivativeX1,*s.modelStateX1,modelGradX1);
155 mpe_model->weightedParameterDerivative(batchX2,s.intermediateX2, inputDerivativeX2,*s.modelStateX2,modelGradX2);
156 noalias(gradient) = kernelGrad | (modelGradX1+modelGradX2);
157 }
158
159 void read(InArchive& ar){
160 SHARK_RUNTIME_CHECK(mpe_kernel, "The kernel function is NULL, kernel needs to be constructed prior to read in");
161 SHARK_RUNTIME_CHECK(mpe_model, "The model is NULL, model needs to be constructed prior to read in");
162 ar >> *mpe_kernel;
163 ar >> *mpe_model;
164 }
165
166 void write(OutArchive& ar) const{
167 ar << *mpe_kernel;
168 ar << *mpe_model;
169 }
170
171private:
172 Kernel* mpe_kernel;
173 Model* mpe_model;
174};
175}
176
177
178/// \brief Kernel function that uses a Model as transformation function for another kernel
179///
180/// Using an Abstractmodel \f$ f: X \rightarrow X' \f$ and an inner kernel
181/// \f$k: X' \times X' \rightarrow \mathbb{R} \f$, this class defines another kernel
182/// \f$K: X \times X \rightarrow \mathbb{R}\f$ using
183/// \f[
184/// K(x,y) = k(f(x),f(y))
185///\f]
186/// If the inner kernel \f$k\f$ suports both input, as well as parameter derivative and
187/// the model also supports the parameter derivative, the kernel \f$K\f$ also
188/// supports the first parameter derivative using
189/// \f[
190/// \frac{\partial}{\partial \theta} K(x,y) =
191/// \frac{\partial}{\partial f(x)} k(f(x),f(y))\frac{\partial}{\partial \theta} f(x)
192/// +\frac{\partial}{\partial f(y)} k(f(x),f(y))\frac{\partial}{\partial \theta} f(y)
193///\f]
194/// This requires the derivative of the inputs of the kernel wrt both parameters which,
195/// by limitation of the current kernel interface, requires to compute \f$k(f(x),f(y))\f$ and \f$k(f(y),f(x))\f$.
196/// \ingroup kernels
197template<class InputType=RealVector>
198class ModelKernel: public AbstractKernelFunction<InputType>{
199private:
201public:
205
206 template<class IntermediateType>
210 ):m_wrapper(new detail::ModelKernelImpl<InputType,IntermediateType>(kernel,model)){
211 SHARK_RUNTIME_CHECK(kernel, "The kernel function is not allowed to be NULL");
212 SHARK_RUNTIME_CHECK(model, "The model is not allowed to be NULL");
213 if(m_wrapper->hasFirstParameterDerivative())
215 }
216
217 /// \brief From INameable: return the class name.
218 std::string name() const
219 { return "ModelKernel"; }
220
221 /// \brief Returns the number of parameters.
222 std::size_t numberOfParameters()const{
223 return m_wrapper->numberOfParameters();
224 }
225 ///\brief Returns the concatenated parameters of kernel and model.
226 RealVector parameterVector() const{
227 return m_wrapper->parameterVector();
228 }
229 ///\brief Sets the concatenated parameters of kernel and model.
230 void setParameterVector(RealVector const& newParameters){
231 m_wrapper->setParameterVector(newParameters);
232 }
233
234 ///\brief Returns the internal state object used for eval and the derivatives.
235 boost::shared_ptr<State> createState()const{
236 return m_wrapper->createState();
237 }
238
239 ///\brief Computes K(x,y) for a single input pair.
241 return m_wrapper->eval(x1,x2);
242 }
243
244 /// \brief For two batches X1 and X2 computes the matrix k_ij=K(X1_i,X2_j) and stores the state for the derivatives.
245 void eval(ConstBatchInputReference batchX1, ConstBatchInputReference batchX2, RealMatrix& result, State& state) const{
246 return m_wrapper->eval(batchX1,batchX2,result,state);
247 }
248 /// \brief For two batches X1 and X2 computes the matrix k_ij=K(X1_i,X2_j).
249 void eval(ConstBatchInputReference batchX1, ConstBatchInputReference batchX2, RealMatrix& result) const{
250 m_wrapper->eval(batchX1,batchX2,result);
251 }
252
253 ///\brief After a call to eval with state, computes the derivative wrt all parameters of the kernel and the model.
254 ///
255 /// This is computed over the whole kernel matrix k_ij created by eval and summed up using the coefficients c
256 /// thus this call returns \f$ \sum_{i,j} c_{ij} \frac{\partial}{\partial \theta} k(x^1_i,x^2_j)\f$.
260 RealMatrix const& coefficients,
261 State const& state,
262 RealVector& gradient
263 ) const{
264 m_wrapper->weightedParameterDerivative(batchX1,batchX2,coefficients,state,gradient);
265 }
266
267 ///\brief Stores the kernel to an Archive.
268 void write(OutArchive& ar) const{
269 ar<< *m_wrapper;
270 }
271 ///\brief Reads the kernel from an Archive.
272 void read(OutArchive& ar) const{
273 ar >> *m_wrapper;
274 }
275private:
276 boost::scoped_ptr<AbstractKernelFunction<InputType> > m_wrapper;
277};
278
281
282
283}
284#endif