⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 lmtnode.java

📁 Java 编写的多种数据挖掘算法 包括聚类、分类、预处理等
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
/* *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* *    LMTNode.java *    Copyright (C) 2003 Niels Landwehr * */package weka.classifiers.trees.lmt;import weka.classifiers.Evaluation;import weka.classifiers.functions.SimpleLinearRegression;import weka.classifiers.trees.j48.ClassifierSplitModel;import weka.classifiers.trees.j48.ModelSelection;import weka.core.Instance;import weka.core.Instances;import weka.filters.Filter;import weka.filters.supervised.attribute.NominalToBinary;import java.util.Collections;import java.util.Comparator;import java.util.Vector;/**  * Auxiliary class for list of LMTNodes */class CompareNode     implements Comparator {    /**     * Compares its two arguments for order.     *      * @param o1 first object     * @param o2 second object     * @return a negative integer, zero, or a positive integer as the first      *         argument is less than, equal to, or greater than the second.     */    public int compare(Object o1, Object o2) {			if ( ((LMTNode)o1).m_alpha < ((LMTNode)o2).m_alpha) return -1;	if ( ((LMTNode)o1).m_alpha > ((LMTNode)o2).m_alpha) return 1;	return 0;	    }        }/** * Class for logistic model tree structure.  *  *  * @author Niels Landwehr  * @author Marc Sumner  * @version $Revision: 1.4 $ */public class LMTNode     extends LogisticBase {         /** for serialization */    static final long serialVersionUID = 1862737145870398755L;        /** Total number of training instances. */    protected double m_totalInstanceWeight;        /** Node id*/    protected int m_id;        /** ID of logistic model at leaf*/    protected int m_leafModelNum;     /** Alpha-value (for pruning) at the node*/    public double m_alpha;        /** Weighted number of training examples currently misclassified by the logistic model at the node*/     public double m_numIncorrectModel;    /** Weighted number of training examples currently misclassified by the subtree rooted at the node*/    public double m_numIncorrectTree;    /**minimum number of instances at which a node is considered for splitting*/    protected int m_minNumInstances;        /**ModelSelection object (for splitting)*/    protected ModelSelection m_modelSelection;         /**Filter to convert nominal attributes to binary*/    protected NominalToBinary m_nominalToBinary;         /**Simple regression functions fit by LogitBoost at higher levels in the tree*/    protected SimpleLinearRegression[][] m_higherRegressions;        /**Number of simple regression functions fit by LogitBoost at higher levels in the tree*/    protected int m_numHigherRegressions = 0;        /**Number of folds for CART pruning*/    protected static int m_numFoldsPruning = 5;    /**Use heuristic that determines the number of LogitBoost iterations only once in the beginning? */    protected boolean m_fastRegression;        /**Number of instances at the node*/    protected int m_numInstances;        /**The ClassifierSplitModel (for splitting)*/    protected ClassifierSplitModel m_localModel;      /**Array of children of the node*/    protected LMTNode[] m_sons;               /**True if node is leaf*/    protected boolean m_isLeaf;                       /**     * Constructor for logistic model tree node.      *     * @param modelSelection selection method for local splitting model     * @param numBoostingIterations sets the numBoostingIterations parameter     * @param fastRegression sets the fastRegression parameter     * @param errorOnProbabilities Use error on probabilities for stopping criterion of LogitBoost?     * @param minNumInstances minimum number of instances at which a node is considered for splitting     */    public LMTNode(ModelSelection modelSelection, int numBoostingIterations, 		   boolean fastRegression,                    boolean errorOnProbabilities, int minNumInstances,                   double weightTrimBeta, boolean useAIC) {	m_modelSelection = modelSelection;	m_fixedNumIterations = numBoostingIterations;      	m_fastRegression = fastRegression;	m_errorOnProbabilities = errorOnProbabilities;	m_minNumInstances = minNumInstances;	m_maxIterations = 200;        setWeightTrimBeta(weightTrimBeta);        setUseAIC(useAIC);    }                 /**     * Method for building a logistic model tree (only called for the root node).     * Grows an initial logistic model tree and prunes it back using the CART pruning scheme.     *     * @param data the data to train with     * @throws Exception if something goes wrong     */    public void buildClassifier(Instances data) throws Exception{		//heuristic to avoid cross-validating the number of LogitBoost iterations	//at every node: build standalone logistic model and take its optimum number	//of iteration everywhere in the tree.	if (m_fastRegression && (m_fixedNumIterations < 0)) m_fixedNumIterations = tryLogistic(data);		//Need to cross-validate alpha-parameter for CART-pruning	Instances cvData = new Instances(data);	cvData.stratify(m_numFoldsPruning);		double[][] alphas = new double[m_numFoldsPruning][];	double[][] errors = new double[m_numFoldsPruning][];		for (int i = 0; i < m_numFoldsPruning; i++) {	    //for every fold, grow tree on training set...	    Instances train = cvData.trainCV(m_numFoldsPruning, i);	    Instances test = cvData.testCV(m_numFoldsPruning, i);	    	    buildTree(train, null, train.numInstances() , 0);		    	    int numNodes = getNumInnerNodes();	   	    alphas[i] = new double[numNodes + 2];	    errors[i] = new double[numNodes + 2];	    	    //... then prune back and log alpha-values and errors on test set	    prune(alphas[i], errors[i], test);	    	   	}		//build tree using all the data	buildTree(data, null, data.numInstances(), 0);	int numNodes = getNumInnerNodes();	double[] treeAlphas = new double[numNodes + 2];			//prune back and log alpha-values     	int iterations = prune(treeAlphas, null, null);		double[] treeErrors = new double[numNodes + 2];		for (int i = 0; i <= iterations; i++){	    //compute midpoint alphas	    double alpha = Math.sqrt(treeAlphas[i] * treeAlphas[i+1]);	    double error = 0;	    	    //compute error estimate for final trees from the midpoint-alphas and the error estimates gotten in 	    //the cross-validation	    for (int k = 0; k < m_numFoldsPruning; k++) {		int l = 0;		while (alphas[k][l] <= alpha) l++;		error += errors[k][l - 1];	    }	    treeErrors[i] = error;	    	  	   	}		//find best alpha 	int best = -1;	double bestError = Double.MAX_VALUE;	for (int i = iterations; i >= 0; i--) {	    if (treeErrors[i] < bestError) {		bestError = treeErrors[i];		best = i;	    }	    	}	double bestAlpha = Math.sqrt(treeAlphas[best] * treeAlphas[best + 1]);      			//"unprune" final tree (faster than regrowing it)	unprune();	//CART-prune it with best alpha	prune(bestAlpha);    	 			cleanup();	    }    /**     * Method for building the tree structure.     * Builds a logistic model, splits the node and recursively builds tree for child nodes.     * @param data the training data passed on to this node     * @param higherRegressions An array of regression functions produced by LogitBoost at higher      * levels in the tree. They represent a logistic regression model that is refined locally      * at this node.     * @param totalInstanceWeight the total number of training examples     * @param higherNumParameters effective number of parameters in the logistic regression model built     * in parent nodes     * @throws Exception if something goes wrong     */    public void buildTree(Instances data, SimpleLinearRegression[][] higherRegressions, 			  double totalInstanceWeight, double higherNumParameters) throws Exception{	//save some stuff	m_totalInstanceWeight = totalInstanceWeight;	m_train = new Instances(data);		m_isLeaf = true;	m_sons = null;		m_numInstances = m_train.numInstances();	m_numClasses = m_train.numClasses();						//init 	m_numericData = getNumericData(m_train);		  	m_numericDataHeader = new Instances(m_numericData, 0);		m_regressions = initRegressions();	m_numRegressions = 0;		if (higherRegressions != null) m_higherRegressions = higherRegressions;	else m_higherRegressions = new SimpleLinearRegression[m_numClasses][0];		m_numHigherRegressions = m_higherRegressions[0].length;	                m_numParameters = higherNumParameters;                //build logistic model        if (m_numInstances >= m_numFoldsBoosting) {            if (m_fixedNumIterations > 0){                performBoosting(m_fixedNumIterations);            } else if (getUseAIC()) {                performBoostingInfCriterion();            } else {                performBoostingCV();            }        }                m_numParameters += m_numRegressions;		//only keep the simple regression functions that correspond to the selected number of LogitBoost iterations	m_regressions = selectRegressions(m_regressions);	boolean grow;	//split node if more than minNumInstances...	if (m_numInstances > m_minNumInstances) {	    //split node: either splitting on class value (a la C4.5) or splitting on residuals	    if (m_modelSelection instanceof ResidualModelSelection) {			//need ps/Ys/Zs/weights		double[][] probs = getProbs(getFs(m_numericData));		double[][] trainYs = getYs(m_train);		double[][] dataZs = getZs(probs, trainYs);		double[][] dataWs = getWs(probs, trainYs);		m_localModel = ((ResidualModelSelection)m_modelSelection).selectModel(m_train, dataZs, dataWs);		    } else {		m_localModel = m_modelSelection.selectModel(m_train);		    }	    //... and valid split found	    grow = (m_localModel.numSubsets() > 1);	} else {	    grow = false;	}		if (grow) {		    //create and build children of node	    m_isLeaf = false;	    	    	    Instances[] localInstances = m_localModel.split(m_train);	    	    m_sons = new LMTNode[m_localModel.numSubsets()];	    for (int i = 0; i < m_sons.length; i++) {		m_sons[i] = new LMTNode(m_modelSelection, m_fixedNumIterations, 					 m_fastRegression,  					 m_errorOnProbabilities,m_minNumInstances,                                        getWeightTrimBeta(), getUseAIC());		//the "higherRegressions" (partial logistic model fit at higher levels in the tree) passed		//on to the children are the "higherRegressions" at this node plus the regressions added		//at this node (m_regressions).		m_sons[i].buildTree(localInstances[i],				  mergeArrays(m_regressions, m_higherRegressions), m_totalInstanceWeight, m_numParameters);				localInstances[i] = null;	    }	    	}     }    /**      * Prunes a logistic model tree using the CART pruning scheme, given a      * cost-complexity parameter alpha.     *      * @param alpha the cost-complexity measure       * @throws Exception if something goes wrong     */    public void prune(double alpha) throws Exception {		Vector nodeList; 		CompareNode comparator = new CompareNode();			//determine training error of logistic models and subtrees, and calculate alpha-values from them	modelErrors();	treeErrors();	calculateAlphas();		//get list of all inner nodes in the tree	nodeList = getNodes();       			boolean prune = (nodeList.size() > 0);		while (prune) {	    	    //select node with minimum alpha	    LMTNode nodeToPrune = (LMTNode)Collections.min(nodeList,comparator);	    	    //want to prune if its alpha is smaller than alpha	    if (nodeToPrune.m_alpha > alpha) break; 	    	    nodeToPrune.m_isLeaf = true;	    nodeToPrune.m_sons = null;	    	    //update tree errors and alphas	    treeErrors();	    calculateAlphas();	    nodeList = getNodes();	    prune = (nodeList.size() > 0);   	  	}      }    /**     * Method for performing one fold in the cross-validation of the cost-complexity parameter.     * Generates a sequence of alpha-values with error estimates for the corresponding (partially pruned)     * trees, given the test set of that fold.     * @param alphas array to hold the generated alpha-values     * @param errors array to hold the corresponding error estimates     * @param test test set of that fold (to obtain error estimates)     * @throws if something goes wrong     */    public int prune(double[] alphas, double[] errors, Instances test) throws Exception {		Vector nodeList; 		CompareNode comparator = new CompareNode();		//determine training error of logistic models and subtrees, and calculate alpha-values from them	modelErrors();	treeErrors();	calculateAlphas();	//get list of all inner nodes in the tree	nodeList = getNodes();       	boolean prune = (nodeList.size() > 0);           			//alpha_0 is always zero (unpruned tree)	alphas[0] = 0;	Evaluation eval;	//error of unpruned tree	if (errors != null) {	    eval = new Evaluation(test);	    eval.evaluateModel(this, test);	    errors[0] = eval.errorRate(); 	}	       	int iteration = 0;	while (prune) {	    iteration++;	    	    //get node with minimum alpha	    LMTNode nodeToPrune = (LMTNode)Collections.min(nodeList,comparator);	    nodeToPrune.m_isLeaf = true;	    //Do not set m_sons null, want to unprune	    	    //get alpha-value of node	    alphas[iteration] = nodeToPrune.m_alpha; 	    	    //log error	    if (errors != null) {		eval = new Evaluation(test);		eval.evaluateModel(this, test);		errors[iteration] = eval.errorRate(); 	    }	    //update errors/alphas	    treeErrors();	    calculateAlphas();	    nodeList = getNodes();	   	    prune = (nodeList.size() > 0);   	   	} 		//set last alpha 1 to indicate end	alphas[iteration + 1] = 1.0;		return iteration;    }    /**     *Method to "unprune" a logistic model tree.     *Sets all leaf-fields to false.     *Faster than re-growing the tree because the logistic models do not have to be fit again.      */    protected void unprune() {	if (m_sons != null) {	    m_isLeaf = false;	    for (int i = 0; i < m_sons.length; i++) m_sons[i].unprune();	}    }    /**     *Determines the optimum number of LogitBoost iterations to perform by building a standalone logistic      *regression function on the training data. Used for the heuristic that avoids cross-validating this     *number again at every node.     *@param data training instances for the logistic model     *@throws if something goes wrong     */    protected int tryLogistic(Instances data) throws Exception{		//convert nominal attributes	Instances filteredData = new Instances(data);		NominalToBinary nominalToBinary = new NominalToBinary();				nominalToBinary.setInputFormat(filteredData);	filteredData = Filter.useFilter(filteredData, nominalToBinary);			LogisticBase logistic = new LogisticBase(0,true,m_errorOnProbabilities);		//limit LogitBoost to 200 iterations (speed)	logistic.setMaxIterations(200);        logistic.setWeightTrimBeta(getWeightTrimBeta()); // Not in Marc's code. Added by Eibe.        logistic.setUseAIC(getUseAIC());

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -