Shark machine learning library
Installation
Tutorials
Benchmarks
Documentation
Quick references
Class list
Global functions
include
shark
Algorithms
Trainers
LogisticRegression.h
Go to the documentation of this file.
1
//===========================================================================
2
/*!
3
*
4
*
5
* \brief Logistic Regression
6
*
7
*
8
*
9
* \author O.Krause
10
* \date 2017
11
*
12
*
13
* \par Copyright 1995-2017 Shark Development Team
14
*
15
* <BR><HR>
16
* This file is part of Shark.
17
* <https://shark-ml.github.io/Shark/>
18
*
19
* Shark is free software: you can redistribute it and/or modify
20
* it under the terms of the GNU Lesser General Public License as published
21
* by the Free Software Foundation, either version 3 of the License, or
22
* (at your option) any later version.
23
*
24
* Shark is distributed in the hope that it will be useful,
25
* but WITHOUT ANY WARRANTY; without even the implied warranty of
26
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27
* GNU Lesser General Public License for more details.
28
*
29
* You should have received a copy of the GNU Lesser General Public License
30
* along with Shark. If not, see <http://www.gnu.org/licenses/>.
31
*
32
*/
33
//===========================================================================
34
35
36
#ifndef SHARK_ALGORITHMS_TRAINERS_LOGISTICREGRESSION_H
37
#define SHARK_ALGORITHMS_TRAINERS_LOGISTICREGRESSION_H
38
39
#include <
shark/Models/LinearModel.h
>
40
#include <
shark/Algorithms/Trainers/AbstractWeightedTrainer.h
>
41
42
43
namespace
shark
{
44
45
/// \brief Trainer for Logistic regression
46
///
47
/// Logistic regression solves the following optimization problem:
48
/// \f[ \min_{w,b} \sum_i u_i l(y_i,f(x_i^Tw+b)) +\lambda_1 |w|_1 +\lambda_2 |w|^2_2 \f]
49
/// Where \f$l\f$ is the cross-entropy loss and \f$u_i\f$ are individual weuights for each point(assumed to be 1).
50
/// Logistic regression is one of the most well known
51
/// machine learning algorithms for classification using linear models.
52
///
53
/// The solver is based on LBFGS for the case where no l1-regularization is used. Otherwise
54
/// the problem is transformed into a constrained problem and the constrined-LBFGS algorithm
55
/// is used. This is one of the most efficient solvers for logistic regression as long as the
56
/// number of data points is not too large.
57
/// \ingroup supervised_trainer
58
template
<
class
InputVectorType = RealVector>
59
class
LogisticRegression
:
public
AbstractWeightedTrainer
<LinearClassifier<InputVectorType> >,
public
IParameterizable
<>
60
{
61
private
:
62
typedef
AbstractWeightedTrainer<LinearClassifier<InputVectorType>
>
base_type
;
63
public
:
64
typedef
typename
base_type::ModelType
ModelType
;
65
typedef
typename
base_type::DatasetType
DatasetType
;
66
typedef
typename
base_type::WeightedDatasetType
WeightedDatasetType
;
67
68
/// \brief Constructor.
69
///
70
/// \param lambda1 value of the 1-norm regularization parameter (see class description)
71
/// \param lambda2 value of the 2-norm regularization parameter (see class description)
72
/// \param bias whether to train with bias or not
73
/// \param accuracy stopping criterion for the iterative solver, maximal gradient component of the objective function (see class description)
74
LogisticRegression
(
double
lambda1
= 0,
double
lambda2
= 0,
bool
bias =
true
,
double
accuracy
= 1.e-8)
75
: m_bias(bias){
76
setLambda1
(
lambda1
);
77
setLambda2
(
lambda2
);
78
setAccuracy
(
accuracy
);
79
}
80
81
/// \brief From INameable: return the class name.
82
std::string
name
()
const
83
{
return
"LogisticRegression"
; }
84
85
86
/// \brief Return the current setting of the l1-regularization parameter.
87
double
lambda1
()
const
{
88
return
m_lambda1;
89
}
90
91
/// \brief Return the current setting of the l2-regularization parameter.
92
double
lambda2
()
const
{
93
return
m_lambda2;
94
}
95
96
/// \brief Set the l1-regularization parameter.
97
void
setLambda1
(
double
lambda){
98
SHARK_RUNTIME_CHECK
(lambda >= 0.0,
"Lambda1 must be positive"
);
99
m_lambda1 = lambda;
100
}
101
102
/// \brief Set the l2-regularization parameter.
103
void
setLambda2
(
double
lambda){
104
SHARK_RUNTIME_CHECK
(lambda >= 0.0,
"Lambda2 must be positive"
);
105
m_lambda2 = lambda;
106
}
107
/// \brief Return the current setting of the accuracy (maximal gradient component of the optimization problem).
108
double
accuracy
()
const
{
109
return
m_accuracy;
110
}
111
112
/// \brief Set the accuracy (maximal gradient component of the optimization problem).
113
void
setAccuracy
(
double
accuracy
){
114
SHARK_RUNTIME_CHECK
(
accuracy
> 0.0,
"Accuracy must be positive"
);
115
m_accuracy =
accuracy
;
116
}
117
118
/// \brief Get the regularization parameters lambda1 and lambda2 through the IParameterizable interface.
119
RealVector
parameterVector
()
const
{
120
return
{m_lambda1,m_lambda2};
121
}
122
123
/// \brief Set the regularization parameters lambda1 and lambda2 through the IParameterizable interface.
124
void
setParameterVector
(RealVector
const
& param){
125
SIZE_CHECK
(param.size() == 2);
126
setLambda1
(param(0));
127
setLambda2
(param(1));
128
}
129
130
/// \brief Return the number of parameters (one in this case).
131
size_t
numberOfParameters
()
const
{
132
return
2;
133
}
134
135
/// \brief Train a linear model with logistic regression.
136
void
train
(
ModelType
& model,
DatasetType
const
& dataset);
137
138
/// \brief Train a linear model with logistic regression using weights.
139
void
train
(
ModelType
& model,
WeightedDatasetType
const
& dataset);
140
private
:
141
bool
m_bias;
///< whether to train with the bias parameter or not
142
double
m_lambda1;
///< l1-regularization parameter
143
double
m_lambda2;
///< l2-regularization parameter
144
double
m_accuracy;
///< gradient accuracy
145
};
146
147
//reference to explicit external template instantiation
148
extern
template
class
LogisticRegression<RealVector>
;
149
extern
template
class
LogisticRegression<FloatVector>
;
150
extern
template
class
LogisticRegression<CompressedRealVector>
;
151
extern
template
class
LogisticRegression<CompressedFloatVector>
;
152
153
}
154
#endif