📄 ensemble.java
字号:
package fasbir.ensemblers;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.meta.Bagging;
import weka.core.*;
/**
* <p>Description: Ensemble pakage, compatible with weka</p>
* @author Y. Yu (yuy@lamda.nju.edu.cn), LAMDA Group (http://lamda.nju.edu.cn/)
* @version 1.0
*/
public class Ensemble
extends Classifier implements EnsembleClassifier, AdditionalMeasureProducer, Randomizable {
protected int m_NumClasses; // store number of classes of training dataset
protected double[] m_detail; // store classification result of every base learners
protected int m_Seed;
protected Classifier[] m_Classifiers;
protected Classifier m_Classifier;
protected int m_NumIterations;
/**
* set randomzation seed
* @param seed int
*/
public void setSeed( int seed){
this.m_Seed = seed;
}
/**
* get randomzation seed
* @return int
*/
public int getSeed( ){
return m_Seed;
}
/**
* set base classifier of the ensemble
* @param baseclassifier Classifier
*/
public void setClassifier(Classifier baseclassifier){
this.m_Classifier = baseclassifier;
}
/**
* set number of base classifiers the ensemble will generate, which is usually call ensemble size.
* @param iterations int
*/
public void setNumIterations(int iterations ){
this.m_NumIterations = iterations;
}
/**
* Ensemble training method.
* it will merely train every base learners with entire training data.
*
* @param data Instances the training data
* @exception Exception if the classifier could not be built successfully
*/
public void buildClassifier(Instances data) throws Exception {
m_NumClasses = data.numClasses();
Random rnd = new Random(m_Seed);
// m_Classifiers is inherited from Bagging
m_Classifiers = Classifier.makeCopies(m_Classifier, m_NumIterations);
for (int j = 0; j < m_Classifiers.length; j++) {
if (m_Classifier instanceof Randomizable) {
( (Randomizable) m_Classifiers[j]).setSeed(rnd.nextInt());
}
// build classifiers
m_Classifiers[j].buildClassifier(data);
}
}
/**
*
* @param instance Instance instance to be classified
* @return double classification code
* @throws Exception if the classifier could not classify the instance successfully
*/
public double classifyInstance(Instance instance) throws Exception {
double[] dist = distributionForInstance( instance );
int max = 0;
for (int i = 1; i < m_NumClasses; i++)
if (dist[i] > dist[max])
max = i;
// return classification result
return max;
}
/**
* get the class propability distribution
* @param instance Instance
* @return double[]
* @throws Exception
*/
public double[] distributionForInstance(Instance instance) throws Exception {
m_detail = new double[this.m_NumIterations];
double[] dist = new double[m_NumClasses];
// get classification result of each classifier
for (int i = 0; i < m_NumIterations; i++)
dist[ (int) (m_detail[i] = m_Classifiers[i].classifyInstance(instance))]++;
for (int i = 0; i < m_NumClasses; i++)
dist[i] /= (double)m_NumIterations;
return dist;
}
/**
* get classification result of every base learners.
*
* @return double[] classification result of every base learners.
*/
public double[] getDetailClassificationResult() {
return m_detail;
}
public Enumeration enumerateMeasures() {
Vector newVector = new Vector(0);
return newVector.elements();
}
public double getMeasure(String additionalMeasureName) {
// do not support any extra measure
throw new IllegalArgumentException(additionalMeasureName + " is not supported by Ensemble class.");
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -