Perceptron.h
Go to the documentation of this file.
1//===========================================================================
2/*!
3 *
4 *
5 * \brief Perceptron
6 *
7 *
8 *
9 * \author O. Krause
10 * \date 2010
11 *
12 *
13 * \par Copyright 1995-2017 Shark Development Team
14 *
15 * <BR><HR>
16 * This file is part of Shark.
17 * <https://shark-ml.github.io/Shark/>
18 *
19 * Shark is free software: you can redistribute it and/or modify
20 * it under the terms of the GNU Lesser General Public License as published
21 * by the Free Software Foundation, either version 3 of the License, or
22 * (at your option) any later version.
23 *
24 * Shark is distributed in the hope that it will be useful,
25 * but WITHOUT ANY WARRANTY; without even the implied warranty of
26 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27 * GNU Lesser General Public License for more details.
28 *
29 * You should have received a copy of the GNU Lesser General Public License
30 * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31 *
32 */
33//===========================================================================
34#ifndef SHARK_ALGORITHMS_TRAINERS_PERCEPTRON_H
35#define SHARK_ALGORITHMS_TRAINERS_PERCEPTRON_H
36
37
40
41namespace shark{
42
43/// \brief Perceptron online learning algorithm
44/// \ingroup supervised_trainer
45template<class InputType>
46class Perceptron : public AbstractTrainer<KernelClassifier<InputType>,unsigned int >
47{
48public:
49 /// \brief Constructor.
50 ///
51 /// @param kernel is the (Mercer) kernel function.
52 /// @param maxTimesPattern defines the maximum number of times the data is processed before the algorithms stopps.
53 Perceptron(AbstractKernelFunction<InputType>* kernel, std::size_t maxTimesPattern = 10000)
54 :mpe_kernel(kernel),m_maxTimesPattern(maxTimesPattern){}
55
56 /// \brief From INameable: return the class name.
57 std::string name() const
58 { return "Perceptron"; }
59
61 std::size_t patterns = dataset.numberOfElements();
63 model.setStructure(mpe_kernel,dataset.inputs(),false,1);
64 model.alpha().clear();
65
66 bool err;
67 std::size_t iter = 0;
68 do {
69 err = false;
70 for (std::size_t i = 0; i != patterns; i++){
71 double result = model(dataset.element(i).input)(0);
72 //perceptron learning rule with modified target from -1;1
73 double label = dataset.element(i).label*2.0-1;
74 if ( result * label <= 0.0){
75 model.alpha(i,0) += label;
76 err = true;
77 }
78 }
79 if (iter > m_maxTimesPattern * patterns) break; // probably non-separable data
80 iter++;
81 } while (err);
82 }
83private:
85 std::size_t m_maxTimesPattern; //< maximum number of times a training is processed
86};
87
88
89}
90#endif