Shark machine learning library
Installation
Tutorials
Benchmarks
Documentation
Quick references
Class list
Global functions
include
shark
Algorithms
Trainers
Perceptron.h
Go to the documentation of this file.
1
//===========================================================================
2
/*!
3
*
4
*
5
* \brief Perceptron
6
*
7
*
8
*
9
* \author O. Krause
10
* \date 2010
11
*
12
*
13
* \par Copyright 1995-2017 Shark Development Team
14
*
15
* <BR><HR>
16
* This file is part of Shark.
17
* <https://shark-ml.github.io/Shark/>
18
*
19
* Shark is free software: you can redistribute it and/or modify
20
* it under the terms of the GNU Lesser General Public License as published
21
* by the Free Software Foundation, either version 3 of the License, or
22
* (at your option) any later version.
23
*
24
* Shark is distributed in the hope that it will be useful,
25
* but WITHOUT ANY WARRANTY; without even the implied warranty of
26
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27
* GNU Lesser General Public License for more details.
28
*
29
* You should have received a copy of the GNU Lesser General Public License
30
* along with Shark. If not, see <http://www.gnu.org/licenses/>.
31
*
32
*/
33
//===========================================================================
34
#ifndef SHARK_ALGORITHMS_TRAINERS_PERCEPTRON_H
35
#define SHARK_ALGORITHMS_TRAINERS_PERCEPTRON_H
36
37
38
#include <
shark/Models/Kernels/KernelExpansion.h
>
39
#include <
shark/Algorithms/Trainers/AbstractTrainer.h
>
40
41
namespace
shark
{
42
43
/// \brief Perceptron online learning algorithm
44
/// \ingroup supervised_trainer
45
template
<
class
InputType>
46
class
Perceptron
:
public
AbstractTrainer
<KernelClassifier<InputType>,unsigned int >
47
{
48
public
:
49
/// \brief Constructor.
50
///
51
/// @param kernel is the (Mercer) kernel function.
52
/// @param maxTimesPattern defines the maximum number of times the data is processed before the algorithms stopps.
53
Perceptron
(
AbstractKernelFunction<InputType>
* kernel, std::size_t maxTimesPattern = 10000)
54
:mpe_kernel(kernel),m_maxTimesPattern(maxTimesPattern){}
55
56
/// \brief From INameable: return the class name.
57
std::string
name
()
const
58
{
return
"Perceptron"
; }
59
60
void
train
(
KernelClassifier<InputType>
& classifier,
LabeledData<InputType, unsigned int>
const
& dataset){
61
std::size_t patterns = dataset.
numberOfElements
();
62
KernelExpansion<InputType>
& model= classifier.
decisionFunction
();
63
model.
setStructure
(mpe_kernel,dataset.
inputs
(),
false
,1);
64
model.
alpha
().clear();
65
66
bool
err;
67
std::size_t iter = 0;
68
do
{
69
err =
false
;
70
for
(std::size_t i = 0; i != patterns; i++){
71
double
result = model(dataset.
element
(i).input)(0);
72
//perceptron learning rule with modified target from -1;1
73
double
label = dataset.
element
(i).label*2.0-1;
74
if
( result * label <= 0.0){
75
model.
alpha
(i,0) += label;
76
err =
true
;
77
}
78
}
79
if
(iter > m_maxTimesPattern * patterns)
break
;
// probably non-separable data
80
iter++;
81
}
while
(err);
82
}
83
private
:
84
AbstractKernelFunction<InputType>
* mpe_kernel;
85
std::size_t m_maxTimesPattern;
//< maximum number of times a training is processed
86
};
87
88
89
}
90
#endif