📄 lmtnode.java
字号:
logistic.buildClassifier(filteredData); //return best number of iterations return logistic.getNumRegressions(); } /** * Method to count the number of inner nodes in the tree * @return the number of inner nodes */ public int getNumInnerNodes(){ if (m_isLeaf) return 0; int numNodes = 1; for (int i = 0; i < m_sons.length; i++) numNodes += m_sons[i].getNumInnerNodes(); return numNodes; } /** * Returns the number of leaves in the tree. * Leaves are only counted if their logistic model has changed compared to the one of the parent node. * @return the number of leaves */ public int getNumLeaves(){ int numLeaves; if (!m_isLeaf) { numLeaves = 0; int numEmptyLeaves = 0; for (int i = 0; i < m_sons.length; i++) { numLeaves += m_sons[i].getNumLeaves(); if (m_sons[i].m_isLeaf && !m_sons[i].hasModels()) numEmptyLeaves++; } if (numEmptyLeaves > 1) { numLeaves -= (numEmptyLeaves - 1); } } else { numLeaves = 1; } return numLeaves; } /** *Updates the numIncorrectModel field for all nodes. This is needed for calculating the alpha-values. */ public void modelErrors() throws Exception{ Evaluation eval = new Evaluation(m_train); if (!m_isLeaf) { m_isLeaf = true; eval.evaluateModel(this, m_train); m_isLeaf = false; m_numIncorrectModel = eval.incorrect(); for (int i = 0; i < m_sons.length; i++) m_sons[i].modelErrors(); } else { eval.evaluateModel(this, m_train); m_numIncorrectModel = eval.incorrect(); } } /** *Updates the numIncorrectTree field for all nodes. This is needed for calculating the alpha-values. */ public void treeErrors(){ if (m_isLeaf) { m_numIncorrectTree = m_numIncorrectModel; } else { m_numIncorrectTree = 0; for (int i = 0; i < m_sons.length; i++) { m_sons[i].treeErrors(); m_numIncorrectTree += m_sons[i].m_numIncorrectTree; } } } /** *Updates the alpha field for all nodes. */ public void calculateAlphas() throws Exception { if (!m_isLeaf) { double errorDiff = m_numIncorrectModel - m_numIncorrectTree; if (errorDiff <= 0) { //split increases training error (should not normally happen). //prune it instantly. m_isLeaf = true; m_sons = null; m_alpha = Double.MAX_VALUE; } else { //compute alpha errorDiff /= m_totalInstanceWeight; m_alpha = errorDiff / (double)(getNumLeaves() - 1); for (int i = 0; i < m_sons.length; i++) m_sons[i].calculateAlphas(); } } else { //alpha = infinite for leaves (do not want to prune) m_alpha = Double.MAX_VALUE; } } /** * Merges two arrays of regression functions into one * @param a1 one array * @param a2 the other array * * @return an array that contains all entries from both input arrays */ protected SimpleLinearRegression[][] mergeArrays(SimpleLinearRegression[][] a1, SimpleLinearRegression[][] a2){ int numModels1 = a1[0].length; int numModels2 = a2[0].length; SimpleLinearRegression[][] result = new SimpleLinearRegression[m_numClasses][numModels1 + numModels2]; for (int i = 0; i < m_numClasses; i++) for (int j = 0; j < numModels1; j++) { result[i][j] = a1[i][j]; } for (int i = 0; i < m_numClasses; i++) for (int j = 0; j < numModels2; j++) result[i][j+numModels1] = a2[i][j]; return result; } /** * Return a list of all inner nodes in the tree * @return the list of nodes */ public Vector getNodes(){ Vector nodeList = new Vector(); getNodes(nodeList); return nodeList; } /** * Fills a list with all inner nodes in the tree * * @param nodeList the list to be filled */ public void getNodes(Vector nodeList) { if (!m_isLeaf) { nodeList.add(this); for (int i = 0; i < m_sons.length; i++) m_sons[i].getNodes(nodeList); } } /** * Returns a numeric version of a set of instances. * All nominal attributes are replaced by binary ones, and the class variable is replaced * by a pseudo-class variable that is used by LogitBoost. */ protected Instances getNumericData(Instances train) throws Exception{ Instances filteredData = new Instances(train); m_nominalToBinary = new NominalToBinary(); m_nominalToBinary.setInputFormat(filteredData); filteredData = Filter.useFilter(filteredData, m_nominalToBinary); return super.getNumericData(filteredData); } /** * Computes the F-values of LogitBoost for an instance from the current logistic model at the node * Note that this also takes into account the (partial) logistic model fit at higher levels in * the tree. * @param instance the instance * @return the array of F-values */ protected double[] getFs(Instance instance) throws Exception{ double [] pred = new double [m_numClasses]; //Need to take into account partial model fit at higher levels in the tree (m_higherRegressions) //and the part of the model fit at this node (m_regressions). //Fs from m_regressions (use method of LogisticBase) double [] instanceFs = super.getFs(instance); //Fs from m_higherRegressions for (int i = 0; i < m_numHigherRegressions; i++) { double predSum = 0; for (int j = 0; j < m_numClasses; j++) { pred[j] = m_higherRegressions[j][i].classifyInstance(instance); predSum += pred[j]; } predSum /= m_numClasses; for (int j = 0; j < m_numClasses; j++) { instanceFs[j] += (pred[j] - predSum) * (m_numClasses - 1) / m_numClasses; } } return instanceFs; } /** *Returns true if the logistic regression model at this node has changed compared to the *one at the parent node. *@return whether it has changed */ public boolean hasModels() { return (m_numRegressions > 0); } /** * Returns the class probabilities for an instance according to the logistic model at the node. * @param instance the instance * @return the array of probabilities */ public double[] modelDistributionForInstance(Instance instance) throws Exception { //make copy and convert nominal attributes instance = (Instance)instance.copy(); m_nominalToBinary.input(instance); instance = m_nominalToBinary.output(); //saet numeric pseudo-class instance.setDataset(m_numericDataHeader); return probs(getFs(instance)); } /** * Returns the class probabilities for an instance given by the logistic model tree. * @param instance the instance * @return the array of probabilities */ public double[] distributionForInstance(Instance instance) throws Exception { double[] probs; if (m_isLeaf) { //leaf: use logistic model probs = modelDistributionForInstance(instance); } else { //sort into appropiate child node int branch = m_localModel.whichSubset(instance); probs = m_sons[branch].distributionForInstance(instance); } return probs; } /** * Returns the number of leaves (normal count). * @return the number of leaves */ public int numLeaves() { if (m_isLeaf) return 1; int numLeaves = 0; for (int i = 0; i < m_sons.length; i++) numLeaves += m_sons[i].numLeaves(); return numLeaves; } /** * Returns the number of nodes. * @return the number of nodes */ public int numNodes() { if (m_isLeaf) return 1; int numNodes = 1; for (int i = 0; i < m_sons.length; i++) numNodes += m_sons[i].numNodes(); return numNodes; } /** * Returns a description of the logistic model tree (tree structure and logistic models) * @return describing string */ public String toString(){ //assign numbers to logistic regression functions at leaves assignLeafModelNumbers(0); try{ StringBuffer text = new StringBuffer(); if (m_isLeaf) { text.append(": "); text.append("LM_"+m_leafModelNum+":"+getModelParameters()); } else { dumpTree(0,text); } text.append("\n\nNumber of Leaves : \t"+numLeaves()+"\n"); text.append("\nSize of the Tree : \t"+numNodes()+"\n"); //This prints logistic models after the tree, comment out if only tree should be printed text.append(modelsToString()); return text.toString(); } catch (Exception e){ return "Can't print logistic model tree"; } } /** * Returns a string describing the number of LogitBoost iterations performed at this node, the total number * of LogitBoost iterations performed (including iterations at higher levels in the tree), and the number * of training examples at this node. * @return the describing string */ public String getModelParameters(){ StringBuffer text = new StringBuffer(); int numModels = m_numRegressions+m_numHigherRegressions; text.append(m_numRegressions+"/"+numModels+" ("+m_numInstances+")"); return text.toString(); } /** * Help method for printing tree structure. * * @throws Exception if something goes wrong */ protected void dumpTree(int depth,StringBuffer text) throws Exception { for (int i = 0; i < m_sons.length; i++) { text.append("\n"); for (int j = 0; j < depth; j++) text.append("| "); text.append(m_localModel.leftSide(m_train)); text.append(m_localModel.rightSide(i, m_train)); if (m_sons[i].m_isLeaf) { text.append(": "); text.append("LM_"+m_sons[i].m_leafModelNum+":"+m_sons[i].getModelParameters()); }else m_sons[i].dumpTree(depth+1,text); } } /** * Assigns unique IDs to all nodes in the tree */ public int assignIDs(int lastID) { int currLastID = lastID + 1; m_id = currLastID; if (m_sons != null) { for (int i = 0; i < m_sons.length; i++) { currLastID = m_sons[i].assignIDs(currLastID); } } return currLastID; } /** * Assigns numbers to the logistic regression models at the leaves of the tree */ public int assignLeafModelNumbers(int leafCounter) { if (!m_isLeaf) { m_leafModelNum = 0; for (int i = 0; i < m_sons.length; i++){ leafCounter = m_sons[i].assignLeafModelNumbers(leafCounter); } } else { leafCounter++; m_leafModelNum = leafCounter; } return leafCounter; } /** * Returns an array containing the coefficients of the logistic regression function at this node. * @return the array of coefficients, first dimension is the class, second the attribute. */ protected double[][] getCoefficients(){ //Need to take into account partial model fit at higher levels in the tree (m_higherRegressions) //and the part of the model fit at this node (m_regressions). //get coefficients from m_regressions: use method of LogisticBase double[][] coefficients = super.getCoefficients(); //get coefficients from m_higherRegressions: for (int j = 0; j < m_numClasses; j++) { for (int i = 0; i < m_numHigherRegressions; i++) { double slope = m_higherRegressions[j][i].getSlope(); double intercept = m_higherRegressions[j][i].getIntercept(); int attribute = m_higherRegressions[j][i].getAttributeIndex(); coefficients[j][0] += intercept; coefficients[j][attribute + 1] += slope; } } return coefficients; } /** * Returns a string describing the logistic regression function at the node. */ public String modelsToString(){ StringBuffer text = new StringBuffer(); if (m_isLeaf) { text.append("LM_"+m_leafModelNum+":"+super.toString()); } else { for (int i = 0; i < m_sons.length; i++) { text.append("\n"+m_sons[i].modelsToString()); } } return text.toString(); } /** * Returns graph describing the tree. * * @throws Exception if something goes wrong */ public String graph() throws Exception { StringBuffer text = new StringBuffer(); assignIDs(-1); assignLeafModelNumbers(0); text.append("digraph LMTree {\n"); if (m_isLeaf) { text.append("N" + m_id + " [label=\"LM_"+m_leafModelNum+":"+getModelParameters()+"\" " + "shape=box style=filled"); text.append("]\n"); }else { text.append("N" + m_id + " [label=\"" + m_localModel.leftSide(m_train) + "\" "); text.append("]\n"); graphTree(text); } return text.toString() +"}\n"; } /** * Helper function for graph description of tree * * @throws Exception if something goes wrong */ private void graphTree(StringBuffer text) throws Exception { for (int i = 0; i < m_sons.length; i++) { text.append("N" + m_id + "->" + "N" + m_sons[i].m_id + " [label=\"" + m_localModel.rightSide(i,m_train).trim() + "\"]\n"); if (m_sons[i].m_isLeaf) { text.append("N" +m_sons[i].m_id + " [label=\"LM_"+m_sons[i].m_leafModelNum+":"+ m_sons[i].getModelParameters()+"\" " + "shape=box style=filled"); text.append("]\n"); } else { text.append("N" + m_sons[i].m_id + " [label=\""+m_sons[i].m_localModel.leftSide(m_train) + "\" "); text.append("]\n"); m_sons[i].graphTree(text); } } } /** * Cleanup in order to save memory. */ public void cleanup() { super.cleanup(); if (!m_isLeaf) { for (int i = 0; i < m_sons.length; i++) m_sons[i].cleanup(); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -