perturbation_subspace.java

来自「Description: FASBIR(Filtered Attribute S」· Java 代码 · 共 86 行

JAVA
86
字号
package fasbir.ensemblers;

import weka.classifiers.*;
import weka.core.*;
import java.util.Random;
import weka.filters.unsupervised.attribute.*;
import weka.filters.Filter;

/**
 * <p>Description: Random subspace perturbation method</p>
 * A perturbation method which project the original training dataset to a randomly selected attribute subspace.
 * @author Y. Yu (yuy@lamda.nju.edu.cn), LAMDA Group (http://lamda.nju.edu.cn/)
 * @version 1.0
 */
public class Perturbation_Subspace extends PerturbationEncapsulation{
    protected double m_SubspaceRatio = 1;
    protected Remove m_removeFilter;

    /**
     * set ratio of subspace
     * @param ratio double
     */
    public void setSubspaceRatio( double ratio ){
        this.m_SubspaceRatio = ratio;
    }
    /**
     * Generates a classifier given perturbed training dataset.
     *
     * @param data set of instances serving as training data
     * @throws Exception if the classifier has not been generated successfully
     * @todo Implement this weka.classifiers.Classifier method
     */
    public void buildClassifier(Instances dataset) throws Exception {
        Random rnd = new Random(m_Seed);
        Instances perturbedDataset = projectDataset(rnd, dataset);
        m_Classifier.buildClassifier(perturbedDataset);
    }

    protected Instances projectDataset(Random rnd, Instances dataset) throws Exception{
        int numAttributes = dataset.numAttributes() - 1; //do not includes class attribute
        int numSelectAtt = (int)Math.round(numAttributes * (1 - m_SubspaceRatio) ); // select to remove
        boolean[] isSelected = new boolean[numAttributes];
        int[] removeAttr = new int[numSelectAtt];
        for(int i=0; i<numSelectAtt; i++){
            int remove = rnd.nextInt(numAttributes-i);
            for(int j=0; j<=remove; j++)
                if( isSelected[j] ) remove++;
            isSelected[ removeAttr[i] = remove ] = true;
        }

        m_removeFilter = new Remove();
        m_removeFilter.setInvertSelection(false);
        m_removeFilter.setAttributeIndicesArray(removeAttr);
        m_removeFilter.setInputFormat(dataset);

        Instances newdataset = Filter.useFilter(dataset, m_removeFilter);
        newdataset.setClassIndex( newdataset.numAttributes() - 1 );
        return newdataset;
    }

    /**
     * classify a given instance
     * @param instance Instance
     * @return double
     * @throws Exception
     */
    public double classifyInstance(Instance instance) throws Exception {
        this.m_removeFilter.input(instance);
        Instance pinstance = this.m_removeFilter.output();
        return m_Classifier.classifyInstance(pinstance);
    }

    /**
     * class distribution is the same of base classifier's
     * @param instance Instance
     * @return double[]
     * @throws Exception
     */
    public double[] distributionForInstance(Instance instance) throws Exception {
        this.m_removeFilter.input(instance);
        Instance pinstance = this.m_removeFilter.output();
        return m_Classifier.distributionForInstance(pinstance);
    }

}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?