Shark machine learning library
Installation
Tutorials
Benchmarks
Documentation
Quick references
Class list
Global functions
include
shark
Models
Kernels
ScaledKernel.h
Go to the documentation of this file.
1
//===========================================================================
2
/*!
3
*
4
*
5
* \brief A kernel function that wraps a member kernel and multiplies it by a scalar.
6
*
7
*
8
*
9
* \author M. Tuma, T. Glasmachers, O. Krause
10
* \date 2012
11
*
12
*
13
* \par Copyright 1995-2017 Shark Development Team
14
*
15
* <BR><HR>
16
* This file is part of Shark.
17
* <https://shark-ml.github.io/Shark/>
18
*
19
* Shark is free software: you can redistribute it and/or modify
20
* it under the terms of the GNU Lesser General Public License as published
21
* by the Free Software Foundation, either version 3 of the License, or
22
* (at your option) any later version.
23
*
24
* Shark is distributed in the hope that it will be useful,
25
* but WITHOUT ANY WARRANTY; without even the implied warranty of
26
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27
* GNU Lesser General Public License for more details.
28
*
29
* You should have received a copy of the GNU Lesser General Public License
30
* along with Shark. If not, see <http://www.gnu.org/licenses/>.
31
*
32
*/
33
//===========================================================================
34
35
#ifndef SHARK_MODELS_KERNELS_SCALED_KERNEL_H
36
#define SHARK_MODELS_KERNELS_SCALED_KERNEL_H
37
38
39
#include <
shark/Models/Kernels/AbstractKernelFunction.h
>
40
namespace
shark
{
41
42
43
/// \brief Scaled version of a kernel function
44
///
45
/// For a positive definite kernel k, the scaled kernel
46
/// \f[ \tilde k(x_1, x_2) := c k(x_1, x_2) \f]
47
/// is again a positive definite kernel function as long as \f$ c > 0 \f$.
48
/// \ingroup kernels
49
template
<
class
InputType=RealVector>
50
class
ScaledKernel
:
public
AbstractKernelFunction
<InputType>
51
{
52
private
:
53
typedef
AbstractKernelFunction<InputType>
base_type
;
54
public
:
55
typedef
typename
base_type::BatchInputType
BatchInputType
;
56
typedef
typename
base_type::ConstInputReference
ConstInputReference
;
57
typedef
typename
base_type::ConstBatchInputReference
ConstBatchInputReference
;
58
59
ScaledKernel
(
AbstractKernelFunction<InputType>
*
base
,
double
factor
= 1.0 )
60
:
m_base
(
base
),
61
m_factor
(
factor
)
62
{
63
RANGE_CHECK
(
factor
> 0 );
64
SHARK_ASSERT
(
base
!= NULL );
65
if
(
m_base
->
hasFirstInputDerivative
() )
66
this->
m_features
|=
base_type::HAS_FIRST_INPUT_DERIVATIVE
;
67
if
(
m_base
->
hasFirstParameterDerivative
() )
68
this->
m_features
|=
base_type::HAS_FIRST_PARAMETER_DERIVATIVE
;
69
}
70
71
/// \brief From INameable: return the class name.
72
std::string
name
()
const
73
{
return
"ScaledKernel"
; }
74
75
RealVector
parameterVector
()
const
{
76
return
m_base
->
parameterVector
();
77
}
78
void
setParameterVector
(RealVector
const
& newParameters) {
79
m_base
->
setParameterVector
(newParameters);
80
}
81
82
std::size_t
numberOfParameters
()
const
{
83
return
m_base
->
numberOfParameters
();
84
}
85
86
///\brief creates the internal state of the kernel
87
boost::shared_ptr<State>
createState
()
const
{
88
return
m_base
->
createState
();
89
}
90
91
const
double
factor
() {
92
return
m_factor
;
93
}
94
void
setFactor
(
double
f ) {
95
RANGE_CHECK
( f > 0 );
96
m_factor
= f;
97
}
98
99
const
base_type
*
base
()
const
{
100
return
m_base
;
101
}
102
103
double
eval
(
ConstInputReference
x1,
ConstInputReference
x2)
const
{
104
SIZE_CHECK
(x1.size() == x2.size());
105
return
m_factor
*
m_base
->
eval
(x1, x2);
106
}
107
108
void
eval
(
ConstBatchInputReference
x1,
ConstBatchInputReference
x2, RealMatrix& result)
const
{
109
m_base
->
eval
(x1, x2,result);
110
result *=
m_factor
;
111
}
112
113
void
eval
(
ConstBatchInputReference
x1,
ConstBatchInputReference
x2, RealMatrix& result,
State
& state)
const
{
114
m_base
->
eval
(x1, x2,result,state);
115
result *=
m_factor
;
116
}
117
118
/// calculates the weighted derivate w.r.t. the parameters of the base kernel
119
void
weightedParameterDerivative
(
120
ConstBatchInputReference
batchX1,
121
ConstBatchInputReference
batchX2,
122
RealMatrix
const
& coefficients,
123
State
const
& state,
124
RealVector& gradient
125
)
const
{
126
m_base
->
weightedParameterDerivative
( batchX1, batchX2, coefficients, state, gradient );
127
gradient *=
m_factor
;
128
}
129
/// calculates the weighted derivate w.r.t. argument \f$ x_1 \f$
130
void
weightedInputDerivative
(
131
ConstBatchInputReference
batchX1,
132
ConstBatchInputReference
batchX2,
133
RealMatrix
const
& coefficientsX2,
134
State
const
& state,
135
BatchInputType
& gradient
136
)
const
{
137
SIZE_CHECK
(coefficientsX2.size1() ==
batchSize
(batchX1));
138
SIZE_CHECK
(coefficientsX2.size2() ==
batchSize
(batchX2));
139
m_base
->
weightedInputDerivative
( batchX1, batchX2, coefficientsX2, state, gradient );
140
gradient *=
m_factor
;
141
}
142
143
void
read
(
InArchive
& ar){
144
ar >>
m_factor
;
145
ar >> *
m_base
;
146
}
147
148
/// \brief The kernel does not serialize anything
149
void
write
(
OutArchive
& ar)
const
{
150
ar <<
m_factor
;
151
//const cast needed to prevent warning
152
ar << const_cast<AbstractKernelFunction<InputType>
const
&>(*m_base);
153
}
154
155
protected
:
156
AbstractKernelFunction<InputType>
*
m_base
;
///< kernel to scale
157
double
m_factor
;
///< scaling factor
158
};
159
160
typedef
ScaledKernel<>
DenseScaledKernel
;
161
typedef
ScaledKernel<CompressedRealVector>
CompressedScaledKernel
;
162
163
}
164
#endif