📄 knearestneighbor.java
字号:
package fasbir.classifiers;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Enumeration;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;
/**
* <p>Description: k nearest neighbor classifier<br>
* For nominal attribute, KNearestNeighbor emploies VDM metric.<br>
* Ref: Z.-H. Zhou and Y. Yu. Ensembling Local Learners Through Multi-Modal
* Perturbation. IEEE Transactions on systems, man, and cybernetics - part B, inpress. <br>
* KNearestNeighbor emploies projected-distance sorting to accelerate searching process.
* </p>
* @author Y. Yu (yuy@lamda.nju.edu.cn), LAMDA Group (http://lamda.nju.edu.cn/)
* @version 1.0
*/
public class KNearestNeighbor extends Classifier {
/**
* KNNSample store an instance with its precalculated projected-distance
*/
protected class KNNSample implements Serializable {
public Instance instance;
public double precalc;
}
/**
* SampleComparator provides compare method of two object of KNNSample, comparing by their projected-distance
*/
protected class SampleComparator implements Comparator, Serializable {
public int compare(Object o1, Object o2) {
double d1 = ( (KNNSample) o1).precalc;
double d2 = ( (KNNSample) o2).precalc;
if (d1 > d2)
return 1;
else if (d1 < d2)
return -1;
else
return 0;
}
public boolean equals(Object o) {
return this.getClass().equals(o.getClass());
}
}
protected SampleComparator m_comparator = new SampleComparator();
protected KNNSample[] m_knnsamples;
protected int m_k;
protected double[] m_refpoint;
protected int m_sampleDimension;
protected int m_numClasses;
protected double[] m_maxValue, m_minValue; // store max and min value of each attribute
protected double[][][] m_probability; // store m_probability for each discrete value
protected boolean[] m_nominalAttribute;
/**
* set number of neighbors
* @param k int
*/
public void setK(int k) {
this.m_k = k;
}
/**
* get normalized value of given instance
* @param ins Instance
* @return double[]
*/
protected double[] getDimensionValues(Instance ins) {
double[] point = new double[m_sampleDimension];
for (int d = 0; d < m_sampleDimension; d++)
if (!m_nominalAttribute[d])
if (m_maxValue[d] == m_minValue[d])
point[d] = 0;
else
point[d] = (ins.value(d) - m_minValue[d]) / (m_maxValue[d] - m_minValue[d]);
else
point[d] = ins.value(d);
return point;
}
/**
* calculate L2(Euclidean) distance between two instances
* @param point1 double[] the first instance
* @param point2 double[] the second instance
* @return double L2 distance
*/
protected double distanceCalc(double[] point1, double[] point2) {
double dist = 0;
for (int d = 0; d < m_sampleDimension; d++) {
// deal with missing value
if( Double.isNaN(point1[d]) || Double.isNaN(point2[d]) )
continue;
if (m_nominalAttribute[d]) {
double tmp = 0;
for (int i = 0; i < m_numClasses; i++) {
tmp += Math.abs(
m_probability[d][ (int) (point1[d] - m_minValue[d])][i] - m_probability[d][ (int) (point2[d] - m_minValue[d])][i]
);
}
dist += Math.abs(tmp);
}
else {
dist += (point1[d] - point2[d]) * (point1[d] - point2[d]);
}
}
dist = Math.sqrt(dist);
return dist;
}
/**
* train a kNN classifier:
* store all instances, access normalization information, access VDM informatoin and calculate projected-distance
* @param tsample Instances
* @throws Exception
*/
public void buildClassifier(Instances instances) throws java.lang.Exception {
m_numClasses = instances.numClasses();
m_sampleDimension = instances.numAttributes()-1;
m_refpoint = new double[m_sampleDimension];
m_knnsamples = new KNNSample[instances.numInstances()];
this.m_maxValue = new double[m_sampleDimension];
this.m_minValue = new double[m_sampleDimension];
this.m_probability = new double[m_sampleDimension][][];
this.m_nominalAttribute = new boolean[m_sampleDimension];
// initialization
for (int i = 0; i < m_sampleDimension-1; i++) {
m_nominalAttribute[i] = instances.attribute(i).isNominal();
// pre-set
m_maxValue[i] = Double.NEGATIVE_INFINITY;
m_minValue[i] = Double.POSITIVE_INFINITY;
}
// access for normalization
for (Enumeration enumer = instances.enumerateInstances(); enumer.hasMoreElements(); ) {
Instance ins = (Instance) enumer.nextElement();
// find max and min
for (int i = 0; i < m_sampleDimension; i++) {
if (m_nominalAttribute[i]) {
m_minValue[i] = 0;
m_maxValue[i] = instances.numDistinctValues(i) - 1;
}
else {
if (ins.value(i) < m_minValue[i])
m_minValue[i] = ins.value(i);
if (ins.value(i) > m_maxValue[i])
m_maxValue[i] = ins.value(i);
}
}
}
// access for VDM
for (int i = 0; i < m_sampleDimension; i++) {
// if the attribute is not numeric, divide into #classes attributes
if (m_nominalAttribute[i])
m_probability[i] = new double[ (int) (m_maxValue[i] - m_minValue[i] + 1)][m_numClasses];
}
for (Enumeration enumer = instances.enumerateInstances(); enumer.hasMoreElements(); ) {
Instance ins = (Instance) enumer.nextElement();
for (int d = 0; d < m_sampleDimension; d++) {
if (m_nominalAttribute[d]) {
double x = ins.value(d) - m_minValue[d];
double c = ins.classValue();
m_probability[d][ (int) x][ (int) c]++;
}
}
}
for (int d = 0; d < m_sampleDimension; d++)
if (m_nominalAttribute[d]) {
for (int x = 0; x < (int) (m_maxValue[d] - m_minValue[d] + 1); x++) {
double countInClass = 0;
for (int c = 0; c < m_numClasses; c++)
countInClass += m_probability[d][x][c];
if (countInClass == 0)
countInClass = 1;
for (int c = 0; c < m_numClasses; c++)
m_probability[d][x][c] /= countInClass;
}
}
int splindex = 0;
for (Enumeration enumer = instances.enumerateInstances(); enumer.hasMoreElements(); ) {
// re-store the sample with more information
KNNSample knnspl = new KNNSample();
knnspl.instance = (Instance) enumer.nextElement();
knnspl.precalc = distanceCalc(getDimensionValues(knnspl.instance), m_refpoint);
m_knnsamples[splindex++] = knnspl;
}
Arrays.sort(m_knnsamples, m_comparator);
}
/**
* find instance with largest distance, i.e. OBJECT.precalc
* @param knndist KNNSample[]
* @return int
*/
protected int findMaxKnnSite(KNNSample[] knndist) {
int maxSite = 0;
for (short i = 1; i < knndist.length; i++)
if (knndist[i].precalc > knndist[maxSite].precalc)
maxSite = i;
return maxSite;
}
/**
* find k nearest neighbors around given instance.
* @param instance Instance
* @param k int number of nearest neighbors to be found
* @return Instance[] found neighbors
* @throws Exception
*/
public Instance[] findKNN(Instance instance, int k) throws Exception {
if ( (instance.numAttributes() - 1) != m_sampleDimension)
throw new Exception("inequal dimension of instance");
m_refpoint = new double[m_sampleDimension];
KNNSample clsspl = new KNNSample();
clsspl.instance = instance;
clsspl.precalc = distanceCalc(getDimensionValues(instance), m_refpoint);
int locate = Arrays.binarySearch(m_knnsamples, clsspl, m_comparator);
if (locate < 0)
locate = -locate - 1;
int lsearch = locate - 1, rsearch = locate;
if (rsearch == m_knnsamples.length)
rsearch = -1;
KNNSample[] knnspl = new KNNSample[k];
for (short i = 0; i < k; i++) {
knnspl[i] = new KNNSample();
knnspl[i].precalc = Double.POSITIVE_INFINITY;
}
double maxknnDist = Double.POSITIVE_INFINITY;
int maxknnSite = 0;
while (lsearch >= 0 || rsearch >= 0) {
if (lsearch >= 0 && m_knnsamples[lsearch].instance == instance)
lsearch--;
if (rsearch >= 0 && m_knnsamples[rsearch].instance == instance)
rsearch++;
if (rsearch >= m_knnsamples.length)
rsearch = -1;
if (lsearch >= 0) {
double dist = distanceCalc(getDimensionValues(m_knnsamples[lsearch].instance), getDimensionValues(instance));
if (Double.isNaN(dist))
dist = dist;
if (dist < maxknnDist) {
knnspl[maxknnSite].precalc = dist;
knnspl[maxknnSite].instance = m_knnsamples[lsearch].instance;
maxknnSite = findMaxKnnSite(knnspl);
maxknnDist = knnspl[maxknnSite].precalc;
}
lsearch--;
if (lsearch >= 0 && m_knnsamples[lsearch].precalc < clsspl.precalc - maxknnDist)
lsearch = -1;
}
if (rsearch >= 0) {
double dist = distanceCalc(getDimensionValues(m_knnsamples[rsearch].instance), getDimensionValues(instance));
if (dist < maxknnDist) {
knnspl[maxknnSite].precalc = dist;
knnspl[maxknnSite].instance = m_knnsamples[rsearch].instance;
maxknnSite = findMaxKnnSite(knnspl);
maxknnDist = knnspl[maxknnSite].precalc;
}
rsearch++;
if (rsearch == m_knnsamples.length)
rsearch = -1;
if (rsearch >= 0 && m_knnsamples[rsearch].precalc > clsspl.precalc + maxknnDist)
rsearch = -1;
}
}
Arrays.sort(knnspl, m_comparator);
Instance[] kNNs = new Instance[k];
for (short i = 0; i < k; i++)
kNNs[i] = knnspl[i].instance;
return kNNs;
}
/**
* classify instance according to k nearest neighbors' voting
* @param sample Instance
* @return double
* @throws Exception
*/
public double classifyInstance(Instance instance) throws Exception {
Instance[] kNNs = this.findKNN(instance, m_k);
int[] rate = new int[m_numClasses];
for (short i = 0; i < m_k; i++)
rate[ (int) kNNs[i].classValue()]++;
int max = 0;
for (short i = 1; i < m_numClasses; i++)
if (rate[i] > rate[max])
max = i;
return max;
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -