⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 multiclasswrapmh.java

📁 Boosting算法软件包
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
		s[i] = l[i] * m_numLabels;	    bags[0].subtractExampleList(s);	    for (int j = 1; j < m_numLabels; j++) {		for (i = 0; i < l.length; i++)		    s[i]++;		bags[j].subtractExampleList(s);	    }	}    	public void addBag(Bag b) {	    MultiBag other = (MultiBag) b;	    for (int j = 0; j < m_numLabels; j++)		bags[j].addBag(other.bags[j]);	}    	public void subtractBag(Bag b) {	    MultiBag other = (MultiBag) b;	    for (int j = 0; j < m_numLabels; j++)		bags[j].subtractBag(other.bags[j]);	}    	public void copyBag(Bag b) {	    MultiBag other = (MultiBag) b;	    for (int j = 0; j < m_numLabels; j++)		bags[j].copyBag(other.bags[j]);	}    	public void refresh(int index) {	    int s = index * m_numLabels;	    for (int j = 0; j < m_numLabels; j++)		bags[j].refresh(s + j);	}    	public void refreshList(int[] l) {	    int i;	    int[] s = new int[l.length];	    for (i = 0; i < l.length; i++)		s[i] = l[i] * m_numLabels;	    bags[0].refreshList(s);	    for (int j = 1; j < m_numLabels; j++) {		for (i = 0; i < l.length; i++)		    s[i]++;		bags[j].refreshList(s);	    }	}    	/**	 * Computes the loss for this bag.  This loss is only meaningful for	 * additive losses.	 */	public double getLoss() {	    double loss = 0.0;	    for (int j = 0; j < m_numLabels; j++)		loss += bags[j].getLoss();	    return loss;	}    	public double getLoss(int s)	    throws jboost.NotSupportedException {	    double loss = 0.0;	    for (int j = 0; j < m_numLabels; j++)		loss += bags[j].getLoss(s);	    return loss;	}    }      /**     * This is the prediction class associated with this booster.     * Each prediction is composed of an array of predictions from the     * underlying booster, one for each class.       */    class MultiPrediction extends Prediction {	/**	 * The predictions made.  Has same length as the number of classes.	 */	private Prediction[] preds;    	/**	 * Constructor.	 */	private MultiPrediction() {	    preds = new Prediction[m_numLabels];	}    	public Object clone() {	    MultiPrediction newpred = new MultiPrediction();      	    for (int j = 0; j < m_numLabels; j++) {		newpred.preds[j] = (Prediction) preds[j].clone();	    }	    return newpred;	}    	public Prediction add(Prediction p) {	    for (int j = 0; j < m_numLabels; j++) {		preds[j].add(((MultiPrediction) p).preds[j]);	    }	    return this;	}    	public Prediction scale(double w) {	    for (int j = 0; j < m_numLabels; j++)		preds[j].scale(w);	    return this;	}    	public Prediction add(double w, Prediction p) {	    for (int j = 0; j < m_numLabels; j++) {		preds[j].add(w, ((MultiPrediction) p).preds[j]);	    }	    return this;	}    	public double[] getMarginsSingleLabel(Label l) {	    int maxClass = -1;	    double maxScore = Double.MIN_VALUE;	    int thisClass = -1;	    double thisScore = 0;	    for (int j = 0; j < m_numLabels; j++) {		double predScore = preds[j].getClassScores()[1];		if (l.getMultiValue(j)){		    thisClass = j;		    thisScore = predScore;		} else if (maxScore < predScore) {		    maxScore = predScore;		    maxClass = j;		}	    }	    double[] ret = new double[1];	    ret[0] = thisScore - maxScore;	    return ret;	}		public double[] getMarginsMultiLabel(Label l) {	    double[] ret = new double[m_numLabels];	    for (int j = 0; j < m_numLabels; j++) {		ret[j] = preds[j].getMargins(new Label(l.getMultiValue(j) ?						       1 : 0))[0];	    }	    return ret;	}	public double[] getMargins(Label l) {	    if (m_isMultiLabel) 		return getMarginsMultiLabel(l);	    return getMarginsSingleLabel(l);	}    	public double[] getClassScores() {	    double[] scores = new double[m_numLabels];	    double[] uscore;      	    for (int j = 0; j < m_numLabels; j++) {		uscore = preds[j].getClassScores();		scores[j] = uscore[1];	    }	    return scores;	}    	/**	 * Check to see if this MultiPrediction is the same as the other	 * @param other	 * @return true if all the Predictions of this object are equal to this other's	 */	public boolean equals(Prediction p) {	    boolean retval= true;	    MultiPrediction other= (MultiPrediction) p;	    for (int k=0; k < m_numLabels; k++) {		if (!preds[k].equals(other.preds[k])) {		    retval= false;		}	    }	    return retval;	}    	public String toString() {	    String s = "MultiPrediction.\n";	    for (int j = 0; j < m_numLabels; j++)		s += "prediction " + j + ": " + preds[j] + "\n";	    return s;	}    	public String shortText() {	    String s = "[,"+preds[0];	    for (int j = 0; j < m_numLabels; j++)		s += ","+preds[j];	    return s+"]";	}    	public String cPreamble() {	    String code = "";      	    code +=  "typedef double Prediction_t[" + m_numLabels + "];\n";      	    code +=  "#define reset_pred()  { \\\n";	    for (int i = 0; i < m_numLabels; i++)		code += "        p["+i+"] = 0.0; \\\n";	    code += "     }\n";      	    code +=  "#define add_pred(";	    for (int i = 0; i < m_numLabels; i++)		code += (i == 0 ? "" : ",") + "X" + i;	    code += ") { \\\n";	    for (int i = 0; i < m_numLabels; i++)		code += "        p["+i+"] += X"+i+"; \\\n";	    code += "     }\n";      	    code += "#define finalize_pred() \\\n";	    code += "        (r ? ( \\\n";	    for (int i = 0; i < m_numLabels; i++)		code += "                r["+i+"] = p["+i+"], \\\n";	    code += "                p[0])       : p[0])\n";      	    return code;	}    	public String javaPreamble() {	    String code = "";      	    code += ""		+ "  static private double[] p = new double[" + m_numLabels + "];\n"		+ "  static private void reset_pred() {\n"		+ "    java.util.Arrays.fill(p, 0.0);\n"		+ "  }\n"		+ "  static private void add_pred(";	    for (int i = 0; i < m_numLabels; i++)		code += (i == 0 ? "" : ",") + "double x" + i;	    code += ") {\n";	    for (int i = 0; i < m_numLabels; i++)		code += "    p["+i+"] += x"+i+";\n";	    code += ""		+ "  }\n"		+ "  static private double[] finalize_pred() {\n"		+ "    return (double[]) p.clone();\n"		+ "  }\n";	    return code;	}    	public double[] toCodeArray() {	    return getClassScores();	}    }      /** a main for testing */    public static void main(String[] argv) {	try {	    AbstractBooster ada =		new DebugWrap(new MulticlassWrapMH(new DebugWrap(new AdaBoost(0.0)), 2, true));      	    for(int i=0; i< 10; i++) {		ada.addExample(i,new Label(i % 2));	    }      	    ada.finalizeData();	    if(Monitor.logLevel>3) Monitor.log(ada);      	    int[] elements = {1,2,6,3,4};	    String s="\n Generating a bag with elements:\n" +elements[0];	    for(int i=1; i<elements.length; i++) s+=", " + elements[i];	    if(Monitor.logLevel>3) Monitor.log(s);	    Bag bag = ada.newBag(elements);	    if(Monitor.logLevel>3) Monitor.log(bag);      	    Prediction[] p;	    p = ada.getPredictions(new Bag[] {bag}); // calc optimal prediction for this bag	    if(Monitor.logLevel>3) Monitor.log("best prediction for this bag is " + 					       p[0].getClassScores()[1]);      	    int[][] exampleList = new int[1][];	    exampleList[0] = elements;	    ada.update(p,exampleList);	    if(Monitor.logLevel>3) Monitor.log("Adaboost after updating the m_weights is");	    if(Monitor.logLevel>3) Monitor.log(ada);	    bag.refreshList(elements);	    if(Monitor.logLevel>3) Monitor.log("refreshed bag: "+bag);	    if(Monitor.logLevel>3) Monitor.log("now best prediction for this bag is "+					       (ada.getPredictions(new Bag[] {bag})[0]).getClassScores()[1]);      	}	catch(Exception e) {	    if(Monitor.logLevel>3) Monitor.log(e.getMessage());	    e.printStackTrace();	}            }      /* (non-Javadoc)     * @see jboost.booster.Booster#init(jboost.controller.Configuration)     */    public void init(Configuration config) {	// TODO Auto-generated method stub        }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -