📄 abstractbooster.java
字号:
package jboost.booster;import java.io.Serializable;import jboost.controller.Configuration;import jboost.monitor.Monitor;/** * This is the abstract definition of a booster. The booster is * responsible for maintaining, for a set of examples, the m_margins and * m_weights of these examples. Each subclass implements a specific * boosting algorithm, e.g., adaboost, brownboost, etc. The booster * also computes the loss and predictions for a partition of the data. * AbstractBooster also operates as a factory, where it generate the * appropriate Booster with the appropriate configuration. * * @author Rob Schapire (rewritten by Aaron Arvey) * @version $Header: /cvsroot/jboost/jboost/src/jboost/booster/AbstractBooster.java,v 1.6 2007/10/23 22:45:40 aarvey Exp $ */public abstract class AbstractBooster implements Booster, Serializable { protected static final String PREFIX= "booster_"; /** * Factory method to build a booster instance according to * given configuration. Uses reflection to do this. * * @param c set of options for the booster * @param num_labels the number of m_labels in the data * @param isMultiLabel true if multilabled data * @return Booster */ public static Booster getInstance(Configuration c, int num_labels, boolean isMultiLabel) throws ClassNotFoundException, InstantiationException, IllegalAccessException, Exception { AbstractBooster result = null; // Get the booster type from configuration and // create a class of that type. String boosterType= c.getString(PREFIX + "type", "jboost.booster.AdaBoost"); System.out.println("Booster type: " + boosterType); Class boosterClass = Class.forName(boosterType); result = (AbstractBooster) boosterClass.newInstance(); result.init(c); // Get the runtime of the booster (if applicable). // If the booster is a discrete iterative scheme, the // number of iterations is dealt with elsewhere. if (result instanceof jboost.booster.BrownBoost) { double eps = 0.001; double runtime = Double.parseDouble(c.getString("boostingRuntime", "0.0")); if (runtime <= eps) { String str = "Need to specify runtime for m_booster " + result + ". Runtime must be larger than " + eps + "."; Monitor.log(str); throw new Exception(str); } jboost.booster.BrownBoost brown = (jboost.booster.BrownBoost) result; brown.setRuntime(runtime); result = brown; if (result instanceof jboost.booster.YabaBoost) { double c1=0, c2=0, theta=0; double rpos=0, c1pos=0, c2pos=0, thetapos=0; double rneg=0, c1neg=0, c2neg=0, thetaneg=0; try { c1 = Double.parseDouble(c.getString("c1", "Z1.0")); c2 = Double.parseDouble(c.getString("c2", "Z1.0")); theta = Double.parseDouble(c.getString("theta", "Z0.15")); //c1 = Double.parseDouble(c.getString("pos_r", "Z1.0")); //c1 = Double.parseDouble(c.getString("pos_c1", "Z1.0")); //c2 = Double.parseDouble(c.getString("pos_c2", "Z1.0")); //theta = Double.parseDouble(c.getString("pos_theta", "Z0.15")); } catch (NumberFormatException e) { String s = "Need to supply r, c1, c2, and theta!"; System.err.println(s); throw new InstantiationException(s); } jboost.booster.YabaBoost yaba = (jboost.booster.YabaBoost) result; yaba.setParams(c1,c2,theta); yaba.setCostSensitiveParams(rpos, c1pos, c2pos, thetapos, rneg, c1neg, c2neg, thetaneg); result = yaba; } } // If we have a multilable or multiclass problem, we need to wrap it. if (num_labels > 2 || isMultiLabel) { result= new MulticlassWrapMH(result, num_labels, isMultiLabel); } // If we are debugging, then wrap in paranoia boolean paranoid= c.getBool(PREFIX + "paranoid", false); if (paranoid) { result= new DebugWrap(result); } return result; } public int getNumExamples(){ return 0; } public String getParamString() { return "No parameters defined"; } /** * Create and return a new Bag which initially contains the * elements in the list. * * @param list initial items to add to the Bag */ public Bag newBag(int[] list) { Bag bag= newBag(); bag.addExampleList(list); return bag; } /** * Clone a bag * * @param orig the bag to clone * @return new bag */ public Bag newBag(Bag orig) { Bag newbag= newBag(); newbag.copyBag(orig); return newbag; } /** * Find the best binary split for a sorted list of example indices * with given split points. * @param l an array of example indices, sorted. * @param sp an array with true in position i when a split between * positions i-1 and i should be checked * @param b0 - a bag with all points below the best split (upon return) * @param b1 - a bag with all points at or above the best split (upon return) * @return the index in l where the best split occurred (possibly * 0 if the best split puts all points on one side) */ public int findBestSplit(Bag b0, Bag b1, int[] l, boolean[] sp) { Bag[] bags= new Bag[2]; bags[0]= newBag(); // init an empty bag bags[1]= newBag(l); // init a full bag b0.reset(); b1.copyBag(bags[1]); if (l.length == 0) return 0; double bestLoss= getLoss(bags); int bestIndex= 0; double loss; for (int i= 0; i < l.length - 1; i++) { bags[1].subtractExample(l[i]); bags[0].addExample(l[i]); if (sp[i + 1]) { // if this is a potential split point if ((loss= getLoss(bags)) < bestLoss) { bestLoss= loss; bestIndex= i + 1; b0.copyBag(bags[0]); b1.copyBag(bags[1]); } } } return bestIndex; } /** * Compute the loss associated with an array of bags where small * loss is considered "better". We assume that loss is additive * for a set of bags. * * @param bags array of bags whose losses will be added up and returned * @return loss the sum of the losses for all the bags */ public double getLoss(Bag[] bags) { double loss = 0; for (int i=0; i < bags.length; i++) { loss += bags[i].getLoss(); } return loss; }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -