📄 adaboostm1.java
字号:
* Boosting method.
*
* @param data the training data to be used for generating the
* boosted classifier.
* @exception Exception if the classifier could not be built successfully
*/
public void buildClassifier(Instances data) throws Exception {
super.buildClassifier(data);
if (data.checkForStringAttributes()) {
throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");
}
data = new Instances(data);
data.deleteWithMissingClass();
if (data.numInstances() == 0) {
throw new Exception("No train instances without class missing!");
}
if (data.classAttribute().isNumeric()) {
throw new UnsupportedClassTypeException("AdaBoostM1 can't handle a numeric class!");
}
m_NumClasses = data.numClasses();
if ((!m_UseResampling) &&
(m_Classifier instanceof WeightedInstancesHandler)) {
buildClassifierWithWeights(data);
} else {
buildClassifierUsingResampling(data);
}
}
/**
* Boosting method. Boosts using resampling
*
* @param data the training data to be used for generating the
* boosted classifier.
* @exception Exception if the classifier could not be built successfully
*/
protected void buildClassifierUsingResampling(Instances data)
throws Exception {
Instances trainData, sample, training;
double epsilon, reweight, beta = 0, sumProbs;
Evaluation evaluation;
int numInstances = data.numInstances();
Random randomInstance = new Random(m_Seed);
double[] probabilities;
int resamplingIterations = 0;
int k, l;
// Initialize data
m_Betas = new double [m_Classifiers.length];
m_NumIterationsPerformed = 0;
// Create a copy of the data so that when the weights are diddled
// with it doesn't mess up the weights for anyone else
training = new Instances(data, 0, numInstances);
sumProbs = training.sumOfWeights();
for (int i = 0; i < training.numInstances(); i++) {
training.instance(i).setWeight(training.instance(i).
weight() / sumProbs);
}
// Do boostrap iterations
for (m_NumIterationsPerformed = 0; m_NumIterationsPerformed < m_Classifiers.length;
m_NumIterationsPerformed++) {
if (m_Debug) {
System.err.println("Training classifier " + (m_NumIterationsPerformed + 1));
}
// Select instances to train the classifier on
if (m_WeightThreshold < 100) {
trainData = selectWeightQuantile(training,
(double)m_WeightThreshold / 100);
} else {
trainData = new Instances(training);
}
// Resample
resamplingIterations = 0;
double[] weights = new double[trainData.numInstances()];
for (int i = 0; i < weights.length; i++) {
weights[i] = trainData.instance(i).weight();
}
do {
sample = trainData.resampleWithWeights(randomInstance, weights);
// Build and evaluate classifier
m_Classifiers[m_NumIterationsPerformed].buildClassifier(sample);
evaluation = new Evaluation(data);
evaluation.evaluateModel(m_Classifiers[m_NumIterationsPerformed],
training);
epsilon = evaluation.errorRate();
resamplingIterations++;
} while (Utils.eq(epsilon, 0) &&
(resamplingIterations < MAX_NUM_RESAMPLING_ITERATIONS));
// Stop if error too big or 0
if (Utils.grOrEq(epsilon, 0.5) || Utils.eq(epsilon, 0)) {
if (m_NumIterationsPerformed == 0) {
m_NumIterationsPerformed = 1; // If we're the first we have to to use it
}
break;
}
// Determine the weight to assign to this model
m_Betas[m_NumIterationsPerformed] = beta = Math.log((1 - epsilon) / epsilon);
reweight = (1 - epsilon) / epsilon;
if (m_Debug) {
System.err.println("\terror rate = " + epsilon
+" beta = " + m_Betas[m_NumIterationsPerformed]);
}
// Update instance weights
setWeights(training, reweight);
}
}
/**
* Sets the weights for the next iteration.
*/
protected void setWeights(Instances training, double reweight)
throws Exception {
double oldSumOfWeights, newSumOfWeights;
oldSumOfWeights = training.sumOfWeights();
Enumeration em = training.emerateInstances();
while (em.hasMoreElements()) {
Instance instance = (Instance) em.nextElement();
if (!Utils.eq(m_Classifiers[m_NumIterationsPerformed].classifyInstance(instance),
instance.classValue()))
instance.setWeight(instance.weight() * reweight);
}
// Renormalize weights
newSumOfWeights = training.sumOfWeights();
em = training.emerateInstances();
while (em.hasMoreElements()) {
Instance instance = (Instance) em.nextElement();
instance.setWeight(instance.weight() * oldSumOfWeights
/ newSumOfWeights);
}
}
/**
* Boosting method. Boosts any classifier that can handle weighted
* instances.
*
* @param data the training data to be used for generating the
* boosted classifier.
* @exception Exception if the classifier could not be built successfully
*/
protected void buildClassifierWithWeights(Instances data)
throws Exception {
Instances trainData, training;
double epsilon, reweight, beta = 0;
double oldSumOfWeights, newSumOfWeights;
Evaluation evaluation;
int numInstances = data.numInstances();
// Initialize data
m_Betas = new double [m_Classifiers.length];
m_NumIterationsPerformed = 0;
// Create a copy of the data so that when the weights are diddled
// with it doesn't mess up the weights for anyone else
training = new Instances(data, 0, numInstances);
// Do boostrap iterations
for (m_NumIterationsPerformed = 0; m_NumIterationsPerformed < m_Classifiers.length;
m_NumIterationsPerformed++) {
if (m_Debug) {
System.err.println("Training classifier " + (m_NumIterationsPerformed + 1));
}
// Select instances to train the classifier on
if (m_WeightThreshold < 100) {
trainData = selectWeightQuantile(training,
(double)m_WeightThreshold / 100);
} else {
trainData = new Instances(training, 0, numInstances);
}
// Build the classifier
m_Classifiers[m_NumIterationsPerformed].buildClassifier(trainData);
// Evaluate the classifier
evaluation = new Evaluation(data);
evaluation.evaluateModel(m_Classifiers[m_NumIterationsPerformed], training);
epsilon = evaluation.errorRate();
// Stop if error too small or error too big and ignore this model
if (Utils.grOrEq(epsilon, 0.5) || Utils.eq(epsilon, 0)) {
if (m_NumIterationsPerformed == 0) {
m_NumIterationsPerformed = 1; // If we're the first we have to to use it
}
break;
}
// Determine the weight to assign to this model
m_Betas[m_NumIterationsPerformed] = beta = Math.log((1 - epsilon) / epsilon);
reweight = (1 - epsilon) / epsilon;
if (m_Debug) {
System.err.println("\terror rate = " + epsilon
+" beta = " + m_Betas[m_NumIterationsPerformed]);
}
// Update instance weights
setWeights(training, reweight);
}
}
/**
* Calculates the class membership probabilities for the given test instance.
*
* @param instance the instance to be classified
* @return predicted class probability distribution
* @exception Exception if instance could not be classified
* successfully
*/
public double [] distributionForInstance(Instance instance)
throws Exception {
if (m_NumIterationsPerformed == 0) {
throw new Exception("No model built");
}
double [] sums = new double [instance.numClasses()];
if (m_NumIterationsPerformed == 1) {
return m_Classifiers[0].distributionForInstance(instance);
} else {
for (int i = 0; i < m_NumIterationsPerformed; i++) {
sums[(int)m_Classifiers[i].classifyInstance(instance)] += m_Betas[i];
}
return Utils.logs2probs(sums);
}
}
/**
* Returns the boosted model as Java source code.
*
* @return the tree as Java source code
* @exception Exception if something goes wrong
*/
public String toSource(String className) throws Exception {
if (m_NumIterationsPerformed == 0) {
throw new Exception("No model built yet");
}
if (!(m_Classifiers[0] instanceof Sourcable)) {
throw new Exception("Base learner " + m_Classifier.getClass().getName()
+ " is not Sourcable");
}
StringBuffer text = new StringBuffer("class ");
text.append(className).append(" {\n\n");
text.append(" public static double classify(Object [] i) {\n");
if (m_NumIterationsPerformed == 1) {
text.append(" return " + className + "_0.classify(i);\n");
} else {
text.append(" double [] sums = new double [" + m_NumClasses + "];\n");
for (int i = 0; i < m_NumIterationsPerformed; i++) {
text.append(" sums[(int) " + className + '_' + i
+ ".classify(i)] += " + m_Betas[i] + ";\n");
}
text.append(" double maxV = sums[0];\n" +
" int maxI = 0;\n"+
" for (int j = 1; j < " + m_NumClasses + "; j++) {\n"+
" if (sums[j] > maxV) { maxV = sums[j]; maxI = j; }\n"+
" }\n return (double) maxI;\n");
}
text.append(" }\n}\n");
for (int i = 0; i < m_Classifiers.length; i++) {
text.append(((Sourcable)m_Classifiers[i])
.toSource(className + '_' + i));
}
return text.toString();
}
/**
* Returns description of the boosted classifier.
*
* @return description of the boosted classifier as a string
*/
public String toString() {
StringBuffer text = new StringBuffer();
if (m_NumIterationsPerformed == 0) {
text.append("AdaBoostM1: No model built yet.\n");
} else if (m_NumIterationsPerformed == 1) {
text.append("AdaBoostM1: No boosting possible, one classifier used!\n");
text.append(m_Classifiers[0].toString() + "\n");
} else {
text.append("AdaBoostM1: Base classifiers and their weights: \n\n");
for (int i = 0; i < m_NumIterationsPerformed ; i++) {
text.append(m_Classifiers[i].toString() + "\n\n");
text.append("Weight: " + Utils.roundDouble(m_Betas[i], NumberFormatter.MAX_FRACTION_DIGIT) + "\n\n");
}
text.append("Number of performed Iterations: "
+ m_NumIterationsPerformed + "\n");
}
return text.toString();
}
/**
* Main method for testing this class.
*
* @param argv the options
*/
public static void main(String [] argv) {
try {
System.out.println(Evaluation.evaluateModel(new AdaBoostM1(), argv));
} catch (Exception e) {
System.err.println(e.getMessage());
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -