📄 adaboost.java
字号:
/* -*- mode: java; c-basic-offset: 2; indent-tabs-mode: nil -*- */package jboost.booster;import java.text.DecimalFormat;import java.text.NumberFormat;import java.util.ArrayList;import java.util.List;import jboost.controller.Configuration;import jboost.examples.Label;/** * The simplest possible implementation of a booster. confidence-rated adaboost * based on equality/inequality of m_labels * * @author Yoav Freund * @version $Header: * /proj/gene/cvs-repository/jboost/src/jboost/booster/AdaBoost.java,v * 1.2 2003/10/01 18:36:15 freund Exp $ */public class AdaBoost extends AbstractBooster { /** permanent storage for m_labels */ protected short[] m_labels; /** permanent storage for m_margins */ protected double[] m_margins; /** permanent storage for old m_margins */ protected double[] m_oldMargins; /** permanent storage for example m_weights */ protected double[] m_weights; /** permanent storage for example's old m_weights */ protected double[] m_oldWeights; /** Records the potentials. Similar to m_margins and m_weights. */ protected double[] m_potentials; /** */ protected int[] m_posExamples; /** */ protected int[] m_negExamples; /** */ protected int m_numPosExamples; /** */ protected int m_numNegExamples; /** sampling weights for the examples */ protected double[] m_sampleWeights; /** if false, then assume all sample weights are 1 */ //protected boolean m_useSampleWeights; protected double m_totalWeight; // total weight of all examples protected int m_numExamples= 0; // number of examples in training set protected double m_smooth; protected double m_epsilon; // the hedge term for the // calculation of the prediction /** temporary location for storing the examples as they are read in */ protected List m_tmpList; /** * default constructor * */ public AdaBoost() { this(0.0); } /** * Constructor which takes a smoothing factor * * @param smooth "smoothing" factor */ AdaBoost(double smooth) { m_tmpList= new ArrayList(); m_numExamples= 0; m_smooth= smooth; init(new Configuration()); } /** * @see jboost.booster.Booster#init(jboost.controller.Configuration) */ public void init(Configuration config) { m_smooth= config.getDouble(PREFIX + "smooth", 0.5); } /** * Add an example to the data set of this booster * @param index * @param label * @param weight */ public void addExample(int index, Label label, double weight) { int l= label.getSingleValue(); String failed= null; if (l == 1 || l == 0) { if (index == m_numExamples) { m_numExamples++; m_tmpList.add(new TmpData(index, (short) l, weight)); if (l==1) m_numPosExamples++; } else { failed= "AdaBoost.addExample received index " + index + ", when it expected index " + m_numExamples; } } else { failed= "Adaboost.addExample expected a label which is either 0 or 1. It received " + l; } if (failed != null) { throw new IllegalArgumentException(failed); } } /** * Add an example to the dataset * Default the weight for this example to 1 * If this method is used, then this booster will assume * that all the sample weights are 1 * @param index * @param label */ public void addExample(int index, Label label) { addExample(index, label, 1); } /** reset the booster */ public void clear() { m_labels= null; m_margins= null; m_potentials= null; m_weights= null; m_oldWeights= null; m_sampleWeights= null; m_tmpList.clear(); m_numExamples= 0; } protected void finalizeData(double defaultWeight) { m_margins= new double[m_numExamples]; m_oldMargins= new double[m_numExamples]; m_weights= new double[m_numExamples]; m_oldWeights= new double[m_numExamples]; m_potentials= new double[m_numExamples]; m_labels= new short[m_numExamples]; m_sampleWeights= new double[m_numExamples]; m_epsilon= m_smooth / m_numExamples; m_posExamples = new int[m_numPosExamples]; m_numNegExamples = m_numExamples - m_numPosExamples; m_negExamples = new int[m_numNegExamples]; int m_posIndex=0, m_negIndex=0; for (int i= 0; i < m_tmpList.size(); i++) { TmpData a= (TmpData) m_tmpList.get(i); int index= a.getIndex(); m_margins[index]= 0.0; m_weights[index]= m_oldWeights[index]= defaultWeight; m_labels[index]= a.getLabel(); if (a.getLabel()==1) m_posExamples[m_posIndex++] = index; else m_negExamples[m_negIndex++] = index; m_sampleWeights[index]= a.getWeight(); } m_totalWeight= defaultWeight*m_numExamples; m_tmpList.clear(); // free the memory } public void finalizeData() { finalizeData(1.0); } /** * Return the theoretical bound on the training error. */ public double getTheoryBound() { return m_totalWeight / m_numExamples; } /** * Returns the margin values of the training examples. */ public double[][] getMargins() { double[][] r= new double[m_numExamples][1]; for (int i= 0; i < m_numExamples; i++) r[i][0]= m_margins[i]; return r; } /** * */ public double[][] getWeights() { double[][] r= new double[m_numExamples][1]; for (int i= 0; i < m_numExamples; i++) r[i][0]= m_weights[i]; return r; } /** * */ public double[][] getPotentials() { double[][] r= new double[m_numExamples][1]; for (int i= 0; i < m_numExamples; i++) r[i][0]= m_potentials[i]; return r; } /** * */ public int getNumExamples() { return m_numExamples; } /** * */ public double getTotalWeight() { return m_totalWeight; } /** * Returns a string with all the weights, margins, etc */ public String getExampleData() { StringBuffer ret = new StringBuffer(""); ret.append(getParamString()); for (int i=0; i<m_margins.length; i++){ ret.append(String.format("[%d];[%.4f];[%.4f];[%.4f];\n", m_labels[i], m_margins[i], m_weights[i], m_potentials[i])); } return ret.toString(); } public String getParamString() { String ret = String.format("None (AdaBoost)"); return ret; } /** output AdaBoost contents as a human-readable string */ public String toString() { String s= "Adaboost. No of examples = " + m_numExamples + ", m_epsilon = " + m_epsilon; s += "\nindex\tmargin\tweight\told weight\tlabel\n"; NumberFormat f= new DecimalFormat("0.00"); for (int i= 0; i < m_numExamples; i++) { s += " " + i + " \t " + f.format(m_margins[i]) + " \t " + f.format(m_weights[i]) + " \t " + f.format(m_oldWeights[i]) + " \t" + f.format(m_sampleWeights[i]) + "\t\t" + m_labels[i] + "\n"; } return s; } public Bag newBag(int[] list) { return new BinaryBag(list); } public Bag newBag() { return new BinaryBag(); } public Bag newBag(Bag bag) { return new BinaryBag((BinaryBag) bag); } /** * Returns the prediction associated with a bag representing a subset of the * data. */ protected Prediction getPrediction(Bag b) { return ((BinaryBag) b).calcPrediction(); } /* * Returns the predictions associated with a list of bags representing a * partition of the data. */ public Prediction[] getPredictions(Bag[] b) { Prediction[] p= new BinaryPrediction[b.length]; for (int i= 0; i < b.length; i++) { p[i]= ((BinaryBag) b[i]).calcPrediction(); } return p; } /** * Returns the predictions associated with a list of bags representing a * partition of the data. AdaBoost does not use the partition in exampleIndex. */ public Prediction[] getPredictions(Bag[] bags, int[][] exampleIndex) { return getPredictions(bags); } /** * AdaBoost uses e^(-margin) as the weight calculation */ public double calculateWeight(double margin) { return Math.exp(-1 * margin); } /** * Update the examples m_margins and m_weights using the * exponential update * @param predictions values for examples * @param exampleIndex the list of examples to update */ public void update(Prediction[] predictions, int[][] exampleIndex) { // save old m_weights for (int i= 0; i < m_weights.length; i++) m_oldWeights[i]= m_weights[i]; // update m_weights and m_margins for (int i= 0; i < exampleIndex.length; i++) { double p= predictions[i].getClassScores()[1]; double[] value= new double[] { -p, p }; int[] indexes= exampleIndex[i]; for (int j= 0; j < indexes.length; j++) { int example= indexes[j]; m_margins[example] += value[m_labels[example]]; m_totalWeight -= m_weights[example]; m_weights[example]= calculateWeight(m_margins[example]); m_totalWeight += m_weights[example]; } } } /** * Defines the state of an example * Inner class used to store a list of Examples * The list is converted into the internal data structures for the * Booster by finalizeData(); */ protected class TmpData { int m_index; short m_label; double m_weight; /** * Ctor for a TmpData object * @param index * @param label * @param weight */ TmpData(int index, short label, double weight) { m_index= index; m_label= label; m_weight= weight; } /** * Get the index for this example * @return m_index */ protected int getIndex() { return m_index; } /** * Get the label for this example * @return m_label */ protected short getLabel() { return m_label; } /** * Get the weigh for this example * @return m_weight */ protected double getWeight() { return m_weight; } } /** * This is the definition of a bag for AdaBoost. The two m_labels are * internally referred to as 0 or 1. The bag maintains the total weight of * examples labeled 0 and the total weight of examples labeled 1. * This bag uses the weights and labels stored in the booster. * @author Yoav Freund */ class BinaryBag extends Bag { /** total weight for examples of each label */ protected double[] m_w; /** default constructor */ protected BinaryBag() { m_w= new double[2]; reset(); } /** constructor that copies an existing bag */ protected BinaryBag(BinaryBag bag) { m_w= new double[2]; m_w[0]= bag.m_w[0]; m_w[1]= bag.m_w[1]; } /** a constructor that initializes a bag the given list of axamples */ protected BinaryBag(int[] list) { m_w= new double[2]; reset(); this.addExampleList(list); } public String toString() { String s= "BinaryBag.\t w0=" + m_w[0] + "\t w1=" + m_w[1] + "\n"; return s; } /** * Resets the bag to empty */ public void reset() { m_w[0]= 0.0; m_w[1]= 0.0; } /** * Checks if the bag has any weight. */ public boolean isWeightless() { double EPS = 0.0000001; if (m_w[0] < EPS && m_w[1] < EPS) { return true; } return false; } /** * Adds one example index to the bag. * Update the weights in this bag using the weights from the booster * The example index is used to find the label and weight for this example * * @param index the example that is being added to this bag. The index refers to the booster's * internal data structures */ public void addExample(int index) { m_w[m_labels[index]] += m_weights[index]*m_sampleWeights[index]; } /** * Subtracts one example index from the bag. */ public void subtractExample(int i) { if ((m_w[m_labels[i]] -= m_weights[i]) < 0.0) m_w[m_labels[i]]= 0.0; } /** * Adds the given bag to this one. It is assumed that the two bags are * disjoint and the same type. */ public void addBag(Bag b) { m_w[0] += ((BinaryBag) b).m_w[0]; m_w[1] += ((BinaryBag) b).m_w[1]; } /** * Subtracts the given bag from this one. It is assumed that the bag being * subtracted is a subset of the other one, and that the two bags are the * same type. */ public void subtractBag(Bag b) { if ((m_w[0] -= ((BinaryBag) b).m_w[0]) < 0.0) m_w[0]= 0.0; if ((m_w[1] -= ((BinaryBag) b).m_w[1]) < 0.0) m_w[1]= 0.0; } /** * Copies a given bag of the same type into this one. */ public void copyBag(Bag b) { m_w[0]= ((BinaryBag) b).m_w[0]; m_w[1]= ((BinaryBag) b).m_w[1]; } /** * Updates the weight of a single example contained in this bag. In other * words, subtracts its old weight and adds its new weight. */ public void refresh(int i) { short label= m_labels[i]; if ((m_w[label] += m_weights[i] - m_oldWeights[i]) < 0.0) m_w[label]= 0.0; } /** * Computes the loss using the following formula: 2*Sqrt(w_0 * w_1) - w_0 - w_1 * Where w_0 and w_1 are the weights of the 0 and 1 labeled examples, respectively * If w_0 and w_1 are equal, then the loss will return 0. * @return Z the result of the computation */ public double getLoss() { return 2 * Math.sqrt(m_w[0] * m_w[1]) - m_w[0] - m_w[1]; } /** * compute the optimal binary prediction associated with this bag */ public BinaryPrediction calcPrediction() { double smoothFactor= m_epsilon * m_totalWeight; double EPS = 1e-50; if (Double.isNaN(smoothFactor) || (Math.abs(m_totalWeight)<EPS) || (Math.abs(smoothFactor)<EPS) || Double.isNaN(m_totalWeight)) { return new BinaryPrediction(0.0); } BinaryPrediction p = new BinaryPrediction( m_w[1] == m_w[0] ? 0.0 : // handle case that w0=w1=0 0.5 * Math.log((m_w[1] + smoothFactor) / (m_w[0] + smoothFactor))); return p; } /** * Compare a bag to this bag and output true if they are equal * * @param other * bag to compare to this bag * @return result true if this bag has the same values as the other bag */ public boolean equals(BinaryBag other) { return (m_w[0] == other.m_w[0]) && (m_w[1] == other.m_w[1]); } } /** end of class BinaryBag */} /** end of class AdaBoost */
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -