Shark machine learning library
Installation
Tutorials
Benchmarks
Documentation
Quick references
Class list
Global functions
include
shark
Models
RBFLayer.h
Go to the documentation of this file.
1
/*!
2
*
3
*
4
* \brief Implements a radial basis function layer.
5
*
6
*
7
*
8
* \author O. Krause
9
* \date 2014
10
*
11
*
12
* \par Copyright 1995-2017 Shark Development Team
13
*
14
* <BR><HR>
15
* This file is part of Shark.
16
* <https://shark-ml.github.io/Shark/>
17
*
18
* Shark is free software: you can redistribute it and/or modify
19
* it under the terms of the GNU Lesser General Public License as published
20
* by the Free Software Foundation, either version 3 of the License, or
21
* (at your option) any later version.
22
*
23
* Shark is distributed in the hope that it will be useful,
24
* but WITHOUT ANY WARRANTY; without even the implied warranty of
25
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
26
* GNU Lesser General Public License for more details.
27
*
28
* You should have received a copy of the GNU Lesser General Public License
29
* along with Shark. If not, see <http://www.gnu.org/licenses/>.
30
*
31
*/
32
#ifndef SHARK_MODELS_RBFLayer_H
33
#define SHARK_MODELS_RBFLayer_H
34
35
#include <
shark/Core/DLLSupport.h
>
36
#include <
shark/Models/AbstractModel.h
>
37
#include <boost/math/constants/constants.hpp>
38
namespace
shark
{
39
40
/// \brief Implements a layer of radial basis functions in a neural network.
41
///
42
/// A Radial basis function layer as modeled in shark is a set of N
43
/// Gaussian distributions \f$ p(x|i) \f$.
44
/// \f[
45
/// p(x|i) = e^{\gamma_i*\|x-m_i\|^2}
46
/// \f]
47
/// and the layer transforms an input x to a vector \f$(p(x|1),\dots,p(x|N)\f$.
48
/// The \f$\gamma_i\f$ govern the width of the Gaussians, while the
49
/// vectors \f$ m_i \f$ set the centers of every Gaussian distribution.
50
///
51
/// RBF networks profit much from good guesses on the centers and
52
/// kernel function parameters. In case of a Gaussian kernel a call
53
/// to k-Means or the EM-algorithm can be used to get a good
54
/// initialisation for the network.
55
///
56
/// \ingroup models
57
class
RBFLayer
:
public
AbstractModel
<RealVector,RealVector>
58
{
59
private
:
60
struct
InternalState:
public
State
{
61
RealMatrix norm2;
62
63
void
resize(std::size_t numPatterns, std::size_t numNeurons){
64
norm2.resize(numPatterns,numNeurons);
65
}
66
};
67
68
public
:
69
/// \brief Creates an empty Radial Basis Function layer.
70
SHARK_EXPORT_SYMBOL
RBFLayer
();
71
72
/// \brief Creates a layer of a Radial Basis Function Network.
73
///
74
/// This method creates a Radial Basis Function Network (RBFN) with
75
/// \em numInput input neurons and \em numOutput output neurons.
76
///
77
/// \param numInput Number of input neurons, equal to dimensionality of
78
/// input space.
79
/// \param numOutput Number of output neurons, equal to dimensionality of
80
/// output space and number of gaussian distributions
81
SHARK_EXPORT_SYMBOL
RBFLayer
(std::size_t numInput, std::size_t numOutput);
82
83
/// \brief From INameable: return the class name.
84
std::string
name
()
const
85
{
return
"RBFLayer"
; }
86
87
///\brief Returns the current parameter vector. The amount and order of weights depend on the training parameters.
88
///
89
///The format of the parameter vector is \f$ (m_1,\dots,m_k,\log(\gamma_1),\dots,\log(\gamma_k))\f$
90
///if training of one or more parameters is deactivated, they are removed from the parameter vector
91
SHARK_EXPORT_SYMBOL
RealVector
parameterVector
()
const
;
92
93
///\brief Sets the new internal parameters.
94
SHARK_EXPORT_SYMBOL
void
setParameterVector
(RealVector
const
& newParameters);
95
96
///\brief Returns the number of parameters which are currently enabled for training.
97
SHARK_EXPORT_SYMBOL
std::size_t
numberOfParameters
()
const
;
98
99
///\brief Returns the number of input neurons.
100
Shape
inputShape
()
const
{
101
return
m_centers
.size2();
102
}
103
104
///\brief Returns the number of output neurons.
105
Shape
outputShape
()
const
{
106
return
m_centers
.size1();
107
}
108
109
boost::shared_ptr<State>
createState
()
const
{
110
return
boost::shared_ptr<State>(
new
InternalState());
111
}
112
113
114
/// \brief Configures a Radial Basis Function Network.
115
///
116
/// This method initializes the structure of the Radial Basis Function Network (RBFN) with
117
/// \em numInput input neurons, \em numOutput output neurons and \em numHidden
118
/// hidden neurons.
119
///
120
/// \param numInput Number of input neurons, equal to dimensionality of
121
/// input space.
122
/// \param numOutput Number of output neurons (basis functions), equal to dimensionality of
123
/// output space.
124
SHARK_EXPORT_SYMBOL
void
setStructure
(std::size_t numInput, std::size_t numOutput);
125
126
127
using
AbstractModel
<RealVector,RealVector>
::eval
;
128
SHARK_EXPORT_SYMBOL
void
eval
(
BatchInputType
const
& patterns,
BatchOutputType
& outputs,
State
& state)
const
;
129
130
131
SHARK_EXPORT_SYMBOL
void
weightedParameterDerivative
(
132
BatchInputType
const
& pattern,
BatchOutputType
const
& outputs,
133
BatchOutputType
const
& coefficients,
State
const
& state, RealVector& gradient
134
)
const
;
135
136
///\brief Enables or disables parameters for learning.
137
///
138
/// \param centers whether the centers should be trained
139
/// \param width whether the distribution width should be trained
140
SHARK_EXPORT_SYMBOL
void
setTrainingParameters
(
bool
centers
,
bool
width);
141
142
///\brief Returns the center values of the neurons.
143
BatchInputType
const
&
centers
()
const
{
144
return
m_centers
;
145
}
146
///\brief Sets the center values of the neurons.
147
BatchInputType
&
centers
(){
148
return
m_centers
;
149
}
150
151
///\brief Returns the width parameter of the Gaussian functions
152
RealVector
const
&
gamma
()
const
{
153
return
m_gamma
;
154
}
155
156
/// \brief sets the width parameters - the gamma values - of the distributions.
157
SHARK_EXPORT_SYMBOL
void
setGamma
(RealVector
const
&
gamma
);
158
159
/// From ISerializable, reads a model from an archive
160
SHARK_EXPORT_SYMBOL
void
read
(
InArchive
& archive );
161
162
/// From ISerializable, writes a model to an archive
163
SHARK_EXPORT_SYMBOL
void
write
(
OutArchive
& archive )
const
;
164
protected
:
165
//====model parameters
166
167
///\brief The center points. The i-th element corresponds to the center of neuron number i
168
RealMatrix
m_centers
;
169
170
///\brief stores the width parameters of the Gaussian functions
171
RealVector
m_gamma
;
172
173
/// \brief the logarithm of the normalization constant for every distribution
174
RealVector
m_logNormalization
;
175
176
//=====training parameters
177
///enables learning of the center points of the neurons
178
bool
m_trainCenters
;
179
///enables learning of the width parameters.
180
bool
m_trainWidth
;
181
182
183
184
};
185
}
186
187
#endif
188