📄 racedincrementallogitboost.java
字号:
m_UseResampling = true;
}
if (data.checkForStringAttributes()) {
throw new Exception("Can't handle string attributes!");
}
m_NumClasses = data.numClasses();
m_ClassAttribute = data.classAttribute();
// Create a copy of the data with the class transformed into numeric
boostData = new Instances(data);
boostData.deleteWithMissingClass();
// Temporarily unset the class index
boostData.setClassIndex(-1);
boostData.deleteAttributeAt(classIndex);
boostData.insertAttributeAt(new Attribute("'pseudo class'"), classIndex);
boostData.setClassIndex(classIndex);
m_NumericClassData = new Instances(boostData, 0);
data = new Instances(data);
data.randomize(m_RandomInstance);
// create the committees
int cSize = m_minChunkSize;
m_committees = new FastVector();
while (cSize <= m_maxChunkSize) {
m_committees.addElement(new Committee(cSize));
m_maxBatchSizeRequired = cSize;
cSize *= 2;
}
// set up for consumption
m_validationSet = new Instances(data, m_validationChunkSize);
m_currentSet = new Instances(data, m_maxBatchSizeRequired);
m_bestCommittee = null;
m_numInstancesConsumed = 0;
// start eating what we've been given
for (int i=0; i<data.numInstances(); i++) updateClassifier(data.instance(i));
}
/**
* Updates the classifier.
*
* @param instance the next instance in the stream of training data
* @exception Exception if something goes wrong
*/
public void updateClassifier(Instance instance) throws Exception {
m_numInstancesConsumed++;
if (m_validationSet.numInstances() < m_validationChunkSize) {
m_validationSet.add(instance);
m_validationSetChanged = true;
} else {
m_currentSet.add(instance);
boolean hasChanged = false;
// update each committee
for (int i=0; i<m_committees.size(); i++) {
Committee c = (Committee) m_committees.elementAt(i);
if (c.update()) {
hasChanged = true;
if (m_PruningType == PRUNETYPE_LOGLIKELIHOOD) {
double oldLL = c.logLikelihood();
double newLL = c.logLikelihoodAfter();
if (newLL >= oldLL && c.committeeSize() > 1) {
c.pruneLastModel();
if (m_Debug) System.out.println("Pruning " + c.chunkSize()+ " committee (" +
oldLL + " < " + newLL + ")");
} else c.keepLastModel();
} else c.keepLastModel(); // no pruning
}
}
if (hasChanged) {
if (m_Debug) System.out.println("After consuming " + m_numInstancesConsumed
+ " instances... (" + m_validationSet.numInstances()
+ " + " + m_currentSet.numInstances()
+ " instances currently in memory)");
// find best committee
double lowestError = 1.0;
for (int i=0; i<m_committees.size(); i++) {
Committee c = (Committee) m_committees.elementAt(i);
if (c.committeeSize() > 0) {
double err = c.validationError();
double ll = c.logLikelihood();
if (m_Debug) System.out.println("Chunk size " + c.chunkSize() + " with "
+ c.committeeSize() + " models, has validation error of "
+ err + ", log likelihood of " + ll);
if (err < lowestError) {
lowestError = err;
m_bestCommittee = c;
}
}
}
}
if (m_currentSet.numInstances() >= m_maxBatchSizeRequired) {
m_currentSet = new Instances(m_currentSet, m_maxBatchSizeRequired);
// reset consumation counts
for (int i=0; i<m_committees.size(); i++) {
Committee c = (Committee) m_committees.elementAt(i);
c.resetConsumed();
}
}
}
}
/**
* Convert from function responses to probabilities
*
* @param R an array containing the responses from each function
* @param j the class value of interest
* @return the probability prediction for j
*/
protected static double RtoP(double []Fs, int j)
throws Exception {
double maxF = -Double.MAX_VALUE;
for (int i = 0; i < Fs.length; i++) {
if (Fs[i] > maxF) {
maxF = Fs[i];
}
}
double sum = 0;
double[] probs = new double[Fs.length];
for (int i = 0; i < Fs.length; i++) {
probs[i] = Math.exp(Fs[i] - maxF);
sum += probs[i];
}
if (sum == 0) {
throw new Exception("Can't normalize");
}
return probs[j] / sum;
}
/**
* Computes class distribution of an instance using the best committee.
*/
public double[] distributionForInstance(Instance instance) throws Exception {
if (m_bestCommittee != null) return m_bestCommittee.distributionForInstance(instance);
else {
if (m_validationSetChanged || m_zeroR == null) {
m_zeroR = new ZeroR();
m_zeroR.buildClassifier(m_validationSet);
m_validationSetChanged = false;
}
return m_zeroR.distributionForInstance(instance);
}
}
/**
* Returns an enumeration describing the available options
*
* @return an enumeration of all the available options
*/
public Enumeration listOptions() {
Vector newVector = new Vector(9);
newVector.addElement(new Option(
"\tMinimum size of chunks.\n"
+"\t(default 500)",
"C", 1, "-C <num>"));
newVector.addElement(new Option(
"\tMaximum size of chunks.\n"
+"\t(default 2000)",
"M", 1, "-M <num>"));
newVector.addElement(new Option(
"\tSize of validation set.\n"
+"\t(default 1000)",
"V", 1, "-V <num>"));
newVector.addElement(new Option(
"\tCommittee pruning to perform.\n"
+"\t0=none, 1=log likelihood (default)",
"P", 1, "-P <pruning type>"));
newVector.addElement(new Option(
"\tUse resampling for boosting.",
"Q", 0, "-Q"));
Enumeration enu = super.listOptions();
while (enu.hasMoreElements()) {
newVector.addElement(enu.nextElement());
}
return newVector.elements();
}
/**
* Parses a given list of options. Valid options are:<p>
*
* @param options the list of options as an array of strings
* @exception Exception if an option is not supported
*/
public void setOptions(String[] options) throws Exception {
String minChunkSize = Utils.getOption('C', options);
if (minChunkSize.length() != 0) {
setMinChunkSize(Integer.parseInt(minChunkSize));
} else {
setMinChunkSize(500);
}
String maxChunkSize = Utils.getOption('M', options);
if (maxChunkSize.length() != 0) {
setMaxChunkSize(Integer.parseInt(maxChunkSize));
} else {
setMaxChunkSize(2000);
}
String validationChunkSize = Utils.getOption('V', options);
if (validationChunkSize.length() != 0) {
setValidationChunkSize(Integer.parseInt(validationChunkSize));
} else {
setValidationChunkSize(1000);
}
String pruneType = Utils.getOption('P', options);
if (pruneType.length() != 0) {
setPruningType(new SelectedTag(Integer.parseInt(pruneType), TAGS_PRUNETYPE));
} else {
setPruningType(new SelectedTag(PRUNETYPE_LOGLIKELIHOOD, TAGS_PRUNETYPE));
}
setUseResampling(Utils.getFlag('Q', options));
super.setOptions(options);
}
/**
* Gets the current settings of the Classifier.
*
* @return an array of strings suitable for passing to setOptions
*/
public String [] getOptions() {
String [] superOptions = super.getOptions();
String [] options = new String [superOptions.length + 9];
int current = 0;
if (getUseResampling()) {
options[current++] = "-Q";
}
options[current++] = "-C"; options[current++] = "" + getMinChunkSize();
options[current++] = "-M"; options[current++] = "" + getMaxChunkSize();
options[current++] = "-V"; options[current++] = "" + getValidationChunkSize();
options[current++] = "-P"; options[current++] = "" + m_PruningType;
System.arraycopy(superOptions, 0, options, current,
superOptions.length);
current += superOptions.length;
while (current < options.length) {
options[current++] = "";
}
return options;
}
/**
* @return a description of the classifier suitable for
* displaying in the explorer/experimenter gui
*/
public String globalInfo() {
return "Classifier for incremental learning of large datasets by way of racing logit-boosted committees.";
}
/**
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String minChunkSizeTipText() {
return "The minimum number of instances to train the base learner with.";
}
/**
* Set the minimum chunk size
*
* @param chunkSize
*/
public void setMinChunkSize(int chunkSize) {
m_minChunkSize = chunkSize;
}
/**
* Get the minimum chunk size
*
* @return the chunk size
*/
public int getMinChunkSize() {
return m_minChunkSize;
}
/**
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String maxChunkSizeTipText() {
return "The maximum number of instances to train the base learner with. The chunk sizes used will start at minChunkSize and grow twice as large for as many times as they are less than or equal to the maximum size.";
}
/**
* Set the maximum chunk size
*
* @param chunkSize
*/
public void setMaxChunkSize(int chunkSize) {
m_maxChunkSize = chunkSize;
}
/**
* Get the maximum chunk size
*
* @return the chunk size
*/
public int getMaxChunkSize() {
return m_maxChunkSize;
}
/**
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String validationChunkSizeTipText() {
return "The number of instances to hold out for validation. These instances will be taken from the beginning of the stream, so learning will not start until these instances have been consumed first.";
}
/**
* Set the validation chunk size
*
* @param chunkSize
*/
public void setValidationChunkSize(int chunkSize) {
m_validationChunkSize = chunkSize;
}
/**
* Get the validation chunk size
*
* @return the chunk size
*/
public int getValidationChunkSize() {
return m_validationChunkSize;
}
/**
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String pruningTypeTipText() {
return "The pruning method to use within each committee. Log likelihood pruning will discard new models if they have a negative effect on the log likelihood of the validation data.";
}
/**
* Set the pruning type
*
* @param pruneType
*/
public void setPruningType(SelectedTag pruneType) {
if (pruneType.getTags() == TAGS_PRUNETYPE) {
m_PruningType = pruneType.getSelectedTag().getID();
}
}
/**
* Get the pruning type
*
* @return the type
*/
public SelectedTag getPruningType() {
return new SelectedTag(m_PruningType, TAGS_PRUNETYPE);
}
/**
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String useResamplingTipText() {
return "Force the use of resampling data rather than using the weight-handling capabilities of the base classifier. Resampling is always used if the base classifier cannot handle weighted instances.";
}
/**
* Set resampling mode
*
* @param resampling true if resampling should be done
*/
public void setUseResampling(boolean r) {
m_UseResampling = r;
}
/**
* Get whether resampling is turned on
*
* @return true if resampling output is on
*/
public boolean getUseResampling() {
return m_UseResampling;
}
/**
* Get the best committee chunk size
*/
public int getBestCommitteeChunkSize() {
if (m_bestCommittee != null) {
return m_bestCommittee.chunkSize();
}
else return 0;
}
/**
* Get the number of members in the best committee
*/
public int getBestCommitteeSize() {
if (m_bestCommittee != null) {
return m_bestCommittee.committeeSize();
}
else return 0;
}
/**
* Get the best committee's error on the validation data
*/
public double getBestCommitteeErrorEstimate() {
if (m_bestCommittee != null) {
try {
return m_bestCommittee.validationError() * 100.0;
} catch (Exception e) {
System.err.println(e.getMessage());
return 100.0;
}
}
else return 100.0;
}
/**
* Get the best committee's log likelihood on the validation data
*/
public double getBestCommitteeLLEstimate() {
if (m_bestCommittee != null) {
try {
return m_bestCommittee.logLikelihood();
} catch (Exception e) {
System.err.println(e.getMessage());
return Double.MAX_VALUE;
}
}
else return Double.MAX_VALUE;
}
/**
* Returns description of the boosted classifier.
*
* @return description of the boosted classifier as a string
*/
public String toString() {
if (m_bestCommittee != null) {
return m_bestCommittee.toString();
} else {
if ((m_validationSetChanged || m_zeroR == null) && m_validationSet != null
&& m_validationSet.numInstances() > 0) {
m_zeroR = new ZeroR();
try {
m_zeroR.buildClassifier(m_validationSet);
} catch (Exception e) {}
m_validationSetChanged = false;
}
if (m_zeroR != null) {
return ("RacedIncrementalLogitBoost: insufficient data to build model, resorting to ZeroR:\n\n"
+ m_zeroR.toString());
}
else return ("RacedIncrementalLogitBoost: no model built yet.");
}
}
/**
* Main method for this class.
*/
public static void main(String[] argv) {
try {
System.out.println(Evaluation.evaluateModel(new RacedIncrementalLogitBoost(), argv));
} catch (Exception e) {
e.printStackTrace();
System.err.println(e.getMessage());
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -