Shark machine learning library
Installation
Tutorials
Benchmarks
Documentation
Quick references
Class list
Global functions
include
shark
LinAlg
GaussianKernelMatrix.h
Go to the documentation of this file.
1
//===========================================================================
2
/*!
3
*
4
*
5
* \brief Efficient special case if the kernel is gaussian and the inputs are sparse vectors
6
*
7
*
8
* \par
9
*
10
*
11
*
12
* \author T. Glasmachers
13
* \date 2007-2012
14
*
15
*
16
* \par Copyright 1995-2017 Shark Development Team
17
*
18
* <BR><HR>
19
* This file is part of Shark.
20
* <https://shark-ml.github.io/Shark/>
21
*
22
* Shark is free software: you can redistribute it and/or modify
23
* it under the terms of the GNU Lesser General Public License as published
24
* by the Free Software Foundation, either version 3 of the License, or
25
* (at your option) any later version.
26
*
27
* Shark is distributed in the hope that it will be useful,
28
* but WITHOUT ANY WARRANTY; without even the implied warranty of
29
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
30
* GNU Lesser General Public License for more details.
31
*
32
* You should have received a copy of the GNU Lesser General Public License
33
* along with Shark. If not, see <http://www.gnu.org/licenses/>.
34
*
35
*/
36
//===========================================================================
37
38
39
#ifndef SHARK_LINALG_GAUSSIANKERNELMATRIX_H
40
#define SHARK_LINALG_GAUSSIANKERNELMATRIX_H
41
42
#include <
shark/Data/Dataset.h
>
43
#include <
shark/LinAlg/Base.h
>
44
45
#include <vector>
46
#include <cmath>
47
48
49
namespace
shark
{
50
51
52
///\brief Efficient special case if the kernel is Gaussian and the inputs are sparse vectors
53
template
<
class
T,
class
CacheType>
54
class
GaussianKernelMatrix
55
{
56
public
:
57
58
typedef
CacheType
QpFloatType
;
59
typedef
T
InputType
;
60
61
/// Constructor
62
/// \param gamma bandwidth parameter of Gaussian kernel
63
/// \param data data evaluated by the kernel function
64
GaussianKernelMatrix
(
65
double
gamma,
66
Data<InputType>
const
& data
67
)
68
:
m_squaredNorms
(data.numberOfElements())
69
,
m_gamma
(gamma)
70
,
m_accessCounter
( 0 )
71
{
72
std::size_t elements = data.
numberOfElements
();
73
x
.resize(elements);
74
PointerType
iter=data.
elements
().begin();
75
for
(std::size_t i = 0; i != elements; ++i,++iter){
76
x
[i]=iter;
77
m_squaredNorms
(i) =inner_prod(*
x
[i],*
x
[i]);
//precompute the norms
78
}
79
}
80
81
/// return a single matrix entry
82
QpFloatType
operator ()
(std::size_t i, std::size_t j)
const
83
{
return
entry
(i, j); }
84
85
/// return a single matrix entry
86
QpFloatType
entry
(std::size_t i, std::size_t j)
const
87
{
88
++
m_accessCounter
;
89
double
distance =
m_squaredNorms
(i)-2*inner_prod(*
x
[i], *
x
[j])+
m_squaredNorms
(j);
90
return
(
QpFloatType
)std::exp(-
m_gamma
* distance);
91
}
92
93
/// \brief Computes the i-th row of the kernel matrix.
94
///
95
///The entries start,...,end of the i-th row are computed and stored in storage.
96
///There must be enough room for this operation preallocated.
97
void
row
(std::size_t i, std::size_t start,std::size_t end,
QpFloatType
* storage)
const
98
{
99
typename
ConstProxyReference<T>::type xi = *
x
[i];
100
m_accessCounter
+=end-start;
101
SHARK_PARALLEL_FOR
(
int
j = start; j < (int) end; j++)
102
{
103
double
distance =
m_squaredNorms
(i)-2*inner_prod(xi, *
x
[j])+
m_squaredNorms
(j);
104
storage[j-start] = std::exp(-
m_gamma
* distance);
105
}
106
}
107
108
/// \brief Computes the kernel-matrix
109
template
<
class
M>
110
void
matrix
(
111
blas::matrix_expression<M, blas::cpu_tag> & storage
112
)
const
{
113
for
(std::size_t i = 0; i !=
size
(); ++i){
114
row
(i,0,
size
(),&storage()(i,0));
115
}
116
}
117
118
/// swap two variables
119
void
flipColumnsAndRows
(std::size_t i, std::size_t j){
120
using
std::swap;
121
swap
(
x
[i],
x
[j]);
122
swap
(
m_squaredNorms
[i],
m_squaredNorms
[j]);
123
}
124
125
/// return the size of the quadratic matrix
126
std::size_t
size
()
const
127
{
return
x
.size(); }
128
129
/// query the kernel access counter
130
unsigned
long
long
getAccessCount
()
const
131
{
return
m_accessCounter
; }
132
133
/// reset the kernel access counter
134
void
resetAccessCount
()
135
{
m_accessCounter
= 0; }
136
137
protected
:
138
139
//~ typedef blas::sparse_vector_adaptor<typename T::value_type const,std::size_t> PointerType;
140
typedef
typename
Data<InputType>::const_element_range::iterator
PointerType
;
141
/// Array of data pointers for kernel evaluations
142
std::vector<PointerType>
x
;
143
144
RealVector
m_squaredNorms
;
145
146
double
m_gamma
;
147
148
/// counter for the kernel accesses
149
mutable
unsigned
long
long
m_accessCounter
;
150
};
151
152
}
153
#endif