📄 bayboostmodel.java
字号:
/*
* YALE - Yet Another Learning Environment
* Copyright (C) 2001-2004
* Simon Fischer, Ralf Klinkenberg, Ingo Mierswa,
* Katharina Morik, Oliver Ritthoff
* Artificial Intelligence Unit
* Computer Science Department
* University of Dortmund
* 44221 Dortmund, Germany
* email: yale-team@lists.sourceforge.net
* web: http://yale.cs.uni-dortmund.de/
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License as
* published by the Free Software Foundation; either version 2 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
* USA.
*/
package edu.udo.cs.yale.operator.learner.meta;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Iterator;
import java.util.List;
import java.util.Vector;
import edu.udo.cs.yale.example.Attribute;
import edu.udo.cs.yale.example.Example;
import edu.udo.cs.yale.example.ExampleReader;
import edu.udo.cs.yale.example.ExampleSet;
import edu.udo.cs.yale.operator.OperatorException;
import edu.udo.cs.yale.operator.learner.IOModel;
import edu.udo.cs.yale.operator.learner.Model;
import edu.udo.cs.yale.tools.LogService;
import edu.udo.cs.yale.tools.Ontology;
/**
* @author scholz
*/
public class BayBoostModel extends IOModel {
public static final String ID = "YALE BayBoost Model";
private static final int FILE_MODEL = 1;
private static final int IO_MODEL = 2;
// Holds the models and their weights in array format.
// Please access with getter methods.
private List modelInfo;
// The classes priors in the training set, starting with index 0.
private double[] priors;
// Indicates if the output is a crisp or a soft classifier.
private boolean crispPredictions;
/** Needed for creation of IOModel (when loading from file).
* @param label the class label
*/
public BayBoostModel(Attribute label) {
super(label);
}
/**
* @param label the class label
* @param crispPredictions indicates whether the output is a crisp or soft classifier
*/
public BayBoostModel(Attribute label, boolean crispPredictions) {
super(label);
this.crispPredictions = crispPredictions;
}
/**
* @param label the class label
* @param modelInfo a <code>List</code> of <code>Object[2]</code>
* arrays, each entry holding a model and a <code>double[][]</code> array
* containing weights for all prediction/label combinations.
* @param priors an array of the prior probabilities of labels
* @param crispPredictions indicates whether the output is a crisp or soft classifier
*/
public BayBoostModel(Attribute label, List modelInfo, double[] priors, boolean crispPredictions) {
super(label);
this.modelInfo = modelInfo;
this.priors = priors;
this.crispPredictions = crispPredictions;
}
/** @return the static class identifier of this model */
public String getIdentifier() { return ID; }
/** Reads all models from file. */
public void readData(ObjectInputStream in) throws IOException {
Vector modelList = new Vector();
int numModels = in.readInt();
this.crispPredictions = in.readBoolean();
for (int i=0; i<numModels; i++) {
// Read model:
Model model = Model.readModel(in);
// Read weight matrix:
int rows = in.readInt();
int cols = in.readInt();
double[][] factors = new double[rows][cols];
for (int j=0; j<rows;j++) {
for (int k=0;k<cols;k++) {
factors[j][k]=in.readDouble();
}
}
// Integrate into object array format:
modelList.add(new Object[] { model, factors });
}
// read the class prior double array
double[] classPriors = new double[in.readInt()];
for (int i=0; i<classPriors.length; i++) {
classPriors[i] = in.readDouble();
}
this.modelInfo = modelList;
this.priors = classPriors;
}
/** Writes the models subsequently to the output stream. */
public void writeData(ObjectOutputStream out) throws IOException {
List modelList = this.modelInfo;
out.writeInt(modelList.size());
out.writeBoolean(this.crispPredictions);
Iterator it = modelList.iterator();
while (it.hasNext()) {
// Decompose and cast model description:
Object[] obj = (Object[]) it.next();
Model model = (Model) obj[0];
double[][] factors = (double[][]) obj[1];
// Write model:
model.writeModel(out);
// Write weight matrix:
int rows = factors.length;
int cols = (rows > 0 && factors[0] != null) ? factors[0].length : 0;
out.writeInt(rows);
out.writeInt(cols);
for (int j=0; j<rows;j++) {
for (int k=0;k<cols;k++) {
out.writeDouble(factors[j][k]);
}
}
}
// write the class prior double array
out.writeInt(this.priors.length);
for (int i=0; i<this.priors.length; i++) {
out.writeDouble(this.priors[i]);
}
}
/** @return a <code>String</code> representation of this boosting model. */
public String toString() {
String result = super.toString() + ("\nNumber of inner models: " + this.getNumberOfModels());
for (int i = 0; i < this.getNumberOfModels(); i++) {
Model model = this.getModel(i);
result += (i>0 ? "\n" : "")
// + "Weights: " + this.getFactorForModel(i, true) + ","
// + this.getFactorForModel(i, false) + " - "
+ "(Embedded model #" + i + "):" + model.toResultString();
}
return result;
}
/** @return the number of embedded models */
public int getNumberOfModels() {
return modelInfo.size();
}
/**
* Gets factors for models in the case of general nominal class labels.
* The indices are not in YALE format, so add <code>Attribute.FIRST_CLASS_INDEX</code>
* before calling this method and before reading from the returned array.
* @return a <code>double[]</code> object with the factors to be applied for each
* class if the corresponding rule yields <code>predicted</code>.
* @param modelNr the number of the model
* @param predicted the predicted label
* @return a <code>double[]</code> with one factor per class label,
* <code>Double.POSITIVE_INFINITY</code> if the rule deterministically predicts a
* value, and <code>RULE_DOES_NOT_APPLY</code> if no prediction can be made.
*
*/
private double[] getFactorsForModel(int modelNr, int predicted) {
Object[] obj = (Object[]) this.modelInfo.get(modelNr);
double[][] factor = (double[][]) obj[1];
return factor[predicted];
}
/**
* Getter method for prior class probabilities estimated as the relative frequencies
* in the training set.
* @param classIndex the index of a class starting with 0 (not the internal YALE representation!)
* @return the prior probability of the specified class
*/
private double getPriorOfClass(int classIndex) {
return this.priors[classIndex];
}
/**
* Getter method for embedded models
* @param index the number of a model part of this boost model
* @return binary or nominal decision model for the given classification index.
*/
public Model getModel(int index) {
Object[] obj = (Object[]) this.modelInfo.get(index);
return (Model) obj[0];
}
/**
* Iterates over all models and returns the class with maximum likelihood.
* @param exampleSet the set of examples to be classified
* @return the predicted class
*/
public void apply(ExampleSet exampleSet) throws OperatorException {
// Prepare an ExampleSet for each model.
ExampleSet[] eSet = new ExampleSet[this.getNumberOfModels()];
// Each model may either be a probability estimator or a crisp predictor:
boolean[] isCrispModel = new boolean[this.getNumberOfModels()];
for (int i = 0; i < this.getNumberOfModels(); i++) {
Model model = this.getModel(i);
// model.setPredictionType(PREDICT_CLASSIFICATION); // PREDICT_CONFIDENCE
eSet[i] = (ExampleSet) exampleSet.clone();
model.createPredictedLabel(eSet[i]);
model.apply(eSet[i]);
// Only for binary classification tasks probability estimates are supported.
// In this case the type of the predicted label is used to distinguish between
// crisp and soft classifiers.
Attribute predLabel = eSet[i].getPredictedLabel();
isCrispModel[i] = predLabel.isNominal();
}
// Prepare one ExampleReader per ExampleSet
ExampleReader[] reader = new ExampleReader[eSet.length];
for (int r = 0; r < reader.length; r++) {
reader[r] = eSet[r].getExampleReader();
}
// variables to keep track of training error, only used for debug messages
int errors = 0;
double probCorrect = 0;
// Apply all models:
ExampleReader originalReader = exampleSet.getExampleReader();
while (originalReader.hasNext()) {
double[] intermediateProducts = new double[this.getLabel().getValues().size()];
for (int k=0; k<intermediateProducts.length; k++) {
double pri = this.getPriorOfClass(k);
// Compute the pos/neg-ratio, take care for the deterministic case:
pri = (pri == 1) ? Double.POSITIVE_INFINITY : (pri / (1 - pri));
intermediateProducts[k] = pri; // Initialize with prior probability ratio.
}
// System.out.print("predictions: ");
boolean classKnown = false;
L: for (int k = 0; k < reader.length; k++) {
Example e = reader[k].next();
if (classKnown) {
continue L;
}
double[] biasFactors;
if (isCrispModel[k] == true) {
int predicted = ((int) e.getPredictedLabel()) - Attribute.FIRST_CLASS_INDEX;
// System.out.print(predicted + " ");
biasFactors = this.getFactorsForModel(k, predicted);
}
else {
double predicted = e.getPredictedLabel();
double[] biasFactors0 = this.getFactorsForModel(k, 0);
double[] biasFactors1 = this.getFactorsForModel(k, 1);
biasFactors = new double[biasFactors0.length];
for (int i=0; i<biasFactors.length; i++) {
if (biasFactors0[i] == WeightedPerformanceMeasures.RULE_DOES_NOT_APPLY)
biasFactors[i] = biasFactors1[i];
else if (biasFactors1[i] == WeightedPerformanceMeasures.RULE_DOES_NOT_APPLY)
biasFactors[i] = biasFactors0[i];
else biasFactors[i] = (1 - predicted) * biasFactors0[i] + predicted * biasFactors1[i];
}
}
classKnown = adjustIntermediateProducts(intermediateProducts, biasFactors);
if (classKnown) {
continue L;
}
}
// Turn bias ratio into conditional probabilities:
double probSum = 0;
double[] classProb = new double[intermediateProducts.length];
int bestIndex = 0;
for (int n=0; n<classProb.length; n++) {
// The probability Prob( C | x ) for class C given the description can
// be calculated from factor = Prob(C | x) / Prob(neg(C) | x) as
// Prob( C | x ) = factor / (1 + factor):
classProb[n] = (intermediateProducts[n] == Double.POSITIVE_INFINITY) ? 1
: intermediateProducts[n] / ( 1 + intermediateProducts[n] );
probSum += classProb[n]; // accumulate probabilities, should be 1
if (classProb[n] > classProb[bestIndex]) {
bestIndex = n;
}
}
// Normalize probabilities if the sum is not 1.
// This can happen if the subset defined by a rule does not contain all classes.
if (probSum != 1.0) {
for (int k=0; k<classProb.length; k++) {
classProb[k] /= probSum;
}
}
// Store the final prediction:
Example example = (Example) originalReader.next();
if (crispPredictions == true) {
example.setPredictedLabel(bestIndex + Attribute.FIRST_CLASS_INDEX);
}
else {
int posIndex = exampleSet.getLabel().getPositiveIndex() - Attribute.FIRST_CLASS_INDEX; // IM: Hmmm
example.setPredictedLabel(classProb[posIndex]);
}
/*
System.out.println(" -> '" + (bestIndex + Attribute.FIRST_CLASS_INDEX)
+ " (" + classProb[bestIndex] + ") ' Correct: '"
+ ((int) example.getLabel()) + "'");
*/
int correctLabel = (int) example.getLabel() - Attribute.FIRST_CLASS_INDEX;
// System.out.println(" label: " + correctLabel + " predicted: " + bestIndex + " prob: " + classProb[correctLabel]);
if (bestIndex != correctLabel) {
errors++;
}
probCorrect += classProb[correctLabel];
}
probCorrect /= ((double) exampleSet.getSize());
LogService.logMessage("< Number of models: " + this.getNumberOfModels()
+ " - Total number of errors: " + errors
+ ", prob. to predict correct label: " + probCorrect + " >", LogService.STATUS);
}
/**
* Helper method to adjust the intermediate products during model application.
* @param products the intermediate products, these values are changed by the method
* @param biasFactors the factor vector that applies for the prediction for the current example
*
* @return <code>true</code> iff the class is deterministically known after applying this method
* */
public static boolean adjustIntermediateProducts(double[] products, double biasFactors[]) {
L: for (int i=0; i<biasFactors.length; i++) {
// Change the intermediate estimates, take care about deterministic and non-applicable rules:
if (biasFactors[i] == WeightedPerformanceMeasures.RULE_DOES_NOT_APPLY) {
continue L;
}
else if ((biasFactors[i] == Double.POSITIVE_INFINITY)) {
if (products[i] != 0) {
for (int j=0; j<products.length; j++) {
products[j] = 0; // reset all probabilities to 0
}
products[i] = biasFactors[i]; // class is deterministically correct
return true; // class is known
}
else continue L; // ignore factor, class is already known to be deterministically incorrect
}
else { // the "normal" case
products[i] *= biasFactors[i];
}
}
return false;
}
/** Creates a predicted label with the given name. If name is null, the name "prediction(labelname)"
* is used. */
public Attribute createPredictedLabel(ExampleSet exampleSet, String name) {
Attribute predictedLabel = super.createPredictedLabel(exampleSet, name);
if (crispPredictions == false) {
predictedLabel.setValueType(Ontology.REAL);
}
return predictedLabel;
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -