📄 instrumentedalternatingtree.java
字号:
package jboost.atree;import java.util.ArrayList;import java.util.Arrays;import java.util.Iterator;import java.util.Vector;import jboost.CandidateSplit;import jboost.ComplexLearner;import jboost.NotSupportedException;import jboost.Predictor;import jboost.WritablePredictor;import jboost.booster.Bag;import jboost.booster.Booster;import jboost.booster.BrownBoost;import jboost.booster.Prediction;import jboost.booster.NormalizedPrediction;import jboost.controller.Configuration;import jboost.controller.ConfigurationException;import jboost.examples.Example;import jboost.learner.Splitter;import jboost.learner.SplitterBuilder;import jboost.monitor.Monitor;import jboost.util.ExecutorSinglet;import EDU.oswego.cs.dl.util.concurrent.CountDown;import EDU.oswego.cs.dl.util.concurrent.Executor;/** * This data structure uses Splitters from the learner package to find * the best alternating decision tree from the m_examples. This class * can easily output an AlternatingTree once it has learned from all * the m_examples. */public class InstrumentedAlternatingTree extends ComplexLearner { /** * A list that holds each {@link PredictorNode} in the tree, * ordered by the order that the learning process added them in. * The first node is the root node. */ private ArrayList m_predictors; /** * A list that holds all the {@link Splitter} nodes in the tree, * ordered by the order that the learning process added them in. * XXX: this is not true. what is this list for? */ private ArrayList m_splitters; /** The time index at which nodes are added */ private int m_index; /** * The list of each {@link SplitterBuilder} used by this tree, * stored as a list of @{link PredictorNodeSB} nodes. */ private ArrayList m_splitterBuilders; /** * The example mask corresponding to the m_examples that reach * each node Each example that has a true value for its mask is * an example that reaches that node */ private ArrayList m_masks; /** A list of indices of examples to be used by this tree */ private int[] m_examples; /** The booster to be used for learning this alternating tree */ private Booster m_booster; /** The tree type used by this alternating tree */ private AtreeType m_treeType; /** This flag is used to all the tree to emulate BoosTexter * functionality */ private boolean m_emulateBoosTexter; /** * Constructor which allows the controller to specify some * SplitterBuilders. * * @param sb The splitter builders for this tree * @param b The m_booster to be used. * @param ex The example indices. * @param config The configuration information. */ public InstrumentedAlternatingTree(Vector sb, Booster b, int[] ex, Configuration config) { init(sb, b, ex, config); // create root node createRoot(); } public InstrumentedAlternatingTree(AlternatingTree tree, Vector splitterbuilders, Booster booster, int[] examples, Configuration config) throws InstrumentException, NotSupportedException { init(splitterbuilders, booster, examples, config); createRoot(); instrumentAlternatingTree(tree); } private void init(Vector splitterbuilders, Booster booster, int[] examples, Configuration config) { // Use the number of boosting iterations as the default // size for the internal lists used by this tree int listSize= config.getInt("numRounds", 200); // initialize the data structures used by the tree m_predictors= new ArrayList(listSize); m_splitters= new ArrayList(listSize); m_splitterBuilders= new ArrayList(listSize); m_masks= new ArrayList(listSize); SplitterBuilder[] initialSplitterBuilders= new SplitterBuilder[splitterbuilders.size()]; splitterbuilders.toArray(initialSplitterBuilders); PredictorNodeSB pnSB= new PredictorNodeSB(0, initialSplitterBuilders); m_splitterBuilders.add(pnSB); m_booster= booster; m_examples= examples; m_index= 1; try { // set configuraiton options setAddType(config); } catch (ConfigurationException e) { // TODO Auto-generated catch block e.printStackTrace(); } m_emulateBoosTexter= config.getBool("BoosTexter", false); // create initial example mask and place it in the masks list boolean[] exampleMask= new boolean[m_examples.length]; Arrays.fill(exampleMask, true); m_masks.add(exampleMask); } /** * Create the root node of this tree. Use the booster to create a * Bag of the examples specified when this tree was constructed. * Create the initial set of predictions by taking the Bag and * passing it back through the booster. The bag used contains the * relative weights of the labels in the training set. Update the * booster with those predictions and create the initial * prediction node. */ private void createRoot() { Bag[] initialWeights= new Bag[1]; Prediction[] tmpPred= null; initialWeights[0]= m_booster.newBag(m_examples); int[][] tmpEx= new int[1][]; tmpEx[0]= m_examples; // To make it behave like boost texter. if (m_emulateBoosTexter) { if (Monitor.logLevel > 3) { Monitor.log("This has been modified to behave like boostexter: " + "Instrumented ATree Constructor."); } initialWeights[0].reset(); tmpPred= m_booster.getPredictions(initialWeights, tmpEx); m_booster.update(tmpPred, tmpEx); PredictorNode predictorNode= new PredictorNode(tmpPred[0], "R", 0, null, null, 0); m_predictors.add(predictorNode); return; } tmpPred= m_booster.getPredictions(initialWeights, tmpEx); m_booster.update(tmpPred, tmpEx); PredictorNode predictorNode= new PredictorNode(tmpPred[0], "R", 0, null, null, 0); m_predictors.add(predictorNode); } /** * Suggest a list of Candidate Splitters * @return * @throws NotSupportedException */ public Vector getCandidates() throws NotSupportedException { // set initial capacity Vector retval= new Vector(m_splitterBuilders.size()); Bag tmpBag= null; if (!m_emulateBoosTexter) { // Add a candidate that would just adjust the // prediction at the root. tmpBag= m_booster.newBag(m_examples); double loss = m_booster.getLoss(new Bag[] { tmpBag }); retval.add(new AtreeCandidateSplit(loss)); } // add all other splitters retval.addAll(buildSplitters()); return retval; } /** * Generate candidates from m_splitterBuilders * @return a vector of candidate splitters * @throws NotSupportedException */ private Vector buildSplitters() throws NotSupportedException { Executor pe=ExecutorSinglet.getExecutor(); int childCount; // create a synchronization barrier that counts the number // of processed splitter builders CountDown sbCount=new CountDown(m_splitterBuilders.size()); Vector splitters=new Vector(m_splitterBuilders.size()); for (Iterator i = m_splitterBuilders.iterator(); i.hasNext(); ) { PredictorNodeSB pSB=(PredictorNodeSB)i.next(); if (m_treeType == AtreeType.ADD_ROOT && pSB.pNode != 0) { while(sbCount.currentCount()!=0) { sbCount.release(); } break; } if (m_treeType == AtreeType.ADD_SINGLES) { childCount= ((PredictorNode) m_predictors.get(pSB.pNode)).getSplitterNodeNo(); if (childCount > 0) { sbCount.release(); continue; } } if (m_treeType == AtreeType.ADD_ROOT_OR_SINGLES && pSB.pNode != 0) { childCount= ((PredictorNode) m_predictors.get(pSB.pNode)).getSplitterNodeNo(); if (childCount > 0) { sbCount.release(); continue; } } SplitterBuilderWorker sbw= new SplitterBuilderWorker(pSB,splitters,sbCount); try { pe.execute(sbw); } catch (InterruptedException ie) { System.err.println("exception ocurred while handing off the " + "splitter job to the pool: " + ie.getMessage()); ie.printStackTrace(); } } // wait on all threads to finish try { sbCount.acquire(); } catch(InterruptedException ie) { if(sbCount.currentCount()!=0) { System.err.println("interrupted exception occurred, but the " + "sbCount is " + sbCount.currentCount()); } }; return splitters; } /** * Build a splitter using a single splitter buildier * @param retval * @throws NotSupportedException */ private void buildSplitter(PredictorNodeSB pSB, Vector splitters) throws NotSupportedException { CandidateSplit split; double trivLoss; long start; long stop; // Create bag containing all m_examples reaching this node: // tmpBag = m_booster.newBag(makeIndices((boolean []) // m_masks.get(pSB.pNode))); Compute loss for trivial split: // trivLoss = m_booster.getLoss(new Bag[] {tmpBag}); // TODO: // need to fix so that splits worse than trivial are not // added. In the meantime, allow all splits. trivLoss= Double.MAX_VALUE; start= System.currentTimeMillis(); int j=0; for (j= 0; j < pSB.SB.length; j++) { split= pSB.SB[j].build(); // only add candidates with loss better than trivial split // TODO: figure out what to do if no splits better // than trivial if (split != null && split.getLoss() < trivLoss) splitters.add(new AtreeCandidateSplit(pSB.pNode, split)); } stop= System.currentTimeMillis(); if (Monitor.logLevel > 3) { Monitor.log("It took an average of " + (stop-start)/(j*1000.0) + " seconds to build " + j + " splitterbuilders."); } } /** * Update the booster and add the new predictor to the root of the tree */ private void updateRoot() { Bag[] bags= new Bag[1]; int[][] partition= new int[1][]; partition[0]= m_examples; bags[0]= m_booster.newBag(m_examples); Prediction[] pred= m_booster.getPredictions(bags, partition); m_booster.update(pred, partition); if (pred.length > 0 && pred[0] instanceof NormalizedPrediction) { System.err.println("Cannot update root with mixed binary pred"); System.exit(2); } ((PredictorNode) m_predictors.get(0)).addToPrediction(pred[0]); if (pred==null) { System.err.println("Updating root pred is null!"); } lastBasePredictor= new AtreePredictor(pred); } /** * Search the children of the parent node for a matching splitter * Return null if no match is found * @param parent * @param splitter * @return null if no matching splitter is found, otherwise the match */ private SplitterNode findSplitter(PredictorNode parent, Splitter splitter) { // Check if this split is already added to this predictor node. SplitterNode sameAs= null; for (int i= 0; i < parent.getSplitterNodeNo(); i++) { sameAs= (SplitterNode) parent.splitterNodes.get(i); if (splitter.equals(sameAs.splitter)) { return sameAs; } } return null; } /** * * @param bags * @param parent * @param splitter * @param parentArray * @param predictions * @param partition */ private SplitterNode insert(Bag[] bags, PredictorNode parent, Splitter splitter, SplitterBuilder[] parentArray, Prediction[] predictions, int[][] partition) { boolean[] examplesMask= null; int s= bags.length; PredictorNode[] pNode= new PredictorNode[s]; int[] pInt= new int[s]; String ID= new String(parent.id); ID= ID + "." + parent.getSplitterNodeNo(); int splitterIndex= m_index++; // create new splitter node SplitterNode sNode= new SplitterNode(splitter, ID, splitterIndex, pNode, parent); // 1) Generate the prediction nodes. for (int i= 0; i < s; i++) { pNode[i]= new PredictorNode(predictions[i], ID + ":" + i, splitterIndex, null, sNode, i); // 1.a) Add the new prediction nodes to the alternating tree list
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -