📄 knn.cpp
字号:
//====
// knn.cpp
// - k nearest neighbours (the classic case-based classifier)
// - returns the most likely category of a target according to
// its k nearest neighbours whose categories are known
// - input
// - array of known data points (double **p) <--
// - number of known data points (int n) |
// - dimension of data points (int d) |
// - array of data labels parallel to ----------
// - target point to be classified (double *x)
// - desired k value (int k)
// - 1, 3, and 5 seem to be frequent choices in the literature
// - in order to avoid ties, always use an odd number
// - an optional threshold for optimization (double t = 0)
// - when a neighbour has been found so near that it is almost
// identical to the target, it's probably safe to stop there
// and report that neighbour's category as the result
// - this trick may improve efficiency when the knowledge base
// is too large to be completely scanned for each target
// - default is don't stop unless a truely identical neighbour
// has been found, which should be safe in most cases
// - if you really don't want any optimization, use a negative
// value for this argument
// - see http://cs.smu.ca/~r_zhang/edc for the external function
// Notes
// - this package is provided as is with no warranty
// - the authors are not responsible for any damage caused
// either directly or indirectly by using this package
// - anybody is free to do whatever he/she wants with this
// package as long as this header section is preserved
// Created on 2005-07-11 by
// - Caliope Sandiford
// - Chitra Attaluri
// - Naureen Nizam
// - Roger Zhang (rogerz@cs.dal.ca)
// Modifications
// -
// Last compiled under Linux with gcc-3
//====
#include <list>
#include <map>
#include <cassert>
extern "C" double squared_distance(double*, double*, int);
int knnLabel(double **p, int n, int d, int *c, double *x, int k, double t = 0)
{
assert(k > 0 && k <= n && d > 0 && p && c && x);
if (t > 0) {
t *= t; // because we use squared distance
}
// a list of {index => distance} pairs
std::list<std::pair<int, double> > nabors;
std::list<std::pair<int, double> >::iterator i;
for (int j = 0; j < n; j++) {
double dist = squared_distance(p[j], x, d);
if (dist <= t) { // a neighbour within acceptable distance found
return c[j]; // done, immediately report this neighbour's category
}
//====
// check if p[j] is closer to the target than any recorded neighbours,
// and if positive then sort its index/distance profile into the list.
// searching the list for the first one with a longer distance
for (i = nabors.begin(); i != nabors.end() && dist >= i->second; i++);
if (i != nabors.end() || nabors.size() < k) { // p[j] qualified
nabors.insert(i, 1, std::pair<int, double>(j, dist));
if (nabors.size() > k) { // list overfilled (has k+1 profiles)
nabors.pop_back(); // bumping out the farthest neighbour
}
}
}
//====
// each of the k nearest neighbours cast a vote, and the majority wins
// use class average distance to the target to break any possible ties
// a {category => {count => distance}} map
std::map<int, std::pair<int, double> > votes;
int winner = c[0]; // randomly assign an initial category
for (i = nabors.begin(); i != nabors.end(); i++) {
int count = ++(votes[c[i->first]].first);
double dist = (votes[c[i->first]].second += i->second);
if (count > votes[winner].first || /* check for a possible tie */
count == votes[winner].first && dist < votes[winner].second) {
winner = c[i->first];
}
}
return winner;
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -