📄 brownboost.java
字号:
// reverse alpha search direction if(sign(vars.B) != sign(alpha_step)) alpha_step /= -2; if(Math.abs(vars.B) < CORR_EPS) break; } /* System.out.println("(alpha:" + alpha + ", t:" + t + ") is, gamma + pot_diff = |" + vars.B + "| + |" + vars.E + "| = " + (Math.abs(vars.B) + Math.abs(vars.E))); */ // reverse t search direction if(sign(vars.E) != sign(t_step)) t_step /= -2; } // The bisection (binary search) alpha and t double bs_alpha = alpha; double bs_t = t; alpha = bs_alpha; t = bs_t; m_oldS = m_s; m_s -= t; System.out.format("\tBrownBoost: alpha=%.4f, t=%.4f, time left=%.4f, " + "potential=%.4f\n", alpha, t, m_s, vars.Potential); if(t<0) { System.err.println("\nERROR: The value of t: " + t); System.err.println("ERROR: Bad solution for t<0"); m_s = m_oldS; return(0.0); } //System.out.println("\ns: " + m_s); return alpha; } public String getParamString() { String ret = String.format("BrownBoost r=%.4f s=%.4f ", m_c, m_s); return ret; } /** * Update the examples, m_margins, and m_weights using the * brown boost update rule. When m_s reaches * the stopping criterion of m_s < 0, this update returns * immediately and does not actually do any further updating. * @param predictions values for examples * @param exampleIndex the list of examples to update */ public void update(Prediction[] predictions, int[][] exampleIndex) { if (m_s < 0){ return; } for (int i= 0; i < exampleIndex.length; i++) { double p = predictions[i].getClassScores()[1]; double[] value = new double[] { -p, p }; int[] indexes = exampleIndex[i]; for (int j= 0; j < indexes.length; j++) { int example = indexes[j]; m_oldMargins[example] = m_margins[example]; m_margins[example] += value[m_labels[example]]; } } m_totalWeight = 0; m_totalPotential = 0; for (int i=0; i < m_hypPredictions.length; i++) { m_oldWeights[i]= m_weights[i]; m_weights[i] = calculateWeight(m_margins[i]); m_totalWeight += m_weights[i]; m_potentials[i] = calculatePotential(m_margins[i]); m_totalPotential += m_potentials[i]; } } /** * The theoretical bound on error is not defined for BrownBoost, thus * getTheoryBound is undefined. */ public double getTheoryBound() { return -1.0; } /** * Calls calculateWeight(margin,m_s) */ public double calculateWeight(double margin) { return calculateWeight(margin, m_s); } /** * BrownBoost uses (1-erf(-(margin+s)/c))/2 as the potential function */ public double calculatePotential(double margin) { return calculatePotential(margin, m_s); } /** * BrownBoost uses e^(-(margin+s)^2/c) as the weight calculation */ public double calculateWeight(double margin, double time_remaining) { double s = time_remaining; return Math.exp(-1 * Math.pow(margin+s,2)/m_c); } /** * BrownBoost uses (1-erf(-(margin+s)/c))/2 as the potential function */ public double calculatePotential(double margin, double time_remaining) { double s = time_remaining; return (1-erf((margin+s)/Math.sqrt(m_c)))/2; } /** * BrownBag is identical to BinaryBag, except for the method used to * derive the value of prediction (alpha in the literature). BrownBag * uses the value of alpha determined by BrownBoost and its variants. * See comments for AdaBoost.BinaryBag. * @author Aaron Arvey */ class BrownBag extends AdaBoost.BinaryBag{ protected BrownBag(int[] list) { m_w= new double[2]; reset(); this.addExampleList(list); } /** compute the binary prediction associated with this bag */ public BinaryPrediction calcPrediction(double alpha) { BinaryPrediction ret; ret = new BinaryPrediction(m_w[1] > m_w[0] ? 1.0 : -1.0 ); ret.scale(alpha); return ret; } /** compute the binary prediction associated with this bag */ public BinaryPrediction calcPrediction(double posAlpha, double negAlpha) { BinaryPrediction ret; if (m_w[1] > m_w[0]) { ret = new BinaryPrediction(1.0); ret.scale(posAlpha); } else { ret = new BinaryPrediction(-1.0); ret.scale(negAlpha); } return ret; } /** Place holder to ensure that this function is not used in BrownBoost. */ public BinaryPrediction calcPrediction(){ //System.err.println("Need to have alpha for prediction in BrownBag.calcPrediction()!"); return new BinaryPrediction(0); } /** default constructor */ protected BrownBag() { super(); } /** constructor that copies an existing bag */ protected BrownBag(BrownBag bag) { super(bag); } /** Output the weights in the bag */ public String toString() { String s= "BrownBag.\t w0=" + m_w[0] + "\t w1=" + m_w[1] + "\n"; return s; } } /* End BrownBag */ protected double getHypErr(Bag[] bags, int[][] exampleIndex) { double hyp_err = 0.0; double gamma = 0.0; double num_wrong = 0.0; double total_weight = 0.0; double potential = 0.0; int num_predictions = 0; int total = 0; // Keep track of which hypotheses had hypotheses associated with them. boolean[] examplesWithHyp = new boolean[m_margins.length]; // Get all examples that have a hypothesis associated with them for (int i= 0; i < exampleIndex.length; i++) { int[] indexes= exampleIndex[i]; BinaryPrediction pred = ((BrownBag) bags[i]).calcPrediction(1.0); total += 1; for (int j= 0; j < indexes.length; j++) { int example = indexes[j]; examplesWithHyp[example] = true; double step = getStep(m_labels[example], m_hypPredictions[example]); double weight = m_weights[example]; total_weight += weight; gamma += weight*step; if (step > 0){ // we got it right! num_predictions += 1; } else { // We got it wrong hyp_err += 1; num_predictions += 1; } } } // Get all examples that have no hypothesis associated with them. // Also get current potential. for (int i=0; i < m_margins.length; i++) { if(!examplesWithHyp[i]){ int example = i; m_hypPredictions[example] = 0; double weight = m_weights[example]; total_weight += weight; //System.out.println("m_hypPredictions[" + i + "," + example + "]: " + 0 + " (No hyp for example " + example + ")"); } potential += calculatePotential(m_margins[i]); } //updatePotential(exampleIndex); hyp_err /= num_predictions; gamma /= total_weight; potential /= m_margins.length; /* System.out.println("\tTotal number of examples: " + m_margins.length); System.out.println("\tNumber of predictions made: " + num_predictions); String out = "\tgamma (weighted correlation):" + gamma + ", potential (unweighted):" + potential + ", hyp error (unweighted):" + hyp_err; System.out.println(out); */ return gamma; } protected BinaryPrediction getZeroPred() { return new BinaryPrediction(0); } /* * Returns the predictions associated with a list of bags representing a * partition of the data. */ public Prediction[] getPredictions(Bag[] bags, int[][] exampleIndex) { boolean bagsAreWeightless = true; for (int i=0; i < bags.length; i++) { if (!bags[i].isWeightless()) { bagsAreWeightless = false; } } Prediction[] p = new BinaryPrediction[bags.length]; /* * If we have bags that are empty, then we do not process them. * If we have used up all of our time, then we can't do * any more iterations. */ if (bagsAreWeightless || m_s < 0) { for (int i=0; i < bags.length; i++) { p[i] = getZeroPred(); } return p; } // Create a prediction array to accompany the exampleIndex array m_hypPredictions = new double[m_margins.length]; for (int i=0; i < exampleIndex.length; i++){ int[] index = exampleIndex[i]; BrownBag b = (BrownBag)bags[i]; for (int j=0; j < index.length; j++){ int example = index[j]; m_hypPredictions[example] = b.calcPrediction(1.0).getClassScores()[0]; } } // we solve the constraints associated with // the BrownBoost model and obtain alpha. gamma is a good // initial guess for alpha. double gamma = getHypErr(bags, exampleIndex); if (m_isCostSensitive) { System.out.println("Solving positive example constraints"); double posAlpha = solve_constraints(gamma, m_posExamples); System.out.println("Solving negative example constraints"); double negAlpha = solve_constraints(gamma, m_negExamples); for (int i= 0; i < bags.length; i++) { p[i]= ((BrownBag) bags[i]).calcPrediction(posAlpha, negAlpha); System.out.println("p[" + i + "] = " + p[i]); } } else { double alpha = solve_constraints(gamma, m_examples); for (int i= 0; i < bags.length; i++) { p[i]= ((BrownBag) bags[i]).calcPrediction(alpha); System.out.println("p[" + i + "] = " + p[i]); } } return p; } public Bag newBag(int[] list) { return new BrownBag(list); } public Bag newBag() { return new BrownBag(); } public Bag newBag(Bag bag) { return new BrownBag((BrownBag) bag); } /** * Returns the prediction associated with a bag representing a subset of the * data. */ protected Prediction getPrediction(Bag b) { return ((BrownBag) b).calcPrediction(); }} /* End BrownBoost Class */
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -