SimpleNearestNeighbors.h
Go to the documentation of this file.
1//===========================================================================
2/*!
3 *
4 *
5 * \brief Efficient brute force implementation of nearest neighbors.
6 *
7 *
8 *
9 * \author O.Krause
10 * \date 2012
11 *
12 *
13 * \par Copyright 1995-2017 Shark Development Team
14 *
15 * <BR><HR>
16 * This file is part of Shark.
17 * <https://shark-ml.github.io/Shark/>
18 *
19 * Shark is free software: you can redistribute it and/or modify
20 * it under the terms of the GNU Lesser General Public License as published
21 * by the Free Software Foundation, either version 3 of the License, or
22 * (at your option) any later version.
23 *
24 * Shark is distributed in the hope that it will be useful,
25 * but WITHOUT ANY WARRANTY; without even the implied warranty of
26 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
27 * GNU Lesser General Public License for more details.
28 *
29 * You should have received a copy of the GNU Lesser General Public License
30 * along with Shark. If not, see <http://www.gnu.org/licenses/>.
31 *
32 */
33//===========================================================================
34
35#ifndef SHARK_ALGORITHMS_NEARESTNEIGHBORS_SIMPLENEARESTNEIGHBORS_H
36#define SHARK_ALGORITHMS_NEARESTNEIGHBORS_SIMPLENEARESTNEIGHBORS_H
37
40#include <shark/Core/OpenMP.h>
41#include <algorithm>
42
43
44namespace shark {
45
46///\brief Brute force optimized nearest neighbor implementation
47///
48///Returns the labels and distances of the k nearest neighbors of a point
49/// The distance is measured using an arbitrary metric
50template<class InputType, class LabelType>
51class SimpleNearestNeighbors:public AbstractNearestNeighbors<InputType,LabelType>{
52private:
54public:
59
60 /// \brief Constructor.
61 ///
62 /// \par Construct a "brute force" nearest neighbors search algorithm
63 /// from data and a metric. Refer to the AbstractMetric class for details.
64 /// The "default" Euclidean metric is realized by providing a pointer to
65 /// an object of type LinearKernel<InputType>.
67 :m_dataset(dataset), mep_metric(metric){
68 this->m_inputShape=dataset.inputShape();
69 }
70
71 ///\brief Return the k nearest neighbors of the query point.
72 std::vector<DistancePair> getNeighbors(BatchInputType const& patterns, std::size_t k)const{
73 std::size_t numPatterns = batchSize(patterns);
74 std::size_t maxThreads = std::min(SHARK_NUM_THREADS,m_dataset.numberOfBatches());
75 //heaps of key value pairs (distance,classlabel). One heap for every pattern and thread.
76 //For memory alignment reasons, all heaps are stored in one continuous array
77 //the heaps are stored such, that for every pattern the heaps for every thread
78 //are forming one memory area. so later we can just merge all 4 heaps using make_heap
79 //be aware that the values created here allready form a heap since they are all
80 //identical maximum distance.
81 std::vector<DistancePair> heaps(k*numPatterns*maxThreads,DistancePair(std::numeric_limits<double>::max(),LabelType()));
82 typedef typename std::vector<DistancePair>::iterator iterator;
83 //iterate over all batches of the training set in parallel and let
84 //every thread do a KNN-Search on it's subset of data
85 SHARK_PARALLEL_FOR(int b = 0; b < (int)m_dataset.numberOfBatches(); ++b){
86 //evaluate distances between the points of the patterns and the batch
87 RealMatrix distances=mep_metric->featureDistanceSqr(patterns,m_dataset.batch(b).input);
88
89 //now update the heaps with the distances
90 for(std::size_t p = 0; p != numPatterns; ++p){
91 std::size_t batchSize = distances.size2();
92
93 //get current heap
94 std::size_t heap = p*maxThreads+SHARK_THREAD_NUM;
95 iterator heapStart=heaps.begin()+heap*k;
96 iterator heapEnd=heapStart+k;
97 iterator biggest=heapEnd-1;//position of biggest element
98
99 //update heap values using the new distances
100 for(std::size_t i = 0; i != batchSize; ++i){
101 if(biggest->key >= distances(p,i)){
102 //push the smaller neighbor in the heap and replace the biggest one
103 biggest->key=distances(p,i);
104 biggest->value=getBatchElement(m_dataset.batch(b).label,i);
105 std::push_heap(heapStart,heapEnd);
106 //pop biggest element, so that
107 //biggest is again the biggest element
108 std::pop_heap(heapStart,heapEnd);
109 }
110 }
111 }
112 }
113 std::vector<DistancePair> results(k*numPatterns);
114 //finally, we merge all threads in one heap which has the inverse ordering
115 //and create a class histogram over the smallest k neighbors
116 //std::cout<<"info "<<numPatterns<<" "<<maxThreads<<" "<<k<<std::endl;
117 SHARK_PARALLEL_FOR(int p = 0; p < (int)numPatterns; ++p){
118 //find range of the heaps for all threads
119 iterator heapStart=heaps.begin()+p*maxThreads*k;
120 iterator heapEnd=heapStart+maxThreads*k;
121 iterator neighborEnd=heapEnd-k;
122 iterator smallest=heapEnd-1;//position of biggest element
123 //create one single heap of the range with inverse ordering
124 //takes O(maxThreads*k)
125 std::make_heap(heapStart,heapEnd,std::greater<DistancePair>());
126
127 //create histogram from the neighbors
128 for(std::size_t i = 0;heapEnd!=neighborEnd;--heapEnd,--smallest,++i){
129 std::pop_heap(heapStart,heapEnd,std::greater<DistancePair>());
130 results[i+p*k].key = smallest->key;
131 results[i+p*k].value = smallest->value;
132 }
133 }
134 return results;
135 }
136
137 /// \brief Direct access to the underlying data set of nearest neighbor points.
139 return m_dataset;
140 }
141
142private:
143 Dataset m_dataset; ///< data set of nearest neighbor points
144 Metric const* mep_metric; ///< metric for measuring distances, usually given by a kernel function
145};
146
147
148}
149#endif