📄 filterattributeensemble.java
字号:
package fasbir.ensemblers;
import java.util.*;
import weka.classifiers.*;
import weka.core.*;
import weka.classifiers.trees.j48.*;
import weka.filters.unsupervised.attribute.*;
import weka.filters.*;
import weka.core.AdditionalMeasureProducer;
/**
* <p>Description: filter attribute ensemble, which will filter attributes holding low information gain</p>
* @author Y. Yu (yuy@lamda.nju.edu.cn), LAMDA Group (http://lamda.nju.edu.cn/)
* @version 1.0
*/
public class FilterAttributeEnsemble extends Ensemble implements AdditionalMeasureProducer{
protected Remove m_removeFilter;
//Instances m_newdataFormat;
//int[] m_selectedAttributes;
//boolean m_useBootstrap = false;
/**
* decide whether data sampling will be used during training process
* @param use boolean
*/
/*public void setUseBootstrap( boolean use ){
m_useBootstrap = use;
}*/
/**
* select attributes holding high information gain, i.e. higher that averageInfoGain/3.
* the selection process employ C45Split model to compute information gain, which is a convenient way to handle both nomial and continuous attributes
* @param instances Instances training data set
* @return int[] selected attributes. j=int[i] implies that number j of original attribute is the i-th selected one
* @throws Exception
*/
protected int[] selectSubspace( Instances instances ) throws Exception{
boolean[] attributes = new boolean[instances.numAttributes()];
double avgInformationGain = 0; // average information gain
double numValidModel = 0;
double sumOfWeight = instances.sumOfWeights();
C45Split [] attModel = new C45Split[ instances.numAttributes() ];
// for each attribute, get information gain
for(int i=0; i<instances.numAttributes(); i++){
attributes[i] = false;
if( i == instances.classIndex() ){
attModel[i] = null;
continue;
}
attModel[i] = new C45Split(i,0,sumOfWeight);
attModel[i].buildClassifier( instances );
// check available models
if( attModel[i].checkModel()){
avgInformationGain += attModel[i].infoGain();
numValidModel ++;
}else{
attModel[i] = null;
}
}
if( numValidModel==0 )
throw new Exception("no attribute available");
// average information gain
avgInformationGain /= numValidModel;
numValidModel = 0; // variable reuse
// find attributes holding high information gain
for(int i=0; i<instances.numAttributes(); i++){
if( attModel[i] != null ){
if( attModel[i].infoGain() >= avgInformationGain/3 ){
attributes[i] = true;
numValidModel ++;
}
}
}
int[] attIndex = new int[(int)numValidModel + 1];
// the class attribute should be selected
attributes[instances.classIndex()] = true;
for(int i=0,j=0; i<instances.numAttributes(); i++)
if( attributes[i] == true )
attIndex[j++] = i;
return attIndex;
}
/**
* training FilterAttributeEnsemble classifier.
* Each base learner will be trained in selected subspace.
* data set sampling could be used same while via turning setUseBootstrap to be true
* @param data Instances
* @throws Exception
*/
public void buildClassifier(Instances data) throws Exception {
Random rnd = new Random(m_Seed);
m_NumClasses = data.numClasses();
int[] m_selectedAttributes = selectSubspace( data );
// setup filter to remove unselected attributes
m_removeFilter = new Remove();
m_removeFilter.setInvertSelection( true );
m_removeFilter.setAttributeIndicesArray( m_selectedAttributes );
m_removeFilter.setInputFormat( data );
Instances newdataset = Filter.useFilter( data, m_removeFilter );
newdataset.setClassIndex( newdataset.numAttributes()-1 );
m_Classifiers = Classifier.makeCopies(m_Classifier, m_NumIterations);
for (int j = 0; j < m_Classifiers.length; j++) {
if (m_Classifier instanceof Randomizable) {
( (Randomizable) m_Classifiers[j]).setSeed(rnd.nextInt());
}
// build the classifier
m_Classifiers[j].buildClassifier(newdataset);
}
}
/**
* get the class propability distribution
* @param instance Instance
* @return double[]
* @throws Exception
*/
public double[] distributionForInstance(Instance instance) throws Exception {
// project instance to selected subspace
m_removeFilter.input(instance);
Instance newIns = m_removeFilter.output();
m_detail = new double[this.m_NumIterations];
double[] dist = new double[m_NumClasses];
// get classification result of each classifier
for(int i=0; i<m_NumIterations; i++)
dist[ (int)(m_detail[i]=m_Classifiers[i].classifyInstance(newIns)) ] ++;
for (int i = 0; i < m_NumClasses; i++)
dist[i] /= (double)m_NumIterations;
return dist;
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -