Shark machine learning library
Installation
Tutorials
Benchmarks
Documentation
Quick references
Class list
Global functions
include
shark
LinAlg
RegularizedKernelMatrix.h
Go to the documentation of this file.
1
//===========================================================================
2
/*!
3
*
4
*
5
* \brief Kernel Gram matrix with modified diagonal
6
*
7
*
8
* \par
9
*
10
*
11
*
12
* \author T. Glasmachers
13
* \date 2007-2012
14
*
15
*
16
* \par Copyright 1995-2017 Shark Development Team
17
*
18
* <BR><HR>
19
* This file is part of Shark.
20
* <https://shark-ml.github.io/Shark/>
21
*
22
* Shark is free software: you can redistribute it and/or modify
23
* it under the terms of the GNU Lesser General Public License as published
24
* by the Free Software Foundation, either version 3 of the License, or
25
* (at your option) any later version.
26
*
27
* Shark is distributed in the hope that it will be useful,
28
* but WITHOUT ANY WARRANTY; without even the implied warranty of
29
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
30
* GNU Lesser General Public License for more details.
31
*
32
* You should have received a copy of the GNU Lesser General Public License
33
* along with Shark. If not, see <http://www.gnu.org/licenses/>.
34
*
35
*/
36
//===========================================================================
37
38
39
#ifndef SHARK_LINALG_REGULARIZEDKERNELMATRIX_H
40
#define SHARK_LINALG_REGULARIZEDKERNELMATRIX_H
41
42
#include <
shark/Data/Dataset.h
>
43
#include <
shark/LinAlg/Base.h
>
44
45
#include <vector>
46
#include <cmath>
47
48
49
namespace
shark
{
50
51
52
///
53
/// \brief Kernel Gram matrix with modified diagonal
54
///
55
/// \par
56
/// Regularized version of KernelMatrix. The regularization
57
/// is achieved by adding a vector to the matrix diagonal.
58
/// In particular, this is useful for support vector machines
59
/// with 2-norm penalty term.
60
///
61
template
<
class
InputType,
class
CacheType>
62
class
RegularizedKernelMatrix
63
{
64
private
:
65
typedef
KernelMatrix<InputType,CacheType>
Matrix
;
66
public
:
67
typedef
typename
Matrix::QpFloatType
QpFloatType
;
68
69
/// Constructor
70
/// \param kernelfunction kernel function
71
/// \param data data to evaluate the kernel function
72
/// \param diagModification vector d of diagonal modifiers
73
RegularizedKernelMatrix
(
74
AbstractKernelFunction<InputType>
const
& kernelfunction,
75
Data<InputType>
const
& data,
76
const
RealVector& diagModification
77
):
m_matrix
(kernelfunction,data),
m_diagMod
(diagModification){
78
SIZE_CHECK
(
size
() == diagModification.size());
79
}
80
81
/// return a single matrix entry
82
QpFloatType
operator ()
(std::size_t i, std::size_t j)
const
83
{
return
entry
(i, j); }
84
85
/// return a single matrix entry
86
QpFloatType
entry
(std::size_t i, std::size_t j)
const
87
{
88
QpFloatType
ret =
m_matrix
(i,j);
89
if
(i == j) ret += (
QpFloatType
)
m_diagMod
(i);
90
return
ret;
91
}
92
93
/// \brief Computes the i-th row of the kernel matrix.
94
///
95
///The entries start,...,end of the i-th row are computed and stored in storage.
96
///There must be enough room for this operation preallocated.
97
void
row
(std::size_t k, std::size_t start,std::size_t end,
QpFloatType
* storage)
const
{
98
m_matrix
.
row
(k,start,end,storage);
99
//apply regularization
100
if
(k >= start && k < end){
101
storage[k-start] += (
QpFloatType
)
m_diagMod
(k);
102
}
103
}
104
105
/// \brief Computes the kernel-matrix
106
template
<
class
M>
107
void
matrix
(
108
blas::matrix_expression<M, blas::cpu_tag> & storage
109
)
const
{
110
m_matrix
.
matrix
(storage);
111
for
(std::size_t k = 0; k !=
size
(); ++k){
112
storage()(k,k) += (
QpFloatType
)
m_diagMod
(k);
113
}
114
}
115
116
/// swap two variables
117
void
flipColumnsAndRows
(std::size_t i, std::size_t j){
118
m_matrix
.
flipColumnsAndRows
(i,j);
119
std::swap(
m_diagMod
(i),
m_diagMod
(j));
120
}
121
122
/// return the size of the quadratic matrix
123
std::size_t
size
()
const
124
{
return
m_matrix
.
size
(); }
125
126
/// query the kernel access counter
127
unsigned
long
long
getAccessCount
()
const
128
{
return
m_matrix
.
getAccessCount
(); }
129
130
/// reset the kernel access counter
131
void
resetAccessCount
()
132
{
m_matrix
.
resetAccessCount
(); }
133
134
protected
:
135
Matrix
m_matrix
;
136
RealVector
m_diagMod
;
137
};
138
139
}
140
#endif