Shark machine learning library
Installation
Tutorials
Benchmarks
Documentation
Quick references
Class list
Global functions
include
shark
Algorithms
Trainers
AbstractWeightedTrainer.h
Go to the documentation of this file.
1
//===========================================================================
2
/*!
3
*
4
*
5
* \brief Abstract Trainer Interface for trainers that support weighting
6
*
7
*
8
*
9
* \author O. Krause
10
* \date 2014
11
*
12
*
13
* \par Copyright 1995-2017 Shark Development Team
14
*
15
* <BR><HR>
16
* This file is part of Shark.
17
* <https://shark-ml.github.io/Shark/>
18
*
19
* Shark is free software: you can redistribute it and/or modify
20
* it under the terms of the GNU Lesser General Public License as published
21
* by the Free Software Foundation, either version 3 of the License, or
22
* (at your option) any later version.
23
*
24
* Shark is distributed in the hope that it will be useful,
25
* but WITHOUT ANY WARRANTY; without even the implied warranty of
26
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27
* GNU Lesser General Public License for more details.
28
*
29
* You should have received a copy of the GNU Lesser General Public License
30
* along with Shark. If not, see <http://www.gnu.org/licenses/>.
31
*
32
*/
33
//===========================================================================
34
#ifndef SHARK_ALGORITHMS_TRAINERS_ABSTRACT_WEIGHTED_TRAINER_H
35
#define SHARK_ALGORITHMS_TRAINERS_ABSTRACT_WEIGHTED_TRAINER_H
36
37
#include <
shark/Data/WeightedDataset.h
>
38
#include <
shark/Algorithms/Trainers/AbstractTrainer.h
>
39
40
namespace
shark
{
41
42
43
/// \brief Superclass of weighted supervised learning algorithms
44
///
45
/// \par
46
/// AbstractWeightedTrainer is the super class of all trainers
47
/// that support weighted datasets. Weights are interpreted here
48
/// as the importance of a sample. unweighted training assumes
49
/// that all samples have the same importance, or weight.
50
/// The higher the weight, the more important a point. Weight
51
/// 0 is the same as if the point would not be part of the dataset.
52
/// Negative weights are not allowed.
53
///
54
/// When all weights are integral values there is a simple interpretation
55
/// of the weights as the multiplicity of a point. Thus training
56
/// with a dataset with duplicate points is the same as counting the duplicates
57
/// and run the algorithm with a weighted dataset where all points are unique and
58
/// have their weight is the multiplicity.
59
/// \ingroup supervised_trainer
60
template
<
class
Model,
class
LabelTypeT =
typename
Model::OutputType>
61
class
AbstractWeightedTrainer
:
public
AbstractTrainer
<Model,LabelTypeT>
62
{
63
private
:
64
typedef
AbstractTrainer<Model,LabelTypeT>
base_type
;
65
public
:
66
typedef
typename
base_type::ModelType
ModelType
;
67
typedef
typename
base_type::InputType
InputType
;
68
typedef
typename
base_type::LabelType
LabelType
;
69
typedef
typename
base_type::DatasetType
DatasetType
;
70
typedef
WeightedLabeledData<InputType, LabelType>
WeightedDatasetType
;
71
72
/// \brief Executes the algorithm and trains a model on the given weighted data.
73
virtual
void
train
(
ModelType
& model,
WeightedDatasetType
const
& dataset) = 0;
74
75
/// \brief Executes the algorithm and trains a model on the given unweighted data.
76
///
77
/// This method behaves as using train with a weighted dataset where all weights are equal.
78
/// The default implementation just creates such a dataset and executes the weighted
79
/// version of the algorithm.
80
virtual
void
train
(
ModelType
& model,
DatasetType
const
& dataset){
81
train
(model,
WeightedDatasetType
(dataset, 1.0));
82
}
83
};
84
85
86
/// \brief Superclass of weighted unsupervised learning algorithms
87
///
88
/// \par
89
/// AbstractWeightedUnsupervisedTrainer is the super class of all trainers
90
/// that support weighted datasets. See AbstractWeightedTrainer for more information on
91
/// the weights.
92
/// \see AbstractWeightedTrainer
93
/// \ingroup unsupervised_trainer
94
template
<
class
Model>
95
class
AbstractWeightedUnsupervisedTrainer
:
public
AbstractUnsupervisedTrainer
<Model>
96
{
97
private
:
98
typedef
AbstractUnsupervisedTrainer<Model>
base_type
;
99
public
:
100
typedef
typename
base_type::ModelType
ModelType
;
101
typedef
typename
base_type::InputType
InputType
;
102
typedef
typename
base_type::DatasetType
DatasetType
;
103
typedef
WeightedUnlabeledData<InputType>
WeightedDatasetType
;
104
105
/// \brief Excecutes the algorithm and trains a model on the given weighted data.
106
virtual
void
train
(
ModelType
& model,
WeightedDatasetType
const
& dataset) = 0;
107
108
/// \brief Excecutes the algorithm and trains a model on the given undata.
109
///
110
/// This method behaves as using train with a weighted dataset where all weights are equal.
111
/// The default implementation just creates such a dataset and executes the weighted
112
/// version of the algorithm.
113
virtual
void
train
(
ModelType
& model,
DatasetType
const
& dataset){
114
train
(model,
WeightedDatasetType
(dataset, 1.0));
115
}
116
};
117
118
119
}
120
#endif