NearestNeighborModel.h
Go to the documentation of this file.
1//===========================================================================
2/*!
3 *
4 *
5 * \brief NEarest neighbor model for classification and regression
6 *
7 *
8 *
9 * \author T. Glasmachers, C. Igel, O.Krause
10 * \date 2012-2017
11 *
12 *
13 * \par Copyright 1995-2017 Shark Development Team
14 *
15 * <BR><HR>
16 * This file is part of Shark.
17 * <https://shark-ml.github.io/Shark/>
18 *
19 * Shark is free software: you can redistribute it and/or modify
20 * it under the terms of the GNU Lesser General Public License as published
21 * by the Free Software Foundation, either version 3 of the License, or
22 * (at your option) any later version.
23 *
24 * Shark is distributed in the hope that it will be useful,
25 * but WITHOUT ANY WARRANTY; without even the implied warranty of
26 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27 * GNU Lesser General Public License for more details.
28 *
29 * You should have received a copy of the GNU Lesser General Public License
30 * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31 *
32 */
33//===========================================================================
34
35#ifndef SHARK_MODELS_NEARESTNEIGHBOR_H
36#define SHARK_MODELS_NEARESTNEIGHBOR_H
37
38
42
43namespace shark {
44
45namespace detail{
46template <class InputType, class LabelType>
47class BaseNearestNeighbor : public AbstractModel<InputType, RealVector>
48{
49public:
50 typedef AbstractNearestNeighbors<InputType,LabelType> NearestNeighbors;
51 typedef AbstractModel<InputType, RealVector> base_type;
54
55 /// \brief Constructor
56 ///
57 /// \param algorithm the used algorithm for nearest neighbor search
58 /// \param neighbors number of neighbors
59 BaseNearestNeighbor(NearestNeighbors const* algorithm, std::size_t outputDimensions, unsigned int neighbors = 3)
60 : m_algorithm(algorithm)
61 , m_outputDimensions(outputDimensions)
62 , m_neighbors(neighbors)
63 , m_uniform(true)
64 { }
65
66 /// \brief From INameable: return the class name.
67 std::string name() const
68 { return "Internal"; }
69
70 Shape inputShape() const{
71 return m_algorithm->inputShape();
72 }
73 Shape outputShape() const{
74 return Shape(m_outputDimensions);
75 }
76
77 /// return the number of neighbors
78 unsigned int neighbors() const{
79 return m_neighbors;
80 }
81
82 /// set the number of neighbors
83 void setNeighbors(unsigned int neighbors){
84 m_neighbors=neighbors;
85 }
86
87 bool uniformWeights() const{
88 return m_uniform;
89 }
90 bool& uniformWeights(){
91 return m_uniform;
92 }
93
94 /// get internal parameters of the model
95 virtual RealVector parameterVector() const{
96 RealVector parameters(1);
97 parameters(0) = m_neighbors;
98 return parameters;
99 }
100
101 /// set internal parameters of the model
102 virtual void setParameterVector(RealVector const& newParameters){
103 SHARK_RUNTIME_CHECK(newParameters.size() == 1,"Invalid number of parameters");
104 m_neighbors = (unsigned int)newParameters(0);
105 }
106
107 /// return the size of the parameter vector
108 virtual std::size_t numberOfParameters() const{
109 return 1;
110 }
111
112 boost::shared_ptr<State> createState()const{
113 return boost::shared_ptr<State>(new EmptyState());
114 }
115
116 /// soft k-nearest-neighbor prediction
117 void eval(BatchInputType const& patterns, BatchOutputType& outputs) const {
118 std::size_t numPatterns = batchSize(patterns);
119 std::vector<typename NearestNeighbors::DistancePair> neighbors = m_algorithm->getNeighbors(patterns, m_neighbors);
120
121 outputs.resize(numPatterns, m_outputDimensions);
122 outputs.clear();
123
124 for(std::size_t p = 0; p != numPatterns;++p)
125 {
126 double wsum = 0.0;
127 for ( std::size_t k = 0; k != m_neighbors; ++k)
128 {
129 double w = 1.0;
130 if (!m_uniform){
131 double d = neighbors[p*m_neighbors+k].key;
132 if (d < 1e-100) w = 1e100;
133 else w = 1.0 / d;
134 }
135 updatePrediction(outputs, p, w, neighbors[p*m_neighbors+k].value);
136 wsum += w;
137 }
138 row(outputs, p) /= wsum;
139 }
140 }
141
142 void eval(BatchInputType const& patterns, BatchOutputType& outputs, State&) const {
143 eval(patterns,outputs);
144 }
145 using base_type::eval;
146
147 /// from ISerializable, reads a model from an archive
148 void read(InArchive& archive){
149 archive & m_neighbors;
150 archive & m_outputDimensions;
151 archive & m_uniform;
152 }
153
154 /// from ISerializable, writes a model to an archive
155 void write(OutArchive& archive) const{
156 archive & m_neighbors;
157 archive & m_outputDimensions;
158 archive & m_uniform;
159 }
160
161private:
162 void updatePrediction(RealMatrix& outputs, std::size_t p, double w, unsigned int const label) const{
163 outputs(p, label) += w;
164 }
165 template<class T>
166 void updatePrediction(RealMatrix& outputs, std::size_t p, double w, blas::vector<T> const& label)const{
167 noalias(row(outputs,p)) += w * label;
168 }
169 NearestNeighbors const* m_algorithm;
170
171 /// number of classes
172 std::size_t m_outputDimensions;
173
174 /// number of neighbors to be taken into account
175 unsigned int m_neighbors;
176
177 /// type of distance-based weights computation
178 bool m_uniform;
179};
180}
181
182/// \brief NearestNeighbor model for classification and regression
183///
184/// The classification, the model predicts a class label
185/// according to a local majority decision among its k
186/// nearest neighbors. It is not specified how ties are
187/// broken.
188///
189/// For Regression, the (weighted) mean of the k nearest
190/// neighbours is computed.
191///
192/// \ingroup models
193template <class InputType, class LabelType>
194class NearestNeighborModel: public detail::BaseNearestNeighbor<InputType,LabelType>
195{
196public:
198 typedef detail::BaseNearestNeighbor<InputType,LabelType> base_type;
199
200 /// \brief Type of distance-based weights.
202 UNIFORM, ///< uniform (= no) distance-based weights
203 ONE_OVER_DISTANCE, ///< weight each neighbor's label with 1/distance
204 };
205
206 /// \brief Constructor
207 ///
208 /// \param algorithm the used algorithm for nearest neighbor search
209 /// \param neighbors number of neighbors
210 NearestNeighborModel(NearestNeighbors const* algorithm, unsigned int neighbors = 3)
211 : base_type(algorithm, labelDimension(algorithm->dataset()), neighbors)
212 { }
213
214 /// \brief From INameable: return the class name.
215 std::string name() const
216 { return "NearestNeighbor"; }
217
218 /// query the way distances enter as weights
220 return this->decisionFunction().uniformWeights() ? UNIFORM : ONE_OVER_DISTANCE;
221 }
222
223 /// set the way distances enter as weights
225 this->decisionFunction().uniformWeights() = (dw == UNIFORM);
226 }
227};
228
229
230template <class InputType>
231class NearestNeighborModel<InputType, unsigned int>: public Classifier<detail::BaseNearestNeighbor<InputType,unsigned int> >
232{
233public:
236
237 /// \brief Type of distance-based weights.
239 UNIFORM, ///< uniform (= no) distance-based weights
240 ONE_OVER_DISTANCE, ///< weight each neighbor's label with 1/distance
241 };
242
243 /// \brief Constructor
244 ///
245 /// \param algorithm the used algorithm for nearest neighbor search
246 /// \param neighbors number of neighbors
247 NearestNeighborModel(NearestNeighbors const* algorithm, unsigned int neighbors = 3)
248 : base_type(detail::BaseNearestNeighbor<InputType,unsigned int>(algorithm, numberOfClasses(algorithm->dataset()), neighbors))
249 { }
250
251 /// \brief From INameable: return the class name.
252 std::string name() const
253 { return "NearestNeighbor"; }
254
255 /// return the number of neighbors
256 unsigned int neighbors() const{
257 return this->decisionFunction().neighbors();
258 }
259
260 /// set the number of neighbors
261 void setNeighbors(unsigned int neighbors){
262 this->decisionFunction().setNeighbors(neighbors);
263 }
264
265 /// query the way distances enter as weights
267 return this->decisionFunction().uniformWeights() ? UNIFORM : ONE_OVER_DISTANCE;
268 }
269
270 /// set the way distances enter as weights
272 this->decisionFunction().uniformWeights() = (dw == UNIFORM);
273 }
274};
275
276
277}
278#endif