📄 multiclasswrapmh.java
字号:
package jboost.booster;import jboost.controller.Configuration;import jboost.examples.Label;import jboost.monitor.Monitor;import jboost.NotSupportedException;/** * This wrapper reduces a multiclass problem into a binary problem. * The booster that it is wrapped around can thus be binary and need * no knowledge of the surrounding wrapper. * * The multiclass problem is reduce to a binary problem using a * "one-against-all" reduction. This reduction maps a multiclass * example to k (the # of classes) binary examples. Namely, example * i gets mapped to i*k, i*k + 1, ... , (i+1)*k - 1. Each MultiBag * consists of k underlying binary bags. Each prediction consists of k * underlying predictions. * * @author Rob Schapire (rewritten by Aaron Arvey) * @version $Header: /cvsroot/jboost/jboost/src/jboost/booster/MulticlassWrapMH.java,v 1.5 2007/10/13 04:32:28 aarvey Exp $ */class MulticlassWrapMH extends AbstractBooster { /** The underlying m_booster */ private AbstractBooster m_booster; /** The number of labels */ private int m_numLabels; /** The number of labels */ private boolean m_isMultiLabel; /** * constructor * @param booster associated booster * @param k total number of m_labels */ MulticlassWrapMH(AbstractBooster booster, int numLabels, boolean isMultiLabel) { m_booster = booster; m_numLabels = numLabels; m_isMultiLabel = isMultiLabel; } public String toString() { if (m_booster==null) { String msg = "MulticlassWrapMH.toString: m_booster is null"; if (Monitor.logLevel>3) { Monitor.log(msg); } System.err.println(msg); } return ("MulticlassWrapMH. # of classes = " + m_numLabels + ".\nUnderlying m_booster:\n" + m_booster); } public void addExample(int index, Label l) { addExample(index, l, 1.0); } /** * @see jboost.booster.Booster#addExample(int, jboost.examples.Label, double) */ public void addExample(int index, Label label, double weight) { int s = index * m_numLabels; for (int j = 0; j < m_numLabels; j++) { m_booster.addExample(s+j, new Label(label.getMultiValue(j) ? 1 : 0), weight); } } public void finalizeData() { m_booster.finalizeData(); } public void clear() { m_booster.clear(); } public Bag newBag() { return new MultiBag(); } /** * */ public double[][] getWeights() { int numExamples = m_booster.getNumExamples() / m_numLabels; double[][] r= new double[numExamples][m_numLabels]; double[][] weights = m_booster.getWeights(); for (int i= 0; i < numExamples; i++) for (int j=0; j < m_numLabels; j++) r[i][j]= weights[i*m_numLabels+j][0]; return r; } /** * */ public double[][] getPotentials() { int numExamples = m_booster.getNumExamples() / m_numLabels; double[][] r= new double[numExamples][m_numLabels]; double[][] potentials = m_booster.getPotentials(); for (int i= 0; i < numExamples; i++) for (int j=0; j < m_numLabels; j++) r[i][j]= potentials[i*m_numLabels+j][0]; return r; } /** * * */ public String getParamString() { return m_booster.getParamString(); } /** * According to Rob: "This method should never be called. We need * to rearrange hierarchy so that the multiclass booster doesn't * extend the abstract booster" * * According to Aaron: It seems that multiclass booster should * certainly extend AbstractBooster. MulticlassWrap becomes the * booster and merely insulates the true booster via an Adaptor * design pattern. * * Not sure what context this function would be inappropriate. */ public double calculateWeight(double margin) { return m_booster.calculateWeight(margin); } public void update(Prediction[] preds, int[][] index) { int num_preds = preds.length; Prediction[] upreds = new Prediction[num_preds * m_numLabels]; int[][] uindex = new int[index.length * m_numLabels][]; int i, j, t, k; // create array of predictions to pass to underlying booster t = 0; for (i = 0; i < num_preds; i++) { for (j = 0; j < m_numLabels; j++) { upreds[t] = ((MultiPrediction) preds[i]).preds[j]; t++; } } // create array of indices to pass to underlying booster t = 0; for (i = 0; i < index.length; i++) { for (j = 0; j < m_numLabels; j++) { uindex[t] = new int[index[i].length]; for (k = 0; k < index[i].length; k++) uindex[t][k] = index[i][k] * m_numLabels + j; t++; } } m_booster.update(upreds, uindex); } /** * computes theoretical bound as (m_numLabels/2) * theoretical * bound for underlying booster. This computation may not be * correct in all cases. */ public double getTheoryBound() { return 0.5 * m_numLabels * m_booster.getTheoryBound(); } public double getTotalWeight() { return m_booster.getTotalWeight(); } /** * returns the margin values of all of the "examples" used to * train the underlying booster. */ public double[][] getMargins() { return m_booster.getMargins(); } /** * Returns the predictions associated with a list of bags representing a * partition of the data. */ public Prediction[] getPredictions(Bag[] b, int[][] exampleIndex) { Bag[] ubags = new Bag[b.length * m_numLabels]; for (int i = 0; i < b.length; i++) for (int j = 0; j < m_numLabels; j++) ubags[i * m_numLabels + j] = ((MultiBag) b[i]).bags[j]; Prediction[] upreds = m_booster.getPredictions(ubags, exampleIndex); Prediction[] preds = new Prediction[b.length]; for (int i = 0; i < b.length; i++) { preds[i] = new MultiPrediction(); for (int j = 0; j < m_numLabels; j++) ((MultiPrediction) preds[i]).preds[j] = upreds[i * m_numLabels + j]; } return preds; } public Prediction[] getPredictions(Bag[] b) { System.err.println("Obsolete interface!"); Prediction[] nothing = new Prediction[1]; return nothing; } public double getLoss(Bag[] b) { Bag[] ubags = new Bag[b.length * m_numLabels]; for (int i = 0; i < b.length; i++) for (int j = 0; j < m_numLabels; j++) ubags[i * m_numLabels + j] = ((MultiBag) b[i]).bags[j]; return m_booster.getLoss(ubags); } /** * This is the bag class associated with this booster. Each bag * is composed of an array of bags from the underlying booster, * one for each class. */ class MultiBag extends Bag { private Bag[] bags; // underlying bags private MultiBag() { bags = new Bag[m_numLabels]; for (int j = 0; j < m_numLabels; j++) bags[j] = m_booster.newBag(); } public String toString() { String s = "MultiBag.\n"; for (int j = 0; j < m_numLabels; j++) s += "bag " + j + ":\n" + bags[j]; return s; } public void reset() { for (int j = 0; j < m_numLabels; j++) bags[j].reset(); } public boolean isWeightless() { for (int j = 0; j < m_numLabels; j++) if (! bags[j].isWeightless() ) { return false; } return true; } public void addExample(int index) { int s = index * m_numLabels; for (int j = 0; j < m_numLabels; j++) bags[j].addExample(s + j); } public void subtractExample(int index) { int s = index * m_numLabels; for (int j = 0; j < m_numLabels; j++) bags[j].subtractExample(s + j); } public void addExampleList(int[] l) { int i; int[] s = new int[l.length]; for (i = 0; i < l.length; i++) s[i] = l[i] * m_numLabels; bags[0].addExampleList(s); for (int j = 1; j < m_numLabels; j++) { for (i = 0; i < l.length; i++) s[i]++; bags[j].addExampleList(s); } } public void subtractExampleList(int[] l) { int i; int[] s = new int[l.length]; for (i = 0; i < l.length; i++)
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -