RankingSvmTrainer.h
Go to the documentation of this file.
1//===========================================================================
2/*!
3 *
4 *
5 * \brief Support Vector Machine Trainer for the ranking-SVM
6 *
7 *
8 * \author T. Glasmachers
9 * \date 2016
10 *
11 *
12 * \par Copyright 1995-2017 Shark Development Team
13 *
14 * <BR><HR>
15 * This file is part of Shark.
16 * <https://shark-ml.github.io/Shark/>
17 *
18 * Shark is free software: you can redistribute it and/or modify
19 * it under the terms of the GNU Lesser General Public License as published
20 * by the Free Software Foundation, either version 3 of the License, or
21 * (at your option) any later version.
22 *
23 * Shark is distributed in the hope that it will be useful,
24 * but WITHOUT ANY WARRANTY; without even the implied warranty of
25 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
26 * GNU Lesser General Public License for more details.
27 *
28 * You should have received a copy of the GNU Lesser General Public License
29 * along with Shark. If not, see <http://www.gnu.org/licenses/>.
30 *
31 */
32//===========================================================================
33
34
35#ifndef SHARK_ALGORITHMS_RANKINGSVMTRAINER_H
36#define SHARK_ALGORITHMS_RANKINGSVMTRAINER_H
37
38
46
47namespace shark {
48
49
50///
51/// \brief Training of an SVM for ranking.
52///
53/// A ranking SVM trains a function (linear or linear in a kernel
54/// induced feature space, RKHS) with the aim that the function values
55/// are consistent with given pairwise rankings. I.e., given are pairs
56/// (a, b) of points, and the task of SVM training is to find a
57/// function f such that f(a) < f(b). More exactly, the hard margin
58/// ranking SVM aims for f(b) - f(a) >= 1 while minimizing the squared
59/// RKHS norm of f. The soft-margin SVM relates the constraint analogous
60/// to a standard C-SVM.
61///
62/// The trained model is a real-valued function. To predict the ranking
63/// of a pair of points the function is applied to both points. The one
64/// with smaller function value is ranked as being "smaller", i.e., if f
65/// is the trained model and a and b are data points, then the following
66/// code computes the ranking:
67///
68/// bool a_better_than_b = (f(a) < f(b));
69///
70/// \ingroup supervised_trainer
71template <class InputType, class CacheType = float>
72class RankingSvmTrainer : public AbstractSvmTrainer< InputType, unsigned int, KernelExpansion<InputType> >
73{
74private:
76
77public:
78 /// \brief Convenience typedefs:
79 /// this and many of the below typedefs build on the class template type CacheType.
80 /// Simply changing that one template parameter CacheType thus allows to flexibly
81 /// switch between using float or double as type for caching the kernel values.
82 /// The default is float, offering sufficient accuracy in the vast majority
83 /// of cases, at a memory cost of only four bytes. However, the template
84 /// parameter makes it easy to use double instead, (e.g., in case high
85 /// accuracy training is needed).
86 typedef CacheType QpFloatType;
87
89
90 //! Constructor
91 //! \param kernel kernel function to use for training and prediction
92 //! \param C regularization parameter - always the 'true' value of C, even when unconstrained is set
93 //! \param unconstrained when a C-value is given via setParameter, should it be piped through the exp-function before using it in the solver?
94 RankingSvmTrainer(KernelType* kernel, double C, bool unconstrained = false)
95 : base_type(kernel, C, false, unconstrained)
96 { }
97
98 /// \brief From INameable: return the class name.
99 std::string name() const
100 { return "RankingSvmTrainer"; }
101
102 /// \brief Train the ranking SVM.
103 ///
104 /// This variant of the train function assumes that all pairs of
105 /// points should be ranked according to the order they appear in
106 /// the data set.
107 void train(KernelExpansion<InputType>& function, Data<InputType> const& dataset)
108 {
109 // create all pairs
110 std::size_t n = dataset.numberOfElements();
111 std::vector<std::pair<std::size_t, std::size_t>> pairs;
112 for (std::size_t i=0; i<n; i++) {
113 for (std::size_t j=0; j<i; j++) {
114 pairs.push_back(std::make_pair(j, i));
115 }
116 }
117 train(function, dataset, pairs);
118 }
119
120 /// \brief Train the ranking SVM.
121 ///
122 /// This variant of the train function uses integer labels to define
123 /// pairwise rankings. It is trained on all pairs of data points
124 /// with different label, aiming for a smaller function value for
125 /// the point with smaller label.
127 {
128 std::vector<std::pair<std::size_t, std::size_t>> pairs;
129 std::size_t i = 0;
130 for (auto const& yi : dataset.labels().elements()) {
131 std::size_t j = 0;
132 for (auto const& yj : dataset.labels().elements()) {
133 if (j >= i) break;
134 if (yi < yj) pairs.push_back(std::make_pair(i, j));
135 else if (yi > yj) pairs.push_back(std::make_pair(j, i));
136 j++;
137 }
138 i++;
139 }
140 train(function, dataset.inputs(), pairs);
141 }
142
143 /// \brief Train the ranking SVM.
144 ///
145 /// This variant of the train function works with explicitly given
146 /// pairs of data points. Each pair is identified by the indices of
147 /// the training points in the data set.
148 void train(KernelExpansion<InputType>& function, Data<InputType> const& dataset, std::vector<std::pair<std::size_t, std::size_t>> const& pairs)
149 {
150 function.setStructure(base_type::m_kernel, dataset, false);
151 DifferenceKernelMatrix<InputType, QpFloatType> dm(*function.kernel(), dataset, pairs);
152
154 {
156 trainInternal(function, dataset, pairs, matrix);
157 }
158 else
159 {
161 trainInternal(function, dataset, pairs, matrix);
162 }
163 }
164
165private:
166 template <typename MatrixType>
167 void trainInternal(KernelExpansion<InputType>& function, Data<InputType> const& dataset, std::vector<std::pair<std::size_t, std::size_t>> const& pairs, MatrixType& matrix)
168 {
170 qp.linear = RealVector(qp.dimensions(), 1.0);
171 qp.boxMin = RealVector(qp.dimensions(), 0.0);
172 qp.boxMax = RealVector(qp.dimensions(), this->C());
174 ProblemType problem(qp, base_type::m_shrinking);
175
176 QpSolver<ProblemType> solver(problem);
178 RealVector alpha = problem.getUnpermutedAlpha();
179 RealVector coeff(dataset.numberOfElements(), 0.0);
180 SIZE_CHECK(pairs.size() == alpha.size());
181 for (std::size_t i=0; i<alpha.size(); i++)
182 {
183 double a = alpha(i);
184 coeff(pairs[i].first) -= a;
185 coeff(pairs[i].second) += a;
186 }
187 blas::column(function.alpha(),0) = coeff;
188
189 if (base_type::sparsify()) function.sparsify();
190 }
191};
192
193
194}
195#endif