⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 knearestneighbor.java

📁 Description: FASBIR(Filtered Attribute Subspace based Bagging with Injected Randomness) is a variant
💻 JAVA
字号:
package fasbir.classifiers;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Enumeration;

import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;

/**
 * <p>Description: k nearest neighbor classifier<br>
 * For nominal attribute, KNearestNeighbor emploies VDM metric.<br>
 *     Ref: Z.-H. Zhou and Y. Yu. Ensembling Local Learners Through Multi-Modal
 *          Perturbation. IEEE Transactions on systems, man, and cybernetics - part B, inpress. <br>
 * KNearestNeighbor emploies projected-distance sorting to accelerate searching process.
 * </p>
 * @author Y. Yu (yuy@lamda.nju.edu.cn), LAMDA Group (http://lamda.nju.edu.cn/)
 * @version 1.0
 */
public class KNearestNeighbor extends Classifier {
    /**
     * KNNSample store an instance with its precalculated projected-distance
     */
    protected class KNNSample implements Serializable {
        public Instance instance;
        public double precalc;
    }

    /**
     * SampleComparator provides compare method of two object of KNNSample, comparing by their projected-distance
     */
    protected class SampleComparator implements Comparator, Serializable {
        public int compare(Object o1, Object o2) {
            double d1 = ( (KNNSample) o1).precalc;
            double d2 = ( (KNNSample) o2).precalc;
            if (d1 > d2)
                return 1;
            else if (d1 < d2)
                return -1;
            else
                return 0;
        }

        public boolean equals(Object o) {
            return this.getClass().equals(o.getClass());
        }
    }

    protected SampleComparator m_comparator = new SampleComparator();
    protected KNNSample[] m_knnsamples;
    protected int m_k;
    protected double[] m_refpoint;
    protected int m_sampleDimension;
    protected int m_numClasses;

    protected double[] m_maxValue, m_minValue; // store max and min value of each attribute
    protected double[][][] m_probability; // store m_probability for each discrete value
    protected boolean[] m_nominalAttribute;

    /**
     * set number of neighbors
     * @param k int
     */
    public void setK(int k) {
        this.m_k = k;
    }

    /**
     * get normalized value of given instance
     * @param ins Instance
     * @return double[]
     */
    protected double[] getDimensionValues(Instance ins) {
        double[] point = new double[m_sampleDimension];
        for (int d = 0; d < m_sampleDimension; d++)
            if (!m_nominalAttribute[d])
                if (m_maxValue[d] == m_minValue[d])
                    point[d] = 0;
                else
                    point[d] = (ins.value(d) - m_minValue[d]) / (m_maxValue[d] - m_minValue[d]);
            else
                point[d] = ins.value(d);
        return point;
    }

    /**
     * calculate L2(Euclidean) distance between two instances
     * @param point1 double[] the first instance
     * @param point2 double[] the second instance
     * @return double L2 distance
     */
    protected double distanceCalc(double[] point1, double[] point2) {
        double dist = 0;
        for (int d = 0; d < m_sampleDimension; d++) {
            // deal with missing value
            if( Double.isNaN(point1[d]) || Double.isNaN(point2[d]) )
                continue;
            if (m_nominalAttribute[d]) {
                double tmp = 0;
                for (int i = 0; i < m_numClasses; i++) {
                    tmp += Math.abs(
                        m_probability[d][ (int) (point1[d] - m_minValue[d])][i] - m_probability[d][ (int) (point2[d] - m_minValue[d])][i]
                        );
                }
                dist += Math.abs(tmp);
            }
            else {
                dist += (point1[d] - point2[d]) * (point1[d] - point2[d]);
            }
        }
        dist = Math.sqrt(dist);
        return dist;
    }

    /**
     * train a kNN classifier:
     * store all instances, access normalization information, access VDM informatoin and calculate projected-distance
     * @param tsample Instances
     * @throws Exception
     */
    public void buildClassifier(Instances instances) throws java.lang.Exception {
        m_numClasses = instances.numClasses();
        m_sampleDimension = instances.numAttributes()-1;
        m_refpoint = new double[m_sampleDimension];
        m_knnsamples = new KNNSample[instances.numInstances()];

        this.m_maxValue = new double[m_sampleDimension];
        this.m_minValue = new double[m_sampleDimension];
        this.m_probability = new double[m_sampleDimension][][];
        this.m_nominalAttribute = new boolean[m_sampleDimension];

        // initialization
        for (int i = 0; i < m_sampleDimension-1; i++) {
            m_nominalAttribute[i] = instances.attribute(i).isNominal();
            // pre-set
            m_maxValue[i] = Double.NEGATIVE_INFINITY;
            m_minValue[i] = Double.POSITIVE_INFINITY;
        }

        // access for normalization
        for (Enumeration enumer = instances.enumerateInstances(); enumer.hasMoreElements(); ) {
            Instance ins = (Instance) enumer.nextElement();
            // find max and min
            for (int i = 0; i < m_sampleDimension; i++) {
                if (m_nominalAttribute[i]) {
                    m_minValue[i] = 0;
                    m_maxValue[i] = instances.numDistinctValues(i) - 1;
                }
                else {
                    if (ins.value(i) < m_minValue[i])
                        m_minValue[i] = ins.value(i);
                    if (ins.value(i) > m_maxValue[i])
                        m_maxValue[i] = ins.value(i);
                }
            }
        }

        // access for VDM
        for (int i = 0; i < m_sampleDimension; i++) {
            // if the attribute is not numeric, divide into #classes attributes
            if (m_nominalAttribute[i])
                m_probability[i] = new double[ (int) (m_maxValue[i] - m_minValue[i] + 1)][m_numClasses];
        }
        for (Enumeration enumer = instances.enumerateInstances(); enumer.hasMoreElements(); ) {
            Instance ins = (Instance) enumer.nextElement();
            for (int d = 0; d < m_sampleDimension; d++) {
                if (m_nominalAttribute[d]) {
                    double x = ins.value(d) - m_minValue[d];
                    double c = ins.classValue();
                    m_probability[d][ (int) x][ (int) c]++;
                }
            }
        }
        for (int d = 0; d < m_sampleDimension; d++)
            if (m_nominalAttribute[d]) {
                for (int x = 0; x < (int) (m_maxValue[d] - m_minValue[d] + 1); x++) {
                    double countInClass = 0;
                    for (int c = 0; c < m_numClasses; c++)
                        countInClass += m_probability[d][x][c];
                    if (countInClass == 0)
                        countInClass = 1;
                    for (int c = 0; c < m_numClasses; c++)
                        m_probability[d][x][c] /= countInClass;
                }
            }

        int splindex = 0;
        for (Enumeration enumer = instances.enumerateInstances(); enumer.hasMoreElements(); ) {
            // re-store the sample with more information
            KNNSample knnspl = new KNNSample();
            knnspl.instance = (Instance) enumer.nextElement();
            knnspl.precalc = distanceCalc(getDimensionValues(knnspl.instance), m_refpoint);
            m_knnsamples[splindex++] = knnspl;
        }

        Arrays.sort(m_knnsamples, m_comparator);
    }

    /**
     * find instance with largest distance, i.e. OBJECT.precalc
     * @param knndist KNNSample[]
     * @return int
     */
    protected int findMaxKnnSite(KNNSample[] knndist) {
        int maxSite = 0;
        for (short i = 1; i < knndist.length; i++)
            if (knndist[i].precalc > knndist[maxSite].precalc)
                maxSite = i;
        return maxSite;
    }

    /**
     * find k nearest neighbors around given instance.
     * @param instance Instance
     * @param k int number of nearest neighbors to be found
     * @return Instance[] found neighbors
     * @throws Exception
     */
    public Instance[] findKNN(Instance instance, int k) throws Exception {
        if ( (instance.numAttributes() - 1) != m_sampleDimension)
            throw new Exception("inequal dimension of instance");

        m_refpoint = new double[m_sampleDimension];
        KNNSample clsspl = new KNNSample();
        clsspl.instance = instance;
        clsspl.precalc = distanceCalc(getDimensionValues(instance), m_refpoint);

        int locate = Arrays.binarySearch(m_knnsamples, clsspl, m_comparator);
        if (locate < 0)
            locate = -locate - 1;
        int lsearch = locate - 1, rsearch = locate;
        if (rsearch == m_knnsamples.length)
            rsearch = -1;

        KNNSample[] knnspl = new KNNSample[k];
        for (short i = 0; i < k; i++) {
            knnspl[i] = new KNNSample();
            knnspl[i].precalc = Double.POSITIVE_INFINITY;
        }
        double maxknnDist = Double.POSITIVE_INFINITY;
        int maxknnSite = 0;
        while (lsearch >= 0 || rsearch >= 0) {
            if (lsearch >= 0 && m_knnsamples[lsearch].instance == instance)
                lsearch--;
            if (rsearch >= 0 && m_knnsamples[rsearch].instance == instance)
                rsearch++;
            if (rsearch >= m_knnsamples.length)
                rsearch = -1;
            if (lsearch >= 0) {
                double dist = distanceCalc(getDimensionValues(m_knnsamples[lsearch].instance), getDimensionValues(instance));
                if (Double.isNaN(dist))
                    dist = dist;
                if (dist < maxknnDist) {
                    knnspl[maxknnSite].precalc = dist;
                    knnspl[maxknnSite].instance = m_knnsamples[lsearch].instance;
                    maxknnSite = findMaxKnnSite(knnspl);
                    maxknnDist = knnspl[maxknnSite].precalc;
                }
                lsearch--;
                if (lsearch >= 0 && m_knnsamples[lsearch].precalc < clsspl.precalc - maxknnDist)
                    lsearch = -1;
            }
            if (rsearch >= 0) {
                double dist = distanceCalc(getDimensionValues(m_knnsamples[rsearch].instance), getDimensionValues(instance));
                if (dist < maxknnDist) {
                    knnspl[maxknnSite].precalc = dist;
                    knnspl[maxknnSite].instance = m_knnsamples[rsearch].instance;
                    maxknnSite = findMaxKnnSite(knnspl);
                    maxknnDist = knnspl[maxknnSite].precalc;
                }
                rsearch++;
                if (rsearch == m_knnsamples.length)
                    rsearch = -1;
                if (rsearch >= 0 && m_knnsamples[rsearch].precalc > clsspl.precalc + maxknnDist)
                    rsearch = -1;
            }
        }

        Arrays.sort(knnspl, m_comparator);
        Instance[] kNNs = new Instance[k];
        for (short i = 0; i < k; i++)
            kNNs[i] = knnspl[i].instance;
        return kNNs;
    }

    /**
     * classify instance according to k nearest neighbors' voting
     * @param sample Instance
     * @return double
     * @throws Exception
     */
    public double classifyInstance(Instance instance) throws Exception {
        Instance[] kNNs = this.findKNN(instance, m_k);
        int[] rate = new int[m_numClasses];
        for (short i = 0; i < m_k; i++)
            rate[ (int) kNNs[i].classValue()]++;
        int max = 0;
        for (short i = 1; i < m_numClasses; i++)
            if (rate[i] > rate[max])
                max = i;
        return max;
    }
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -