ProductKernel.h
Go to the documentation of this file.
1//===========================================================================
2/*!
3 *
4 *
5 * \brief Product of kernel functions.
6 *
7 *
8 *
9 * \author T. Glasmachers, O.Krause
10 * \date 2012
11 *
12 *
13 * \par Copyright 1995-2017 Shark Development Team
14 *
15 * <BR><HR>
16 * This file is part of Shark.
17 * <https://shark-ml.github.io/Shark/>
18 *
19 * Shark is free software: you can redistribute it and/or modify
20 * it under the terms of the GNU Lesser General Public License as published
21 * by the Free Software Foundation, either version 3 of the License, or
22 * (at your option) any later version.
23 *
24 * Shark is distributed in the hope that it will be useful,
25 * but WITHOUT ANY WARRANTY; without even the implied warranty of
26 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27 * GNU Lesser General Public License for more details.
28 *
29 * You should have received a copy of the GNU Lesser General Public License
30 * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31 *
32 */
33//===========================================================================
34
35#ifndef SHARK_MODELS_KERNELS_PRODUCTKERNEL_H
36#define SHARK_MODELS_KERNELS_PRODUCTKERNEL_H
37
38
40
41namespace shark{
42
43
44///
45/// \brief Product of kernel functions.
46///
47/// \par
48/// The product of any number of kernels is again a valid kernel.
49/// This class supports a kernel af the form
50/// \f$ k(x, x') = k_1(x, x') \cdot k_2(x, x') \cdot \dots \cdot k_n(x, x') \f$
51/// for any number of base kernels. All kernels need to be defined
52/// on the same input space.
53///
54/// \par
55/// Derivatives are currently not implemented. Only the plain
56/// kernel value can be computed. Everyone is free to add this
57/// functionality :)
58///
59/// \ingroup kernels
60template<class InputType>
61class ProductKernel : public AbstractKernelFunction<InputType>
62{
63private:
65public:
70 /// \brief Default constructor.
72 // this->m_features |= base_type::HAS_FIRST_PARAMETER_DERIVATIVE;
73 // this->m_features |= base_type::HAS_SECOND_PARAMETER_DERIVATIVE;
74 // this->m_features |= base_type::HAS_FIRST_INPUT_DERIVATIVE;
75 // this->m_features |= base_type::HAS_SECOND_INPUT_DERIVATIVE;
76 this->m_features |= base_type::IS_NORMALIZED; // an "empty" product is a normalized kernel (k(x, x) = 1).
77 }
78
79 /// \brief Constructor for a product of two kernels.
80 ProductKernel(SubKernel* k1, SubKernel* k2){
81 // this->m_features |= base_type::HAS_FIRST_PARAMETER_DERIVATIVE;
82 // this->m_features |= base_type::HAS_SECOND_PARAMETER_DERIVATIVE;
83 // this->m_features |= base_type::HAS_FIRST_INPUT_DERIVATIVE;
84 // this->m_features |= base_type::HAS_SECOND_INPUT_DERIVATIVE;
85 this->m_features |= base_type::IS_NORMALIZED; // an "empty" product is a normalized kernel (k(x, x) = 1).
86 addKernel(k1);
87 addKernel(k2);
88 }
89 ProductKernel(std::vector<SubKernel*> kernels){
90 // this->m_features |= base_type::HAS_FIRST_PARAMETER_DERIVATIVE;
91 // this->m_features |= base_type::HAS_SECOND_PARAMETER_DERIVATIVE;
92 // this->m_features |= base_type::HAS_FIRST_INPUT_DERIVATIVE;
93 // this->m_features |= base_type::HAS_SECOND_INPUT_DERIVATIVE;
94 this->m_features |= base_type::IS_NORMALIZED; // an "empty" product is a normalized kernel (k(x, x) = 1).
95 for(std::size_t i = 0; i != kernels.size(); ++i)
96 addKernel(kernels[i]);
97 }
98
99 /// \brief From INameable: return the class name.
100 std::string name() const
101 { return "ProductKernel"; }
102
103 /// \brief Add one more kernel to the expansion.
104 ///
105 /// \param k The pointer is expected to remain valid during the lifetime of the ProductKernel object.
106 ///
108 SHARK_ASSERT(k != NULL);
109
110 m_kernels.push_back(k);
112 if (! k->isNormalized()) this->m_features.reset(base_type::IS_NORMALIZED); // products of normalized kernels are normalized.
113 }
114
115 RealVector parameterVector() const{
116 RealVector ret(m_numberOfParameters);
117 std::size_t pos = 0;
118 for(auto kernel: m_kernels){
119 auto const& params = kernel->parameterVector();
120 noalias(subrange(ret,pos, pos + params.size())) = params;
121 pos += params.size();
122 }
123 return ret;
124 }
125
126 void setParameterVector(RealVector const& newParameters){
127 SIZE_CHECK(newParameters.size() == m_numberOfParameters);
128
129 std::size_t pos = 0;
130 for(auto kernel: m_kernels){
131 std::size_t numParams = kernel->numberOfParameters();
132 kernel->setParameterVector(subrange(newParameters,pos, pos + numParams));
133 pos += numParams;
134 }
135 }
136
137 std::size_t numberOfParameters() const{
139 }
140
141 /// \brief evaluates the kernel function
142 ///
143 /// This function returns the product of all sub-kernels.
145 double prod = 1.0;
146 for (std::size_t i=0; i<m_kernels.size(); i++)
147 prod *= m_kernels[i]->eval(x1, x2);
148 return prod;
149 }
150
151 void eval(ConstBatchInputReference batchX1, ConstBatchInputReference batchX2, RealMatrix& result) const{
152 std::size_t sizeX1 = batchSize(batchX1);
153 std::size_t sizeX2 = batchSize(batchX2);
154
155 //evaluate first kernel to initialize the result
156 m_kernels[0]->eval(batchX1,batchX2,result);
157
158 RealMatrix kernelResult(sizeX1,sizeX2);
159 for(std::size_t i = 1; i != m_kernels.size(); ++i){
160 m_kernels[i]->eval(batchX1,batchX2,kernelResult);
161 noalias(result) *= kernelResult;
162 }
163 }
164
165 void eval(ConstBatchInputReference batchX1, ConstBatchInputReference batchX2, RealMatrix& result, State& state) const{
166 eval(batchX1,batchX2,result);
167 }
168
169 /// From ISerializable.
170 void read(InArchive& ar){
171 for(std::size_t i = 0;i != m_kernels.size(); ++i ){
172 ar >> *m_kernels[i];
173 }
175 }
176
177 /// From ISerializable.
178 void write(OutArchive& ar) const{
179 for(std::size_t i = 0;i != m_kernels.size(); ++i ){
180 ar << const_cast<AbstractKernelFunction<InputType> const&>(*m_kernels[i]);//prevent serialization warning
181 }
183 }
184
185protected:
186 std::vector<SubKernel*> m_kernels; ///< vector of sub-kernels
187 std::size_t m_numberOfParameters; ///< total number of parameters in the product (this is redundant information)
188};
189
190
191}
192#endif