⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 lmtnode.java

📁 一个数据挖掘软件ALPHAMINERR的整个过程的JAVA版源代码
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
     * Leaves are only counted if their logistic model has changed compared to the one of the parent node.
     * @return the number of leaves
     */
     public int getNumLeaves(){
	int numLeaves;
	if (!m_isLeaf) {
	    numLeaves = 0;
	    int numEmptyLeaves = 0;
	    for (int i = 0; i < m_sons.length; i++) {
		numLeaves += m_sons[i].getNumLeaves();
		if (m_sons[i].m_isLeaf && !m_sons[i].hasModels()) numEmptyLeaves++;
	    }
	    if (numEmptyLeaves > 1) {
		numLeaves -= (numEmptyLeaves - 1);
	    }
	} else {
	    numLeaves = 1;
	}	   
	return numLeaves;	
    }

    /**
     *Updates the numIncorrectModel field for all nodes. This is needed for calculating the alpha-values. 
     */
    public void modelErrors() throws Exception{
		
	Evaluation eval = new Evaluation(m_train);
		
	if (!m_isLeaf) {
	    m_isLeaf = true;
	    eval.evaluateModel(this, m_train);
	    m_isLeaf = false;
	    m_numIncorrectModel = eval.incorrect();
	    for (int i = 0; i < m_sons.length; i++) m_sons[i].modelErrors();
	} else {
	    eval.evaluateModel(this, m_train);
	    m_numIncorrectModel = eval.incorrect();
	}
    }
    
    /**
     *Updates the numIncorrectTree field for all nodes. This is needed for calculating the alpha-values. 
     */
    public void treeErrors(){
	if (m_isLeaf) {
	    m_numIncorrectTree = m_numIncorrectModel;
	} else {
	    m_numIncorrectTree = 0;
	    for (int i = 0; i < m_sons.length; i++) {
		m_sons[i].treeErrors();
		m_numIncorrectTree += m_sons[i].m_numIncorrectTree;
	    }	 
	}	
    }

    /**
     *Updates the alpha field for all nodes.
     */
    public void calculateAlphas() throws Exception {		
		
	if (!m_isLeaf) {	
	    double errorDiff = m_numIncorrectModel - m_numIncorrectTree;	    	    
	    
	    if (errorDiff <= 0) {
		//split increases training error (should not normally happen).
		//prune it instantly.
		m_isLeaf = true;
		m_sons = null;
		m_alpha = Double.MAX_VALUE;		
	    } else {
		//compute alpha
		errorDiff /= m_totalInstanceWeight;		
		m_alpha = errorDiff / (double)(getNumLeaves() - 1);
		
		for (int i = 0; i < m_sons.length; i++) m_sons[i].calculateAlphas();
	    }
	} else {	    
	    //alpha = infinite for leaves (do not want to prune)
	    m_alpha = Double.MAX_VALUE;
	}
    }
    
    /**
     * Merges two arrays of regression functions into one
     * @param a1 one array
     * @param a2 the other array
     *
     * @return an array that contains all entries from both input arrays
     */
    protected SimpleLinearRegression[][] mergeArrays(SimpleLinearRegression[][] a1,	
							   SimpleLinearRegression[][] a2){
	int numModels1 = a1[0].length;
	int numModels2 = a2[0].length;		
	
	SimpleLinearRegression[][] result =
	    new SimpleLinearRegression[m_numClasses][numModels1 + numModels2];
	
	int k = 0;
	for (int i = 0; i < m_numClasses; i++)
	    for (int j = 0; j < numModels1; j++) {
		result[i][j]  = a1[i][j];
	    }
	for (int i = 0; i < m_numClasses; i++)
	    for (int j = 0; j < numModels2; j++) result[i][j+numModels1] = a2[i][j];
	return result;
    }

    /**
     * Return a list of all inner nodes in the tree
     * @return the list of nodes
     */
    public Vector getNodes(){
	Vector nodeList = new Vector();
	getNodes(nodeList);
	return nodeList;
    }

    /**
     * Fills a list with all inner nodes in the tree
     * 
     * @param nodeList the list to be filled
     */
    public void getNodes(Vector nodeList) {
	if (!m_isLeaf) {
	    nodeList.add(this);
	    for (int i = 0; i < m_sons.length; i++) m_sons[i].getNodes(nodeList);
	}	
    }
    
    /**
     * Returns a numeric version of a set of instances.
     * All nominal attributes are replaced by binary ones, and the class variable is replaced
     * by a pseudo-class variable that is used by LogitBoost.
     */
    protected Instances getNumericData(Instances train) throws Exception{
	
	Instances filteredData = new Instances(train);	
	m_nominalToBinary = new NominalToBinary();			
	m_nominalToBinary.setInputFormat(filteredData);
	filteredData = Filter.useFilter(filteredData, m_nominalToBinary);	

	return super.getNumericData(filteredData);
    }

    /**
     * Computes the F-values of LogitBoost for an instance from the current logistic model at the node
     * Note that this also takes into account the (partial) logistic model fit at higher levels in 
     * the tree.
     * @param instance the instance
     * @return the array of F-values 
     */
    protected double[] getFs(Instance instance) throws Exception{
	
	double [] pred = new double [m_numClasses];
	
	//Need to take into account partial model fit at higher levels in the tree (m_higherRegressions) 
	//and the part of the model fit at this node (m_regressions).

	//Fs from m_regressions (use method of LogisticBase)
	double [] instanceFs = super.getFs(instance);		

	//Fs from m_higherRegressions
	for (int i = 0; i < m_numHigherRegressions; i++) {
	    double predSum = 0;
	    for (int j = 0; j < m_numClasses; j++) {
		pred[j] = m_higherRegressions[j][i].classifyInstance(instance);
		predSum += pred[j];
	    }
	    predSum /= m_numClasses;
	    for (int j = 0; j < m_numClasses; j++) {
		instanceFs[j] += (pred[j] - predSum) * (m_numClasses - 1) 
		    / m_numClasses;
	    }
	}
	return instanceFs; 
    }
    
    /**
     *Returns true if the logistic regression model at this node has changed compared to the
     *one at the parent node.
     *@return whether it has changed
     */
    public boolean hasModels() {
	return (m_numRegressions > 0);
    }

    /**
     * Returns the class probabilities for an instance according to the logistic model at the node.
     * @param instance the instance
     * @return the array of probabilities
     */
    public double[] modelDistributionForInstance(Instance instance) throws Exception {
	
	//make copy and convert nominal attributes
	instance = (Instance)instance.copy();		
	m_nominalToBinary.input(instance);
	instance = m_nominalToBinary.output();	
	
	//saet numeric pseudo-class
	instance.setDataset(m_numericDataHeader);		
	
	return probs(getFs(instance));
    }

    /**
     * Returns the class probabilities for an instance given by the logistic model tree.
     * @param instance the instance
     * @return the array of probabilities
     */
    public double[] distributionForInstance(Instance instance) throws Exception {
	
	double[] probs;
	
	if (m_isLeaf) {	    
	    //leaf: use logistic model
	    probs = modelDistributionForInstance(instance);
	} else {
	    //sort into appropiate child node
	    int branch = m_localModel.whichSubset(instance);
	    probs = m_sons[branch].distributionForInstance(instance);
	}  			
	return probs;
    }

    /**
     * Returns the number of leaves (normal count).
     * @return the number of leaves
     */
    public int numLeaves() {	
	if (m_isLeaf) return 1;	
	int numLeaves = 0;
	for (int i = 0; i < m_sons.length; i++) numLeaves += m_sons[i].numLeaves();
   	return numLeaves;
    }
    
    /**
     * Returns the number of nodes.
     * @return the number of nodes
     */
    public int numNodes() {
	if (m_isLeaf) return 1;	
	int numNodes = 1;
	for (int i = 0; i < m_sons.length; i++) numNodes += m_sons[i].numNodes();
   	return numNodes;
    }

    /**
     * Returns a description of the logistic model tree (tree structure and logistic models)
     * @return describing string
     */
    public String toString(){	
	//assign numbers to logistic regression functions at leaves
	assignLeafModelNumbers(0);	
	try{
	    StringBuffer text = new StringBuffer();
	    
	    if (m_isLeaf) {
		text.append(": ");
		text.append("LM_"+m_leafModelNum+":"+getModelParameters());
	    } else {
		dumpTree(0,text);	    	    
	    }
	    text.append("\n\nNumber of Leaves  : \t"+numLeaves()+"\n");
	    text.append("\nSize of the Tree : \t"+numNodes()+"\n");	
	        
	    //This prints logistic models after the tree, comment out if only tree should be printed
	    text.append(modelsToString());
	    return text.toString();
	} catch (Exception e){
	    return "Can't print logistic model tree";
	}
	
        
    }

    /**
     * Returns a string describing the number of LogitBoost iterations performed at this node, the total number
     * of LogitBoost iterations performed (including iterations at higher levels in the tree), and the number
     * of training examples at this node.
     * @return the describing string
     */
    public String getModelParameters(){
	
	StringBuffer text = new StringBuffer();
	int numModels = m_numRegressions+m_numHigherRegressions;
	text.append(m_numRegressions+"/"+numModels+" ("+m_numInstances+")");
	return text.toString();
    }
    
   
    /**
     * Help method for printing tree structure.
     *
     * @exception Exception if something goes wrong
     */
    protected void dumpTree(int depth,StringBuffer text) 
	throws Exception {
	
	for (int i = 0; i < m_sons.length; i++) {
	    text.append("\n");
	    for (int j = 0; j < depth; j++)
		text.append("|   ");
	    text.append(m_localModel.leftSide(m_train));
	    text.append(m_localModel.rightSide(i, m_train));
	    if (m_sons[i].m_isLeaf) {
		text.append(": ");
		text.append("LM_"+m_sons[i].m_leafModelNum+":"+m_sons[i].getModelParameters());
	    }else
		m_sons[i].dumpTree(depth+1,text);
	}
    }

    /**
     * Assigns unique IDs to all nodes in the tree
     */
    public int assignIDs(int lastID) {
	
	int currLastID = lastID + 1;
	
	m_id = currLastID;
	if (m_sons != null) {
	    for (int i = 0; i < m_sons.length; i++) {
		currLastID = m_sons[i].assignIDs(currLastID);
	    }
	}
	return currLastID;
    }
    
    /**
     * Assigns numbers to the logistic regression models at the leaves of the tree
     */
    public int assignLeafModelNumbers(int leafCounter) {
	if (!m_isLeaf) {
	    m_leafModelNum = 0;
	    for (int i = 0; i < m_sons.length; i++){
		leafCounter = m_sons[i].assignLeafModelNumbers(leafCounter);
	    }
	} else {
	    leafCounter++;
	    m_leafModelNum = leafCounter;
	} 
	return leafCounter;
    }

    /**
     * Returns an array containing the coefficients of the logistic regression function at this node.
     * @return the array of coefficients, first dimension is the class, second the attribute. 
     */
    protected double[][] getCoefficients(){
       
	//Need to take into account partial model fit at higher levels in the tree (m_higherRegressions) 
	//and the part of the model fit at this node (m_regressions).
	
	//get coefficients from m_regressions: use method of LogisticBase
	double[][] coefficients = super.getCoefficients();
	//get coefficients from m_higherRegressions:
	for (int j = 0; j < m_numClasses; j++) {
	    for (int i = 0; i < m_numHigherRegressions; i++) {		
		double slope = m_higherRegressions[j][i].getSlope();
		double intercept = m_higherRegressions[j][i].getIntercept();
		int attribute = m_higherRegressions[j][i].getAttributeIndex();
		coefficients[j][0] += intercept;
		coefficients[j][attribute + 1] += slope;
	    }
	}
	return coefficients;
    }
    
    /**
     * Returns a string describing the logistic regression function at the node.
     */
    public String modelsToString(){
	
	StringBuffer text = new StringBuffer();
	if (m_isLeaf) {
	    text.append("LM_"+m_leafModelNum+":"+super.toString());
	} else {
	    for (int i = 0; i < m_sons.length; i++) {
		text.append("\n"+m_sons[i].modelsToString());
	    }
	}
	return text.toString();	    
    }

    /**
     * Returns graph describing the tree.
     *
     * @exception Exception if something goes wrong
     */
    public String graph() throws Exception {
	
	StringBuffer text = new StringBuffer();
	
	assignIDs(-1);
	assignLeafModelNumbers(0);
	text.append("digraph LMTree {\n");
	if (m_isLeaf) {
	    text.append("N" + m_id + " [label=\"LM_"+m_leafModelNum+":"+getModelParameters()+"\" " + 
			"shape=box style=filled");
	    text.append("]\n");
	}else {
	    text.append("N" + m_id 
			+ " [label=\"" + 
			m_localModel.leftSide(m_train) + "\" ");
	    text.append("]\n");
	    graphTree(text);
	}
    
	return text.toString() +"}\n";
    }

    /**
     * Helper function for graph description of tree
     *
     * @exception Exception if something goes wrong
     */
    private void graphTree(StringBuffer text) throws Exception {
	
	for (int i = 0; i < m_sons.length; i++) {
	    text.append("N" + m_id  
			+ "->" + 
			"N" + m_sons[i].m_id +
			" [label=\"" + m_localModel.rightSide(i,m_train).trim() + 
			"\"]\n");
	    if (m_sons[i].m_isLeaf) {
		text.append("N" +m_sons[i].m_id + " [label=\"LM_"+m_sons[i].m_leafModelNum+":"+
			    m_sons[i].getModelParameters()+"\" " + "shape=box style=filled");
		text.append("]\n");
	    } else {
		text.append("N" + m_sons[i].m_id +
			    " [label=\""+m_sons[i].m_localModel.leftSide(m_train) + 
			    "\" ");
		text.append("]\n");
		m_sons[i].graphTree(text);
	    }
	}
    } 
    
    /**
     * Cleanup in order to save memory.
     */
    public void cleanup() {
	super.cleanup();
	if (!m_isLeaf) {
	    for (int i = 0; i < m_sons.length; i++) m_sons[i].cleanup();
	}
    }
}





⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -