35#ifndef SHARK_MODELS_NEARESTNEIGHBOR_H
36#define SHARK_MODELS_NEARESTNEIGHBOR_H
46template <
class InputType,
class LabelType>
47class BaseNearestNeighbor :
public AbstractModel<InputType, RealVector>
50 typedef AbstractNearestNeighbors<InputType,LabelType> NearestNeighbors;
51 typedef AbstractModel<InputType, RealVector> base_type;
59 BaseNearestNeighbor(NearestNeighbors
const* algorithm, std::size_t outputDimensions,
unsigned int neighbors = 3)
60 : m_algorithm(algorithm)
61 , m_outputDimensions(outputDimensions)
62 , m_neighbors(neighbors)
67 std::string
name()
const
68 {
return "Internal"; }
71 return m_algorithm->inputShape();
74 return Shape(m_outputDimensions);
78 unsigned int neighbors()
const{
83 void setNeighbors(
unsigned int neighbors){
84 m_neighbors=neighbors;
87 bool uniformWeights()
const{
90 bool& uniformWeights(){
96 RealVector parameters(1);
97 parameters(0) = m_neighbors;
104 m_neighbors = (
unsigned int)newParameters(0);
113 return boost::shared_ptr<State>(
new EmptyState());
117 void eval(BatchInputType
const& patterns, BatchOutputType& outputs)
const {
118 std::size_t numPatterns =
batchSize(patterns);
119 std::vector<typename NearestNeighbors::DistancePair> neighbors = m_algorithm->getNeighbors(patterns, m_neighbors);
121 outputs.resize(numPatterns, m_outputDimensions);
124 for(std::size_t p = 0; p != numPatterns;++p)
127 for ( std::size_t k = 0; k != m_neighbors; ++k)
131 double d = neighbors[p*m_neighbors+k].key;
132 if (d < 1e-100) w = 1e100;
135 updatePrediction(outputs, p, w, neighbors[p*m_neighbors+k].value);
138 row(outputs, p) /= wsum;
142 void eval(BatchInputType
const& patterns, BatchOutputType& outputs, State&)
const {
143 eval(patterns,outputs);
149 archive & m_neighbors;
150 archive & m_outputDimensions;
156 archive & m_neighbors;
157 archive & m_outputDimensions;
162 void updatePrediction(RealMatrix& outputs, std::size_t p,
double w,
unsigned int const label)
const{
163 outputs(p, label) += w;
166 void updatePrediction(RealMatrix& outputs, std::size_t p,
double w, blas::vector<T>
const& label)
const{
167 noalias(row(outputs,p)) += w * label;
169 NearestNeighbors
const* m_algorithm;
172 std::size_t m_outputDimensions;
175 unsigned int m_neighbors;
193template <
class InputType,
class LabelType>
198 typedef detail::BaseNearestNeighbor<InputType,LabelType> base_type;
211 : base_type(algorithm,
labelDimension(algorithm->dataset()), neighbors)
216 {
return "NearestNeighbor"; }
225 this->decisionFunction().uniformWeights() = (dw ==
UNIFORM);
230template <
class InputType>
253 {
return "NearestNeighbor"; }
257 return this->decisionFunction().neighbors();
262 this->decisionFunction().setNeighbors(neighbors);
272 this->decisionFunction().uniformWeights() = (dw ==
UNIFORM);