JaakkolaHeuristic.h
Go to the documentation of this file.
1//===========================================================================
2/*!
3 *
4 *
5 * \brief Jaakkola's heuristic and related quantities for Gaussian kernel selection
6 *
7 *
8 *
9 * \author T. Glasmachers, O. Krause, C. Igel
10 * \date 2010
11 *
12 *
13 * \par Copyright 1995-2017 Shark Development Team
14 *
15 * <BR><HR>
16 * This file is part of Shark.
17 * <https://shark-ml.github.io/Shark/>
18 *
19 * Shark is free software: you can redistribute it and/or modify
20 * it under the terms of the GNU Lesser General Public License as published
21 * by the Free Software Foundation, either version 3 of the License, or
22 * (at your option) any later version.
23 *
24 * Shark is distributed in the hope that it will be useful,
25 * but WITHOUT ANY WARRANTY; without even the implied warranty of
26 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27 * GNU Lesser General Public License for more details.
28 *
29 * You should have received a copy of the GNU Lesser General Public License
30 * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31 *
32 */
33//===========================================================================
34
35
36#ifndef SHARK_ALGORITHMS_JAAKKOLAHEURISTIC_H
37#define SHARK_ALGORITHMS_JAAKKOLAHEURISTIC_H
38
39
40#include <shark/Data/Dataset.h>
42
43#include <boost/range/adaptor/filtered.hpp>
44#include <algorithm>
45
46namespace shark{
47
48
49/// \brief Jaakkola's heuristic and related quantities for Gaussian kernel selection
50///
51/// \par
52/// Jaakkola's heuristic method for setting the width parameter of the
53/// Gaussian radial basis function kernel is to pick a quantile (usually
54/// the median) of the distribution of Euclidean distances between points
55/// having different labels. The present implementation computes the kernel
56/// width \f$ \sigma \f$ and the bandwidth
57/// \f[ \gamma = \frac{1}{2 \sigma^2} \f]
58/// based on the median or on any other quantile of the empirical
59/// distribution.
60///
61/// By default, only the distance to the closest point with different
62/// label is considered. This behavior can be turned off by an option
63/// of the constructor. This is faster andin accordance with the
64/// original paper.
66{
67public:
68 /// Constructor
69 /// \param dataset vector-valued input data
70 /// \param nearestFalseNeighbor if true, only the nearest neighboring point with different label is considered (default true)
71 template<class InputType>
72 JaakkolaHeuristic(LabeledData<InputType,unsigned int> const& dataset, bool nearestFalseNeighbor = true)
73 {
75 typedef typename ConstProxyReference<InputType const>::type Element;
76 Elements elements = dataset.elements();
77 if(!nearestFalseNeighbor) {
78 for(typename Elements::iterator it = elements.begin(); it != elements.end(); ++it){
79 Element x = it->input;
80 typename Elements::iterator itIn = it;
81 itIn++;
82 for (; itIn != elements.end(); itIn++) {
83 if (itIn->label == it->label) continue;
84 Element y = itIn->input;
85 double dist = distanceSqr(x,y);
86 m_stat.push_back(dist);
87 }
88 }
89
90 } else {
91 std::size_t classes = numberOfClasses(dataset);
92 std::size_t dim = inputDimension(dataset);
93 m_stat.resize(dataset.numberOfElements());
94 std::fill(m_stat.begin(),m_stat.end(), std::numeric_limits<double>::max());
95 std::size_t blockStart = 0;
96 for(std::size_t c = 0; c != classes; ++c){
97
98 typename Elements::iterator leftIt = elements.begin();
99 typename Elements::iterator end = elements.end();
100 while(leftIt != end){
101 //todo: use a filter on the iterator
102 //create the next batch containing only elements of class c as left argument to distanceSqr
103 typename Batch<InputType>::type leftBatch(512, dim);
104 std::size_t leftElements = 0;
105 while(leftElements < 512 && leftIt != end){
106 if(leftIt->label == c){
107 row(leftBatch,leftElements) = leftIt->input;
108 ++leftElements;
109 }
110 ++leftIt;
111 }
112 //now go through all elements and again create batches, this time of all elements which are not of class c
113 typename Elements::iterator rightIt = elements.begin();
114 while(rightIt != end){
115 typename Batch<InputType>::type rightBatch(512, dim);
116 std::size_t rightElements = 0;
117 while(rightElements < 512 && rightIt != end){
118 if(rightIt->label != c){
119 row(rightBatch,rightElements) = rightIt->input;
120 ++rightElements;
121 }
122 ++rightIt;
123 }
124
125 //now compute distances and update shortest distance
126 RealMatrix distances = distanceSqr(leftBatch,rightBatch);
127 for(std::size_t i = 0; i != leftElements;++i){
128 m_stat[blockStart+i]=std::min(min(subrange(row(distances,i),0,rightElements)),m_stat[blockStart+i]);
129 }
130 }
131 blockStart+= leftElements;
132 }
133 }
134
135 //~ for(typename Elements::iterator it = elements.begin(); it != elements.end(); ++it){
136 //~ double minDistSqr = std::numeric_limits<double>::max();//0;
137 //~ Element x = it->input;
138 //~ for (typename Elements::iterator itIn = elements.begin(); itIn != elements.end(); itIn++) {
139 //~ if (itIn->label == it->label) continue;
140 //~ Element y = itIn->input;
141 //~ double dist = distanceSqr(x,y);
142 //~ //if( (minDistSqr == 0) || (dist < minDistSqr)) minDistSqr = dist;
143 //~ if(dist < minDistSqr) minDistSqr = dist;
144 //~ }
145 //~ m_stat.push_back(minDistSqr);
146 //~ }
147
148 }
149 std::sort(m_stat.begin(), m_stat.end());
150 }
151
152 /// Compute the given quantile (usually median)
153 /// of the empirical distribution of Euclidean distances
154 /// of data pairs with different labels.
155 double sigma(double quantile = 0.5)
156 {
157 std::size_t ic = m_stat.size();
158 SHARK_ASSERT(ic > 0);
159
160 std::sort(m_stat.begin(), m_stat.end());
161
162 if (quantile < 0.0)
163 {
164 // TODO: find minimum
165 return std::sqrt(m_stat[0]);
166 }
167 if (quantile >= 1.0)
168 {
169 // TODO: find maximum
170 return std::sqrt(m_stat[ic-1]);
171 }
172 else
173 {
174 // TODO: partial sort!
175 double t = quantile * (ic - 1);
176 std::size_t i = (std::size_t)floor(t);
177 double rest = t - i;
178 return ((1.0 - rest) * std::sqrt(m_stat[i]) + rest * std::sqrt(m_stat[i+1]));
179 }
180 }
181
182 /// Compute the given quantile (usually the median)
183 /// of the empirical distribution of Euclidean distances
184 /// of data pairs with different labels converted into
185 /// a value usable as the gamma parameter of the GaussianRbfKernel.
186 double gamma(double quantile = 0.5)
187 {
188 double s = sigma(quantile);
189 return 0.5 / (s * s);
190 }
191
192
193private:
194 /// all pairwise distances
195 std::vector<double> m_stat;
196};
197
198}
199#endif