TreeNearestNeighbors.h
Go to the documentation of this file.
1//===========================================================================
2/*!
3 *
4 *
5 * \brief Efficient Nearest neighbor queries.
6 *
7 *
8 *
9 * \author T. Glasmachers
10 * \date 2011
11 *
12 *
13 * \par Copyright 1995-2017 Shark Development Team
14 *
15 * <BR><HR>
16 * This file is part of Shark.
17 * <https://shark-ml.github.io/Shark/>
18 *
19 * Shark is free software: you can redistribute it and/or modify
20 * it under the terms of the GNU Lesser General Public License as published
21 * by the Free Software Foundation, either version 3 of the License, or
22 * (at your option) any later version.
23 *
24 * Shark is distributed in the hope that it will be useful,
25 * but WITHOUT ANY WARRANTY; without even the implied warranty of
26 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27 * GNU Lesser General Public License for more details.
28 *
29 * You should have received a copy of the GNU Lesser General Public License
30 * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31 *
32 */
33//===========================================================================
34
35#ifndef SHARK_ALGORITHMS_NEARESTNEIGHBORS_TREENEARESTNEIGHBORS_H
36#define SHARK_ALGORITHMS_NEARESTNEIGHBORS_TREENEARESTNEIGHBORS_H
37
38
39#include <boost/intrusive/rbtree.hpp>
42#include <shark/Data/DataView.h>
43namespace shark {
44
45
46///
47/// \brief Iterative nearest neighbors query.
48///
49/// \par
50/// The IterativeNNQuery class (Iterative Nearest Neighbor
51/// Query) allows the nearest neighbors of a reference point
52/// to be queried iteratively. Given the reference point, a
53/// query is set up that returns the nearest neighbor first,
54/// then the second nearest neighbor, and so on.
55/// Thus, nearest neighbor queries are treated in an "online"
56/// fashion. The algorithm follows the paper (generalized to
57/// arbitrary space-partitioning trees):
58///
59/// \par
60/// Strategies for efficient incremental nearest neighbor search.
61/// A. J. Broder. Pattern Recognition 23(1/2), pp 171-178, 1990.
62///
63/// \par
64/// The algorithm is based on traversing a BinaryTree that
65/// partitions the space into nested cells. The triangle
66/// inequality is applied to exclude cells from the search.
67/// Furthermore, candidate points are cached in a queue,
68/// such that subsequent queries profit from points that
69/// could not be excluded this way, but that did not turn
70/// out the be the (current) nearest neighbor.
71///
72/// \par
73/// The tree must have a bucket size of one, but leaf nodes
74/// with multiple copies of the same point are allowed.
75/// This means that the space partitioning must be carried
76/// out to the finest possible scale.
77///
78/// The Data must be sotred in a random access container. This means that elements
79/// have O(1) access time. This is crucial for the performance of the tree lookup.
80/// When data is stored in a Data<T>, a View should be chosen as template parameter.
81template <class DataContainer>
83{
84public:
85 typedef typename DataContainer::value_type value_type;
88 typedef std::pair<double, std::size_t> result_type;
89
90 /// create a new query
91 /// \param tree Underlying space-partitioning tree (this is assumed to persist for the lifetime of the query object).
92 /// \param data Container holding the stored data which is referenced by the tree
93 /// \param point Point whose nearest neighbors are to be found.
94 IterativeNNQuery(tree_type const* tree, DataContainer const& data, value_type const& point)
95 : m_data(data)
96 , m_reference(point)
97 , m_nextIndex(0)
98 , mp_trace(NULL)
99 , mep_head(NULL)
100 , m_squaredRadius(0.0)
101 , m_neighbors(0)
102 {
103 // Initialize the recursion trace: descend to the
104 // leaf covering the reference point and queue it.
105 // The parent of this leaf becomes the "head".
106 mp_trace = new TraceNode(tree, NULL, m_reference);
107 TraceNode* tn = mp_trace;
108 while (tree->hasChildren())
109 {
110 tn->createLeftNode(tree, m_data, m_reference);
111 tn->createRightNode(tree, m_data, m_reference);
112 bool left = tree->isLeft(m_reference);
113 tn = (left ? tn->mep_left : tn->mep_right);
114 tree = (left ? tree->left() : tree->right());
115 }
116 mep_head = tn->mep_parent;
117 insertIntoQueue((TraceLeaf*)tn);
118 m_squaredRadius = mp_trace->squaredRadius(m_reference);
119 }
120
121 /// destroy the query object and its internal data structures
123 m_queue.clear();
124 delete mp_trace;
125 }
126
127
128 /// return the number of neighbors already found
129 std::size_t neighbors() const {
130 return m_neighbors;
131 }
132
133 /// find and return the next nearest neighbor
134 result_type next() {
135 SHARK_RUNTIME_CHECK(m_neighbors < mp_trace->m_tree->size(), "No more neighbors available");
136
137 assert(! m_queue.empty());
138
139 // Check whether the current node has points
140 // left, or whether it should be discarded.
141 if (m_neighbors > 0){
142 TraceLeaf& q = *m_queue.begin();
143 if (m_nextIndex < q.m_tree->size()){
144 return getNextPoint(q);
145 }
146 else
147 m_queue.erase(q);
148 }
149 if (m_queue.empty() || (*m_queue.begin()).m_squaredPtDistance > m_squaredRadius){
150 // enqueue more points
151 TraceNode* tn = mep_head;
152 while (tn != NULL){
153 enqueue(tn);
154 if (tn->m_status == COMPLETE) mep_head = tn->mep_parent;
155 tn = tn->mep_parent;
156 }
157
158 // re-compute the radius
159 m_squaredRadius = mp_trace->squaredRadius(m_reference);
160 }
161 m_nextIndex = 0;
162 ++m_neighbors;
163 return getNextPoint(*m_queue.begin());
164 }
165
166 /// return the size of the queue,
167 /// which is a measure of the
168 /// overhead of the search
169 std::size_t queuesize() const{
170 return m_queue.size();
171 }
172
173private:
174
175 /// status of a TraceNode object during the search
176 enum Status
177 {
178 NONE, // no points of this node have been queued yet
179 PARTIAL, // some of the points of this node have been queued
180 COMPLETE, // all points of this node have been queued
181 };
182
183 /// The TraceNode class builds up a tree during
184 /// the search. This tree covers only those parts
185 /// of the space partirioning tree that need to be
186 /// traversed in order to find the next nearest
187 /// neighbor.
188 class TraceNode
189 {
190 public:
191 /// Constructor
192 TraceNode(tree_type const* tree, TraceNode* parent, value_type const& reference)
193 : m_tree(tree)
194 , m_status(NONE)
195 , mep_parent(parent)
196 , mep_left(NULL)
197 , mep_right(NULL)
198 , m_squaredDistance(tree->squaredDistanceLowerBound(reference))
199 { }
200
201 /// Destructor
202 virtual ~TraceNode()
203 {
204 if (mep_left != NULL) delete mep_left;
205 if (mep_right != NULL) delete mep_right;
206 }
207
208 void createLeftNode(tree_type const* tree, DataContainer const& data, value_type const& reference){
209 if (tree->left()->hasChildren())
210 mep_left = new TraceNode(tree->left(), this, reference);
211 else
212 mep_left = new TraceLeaf(tree->left(), this, data, reference);
213 }
214 void createRightNode(tree_type const* tree, DataContainer const& data, value_type const& reference){
215 if (tree->right()->hasChildren())
216 mep_right = new TraceNode(tree->right(), this, reference);
217 else
218 mep_right = new TraceLeaf(tree->right(), this, data, reference);
219 }
220
221 /// Compute the squared distance of the area not
222 /// yet covered by the queue to the reference point.
223 /// This is also referred to as the squared "radius"
224 /// of the area covered by the queue (in fact, it is
225 /// the radius of the largest sphere around the
226 /// reference point that fits into the covered area).
227 double squaredRadius(value_type const& ref) const{
228 if (m_status == NONE) return m_squaredDistance;
229 else if (m_status == PARTIAL)
230 {
231 double l = mep_left->squaredRadius(ref);
232 double r = mep_right->squaredRadius(ref);
233 return std::min(l, r);
234 }
235 else return 1e100;
236 }
237
238 /// node of the tree
239 tree_type const* m_tree;
240
241 /// status of the search
242 Status m_status;
243
244 /// parent node
245 TraceNode* mep_parent;
246
247 /// "left" child
248 TraceNode* mep_left;
249
250 /// "right" child
251 TraceNode* mep_right;
252
253 /// squared distance of the box to the reference point
254 double m_squaredDistance;
255 };
256
257 /// hook type for intrusive container
258 typedef boost::intrusive::set_base_hook<> HookType;
259
260 /// Leaves of the three have three roles:
261 /// (1) they are tree nodes holding exactly one point
262 /// (possibly multiple copies of this point),
263 /// (2) they know the distance of their point to the
264 /// reference point,
265 /// (3) they can be added to the candidates queue.
266 class TraceLeaf : public TraceNode, public HookType
267 {
268 public:
269 /// Constructor
270 TraceLeaf(tree_type const* tree, TraceNode* parent, DataContainer const& data, value_type const& ref)
271 : TraceNode(tree, parent, ref){
272 //check whether the tree uses a differen metric than a linear one.
273 if(tree->kernel() != NULL)
274 m_squaredPtDistance = tree->kernel()->featureDistanceSqr(data[tree->index(0)], ref);
275 else
276 m_squaredPtDistance = distanceSqr(data[tree->index(0)], ref);
277 }
278
279 /// Destructor
280 ~TraceLeaf() { }
281
282
283 /// Comparison by distance, ties are broken arbitrarily,
284 /// but deterministically, by tree node pointer.
285 inline bool operator < (TraceLeaf const& rhs) const{
286 if (m_squaredPtDistance == rhs.m_squaredPtDistance)
287 return (this->m_tree < rhs.m_tree);
288 else
289 return (m_squaredPtDistance < rhs.m_squaredPtDistance);
290 }
291
292 /// Squared distance of the single point in the leaf to the reference point.
293 double m_squaredPtDistance;
294 };
295
296 /// insert a point into the queue
297 void insertIntoQueue(TraceLeaf* leaf){
298 m_queue.insert_unique(*leaf);
299
300 // traverse up the tree, updating the state
301 TraceNode* tn = leaf;
302 tn->m_status = COMPLETE;
303 while (true){
304 TraceNode* par = tn->mep_parent;
305 if (par == NULL) break;
306 if (par->m_status == NONE){
307 par->m_status = PARTIAL;
308 break;
309 }
310 else if (par->m_status == PARTIAL){
311 if (par->mep_left == tn){
312 if (par->mep_right->m_status == COMPLETE) par->m_status = COMPLETE;
313 else break;
314 }
315 else{
316 if (par->mep_left->m_status == COMPLETE) par->m_status = COMPLETE;
317 else break;
318 }
319 }
320 tn = par;
321 }
322 }
323
324 result_type getNextPoint(TraceLeaf const& leaf){
325 double dist = std::sqrt(leaf.m_squaredPtDistance);
326 std::size_t index = leaf.m_tree->index(m_nextIndex);
327 ++m_nextIndex;
328 return std::make_pair(dist,index);
329 }
330
331 /// Recursively descend the node and enqueue
332 /// all points in cells intersecting the
333 /// current bounding sphere.
334 void enqueue(TraceNode* tn){
335 // check whether this node needs to be enqueued
336 if (tn->m_status == COMPLETE) return;
337 if (! m_queue.empty() && tn->m_squaredDistance >= (*m_queue.begin()).m_squaredPtDistance) return;
338
339 const tree_type* tree = tn->m_tree;
340 if (tree->hasChildren()){
341 // extend the tree at need
342 if (tn->mep_left == NULL){
343 tn->createLeftNode(tree,m_data,m_reference);
344 }
345 if (tn->mep_right == NULL){
346 tn->createRightNode(tree,m_data,m_reference);
347 }
348
349 // first descend into the closer sub-tree
350 if (tree->isLeft(m_reference))
351 {
352 // left first
353 enqueue(tn->mep_left);
354 enqueue(tn->mep_right);
355 }
356 else
357 {
358 // right first
359 enqueue(tn->mep_right);
360 enqueue(tn->mep_left);
361 }
362 }
363 else
364 {
365 TraceLeaf* leaf = (TraceLeaf*)tn;
366 insertIntoQueue(leaf);
367 }
368 }
369
370 /// the queue is a self-balancing tree of sorted entries
371 typedef boost::intrusive::rbtree<TraceLeaf> QueueType;
372
373
374 ///\brief Datastorage to lookup the points referenced by the space partitioning tree.
375 DataContainer const& m_data;
376
377 /// reference point for this query
378 value_type m_reference;
379
380 /// queue of candidates
381 QueueType m_queue;
382
383 /// index of the next not yet returned element
384 /// of the current leaf.
385 std::size_t m_nextIndex;
386
387 /// recursion trace tree
388 TraceNode* mp_trace;
389
390 /// "head" of the trace tree. This is the
391 /// node containing the reference point,
392 /// but so high up in the tree that it is
393 /// not fully queued yet.
394 TraceNode* mep_head;
395
396 /// squared radius of the "covered" area
397 double m_squaredRadius;
398
399 /// number of neighbors already returned
400 std::size_t m_neighbors;
401};
402
403
404///\brief Nearest Neighbors implementation using binary trees
405///
406/// Returns the labels and distances of the k nearest neighbors of a point.
407template<class InputType, class LabelType>
408class TreeNearestNeighbors:public AbstractNearestNeighbors<InputType,LabelType>
409{
410private:
412
413public:
418
420 : m_dataset(dataset)
421 , m_inputs(dataset.inputs())
422 , m_labels(dataset.labels())
423 , mep_tree(tree)
424 {
425 this->m_inputShape = dataset.inputShape();
426 }
427
428 ///\brief returns the k nearest neighbors of the point
429 std::vector<DistancePair> getNeighbors(BatchInputType const& patterns, std::size_t k)const{
430 std::size_t numPoints = batchSize(patterns);
431 std::vector<DistancePair> results(k*numPoints);
432 for(std::size_t p = 0; p != numPoints; ++p){
433 IterativeNNQuery<DataView<Data<InputType> const> > query(mep_tree, m_inputs, row(patterns, p));
434 //find the neighbors using the queries
435 for(std::size_t i = 0; i != k; ++i){
436 typename IterativeNNQuery<DataView<Data<InputType> const> >::result_type result = query.next();
437 results[i+p*k].key=result.first;
438 results[i+p*k].value= m_labels[result.second];
439 }
440 }
441 return results;
442 }
443
445 return m_dataset;
446 }
447
448private:
449 Dataset const& m_dataset;
450 DataView<Data<InputType> const> m_inputs;
451 DataView<Data<LabelType> const> m_labels;
452 Tree const* mep_tree;
453
454};
455
456
457}
458#endif