📄 racedincrementallogitboost.java
字号:
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* RacedIncrementalLogitBoost.java
* Copyright (C) 2002 Richard Kirkby, Eibe Frank
*
*/
package weka.classifiers.meta;
import weka.classifiers.*;
import weka.classifiers.rules.ZeroR;
import weka.core.*;
import java.util.*;
import java.io.Serializable;
/**
* Classifier for incremental learning of large datasets by way of racing logit-boosted committees.
*
* Valid options are:<p>
*
* -C num <br>
* Set the minimum chunk size (default 500). <p>
*
* -M num <br>
* Set the maximum chunk size (default 2000). <p>
*
* -V num <br>
* Set the validation set size (default 1000). <p>
*
* -D <br>
* Turn on debugging output.<p>
*
* -W classname <br>
* Specify the full class name of a weak learner as the basis for
* boosting (required).<p>
*
* -Q <br>
* Use resampling instead of reweighting.<p>
*
* -S seed <br>
* Random number seed for resampling (default 1).<p>
*
* -P type <br>
* The type of pruning to use. <p>
*
* Options after -- are passed to the designated learner.<p>
*
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @author Eibe Frank (eibe@cs.waikato.ac.nz)
* @version $Revision: 1.1 $
*/
public class RacedIncrementalLogitBoost extends RandomizableSingleClassifierEnhancer
implements UpdateableClassifier {
/** The pruning types */
public static final int PRUNETYPE_NONE = 0;
public static final int PRUNETYPE_LOGLIKELIHOOD = 1;
public static final Tag [] TAGS_PRUNETYPE = {
new Tag(PRUNETYPE_NONE, "No pruning"),
new Tag(PRUNETYPE_LOGLIKELIHOOD, "Log likelihood pruning")
};
/** The committees */
protected FastVector m_committees;
/** The pruning type used */
protected int m_PruningType = PRUNETYPE_LOGLIKELIHOOD;
/** Whether to use resampling */
protected boolean m_UseResampling = false;
/** The number of classes */
protected int m_NumClasses;
/** A threshold for responses (Friedman suggests between 2 and 4) */
protected static final double Z_MAX = 4;
/** Dummy dataset with a numeric class */
protected Instances m_NumericClassData;
/** The actual class attribute (for getting class names) */
protected Attribute m_ClassAttribute;
/** The minimum chunk size used for training */
protected int m_minChunkSize = 500;
/** The maimum chunk size used for training */
protected int m_maxChunkSize = 2000;
/** The size of the validation set */
protected int m_validationChunkSize = 1000;
/** The number of instances consumed */
protected int m_numInstancesConsumed;
/** The instances used for validation */
protected Instances m_validationSet;
/** The instances currently in memory for training */
protected Instances m_currentSet;
/** The current best committee */
protected Committee m_bestCommittee;
/** The default scheme used when committees aren't ready */
protected ZeroR m_zeroR = null;
/** Whether the validation set has recently been changed */
protected boolean m_validationSetChanged;
/** The maximum number of instances required for processing */
protected int m_maxBatchSizeRequired;
/** The random number generator used */
protected Random m_RandomInstance = null;
/**
* Constructor.
*/
public RacedIncrementalLogitBoost() {
m_Classifier = new weka.classifiers.trees.DecisionStump();
}
/**
* String describing default classifier.
*/
protected String defaultClassifierString() {
return "weka.classifiers.trees.DecisionStump";
}
/* Class representing a committee of LogitBoosted models */
protected class Committee implements Serializable {
protected int m_chunkSize;
protected int m_instancesConsumed; // number eaten from m_currentSet
protected FastVector m_models;
protected double m_lastValidationError;
protected double m_lastLogLikelihood;
protected boolean m_modelHasChanged;
protected boolean m_modelHasChangedLL;
protected double[][] m_validationFs;
protected double[][] m_newValidationFs;
/* constructor */
public Committee(int chunkSize) {
m_chunkSize = chunkSize;
m_instancesConsumed = 0;
m_models = new FastVector();
m_lastValidationError = 1.0;
m_lastLogLikelihood = Double.MAX_VALUE;
m_modelHasChanged = true;
m_modelHasChangedLL = true;
m_validationFs = new double[m_validationChunkSize][m_NumClasses];
m_newValidationFs = new double[m_validationChunkSize][m_NumClasses];
}
/* update the committee */
public boolean update() throws Exception {
boolean hasChanged = false;
while (m_currentSet.numInstances() - m_instancesConsumed >= m_chunkSize) {
Classifier[] newModel = boost(new Instances(m_currentSet, m_instancesConsumed, m_chunkSize));
for (int i=0; i<m_validationSet.numInstances(); i++) {
m_newValidationFs[i] = updateFS(m_validationSet.instance(i), newModel, m_validationFs[i]);
}
m_models.addElement(newModel);
m_instancesConsumed += m_chunkSize;
hasChanged = true;
}
if (hasChanged) {
m_modelHasChanged = true;
m_modelHasChangedLL = true;
}
return hasChanged;
}
/* reset consumation counts */
public void resetConsumed() {
m_instancesConsumed = 0;
}
/* remove the last model from the committee */
public void pruneLastModel() {
if (m_models.size() > 0) {
m_models.removeElementAt(m_models.size()-1);
m_modelHasChanged = true;
m_modelHasChangedLL = true;
}
}
/* decide to keep the last model in the committee */
public void keepLastModel() throws Exception {
m_validationFs = m_newValidationFs;
m_newValidationFs = new double[m_validationChunkSize][m_NumClasses];
m_modelHasChanged = true;
m_modelHasChangedLL = true;
}
/* calculate the log likelihood on the validation data */
public double logLikelihood() throws Exception {
if (m_modelHasChangedLL) {
Instance inst;
double llsum = 0.0;
for (int i=0; i<m_validationSet.numInstances(); i++) {
inst = m_validationSet.instance(i);
llsum += (logLikelihood(m_validationFs[i],(int) inst.classValue()));
}
m_lastLogLikelihood = llsum / (double) m_validationSet.numInstances();
m_modelHasChangedLL = false;
}
return m_lastLogLikelihood;
}
/* calculate the log likelihood on the validation data after adding the last model */
public double logLikelihoodAfter() throws Exception {
Instance inst;
double llsum = 0.0;
for (int i=0; i<m_validationSet.numInstances(); i++) {
inst = m_validationSet.instance(i);
llsum += (logLikelihood(m_newValidationFs[i],(int) inst.classValue()));
}
return llsum / (double) m_validationSet.numInstances();
}
/* calculates the log likelihood of an instance */
private double logLikelihood(double[] Fs, int classIndex) throws Exception {
return -Math.log(distributionForInstance(Fs)[classIndex]);
}
/* calculates the validation error of the committee */
public double validationError() throws Exception {
if (m_modelHasChanged) {
Instance inst;
int numIncorrect = 0;
for (int i=0; i<m_validationSet.numInstances(); i++) {
inst = m_validationSet.instance(i);
if (classifyInstance(m_validationFs[i]) != inst.classValue())
numIncorrect++;
}
m_lastValidationError = (double) numIncorrect / (double) m_validationSet.numInstances();
m_modelHasChanged = false;
}
return m_lastValidationError;
}
/* returns the chunk size used by the committee */
public int chunkSize() {
return m_chunkSize;
}
/* returns the number of models in the committee */
public int committeeSize() {
return m_models.size();
}
/* classifies an instance (given Fs values) with the committee */
public double classifyInstance(double[] Fs) throws Exception {
double [] dist = distributionForInstance(Fs);
double max = 0;
int maxIndex = 0;
for (int i = 0; i < dist.length; i++) {
if (dist[i] > max) {
maxIndex = i;
max = dist[i];
}
}
if (max > 0) {
return maxIndex;
} else {
return Instance.missingValue();
}
}
/* classifies an instance with the committee */
public double classifyInstance(Instance instance) throws Exception {
double [] dist = distributionForInstance(instance);
switch (instance.classAttribute().type()) {
case Attribute.NOMINAL:
double max = 0;
int maxIndex = 0;
for (int i = 0; i < dist.length; i++) {
if (dist[i] > max) {
maxIndex = i;
max = dist[i];
}
}
if (max > 0) {
return maxIndex;
} else {
return Instance.missingValue();
}
case Attribute.NUMERIC:
return dist[0];
default:
return Instance.missingValue();
}
}
/* returns the distribution the committee generates for an instance (given Fs values) */
public double[] distributionForInstance(double[] Fs) throws Exception {
double [] distribution = new double [m_NumClasses];
for (int j = 0; j < m_NumClasses; j++) {
distribution[j] = RtoP(Fs, j);
}
return distribution;
}
/* updates the Fs values given a new model in the committee */
public double[] updateFS(Instance instance, Classifier[] newModel, double[] Fs) throws Exception {
instance = (Instance)instance.copy();
instance.setDataset(m_NumericClassData);
double [] Fi = new double [m_NumClasses];
double Fsum = 0;
for (int j = 0; j < m_NumClasses; j++) {
Fi[j] = newModel[j].classifyInstance(instance);
Fsum += Fi[j];
}
Fsum /= m_NumClasses;
double[] newFs = new double[Fs.length];
for (int j = 0; j < m_NumClasses; j++) {
newFs[j] = Fs[j] + ((Fi[j] - Fsum) * (m_NumClasses - 1) / m_NumClasses);
}
return newFs;
}
/* returns the distribution the committee generates for an instance */
public double[] distributionForInstance(Instance instance) throws Exception {
instance = (Instance)instance.copy();
instance.setDataset(m_NumericClassData);
double [] Fs = new double [m_NumClasses];
for (int i = 0; i < m_models.size(); i++) {
double [] Fi = new double [m_NumClasses];
double Fsum = 0;
Classifier[] model = (Classifier[]) m_models.elementAt(i);
for (int j = 0; j < m_NumClasses; j++) {
Fi[j] = model[j].classifyInstance(instance);
Fsum += Fi[j];
}
Fsum /= m_NumClasses;
for (int j = 0; j < m_NumClasses; j++) {
Fs[j] += (Fi[j] - Fsum) * (m_NumClasses - 1) / m_NumClasses;
}
}
double [] distribution = new double [m_NumClasses];
for (int j = 0; j < m_NumClasses; j++) {
distribution[j] = RtoP(Fs, j);
}
return distribution;
}
/* performs a boosting iteration, returning a new model for the committee */
protected Classifier[] boost(Instances data) throws Exception {
Classifier[] newModel = Classifier.makeCopies(m_Classifier, m_NumClasses);
// Create a copy of the data with the class transformed into numeric
Instances boostData = new Instances(data);
boostData.deleteWithMissingClass();
int numInstances = boostData.numInstances();
// Temporarily unset the class index
int classIndex = data.classIndex();
boostData.setClassIndex(-1);
boostData.deleteAttributeAt(classIndex);
boostData.insertAttributeAt(new Attribute("'pseudo class'"), classIndex);
boostData.setClassIndex(classIndex);
double [][] trainFs = new double [numInstances][m_NumClasses];
double [][] trainYs = new double [numInstances][m_NumClasses];
for (int j = 0; j < m_NumClasses; j++) {
for (int i = 0, k = 0; i < numInstances; i++, k++) {
while (data.instance(k).classIsMissing()) k++;
trainYs[i][j] = (data.instance(k).classValue() == j) ? 1 : 0;
}
}
// Evaluate / increment trainFs from the classifiers
for (int x = 0; x < m_models.size(); x++) {
for (int i = 0; i < numInstances; i++) {
double [] pred = new double [m_NumClasses];
double predSum = 0;
Classifier[] model = (Classifier[]) m_models.elementAt(x);
for (int j = 0; j < m_NumClasses; j++) {
pred[j] = model[j].classifyInstance(boostData.instance(i));
predSum += pred[j];
}
predSum /= m_NumClasses;
for (int j = 0; j < m_NumClasses; j++) {
trainFs[i][j] += (pred[j] - predSum) * (m_NumClasses-1)
/ m_NumClasses;
}
}
}
for (int j = 0; j < m_NumClasses; j++) {
// Set instance pseudoclass and weights
for (int i = 0; i < numInstances; i++) {
double p = RtoP(trainFs[i], j);
Instance current = boostData.instance(i);
double z, actual = trainYs[i][j];
if (actual == 1) {
z = 1.0 / p;
if (z > Z_MAX) { // threshold
z = Z_MAX;
}
} else if (actual == 0) {
z = -1.0 / (1.0 - p);
if (z < -Z_MAX) { // threshold
z = -Z_MAX;
}
} else {
z = (actual - p) / (p * (1 - p));
}
double w = (actual - p) / z;
current.setValue(classIndex, z);
current.setWeight(numInstances * w);
}
Instances trainData = boostData;
if (m_UseResampling) {
double[] weights = new double[boostData.numInstances()];
for (int kk = 0; kk < weights.length; kk++) {
weights[kk] = boostData.instance(kk).weight();
}
trainData = boostData.resampleWithWeights(m_RandomInstance,
weights);
}
// Build the classifier
newModel[j].buildClassifier(trainData);
}
return newModel;
}
/* outputs description of the committee */
public String toString() {
StringBuffer text = new StringBuffer();
text.append("RacedIncrementalLogitBoost: Best committee on validation data\n");
text.append("Base classifiers: \n");
for (int i = 0; i < m_models.size(); i++) {
text.append("\nModel "+(i+1));
Classifier[] cModels = (Classifier[]) m_models.elementAt(i);
for (int j = 0; j < m_NumClasses; j++) {
text.append("\n\tClass " + (j + 1)
+ " (" + m_ClassAttribute.name()
+ "=" + m_ClassAttribute.value(j) + ")\n\n"
+ cModels[j].toString() + "\n");
}
}
text.append("Number of models: " +
m_models.size() + "\n");
text.append("Chunk size per model: " + m_chunkSize + "\n");
return text.toString();
}
}
/**
* Builds the classifier.
*
* @param instances the instances to train the classifier with
* @exception Exception if something goes wrong
*/
public void buildClassifier(Instances data) throws Exception {
m_RandomInstance = new Random(m_Seed);
Instances boostData;
int classIndex = data.classIndex();
if (data.classAttribute().isNumeric()) {
throw new Exception("RacedIncrementalLogitBoost can't handle a numeric class!");
}
if (m_Classifier == null) {
throw new Exception("A base classifier has not been specified!");
}
if (!(m_Classifier instanceof WeightedInstancesHandler) &&
!m_UseResampling) {
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -