Shark machine learning library
Installation
Tutorials
Benchmarks
Documentation
Quick references
Class list
Global functions
include
shark
Models
Kernels
ProductKernel.h
Go to the documentation of this file.
1
//===========================================================================
2
/*!
3
*
4
*
5
* \brief Product of kernel functions.
6
*
7
*
8
*
9
* \author T. Glasmachers, O.Krause
10
* \date 2012
11
*
12
*
13
* \par Copyright 1995-2017 Shark Development Team
14
*
15
* <BR><HR>
16
* This file is part of Shark.
17
* <https://shark-ml.github.io/Shark/>
18
*
19
* Shark is free software: you can redistribute it and/or modify
20
* it under the terms of the GNU Lesser General Public License as published
21
* by the Free Software Foundation, either version 3 of the License, or
22
* (at your option) any later version.
23
*
24
* Shark is distributed in the hope that it will be useful,
25
* but WITHOUT ANY WARRANTY; without even the implied warranty of
26
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27
* GNU Lesser General Public License for more details.
28
*
29
* You should have received a copy of the GNU Lesser General Public License
30
* along with Shark. If not, see <http://www.gnu.org/licenses/>.
31
*
32
*/
33
//===========================================================================
34
35
#ifndef SHARK_MODELS_KERNELS_PRODUCTKERNEL_H
36
#define SHARK_MODELS_KERNELS_PRODUCTKERNEL_H
37
38
39
#include <
shark/Models/Kernels/AbstractKernelFunction.h
>
40
41
namespace
shark
{
42
43
44
///
45
/// \brief Product of kernel functions.
46
///
47
/// \par
48
/// The product of any number of kernels is again a valid kernel.
49
/// This class supports a kernel af the form
50
/// \f$ k(x, x') = k_1(x, x') \cdot k_2(x, x') \cdot \dots \cdot k_n(x, x') \f$
51
/// for any number of base kernels. All kernels need to be defined
52
/// on the same input space.
53
///
54
/// \par
55
/// Derivatives are currently not implemented. Only the plain
56
/// kernel value can be computed. Everyone is free to add this
57
/// functionality :)
58
///
59
/// \ingroup kernels
60
template
<
class
InputType>
61
class
ProductKernel
:
public
AbstractKernelFunction
<InputType>
62
{
63
private
:
64
typedef
AbstractKernelFunction<InputType>
base_type
;
65
public
:
66
typedef
AbstractKernelFunction<InputType>
SubKernel
;
67
typedef
typename
base_type::BatchInputType
BatchInputType
;
68
typedef
typename
base_type::ConstInputReference
ConstInputReference
;
69
typedef
typename
base_type::ConstBatchInputReference
ConstBatchInputReference
;
70
/// \brief Default constructor.
71
ProductKernel
(){
72
// this->m_features |= base_type::HAS_FIRST_PARAMETER_DERIVATIVE;
73
// this->m_features |= base_type::HAS_SECOND_PARAMETER_DERIVATIVE;
74
// this->m_features |= base_type::HAS_FIRST_INPUT_DERIVATIVE;
75
// this->m_features |= base_type::HAS_SECOND_INPUT_DERIVATIVE;
76
this->
m_features
|=
base_type::IS_NORMALIZED
;
// an "empty" product is a normalized kernel (k(x, x) = 1).
77
}
78
79
/// \brief Constructor for a product of two kernels.
80
ProductKernel
(SubKernel* k1,
SubKernel
* k2){
81
// this->m_features |= base_type::HAS_FIRST_PARAMETER_DERIVATIVE;
82
// this->m_features |= base_type::HAS_SECOND_PARAMETER_DERIVATIVE;
83
// this->m_features |= base_type::HAS_FIRST_INPUT_DERIVATIVE;
84
// this->m_features |= base_type::HAS_SECOND_INPUT_DERIVATIVE;
85
this->
m_features
|=
base_type::IS_NORMALIZED
;
// an "empty" product is a normalized kernel (k(x, x) = 1).
86
addKernel
(k1);
87
addKernel
(k2);
88
}
89
ProductKernel
(std::vector<SubKernel*> kernels){
90
// this->m_features |= base_type::HAS_FIRST_PARAMETER_DERIVATIVE;
91
// this->m_features |= base_type::HAS_SECOND_PARAMETER_DERIVATIVE;
92
// this->m_features |= base_type::HAS_FIRST_INPUT_DERIVATIVE;
93
// this->m_features |= base_type::HAS_SECOND_INPUT_DERIVATIVE;
94
this->
m_features
|=
base_type::IS_NORMALIZED
;
// an "empty" product is a normalized kernel (k(x, x) = 1).
95
for
(std::size_t i = 0; i != kernels.size(); ++i)
96
addKernel
(kernels[i]);
97
}
98
99
/// \brief From INameable: return the class name.
100
std::string
name
()
const
101
{
return
"ProductKernel"
; }
102
103
/// \brief Add one more kernel to the expansion.
104
///
105
/// \param k The pointer is expected to remain valid during the lifetime of the ProductKernel object.
106
///
107
void
addKernel
(
SubKernel
* k){
108
SHARK_ASSERT
(k != NULL);
109
110
m_kernels
.push_back(k);
111
m_numberOfParameters
+= k->
numberOfParameters
();
112
if
(! k->
isNormalized
()) this->
m_features
.
reset
(
base_type::IS_NORMALIZED
);
// products of normalized kernels are normalized.
113
}
114
115
RealVector
parameterVector
()
const
{
116
RealVector ret(
m_numberOfParameters
);
117
std::size_t pos = 0;
118
for
(
auto
kernel:
m_kernels
){
119
auto
const
& params = kernel->parameterVector();
120
noalias(subrange(ret,pos, pos + params.size())) = params;
121
pos += params.size();
122
}
123
return
ret;
124
}
125
126
void
setParameterVector
(RealVector
const
& newParameters){
127
SIZE_CHECK
(newParameters.size() ==
m_numberOfParameters
);
128
129
std::size_t pos = 0;
130
for
(
auto
kernel:
m_kernels
){
131
std::size_t numParams = kernel->numberOfParameters();
132
kernel->setParameterVector(subrange(newParameters,pos, pos + numParams));
133
pos += numParams;
134
}
135
}
136
137
std::size_t
numberOfParameters
()
const
{
138
return
m_numberOfParameters
;
139
}
140
141
/// \brief evaluates the kernel function
142
///
143
/// This function returns the product of all sub-kernels.
144
double
eval
(
ConstInputReference
x1,
ConstInputReference
x2)
const
{
145
double
prod = 1.0;
146
for
(std::size_t i=0; i<
m_kernels
.size(); i++)
147
prod *=
m_kernels
[i]->
eval
(x1, x2);
148
return
prod;
149
}
150
151
void
eval
(
ConstBatchInputReference
batchX1,
ConstBatchInputReference
batchX2, RealMatrix& result)
const
{
152
std::size_t sizeX1 =
batchSize
(batchX1);
153
std::size_t sizeX2 =
batchSize
(batchX2);
154
155
//evaluate first kernel to initialize the result
156
m_kernels
[0]->eval(batchX1,batchX2,result);
157
158
RealMatrix kernelResult(sizeX1,sizeX2);
159
for
(std::size_t i = 1; i !=
m_kernels
.size(); ++i){
160
m_kernels
[i]->eval(batchX1,batchX2,kernelResult);
161
noalias(result) *= kernelResult;
162
}
163
}
164
165
void
eval
(
ConstBatchInputReference
batchX1,
ConstBatchInputReference
batchX2, RealMatrix& result,
State
& state)
const
{
166
eval
(batchX1,batchX2,result);
167
}
168
169
/// From ISerializable.
170
void
read
(
InArchive
& ar){
171
for
(std::size_t i = 0;i !=
m_kernels
.size(); ++i ){
172
ar >> *
m_kernels
[i];
173
}
174
ar >>
m_numberOfParameters
;
175
}
176
177
/// From ISerializable.
178
void
write
(
OutArchive
& ar)
const
{
179
for
(std::size_t i = 0;i !=
m_kernels
.size(); ++i ){
180
ar << const_cast<AbstractKernelFunction<InputType>
const
&>(*
m_kernels
[i]);
//prevent serialization warning
181
}
182
ar <<
m_numberOfParameters
;
183
}
184
185
protected
:
186
std::vector<SubKernel*>
m_kernels
;
///< vector of sub-kernels
187
std::size_t
m_numberOfParameters
;
///< total number of parameters in the product (this is redundant information)
188
};
189
190
191
}
192
#endif