📄 racedincrementallogitboost.java
字号:
protected Classifier[] boost(Instances data) throws Exception {
Classifier[] newModel = Classifier.makeCopies(m_Classifier, m_NumClasses);
// Create a copy of the data with the class transformed into numeric
Instances boostData = new Instances(data);
boostData.deleteWithMissingClass();
int numInstances = boostData.numInstances();
// Temporarily unset the class index
int classIndex = data.classIndex();
boostData.setClassIndex(-1);
boostData.deleteAttributeAt(classIndex);
boostData.insertAttributeAt(new Attribute("'pseudo class'"), classIndex);
boostData.setClassIndex(classIndex);
double [][] trainFs = new double [numInstances][m_NumClasses];
double [][] trainYs = new double [numInstances][m_NumClasses];
for (int j = 0; j < m_NumClasses; j++) {
for (int i = 0, k = 0; i < numInstances; i++, k++) {
while (data.instance(k).classIsMissing()) k++;
trainYs[i][j] = (data.instance(k).classValue() == j) ? 1 : 0;
}
}
// Evaluate / increment trainFs from the classifiers
for (int x = 0; x < m_models.size(); x++) {
for (int i = 0; i < numInstances; i++) {
double [] pred = new double [m_NumClasses];
double predSum = 0;
Classifier[] model = (Classifier[]) m_models.elementAt(x);
for (int j = 0; j < m_NumClasses; j++) {
pred[j] = model[j].classifyInstance(boostData.instance(i));
predSum += pred[j];
}
predSum /= m_NumClasses;
for (int j = 0; j < m_NumClasses; j++) {
trainFs[i][j] += (pred[j] - predSum) * (m_NumClasses-1)
/ m_NumClasses;
}
}
}
for (int j = 0; j < m_NumClasses; j++) {
// Set instance pseudoclass and weights
for (int i = 0; i < numInstances; i++) {
double p = RtoP(trainFs[i], j);
Instance current = boostData.instance(i);
double z, actual = trainYs[i][j];
if (actual == 1) {
z = 1.0 / p;
if (z > Z_MAX) { // threshold
z = Z_MAX;
}
} else if (actual == 0) {
z = -1.0 / (1.0 - p);
if (z < -Z_MAX) { // threshold
z = -Z_MAX;
}
} else {
z = (actual - p) / (p * (1 - p));
}
double w = (actual - p) / z;
current.setValue(classIndex, z);
current.setWeight(numInstances * w);
}
Instances trainData = boostData;
if (m_UseResampling) {
double[] weights = new double[boostData.numInstances()];
for (int kk = 0; kk < weights.length; kk++) {
weights[kk] = boostData.instance(kk).weight();
}
trainData = boostData.resampleWithWeights(m_RandomInstance,
weights);
}
// Build the classifier
newModel[j].buildClassifier(trainData);
}
return newModel;
}
/* outputs description of the committee */
public String toString() {
StringBuffer text = new StringBuffer();
text.append("RacedIncrementalLogitBoost: Best committee on validation data\n");
text.append("Base classifiers: \n");
for (int i = 0; i < m_models.size(); i++) {
text.append("\nModel "+(i+1));
Classifier[] cModels = (Classifier[]) m_models.elementAt(i);
for (int j = 0; j < m_NumClasses; j++) {
text.append("\n\tClass " + (j + 1)
+ " (" + m_ClassAttribute.name()
+ "=" + m_ClassAttribute.value(j) + ")\n\n"
+ cModels[j].toString() + "\n");
}
}
text.append("Number of models: " +
m_models.size() + "\n");
text.append("Chunk size per model: " + m_chunkSize + "\n");
return text.toString();
}
}
/**
* Builds the classifier.
*
* @param instances the instances to train the classifier with
* @exception Exception if something goes wrong
*/
public void buildClassifier(Instances data) throws Exception {
m_RandomInstance = new Random(m_Seed);
Instances boostData;
int classIndex = data.classIndex();
if (data.classAttribute().isNumeric()) {
throw new Exception("LogitBoost can't handle a numeric class!");
}
if (m_Classifier == null) {
throw new Exception("A base classifier has not been specified!");
}
if (!(m_Classifier instanceof WeightedInstancesHandler) &&
!m_UseResampling) {
m_UseResampling = true;
}
if (data.checkForStringAttributes()) {
throw new Exception("Can't handle string attributes!");
}
m_NumClasses = data.numClasses();
m_ClassAttribute = data.classAttribute();
// Create a copy of the data with the class transformed into numeric
boostData = new Instances(data);
boostData.deleteWithMissingClass();
// Temporarily unset the class index
boostData.setClassIndex(-1);
boostData.deleteAttributeAt(classIndex);
boostData.insertAttributeAt(new Attribute("'pseudo class'"), classIndex);
boostData.setClassIndex(classIndex);
m_NumericClassData = new Instances(boostData, 0);
data = new Instances(data);
data.randomize(m_RandomInstance);
// create the committees
int cSize = m_minChunkSize;
m_committees = new FastVector();
while (cSize <= m_maxChunkSize) {
m_committees.addElement(new Committee(cSize));
m_maxBatchSizeRequired = cSize;
cSize *= 2;
}
// set up for consumption
m_validationSet = new Instances(data, m_validationChunkSize);
m_currentSet = new Instances(data, m_maxBatchSizeRequired);
m_bestCommittee = null;
m_numInstancesConsumed = 0;
// start eating what we've been given
for (int i=0; i<data.numInstances(); i++) updateClassifier(data.instance(i));
}
/**
* Updates the classifier.
*
* @param instance the next instance in the stream of training data
* @exception Exception if something goes wrong
*/
public void updateClassifier(Instance instance) throws Exception {
m_numInstancesConsumed++;
if (m_validationSet.numInstances() < m_validationChunkSize) {
m_validationSet.add(instance);
m_validationSetChanged = true;
} else {
m_currentSet.add(instance);
boolean hasChanged = false;
// update each committee
for (int i=0; i<m_committees.size(); i++) {
Committee c = (Committee) m_committees.elementAt(i);
if (c.update()) {
hasChanged = true;
if (m_PruningType == PRUNETYPE_LOGLIKELIHOOD) {
double oldLL = c.logLikelihood();
double newLL = c.logLikelihoodAfter();
if (newLL >= oldLL && c.committeeSize() > 1) {
c.pruneLastModel();
if (m_Debug) System.out.println("Pruning " + c.chunkSize()+ " committee (" +
oldLL + " < " + newLL + ")");
} else c.keepLastModel();
} else c.keepLastModel(); // no pruning
}
}
if (hasChanged) {
if (m_Debug) System.out.println("After consuming " + m_numInstancesConsumed
+ " instances... (" + m_validationSet.numInstances()
+ " + " + m_currentSet.numInstances()
+ " instances currently in memory)");
// find best committee
double lowestError = 1.0;
for (int i=0; i<m_committees.size(); i++) {
Committee c = (Committee) m_committees.elementAt(i);
if (c.committeeSize() > 0) {
double err = c.validationError();
double ll = c.logLikelihood();
if (m_Debug) System.out.println("Chunk size " + c.chunkSize() + " with "
+ c.committeeSize() + " models, has validation error of "
+ err + ", log likelihood of " + ll);
if (err < lowestError) {
lowestError = err;
m_bestCommittee = c;
}
}
}
}
if (m_currentSet.numInstances() >= m_maxBatchSizeRequired) {
m_currentSet = new Instances(m_currentSet, m_maxBatchSizeRequired);
// reset consumation counts
for (int i=0; i<m_committees.size(); i++) {
Committee c = (Committee) m_committees.elementAt(i);
c.resetConsumed();
}
}
}
}
/**
* Convert from function responses to probabilities
*
* @param R an array containing the responses from each function
* @param j the class value of interest
* @return the probability prediction for j
*/
protected static double RtoP(double []Fs, int j)
throws Exception {
double maxF = -Double.MAX_VALUE;
for (int i = 0; i < Fs.length; i++) {
if (Fs[i] > maxF) {
maxF = Fs[i];
}
}
double sum = 0;
double[] probs = new double[Fs.length];
for (int i = 0; i < Fs.length; i++) {
probs[i] = Math.exp(Fs[i] - maxF);
sum += probs[i];
}
if (sum == 0) {
throw new Exception("Can't normalize");
}
return probs[j] / sum;
}
/**
* Computes class distribution of an instance using the best committee.
*/
public double[] distributionForInstance(Instance instance) throws Exception {
if (m_bestCommittee != null) return m_bestCommittee.distributionForInstance(instance);
else {
if (m_validationSetChanged || m_zeroR == null) {
m_zeroR = new ZeroR();
m_zeroR.buildClassifier(m_validationSet);
m_validationSetChanged = false;
}
return m_zeroR.distributionForInstance(instance);
}
}
/**
* Returns an enumeration describing the available options
*
* @return an enumeration of all the available options
*/
public Enumeration listOptions() {
Vector newVector = new Vector(9);
newVector.addElement(new Option(
"\tTurn on debugging output.",
"D", 0, "-D"));
newVector.addElement(new Option(
"\tMinimum size of chunks.\n"
+"\t(default 500)",
"C", 1, "-C <num>"));
newVector.addElement(new Option(
"\tMaximum size of chunks.\n"
+"\t(default 20000)",
"M", 1, "-M <num>"));
newVector.addElement(new Option(
"\tSize of validation set.\n"
+"\t(default 5000)",
"V", 1, "-V <num>"));
newVector.addElement(new Option(
"\tFull name of 'weak' learner to boost.\n"
+"\teg: weka.classifiers.DecisionStump",
"W", 1, "-W <learner class name>"));
newVector.addElement(new Option(
"\tCommittee pruning to perform.\n"
+"\t0=none, 1=log likelihood (default)",
"P", 1, "-P <pruning type>"));
newVector.addElement(new Option(
"\tUse resampling for boosting.",
"Q", 0, "-Q"));
newVector.addElement(new Option(
"\tSeed for resampling. (Default 1)",
"S", 1, "-S <num>"));
if ((m_Classifier != null) &&
(m_Classifier instanceof OptionHandler)) {
newVector.addElement(new Option(
"",
"", 0, "\nOptions specific to weak learner "
+ m_Classifier.getClass().getName() + ":"));
Enumeration em = ((OptionHandler)m_Classifier).listOptions();
while (em.hasMoreElements()) {
newVector.addElement(em.nextElement());
}
}
return newVector.elements();
}
/**
* Parses a given list of options. Valid options are:<p>
*
* @param options the list of options as an array of strings
* @exception Exception if an option is not supported
*/
public void setOptions(String[] options) throws Exception {
String minChunkSize = Utils.getOption('C', options);
if (minChunkSize.length() != 0) {
setMinChunkSize(Integer.parseInt(minChunkSize));
} else {
setMinChunkSize(500);
}
String maxChunkSize = Utils.getOption('M', options);
if (maxChunkSize.length() != 0) {
setMaxChunkSize(Integer.parseInt(maxChunkSize));
} else {
setMaxChunkSize(20000);
}
String validationChunkSize = Utils.getOption('V', options);
if (validationChunkSize.length() != 0) {
setValidationChunkSize(Integer.parseInt(validationChunkSize));
} else {
setValidationChunkSize(5000);
}
String pruneType = Utils.getOption('P', options);
if (pruneType.length() != 0) {
setPruningType(new SelectedTag(Integer.parseInt(pruneType), TAGS_PRUNETYPE));
} else {
setPruningType(new SelectedTag(PRUNETYPE_LOGLIKELIHOOD, TAGS_PRUNETYPE));
}
setUseResampling(Utils.getFlag('Q', options));
String seedString = Utils.getOption('S', options);
if (seedString.length() != 0) {
setSeed(Integer.parseInt(seedString));
} else {
setSeed(1);
}
setDebug(Utils.getFlag('D', options));
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -