📄 yababoost.java
字号:
package jboost.booster;import java.io.BufferedWriter;import java.io.FileWriter;import java.io.IOException;import java.io.PrintWriter;import java.text.DecimalFormat;import java.text.NumberFormat;import java.util.ArrayList;import java.util.List;import jboost.booster.BrownBoost;import jboost.booster.BrownBoost.ErfVars;import jboost.booster.BrownBoost.BrownBag;import jboost.controller.Configuration;import jboost.examples.Label;import jboost.NotSupportedException;import jboost.booster.MixedBinaryPrediction;import jboost.booster.NotNormalizedPredException;import jboost.monitor.Monitor;/** * An implementation of the YabaBoost boosting algorithm. See * Freund & Arvey 2008 for details. It may also be instructive * to check out the JBoost website for more references. * * @author Aaron Arvey */public class YabaBoost extends BrownBoost { /** * Parameter for yaba that allows for decrease of potential * by translation of potential curve. * m_c1 increase -> potential decrease -> less error tolerance * m_c1 decrease -> potential increase -> more error tolerance */ protected double m_c1; protected double m_posc1; protected double m_negc1; /** * Parameter for yaba that scales the variance (width) of the * potential. * m_c2 increase -> increase in variance */ protected double m_c2; protected double m_posc2; protected double m_negc2; /** * Parameter for yaba that specifies margin goal. This parameter * may be varied during the course of the algorithm. */ protected double m_theta; protected double m_posTheta; protected double m_negTheta; /** * This paramter cannot be changed and contains the original value * for theta. */ protected double m_origTheta; protected double m_posOrigTheta; protected double m_negOrigTheta; /** * A parameter indicating when we should stop the game. * Technically, this should be 0; however due to numerical instability * we finish the game slightly earlier then we should. This leads to * an approximation of the 0/1 loss instead of the exact 0/1 loss. */ protected static final double FINISH_GAME_NOW = 0.05; /** * A parameter indicating whether we should use confidence rated * predictions. */ protected static final boolean USE_CONFIDENCE = false; /** * A parameter indicating when we should stop the game. * Technically, this should be 0; however due to numerical instability * we finish the game slightly earlier then we should. This leads to * an approximation of the 0/1 loss instead of the exact 0/1 loss. */ protected static final double THETA_UPDATE_STEP = 0.02; /** * Default constructor just calls AdaBoost to * get everything initialized */ public YabaBoost() { super(); } public void setRuntime(double runtime) { if (runtime < FINISH_GAME_NOW) { System.err.println("Yaba runtime is too short!\nc:"+runtime+"\n"); System.exit(2); } m_c = runtime; m_s = m_c; m_initialPotential = calculatePotential(0,m_c); } public void setParams(double c1, double c2, double theta) throws Exception { double EPS = 0.01; if (c1 < EPS || c2 < EPS || theta < EPS || m_c < EPS) { throw new Exception("Yaba params are bad!\nc1:"+c1+"\nc2:" +c2+"\ntheta:"+theta+"\n"); } m_c1 = c1; m_c2 = c2; m_theta = theta; m_origTheta = theta; m_initialPotential = calculatePotential(0,m_c); } public void setCostSensitiveParams(double pc, double pc1, double pc2, double ptheta, double nc, double nc1, double nc2, double ntheta) { m_posc = pc; m_posc1 = pc1; m_posc2 = pc2; m_posTheta = ptheta; m_posOrigTheta = ptheta; m_negc = pc; m_negc1 = pc1; m_negc2 = pc2; m_negTheta = ptheta; m_negOrigTheta = ptheta; } public String surfingData() { StringBuffer ret = new StringBuffer(""); ret.append(String.format("YabaBoost Params: %.4f %.4f %4f %4f %4f\n", m_c, m_s, m_c1, m_c2, m_theta)); for (int i=0; i<m_margins.length; i++){ ret.append(String.format("%.4f\t%.4f\t%.4f\n", m_margins[i], m_weights[i], m_potentials[i])); } return ret.toString(); } public void finalizeData() { double EPS = 0.01; if (m_c1 < EPS || m_c2 < EPS || m_theta < EPS || m_c < EPS) { System.err.println("Yaba params are bad!\nc:"+m_c+"\nc1:" +m_c1+"\nc2:"+m_c2+"\ntheta:"+m_theta+"\n"); System.exit(2); } m_initialPotential = calculatePotential(0,m_c); String yabaout = "\nYabaBoost:\n" + "\t m_c: " + m_c + "\n" + "\t m_c1: " + m_c1 + "\n" + "\t m_c2: " + m_c2 + "\n" + "\t m_theta: " + m_theta + "\n" + "\t initial potential: " + m_initialPotential + "\n"; System.out.println(yabaout); Monitor.log(yabaout); super.finalizeData(); } /** * @see jboost.booster.BrownBoost#calc_constraints(int[][], double, double) */ protected ErfVars calc_constraints(double alpha, double t, int[] examples) { ErfVars vars = new ErfVars(); // Amount of time the game has been played double c = m_c; double s = m_c - m_s; // The amount of time the game will be played if t is chosen double new_time = s + t; double new_time_remaining = m_c - new_time; double orig_time = s; double orig_time_remaining = m_c - orig_time; double totalWeight = 0; double EPS = 0.0001; double margin, orig_margin, step, new_margin, new_weight, new_pot, orig_pot; int example; for (int i= 0; i < examples.length; i++) { example = examples[i]; margin = m_margins[example]; orig_margin = margin; // m_labels are 0 or 1, this moves it to in margin space +-1 step = getStep(m_labels[example], m_hypPredictions[example]); EPS = 0.0001; if (Math.abs(step) > EPS) { new_margin = (1-alpha)*margin + alpha*step; } else { new_margin = margin; } new_weight = calculateWeight(new_margin, new_time_remaining); new_pot = calculatePotential(new_margin, new_time_remaining); orig_pot = calculatePotential(orig_margin, orig_time_remaining); vars.B += new_weight*step; vars.E += orig_pot - new_pot; vars.Potential += new_pot; totalWeight += new_weight; //System.out.println("Example: " + j + ", Step: " + step + ", Margin: " + margin + ", Weight: " + wj); //System.out.println("aj: " + aj + ", dj: " + dj + ", dj^2/sd^2: " + (dj*dj/(sd*sd)) + ", bj: " + bj); //System.out.println("N(mu,sigma):" + mu + "," + sd + ", B:" + vars.B + ", E:" + vars.E); } vars.B /= totalWeight; vars.E /= examples.length; vars.Potential /= examples.length; return vars; } private void dump_everything(ErfVars v){ PrintWriter dumpfile; try { dumpfile = new PrintWriter(new BufferedWriter( new FileWriter("yababoost.dump"))); } catch (IOException e) { String msg = "YabaBoost.dump_everything Cannot output file!"; System.err.println(msg); throw new RuntimeException(msg); } System.err.println(""); System.err.println(""); System.err.println("Dumping everything"); dumpfile.println("% Avg Diff Potential, Gamma correlation"); dumpfile.println("% " + v.E + ", " + v.B); dumpfile.println("% Total Time, time reaming, c1, c2, theta, nada"); dumpfile.println("" + m_c + ", " + m_s + ", " + m_c1 + ", " + m_c2 + ", " + m_theta + ", 0.00"); dumpfile.println("% Example, Margin, Weight, label, hyp, step "); for (int i= 0; i < m_hypPredictions.length; i++) { int example = i; double margin = m_margins[example]; double weight = m_weights[example]; double orig_margin = margin; double step = getStep(m_labels[example], m_hypPredictions[example]); dumpfile.println(example + ", " + margin + ", " + weight + ", " + getLabel(m_labels[example]) + ", " + m_hypPredictions[example] + ", " + step); } dumpfile.flush(); dumpfile.close(); } /** * This is the heart of the YABA booster. For details of what the * constraints are, see Freund and Arvey 2008. The time remaining * in the game is updated in this function. Basic algorithm is * bisection. Eventually will implement Newton Raphson updates. * See inline comments. * @param exampleIndex - Used to iterate over all examples with hypotheses. * @return alpha - An appropriate value of alpha that satisfies constraints. */ protected double solve_constraints(double hyp_err, int[] examples) { /* * If the game has a small amount of time remaining, quit now. * The last little bit of the game is very numerically * instable. As such, we avoid playing. */ if( m_s < FINISH_GAME_NOW){ m_s = -1; return 0; } // how much time the game has been played double s = m_c - m_s; // Used to capture the variables ErfVars vars = new ErfVars(); double new_alpha = 0.0; double new_t = 0.0; /* * We start with alpha and t at one of the extremal values close to 0 or 1. * alpha \in [0,1] * t \in [0,m_s] * We then either increment or decrement each of the values during a binary * search (bisection) algorithm. */ double t_step = 0.1; double t=0.3; double alpha=0.1; /* * If the true t is larger than m_s, then the bisection algorithm * will continue to try to push t higher. However, t will already be too high. * When t > m_s, we have numerical instability and may even create * complex numbers while calculating the constraints. Thus, we cap t at m_s * and say that after NUM_ITERATIONS_FINISH_GAME of t>m_s, we end the game. */ int NUM_ITERATIONS_FINISH_GAME = 10; int count_t_over_s = 0; /* * Sometimes the average difference in potential will plateu * at a non-zero value. This is typically a result of bizarre * boundary conditions when the booster is doing "too well". * We use this to detect the plateu. */ double lastE = 0; double STEP_EPS = 0.0001; boolean first_iter = true; while(Math.abs(t_step) > STEP_EPS) { /* We need this for the end of the game when t_step can be * larger than the time remaining in the game. */ if(Math.abs(t_step) > m_s){ t_step = m_s/2 * sign(t_step); } t+=t_step; /* If t is large, then we cut it down to an appropriate * size. This is not a good idea. We should allow t to * become as large as it needs to be. If it is too large, * we trim it down after the fact. Or maybe we do want to * do this. If we let t>m_s then we wind up with NaN. */ if (t >= m_s) { t = m_s - 0.001; t_step = -t_step; count_t_over_s++; // if we keep going over m_s, the game is probably done
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -