⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 instrumentedalternatingtree.java

📁 Boosting算法软件包
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
package jboost.atree;import java.util.ArrayList;import java.util.Arrays;import java.util.Iterator;import java.util.Vector;import jboost.CandidateSplit;import jboost.ComplexLearner;import jboost.NotSupportedException;import jboost.Predictor;import jboost.WritablePredictor;import jboost.booster.Bag;import jboost.booster.Booster;import jboost.booster.BrownBoost;import jboost.booster.Prediction;import jboost.booster.NormalizedPrediction;import jboost.controller.Configuration;import jboost.controller.ConfigurationException;import jboost.examples.Example;import jboost.learner.Splitter;import jboost.learner.SplitterBuilder;import jboost.monitor.Monitor;import jboost.util.ExecutorSinglet;import EDU.oswego.cs.dl.util.concurrent.CountDown;import EDU.oswego.cs.dl.util.concurrent.Executor;/**  * This data structure uses Splitters from the learner package to find * the best alternating decision tree from the m_examples.  This class * can easily output an AlternatingTree once it has learned from all * the m_examples. */public class InstrumentedAlternatingTree extends ComplexLearner {      /**      * A list that holds each {@link PredictorNode} in the tree,     * ordered by the order that the learning process added them in.     * The first node is the root node.     */    private ArrayList m_predictors;      /**      * A list that holds all the {@link Splitter} nodes in the tree,     * ordered by the order that the learning process added them in.     * XXX: this is not true. what is this list for?     */    private ArrayList m_splitters;      /** The time index at which nodes are added */                 private int m_index;      /**      * The list of each {@link SplitterBuilder} used by this tree,     * stored as a list of @{link PredictorNodeSB} nodes.     */    private ArrayList m_splitterBuilders;      /**     *  The example mask corresponding to the m_examples that reach     *  each node Each example that has a true value for its mask is     *  an example that reaches that node     */    private ArrayList m_masks;      /** A list of indices of examples to be used by this tree */    private int[] m_examples;      /** The booster to be used for learning this alternating tree */    private Booster m_booster;      /** The tree type used by this alternating tree */    private AtreeType m_treeType;      /** This flag is used to all the tree to emulate BoosTexter     * functionality */    private boolean m_emulateBoosTexter;          /**      * Constructor which allows the controller to specify some     * SplitterBuilders.     *     * @param sb The splitter builders for this tree     * @param b The m_booster to be used.     * @param ex The example indices.     * @param config The configuration information.     */    public InstrumentedAlternatingTree(Vector sb, Booster b, int[] ex,				       Configuration config) {    	init(sb, b, ex, config);    	// create root node	createRoot();    }      public InstrumentedAlternatingTree(AlternatingTree tree, 				       Vector splitterbuilders,				       Booster booster,  int[] examples, 				       Configuration config)	throws InstrumentException, NotSupportedException {    	init(splitterbuilders, booster, examples, config);   	createRoot();	instrumentAlternatingTree(tree);    }      private void init(Vector splitterbuilders, Booster booster,		      int[] examples, Configuration config) {	// Use the number of boosting iterations as the default	// size for the internal lists used by this tree	int listSize= config.getInt("numRounds", 200);    	// initialize the data structures used by the tree	m_predictors= new ArrayList(listSize);	m_splitters= new ArrayList(listSize);	m_splitterBuilders= new ArrayList(listSize);    	m_masks= new ArrayList(listSize);	SplitterBuilder[] initialSplitterBuilders= 	    new SplitterBuilder[splitterbuilders.size()];	splitterbuilders.toArray(initialSplitterBuilders);	PredictorNodeSB pnSB= new PredictorNodeSB(0, initialSplitterBuilders);	m_splitterBuilders.add(pnSB);    	m_booster= booster;	m_examples= examples;	m_index= 1;    	try {	    // set configuraiton options	    setAddType(config);	} catch (ConfigurationException e) {	    // TODO Auto-generated catch block	    e.printStackTrace();	}    	m_emulateBoosTexter= config.getBool("BoosTexter", false);    	// create initial example mask and place it in the masks list	boolean[] exampleMask= new boolean[m_examples.length];	Arrays.fill(exampleMask, true);	m_masks.add(exampleMask);        }    /**     * Create the root node of this tree.  Use the booster to create a     * Bag of the examples specified when this tree was constructed.     * Create the initial set of predictions by taking the Bag and     * passing it back through the booster. The bag used contains the     * relative weights of the labels in the training set.  Update the     * booster with those predictions and create the initial     * prediction node.     */    private void createRoot() {	Bag[] initialWeights= new Bag[1];	Prediction[] tmpPred= null;	initialWeights[0]= m_booster.newBag(m_examples);	int[][] tmpEx= new int[1][];	tmpEx[0]= m_examples;	//  To make it behave like boost texter.	if (m_emulateBoosTexter) {	    if (Monitor.logLevel > 3) {		Monitor.log("This has been modified to behave like boostexter: "			    + "Instrumented ATree Constructor.");	    }	    initialWeights[0].reset();	    tmpPred= m_booster.getPredictions(initialWeights, tmpEx);	    m_booster.update(tmpPred, tmpEx);	    PredictorNode predictorNode= new PredictorNode(tmpPred[0], "R", 						       0, null, null, 0);	    m_predictors.add(predictorNode);	    return;	}    	tmpPred= m_booster.getPredictions(initialWeights, tmpEx);	m_booster.update(tmpPred, tmpEx);    	PredictorNode predictorNode= new PredictorNode(tmpPred[0], "R", 						       0, null, null, 0);	m_predictors.add(predictorNode);    }      /**     *  Suggest a list of Candidate Splitters     *  @return     *  @throws NotSupportedException 						     */    public Vector getCandidates() throws NotSupportedException {	// set initial capacity	Vector retval= new Vector(m_splitterBuilders.size()); 	Bag tmpBag= null;    	if (!m_emulateBoosTexter) {	    // Add a candidate that would just adjust the 	    // prediction at the root.	    tmpBag= m_booster.newBag(m_examples);	    double loss = m_booster.getLoss(new Bag[] { tmpBag });	    retval.add(new AtreeCandidateSplit(loss));	}    	// add all other splitters	retval.addAll(buildSplitters());    	return retval;    }      /**     * Generate candidates from m_splitterBuilders     * @return a vector of candidate splitters     * @throws NotSupportedException     */    private Vector buildSplitters() throws NotSupportedException {	Executor pe=ExecutorSinglet.getExecutor();	int childCount;      	// create a synchronization barrier that counts the number	// of processed splitter builders	CountDown sbCount=new CountDown(m_splitterBuilders.size());	Vector splitters=new Vector(m_splitterBuilders.size());	for (Iterator i = m_splitterBuilders.iterator(); i.hasNext();  ) {	    	    PredictorNodeSB pSB=(PredictorNodeSB)i.next();	    	    if (m_treeType == AtreeType.ADD_ROOT && pSB.pNode != 0) {		while(sbCount.currentCount()!=0) {		    sbCount.release();		}		break;	    }          	    if (m_treeType == AtreeType.ADD_SINGLES) {		childCount= ((PredictorNode) 			     m_predictors.get(pSB.pNode)).getSplitterNodeNo();		if (childCount > 0) {		    sbCount.release();		    continue;		}	    }          	    if (m_treeType == AtreeType.ADD_ROOT_OR_SINGLES && pSB.pNode != 0) {		childCount= ((PredictorNode) 			     m_predictors.get(pSB.pNode)).getSplitterNodeNo();		if (childCount > 0) {		    sbCount.release();		    continue;		}	    }          	    SplitterBuilderWorker sbw=		new SplitterBuilderWorker(pSB,splitters,sbCount);	    try {		pe.execute(sbw);	    } catch (InterruptedException ie) {		System.err.println("exception ocurred while handing off the "				   + "splitter job to the pool: "				   + ie.getMessage());		ie.printStackTrace();	    }	}      	// wait on all threads to finish	try {	    sbCount.acquire();	} catch(InterruptedException ie) {	    if(sbCount.currentCount()!=0) {		System.err.println("interrupted exception occurred, but the "				   + "sbCount is " + sbCount.currentCount());	    }	};      	return splitters;    }      /**     * Build a splitter using a single splitter buildier     * @param retval     * @throws NotSupportedException     */    private void buildSplitter(PredictorNodeSB pSB, Vector splitters) 	throws NotSupportedException {	CandidateSplit split;	double trivLoss;	long start;	long stop;    	// Create bag containing all m_examples reaching this node:	// tmpBag = m_booster.newBag(makeIndices((boolean [])	// m_masks.get(pSB.pNode))); Compute loss for trivial split:	// trivLoss = m_booster.getLoss(new Bag[] {tmpBag}); 	// TODO:	// need to fix so that splits worse than trivial are not	// added.  In the meantime, allow all splits.	trivLoss= Double.MAX_VALUE;	start= System.currentTimeMillis();	int j=0;	for (j= 0; j < pSB.SB.length; j++) {	    split= pSB.SB[j].build();          	    // only add candidates with loss better than trivial split	    // TODO: figure out what to do if no splits better	    // than trivial	    if (split != null && split.getLoss() < trivLoss)		splitters.add(new AtreeCandidateSplit(pSB.pNode, split));	}	stop= System.currentTimeMillis();	if (Monitor.logLevel > 3) {	    Monitor.log("It took an average of " + (stop-start)/(j*1000.0) + 			" seconds to build " + j + " splitterbuilders.");	}         }      /**     * Update the booster and add the new predictor to the root of the tree      */    private void updateRoot() {	Bag[] bags= new Bag[1];	int[][] partition= new int[1][];	partition[0]= m_examples;	bags[0]= m_booster.newBag(m_examples);	Prediction[] pred= m_booster.getPredictions(bags, partition);	m_booster.update(pred, partition);	if (pred.length > 0 && pred[0] instanceof NormalizedPrediction) {	    System.err.println("Cannot update root with mixed binary pred");	    System.exit(2);	}	((PredictorNode) m_predictors.get(0)).addToPrediction(pred[0]);	if (pred==null) {	    System.err.println("Updating root pred is null!");	}	lastBasePredictor= new AtreePredictor(pred);    }      /**     * Search the children of the parent node for a matching splitter     * Return null if no match is found     * @param parent     * @param splitter     * @return null if no matching splitter is found, otherwise the match     */    private SplitterNode findSplitter(PredictorNode parent, Splitter splitter) {	// Check if this split is already added to this predictor node.	SplitterNode sameAs= null;	for (int i= 0; i < parent.getSplitterNodeNo(); i++) {	    sameAs= (SplitterNode) parent.splitterNodes.get(i);	    if (splitter.equals(sameAs.splitter)) {		return sameAs;	    }	}	return null;    }      /**     *      * @param bags     * @param parent     * @param splitter     * @param parentArray     * @param predictions     * @param partition     */    private SplitterNode insert(Bag[] bags, PredictorNode parent, 			Splitter splitter, SplitterBuilder[] parentArray,			Prediction[] predictions, int[][] partition) {	boolean[] examplesMask= null;	int s= bags.length;	PredictorNode[] pNode= new PredictorNode[s];	int[] pInt= new int[s];	String ID= new String(parent.id);	ID= ID + "." + parent.getSplitterNodeNo();	int splitterIndex= m_index++;	// create new splitter node	SplitterNode sNode= new SplitterNode(splitter, ID, splitterIndex, pNode, parent);	// 1) Generate the prediction nodes.	for (int i= 0; i < s; i++) {	    pNode[i]=  new PredictorNode(predictions[i], ID + ":" + i, splitterIndex,					 null, sNode, i);	    // 1.a) Add the new prediction nodes to the alternating tree list

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -