⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 ensemble.java

📁 Description: FASBIR(Filtered Attribute Subspace based Bagging with Injected Randomness) is a variant
💻 JAVA
字号:
package fasbir.ensemblers;

import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;

import weka.classifiers.Classifier;
import weka.classifiers.meta.Bagging;
import weka.core.*;

/**
 * <p>Description: Ensemble pakage, compatible with weka</p>
 * @author Y. Yu (yuy@lamda.nju.edu.cn), LAMDA Group (http://lamda.nju.edu.cn/)
 * @version 1.0
 */

public class Ensemble
    extends Classifier implements EnsembleClassifier, AdditionalMeasureProducer, Randomizable {

    protected int m_NumClasses;   // store number of classes of training dataset
    protected double[] m_detail;  // store classification result of every base learners
    protected int m_Seed;
    protected Classifier[] m_Classifiers;
    protected Classifier m_Classifier;
    protected int m_NumIterations;

    /**
     * set randomzation seed
     * @param seed int
     */
    public void setSeed( int seed){
        this.m_Seed = seed;
    }
    /**
     * get randomzation seed
     * @return int
     */
    public int getSeed( ){
        return m_Seed;
    }
    /**
     * set base classifier of the ensemble
     * @param baseclassifier Classifier
     */
    public void setClassifier(Classifier baseclassifier){
        this.m_Classifier = baseclassifier;
    }
    /**
     * set number of base classifiers the ensemble will generate, which is usually call ensemble size.
     * @param iterations int
     */
    public void setNumIterations(int iterations ){
        this.m_NumIterations = iterations;
    }
    /**
     * Ensemble training method.
     * it will merely train every base learners with entire training data.
     *
     * @param data Instances the training data
     * @exception Exception if the classifier could not be built successfully
     */
    public void buildClassifier(Instances data) throws Exception {

        m_NumClasses = data.numClasses();

        Random rnd = new Random(m_Seed);

        // m_Classifiers is inherited from Bagging
        m_Classifiers = Classifier.makeCopies(m_Classifier, m_NumIterations);

        for (int j = 0; j < m_Classifiers.length; j++) {
            if (m_Classifier instanceof Randomizable) {
                ( (Randomizable) m_Classifiers[j]).setSeed(rnd.nextInt());
            }
            // build classifiers
            m_Classifiers[j].buildClassifier(data);
        }
    }

    /**
     *
     * @param instance Instance instance to be classified
     * @return double classification code
     * @throws Exception if the classifier could not classify the instance successfully
     */
    public double classifyInstance(Instance instance) throws Exception {
        double[] dist =  distributionForInstance( instance );
        int max = 0;
        for (int i = 1; i < m_NumClasses; i++)
            if (dist[i] > dist[max])
                max = i;

        // return classification result
        return max;
    }

    /**
     * get the class propability distribution
     * @param instance Instance
     * @return double[]
     * @throws Exception
     */
    public double[] distributionForInstance(Instance instance) throws Exception {
        m_detail = new double[this.m_NumIterations];
        double[] dist = new double[m_NumClasses];
        // get classification result of each classifier
        for (int i = 0; i < m_NumIterations; i++)
            dist[ (int) (m_detail[i] = m_Classifiers[i].classifyInstance(instance))]++;
        for (int i = 0; i < m_NumClasses; i++)
            dist[i] /= (double)m_NumIterations;
        return dist;
    }

    /**
     * get classification result of every base learners.
     *
     * @return double[] classification result of every base learners.
     */
    public double[] getDetailClassificationResult() {
        return m_detail;
    }

    public Enumeration enumerateMeasures() {
        Vector newVector = new Vector(0);
        return newVector.elements();
    }

    public double getMeasure(String additionalMeasureName) {
        // do not support any extra measure
        throw new IllegalArgumentException(additionalMeasureName + " is not supported by Ensemble class.");
    }
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -