📄 bayesianboosting.java
字号:
public IOObject[] apply() throws OperatorException {
// Reads the input example set and initiales its weights.
ExampleSet exampleSet = (ExampleSet) this.getInput(ExampleSet.class);
// Read start model if present.
this.readOptionalParameters();
// Check if label is present and fits the learning task
if (exampleSet.getLabel() == null) {
throw new UserError(this, 105);
}
if ((this.getParameterAsBoolean(USE_DISTRIBUTION) == true) && (exampleSet.getLabel().getValues().size() != 2)) {
throw new UserError(this, 118, new Object[] { exampleSet.getLabel(), new Integer(exampleSet.getLabel().getValues().size()), new Integer(2) });
}
this.prepareWeights(exampleSet);
final double holdoutSet = this.getParameterAsDouble(VALIDATION_SET);
final boolean useValidationSet = (holdoutSet > 0 && holdoutSet < 1);
LogService.logMessage(useValidationSet ? "Using external validation set." : "No external validation set for measuring performance.", LogService.STATUS);
if (useValidationSet == true) {
exampleSet = new SplittedExampleSet(exampleSet, 1 - holdoutSet);
}
Model model = this.trainBoostingModel(exampleSet, useValidationSet);
// If the parameter for storing the model to file is set, then try to store it:
String modelFile = this.getParameterAsString(MODEL_FILE);
try {
if (modelFile != null) {
model.writeModel(getExperiment().resolveFileName(modelFile));
}
}
catch (IOException e) {
throw new UserError(this, e, 303, new Object[] {modelFile, e.getMessage() });
}
return new IOObject[] { model };
}
/** Helper method applying the start model and adding priors and model info collection accordingly */
private void applyPriorModel(ExampleSet trainingSet, Vector modelInfo, double[] classPriors)
throws OperatorException
{
// If the input contains a model already, initialise the example weights.
if (this.startModel != null) {
this.startModel.createPredictedLabel(trainingSet);
this.startModel.apply(trainingSet);
// Initial values and the input model are stored in the output model.
WeightedPerformanceMeasures wp = new WeightedPerformanceMeasures(trainingSet);
classPriors = wp.getLabelPriors();
wp.reweightExamples(trainingSet, this.getParameterAsBoolean(EQUALLY_PROB_LABELS));
modelInfo.add(new Object[] { this.startModel, wp.createBiasMatrix() });
}
}
/** Main method for training the ensemble classifier */
private BayBoostModel trainBoostingModel(ExampleSet trainingSet, boolean useValidationSet)
throws OperatorException
{
// for priors, models and their probability estimates
double[] classPriors = null;
Vector modelInfo = new Vector();
// if present apply the start model first
this.applyPriorModel(trainingSet, modelInfo, classPriors);
// check whether to use the complete training set for training
final double splitRatio = this.getParameterAsDouble(INTERNAL_BOOTSTRAP);
final boolean bootstrap = ((splitRatio > 0) && (splitRatio < 1.0));
LogService.logMessage(bootstrap ? "Bootstrapping enabled." : "Bootstrapping disabled.", LogService.STATUS);
// maximum number of iterations
final int iterations = this.getParameterAsInt(NUM_OF_ITERATIONS);
L: for (int i=0; i < iterations; i++) {
this.currentIteration = i;
int size = trainingSet.getSize();
ExampleSet splittedSet = trainingSet;
if (bootstrap == true) {
splittedSet = new SplittedExampleSet(trainingSet, splitRatio);
((SplittedExampleSet) splittedSet).selectSingleSubset(0); // switch to training set
}
// train one model per iteration
Model model = this.trainModel(splittedSet);
if (bootstrap == true) {
((SplittedExampleSet) splittedSet).selectSingleSubset(1); // switch to out-of-bag set
model.apply(trainingSet); // apply model to all examples
}
// get the weighted performance value of the example set with respect to the model
WeightedPerformanceMeasures wp = new WeightedPerformanceMeasures(splittedSet);
if (classPriors == null) {
classPriors = wp.getLabelPriors();
}
if (classPriors.length == 2) {
this.debugMessage(wp);
}
{ // Stop if only one class is present/left.
int nonEmptyClasses = 0;
for (int j=0; (j<wp.getNumberOfLabels() && nonEmptyClasses<2); j++) {
double c = wp.getProbabilityLabel(j + Attribute.FIRST_CLASS_INDEX);
if (c > 0) {
nonEmptyClasses++;
}
}
if (nonEmptyClasses < 2) {
// Using the model here is just necessary to avoid a NullPointerException.
// One could use an empty model instead:
modelInfo.add(new Object[] { model, wp.createBiasMatrix() });
break L; // No more iterations!
}
}
// Reweight the example set with respect to the weighted performance values:
boolean positiveWeight = wp.reweightExamples(trainingSet, this.getParameterAsBoolean(EQUALLY_PROB_LABELS));
final double[][] biasMatrix = wp.createBiasMatrix();
// Add the new model and its weights to the collection of models:
modelInfo.add(new Object[] { model, biasMatrix });
if (useValidationSet == true && (trainingSet instanceof SplittedExampleSet)) {
// build a new composed model
Model intermediateModel =
new BayBoostModel(trainingSet.getLabel(), modelInfo, classPriors, true); // always use crisp models
double accuracy;
{ // Switch to validation subset:
SplittedExampleSet validationSet = (SplittedExampleSet) trainingSet;
validationSet.selectSingleSubset(1);
// apply it to the validation set and count misclassifications
intermediateModel.apply(validationSet);
ExampleReader reader = validationSet.getExampleReader();
int errors = 0;
while (reader.hasNext()) {
Example example = (Example) reader.next();
if (example.getLabel() != example.getPredictedLabel()) {
errors++;
}
}
accuracy = 1 - ((double) errors) / validationSet.getSize();
// switch back to training set:
validationSet.selectSingleSubset(0);
}
if (this.performance >= accuracy - MIN_ADVANTAGE) {
LogService.logMessage("Discard model because of low advantage on validation set.", LogService.STATUS);
modelInfo.remove(modelInfo.size() - 1);
break L;
}
else this.performance = accuracy;
}
else if (this.isModelUseful(biasMatrix) == false) {
// If the model is not considered to be useful (low advantage) then discard it and stop.
LogService.logMessage("Discard model because of low advantage on training data.", LogService.STATUS);
modelInfo.remove(modelInfo.size() - 1);
break L;
}
// Stop if weight is null, because all examples have been explained "deterministically"!
if (!positiveWeight) {
break L;
}
}
// Build a Model object. Last parameter is "crispPredictions", so invert "use distribution".
return new BayBoostModel(trainingSet.getLabel(), modelInfo, classPriors, ! this.getParameterAsBoolean(USE_DISTRIBUTION));
}
private void debugMessage(WeightedPerformanceMeasures wp) {
String message = "\nModel learned - training performance of base learner:" +
"\nTPR: " + wp.getProbability(0, 0) +
" FPR: " + wp.getProbability(1, 0) +
" | Positively predicted: " + (wp.getProbability(1, 0) + wp.getProbability(0, 0)) +
"\nFNR: " + wp.getProbability(0, 1) +
" TNR: " + wp.getProbability(1, 1) +
" | Negatively predicted: " + (wp.getProbability(0, 1) + wp.getProbability(1, 1)) +
"\nPositively labelled: " + (wp.getProbability(0, 0) + wp.getProbability(0, 1)) +
"\nNegatively labelled: " + (wp.getProbability(1,0) + wp.getProbability(1,1));
LogService.logMessage(message, LogService.STATUS);
}
/**
* Helper method to decide whether a model improves the training error enough to be considered.
* @param biasMatrix the bias matrix as returned by the getter of the WeightedPerformance class
* @return <code>true</code> iff the advantage is high enough to consider the model to be useful
*/
private boolean isModelUseful(double[][] biasMatrix) {
for (int row=0; row<biasMatrix.length; row++) {
double[] current = biasMatrix[row];
for (int col=0; col<current.length; col++) {
if (Math.abs(current[col] - 1) > MIN_ADVANTAGE)
return true;
}
}
return false;
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -