📄 reptree.java
字号:
} /** * Get the value of NumFolds. * * @return Value of NumFolds. */ public int getNumFolds() { return m_NumFolds; } /** * Set the value of NumFolds. * * @param newNumFolds Value to assign to NumFolds. */ public void setNumFolds(int newNumFolds) { m_NumFolds = newNumFolds; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String maxDepthTipText() { return "The maximum tree depth (-1 for no restriction)."; } /** * Get the value of MaxDepth. * * @return Value of MaxDepth. */ public int getMaxDepth() { return m_MaxDepth; } /** * Set the value of MaxDepth. * * @param newMaxDepth Value to assign to MaxDepth. */ public void setMaxDepth(int newMaxDepth) { m_MaxDepth = newMaxDepth; } /** * Lists the command-line options for this classifier. * * @return an enumeration over all commandline options */ public Enumeration listOptions() { Vector newVector = new Vector(5); newVector. addElement(new Option("\tSet minimum number of instances per leaf " + "(default 2).", "M", 1, "-M <minimum number of instances>")); newVector. addElement(new Option("\tSet minimum numeric class variance proportion\n" + "\tof train variance for split (default 1e-3).", "V", 1, "-V <minimum variance for split>")); newVector. addElement(new Option("\tNumber of folds for reduced error pruning " + "(default 3).", "N", 1, "-N <number of folds>")); newVector. addElement(new Option("\tSeed for random data shuffling (default 1).", "S", 1, "-S <seed>")); newVector. addElement(new Option("\tNo pruning.", "P", 0, "-P")); newVector. addElement(new Option("\tMaximum tree depth (default -1, no maximum)", "L", 1, "-L")); return newVector.elements(); } /** * Gets options from this classifier. * * @return the options for the current setup */ public String[] getOptions() { String [] options = new String [12]; int current = 0; options[current++] = "-M"; options[current++] = "" + (int)getMinNum(); options[current++] = "-V"; options[current++] = "" + getMinVarianceProp(); options[current++] = "-N"; options[current++] = "" + getNumFolds(); options[current++] = "-S"; options[current++] = "" + getSeed(); options[current++] = "-L"; options[current++] = "" + getMaxDepth(); if (getNoPruning()) { options[current++] = "-P"; } while (current < options.length) { options[current++] = ""; } return options; } /** * Parses a given list of options. <p/> * <!-- options-start --> * Valid options are: <p/> * * <pre> -M <minimum number of instances> * Set minimum number of instances per leaf (default 2).</pre> * * <pre> -V <minimum variance for split> * Set minimum numeric class variance proportion * of train variance for split (default 1e-3).</pre> * * <pre> -N <number of folds> * Number of folds for reduced error pruning (default 3).</pre> * * <pre> -S <seed> * Seed for random data shuffling (default 1).</pre> * * <pre> -P * No pruning.</pre> * * <pre> -L * Maximum tree depth (default -1, no maximum)</pre> * <!-- options-end --> * * @param options the list of options as an array of strings * @throws Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { String minNumString = Utils.getOption('M', options); if (minNumString.length() != 0) { m_MinNum = (double)Integer.parseInt(minNumString); } else { m_MinNum = 2; } String minVarString = Utils.getOption('V', options); if (minVarString.length() != 0) { m_MinVarianceProp = Double.parseDouble(minVarString); } else { m_MinVarianceProp = 1e-3; } String numFoldsString = Utils.getOption('N', options); if (numFoldsString.length() != 0) { m_NumFolds = Integer.parseInt(numFoldsString); } else { m_NumFolds = 3; } String seedString = Utils.getOption('S', options); if (seedString.length() != 0) { m_Seed = Integer.parseInt(seedString); } else { m_Seed = 1; } m_NoPruning = Utils.getFlag('P', options); String depthString = Utils.getOption('L', options); if (depthString.length() != 0) { m_MaxDepth = Integer.parseInt(depthString); } else { m_MaxDepth = -1; } Utils.checkForRemainingOptions(options); } /** * Computes size of the tree. * * @return the number of nodes */ public int numNodes() { return m_Tree.numNodes(); } /** * Returns an enumeration of the additional measure names. * * @return an enumeration of the measure names */ public Enumeration enumerateMeasures() { Vector newVector = new Vector(1); newVector.addElement("measureTreeSize"); return newVector.elements(); } /** * Returns the value of the named measure. * * @param additionalMeasureName the name of the measure to query for its value * @return the value of the named measure * @throws IllegalArgumentException if the named measure is not supported */ public double getMeasure(String additionalMeasureName) { if (additionalMeasureName.equalsIgnoreCase("measureTreeSize")) { return (double) numNodes(); } else {throw new IllegalArgumentException(additionalMeasureName + " not supported (REPTree)"); } } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); // attributes result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.NUMERIC_ATTRIBUTES); result.enable(Capability.DATE_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); // class result.enable(Capability.NOMINAL_CLASS); result.enable(Capability.NUMERIC_CLASS); result.enable(Capability.DATE_CLASS); result.enable(Capability.MISSING_CLASS_VALUES); return result; } /** * Builds classifier. * * @param data the data to train with * @throws Exception if building fails */ public void buildClassifier(Instances data) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(data); // remove instances with missing class data = new Instances(data); data.deleteWithMissingClass(); Random random = new Random(m_Seed); m_zeroR = null; if (data.numAttributes() == 1) { m_zeroR = new ZeroR(); m_zeroR.buildClassifier(data); return; } // Randomize and stratify data.randomize(random); if (data.classAttribute().isNominal()) { data.stratify(m_NumFolds); } // Split data into training and pruning set Instances train = null; Instances prune = null; if (!m_NoPruning) { train = data.trainCV(m_NumFolds, 0, random); prune = data.testCV(m_NumFolds, 0); } else { train = data; } // Create array of sorted indices and weights int[][] sortedIndices = new int[train.numAttributes()][0]; double[][] weights = new double[train.numAttributes()][0]; double[] vals = new double[train.numInstances()]; for (int j = 0; j < train.numAttributes(); j++) { if (j != train.classIndex()) { weights[j] = new double[train.numInstances()]; if (train.attribute(j).isNominal()) { // Handling nominal attributes. Putting indices of // instances with missing values at the end. sortedIndices[j] = new int[train.numInstances()]; int count = 0; for (int i = 0; i < train.numInstances(); i++) { Instance inst = train.instance(i); if (!inst.isMissing(j)) { sortedIndices[j][count] = i; weights[j][count] = inst.weight(); count++; } } for (int i = 0; i < train.numInstances(); i++) { Instance inst = train.instance(i); if (inst.isMissing(j)) { sortedIndices[j][count] = i; weights[j][count] = inst.weight(); count++; } } } else { // Sorted indices are computed for numeric attributes for (int i = 0; i < train.numInstances(); i++) { Instance inst = train.instance(i); vals[i] = inst.value(j); } sortedIndices[j] = Utils.sort(vals); for (int i = 0; i < train.numInstances(); i++) { weights[j][i] = train.instance(sortedIndices[j][i]).weight(); } } } } // Compute initial class counts double[] classProbs = new double[train.numClasses()]; double totalWeight = 0, totalSumSquared = 0; for (int i = 0; i < train.numInstances(); i++) { Instance inst = train.instance(i); if (data.classAttribute().isNominal()) { classProbs[(int)inst.classValue()] += inst.weight(); totalWeight += inst.weight(); } else { classProbs[0] += inst.classValue() * inst.weight(); totalSumSquared += inst.classValue() * inst.classValue() * inst.weight(); totalWeight += inst.weight(); } } m_Tree = new Tree(); double trainVariance = 0; if (data.classAttribute().isNumeric()) { trainVariance = m_Tree. singleVariance(classProbs[0], totalSumSquared, totalWeight) / totalWeight; classProbs[0] /= totalWeight; } // Build tree m_Tree.buildTree(sortedIndices, weights, train, totalWeight, classProbs, new Instances(train, 0), m_MinNum, m_MinVarianceProp * trainVariance, 0, m_MaxDepth); // Insert pruning data and perform reduced error pruning if (!m_NoPruning) { m_Tree.insertHoldOutSet(prune); m_Tree.reducedErrorPrune(); m_Tree.backfitHoldOutSet(prune); } } /** * Computes class distribution of an instance using the tree. * * @param instance the instance to compute the distribution for * @return the computed class probabilities * @throws Exception if computation fails */ public double[] distributionForInstance(Instance instance) throws Exception { if (m_zeroR != null) { return m_zeroR.distributionForInstance(instance); } else { return m_Tree.distributionForInstance(instance); } } /** * For getting a unique ID when outputting the tree source * (hashcode isn't guaranteed unique) */ private static long PRINTED_NODES = 0; /** * Gets the next unique node ID. * * @return the next unique node ID. */ protected static long nextID() { return PRINTED_NODES ++; } /** * resets the counter for the nodes */ protected static void resetID() { PRINTED_NODES = 0; } /** * Returns the tree as if-then statements. * * @param className the name for the generated class * @return the tree as a Java if-then type statement * @throws Exception if something goes wrong */ public String toSource(String className) throws Exception { if (m_Tree == null) { throw new Exception("REPTree: No model built yet."); } StringBuffer [] source = m_Tree.toSource(className, m_Tree); return "class " + className + " {\n\n" +" public static double classify(Object [] i)\n" +" throws Exception {\n\n" +" double p = Double.NaN;\n" + source[0] // Assignment code +" return p;\n" +" }\n" + source[1] // Support code +"}\n"; } /** * Returns the type of graph this classifier * represents. * @return Drawable.TREE */ public int graphType() { return Drawable.TREE; } /** * Outputs the decision tree as a graph * * @return the tree as a graph * @throws Exception if generation fails */ public String graph() throws Exception { if (m_Tree == null) { throw new Exception("REPTree: No model built yet."); } StringBuffer resultBuff = new StringBuffer(); m_Tree.toGraph(resultBuff, 0, null); String result = "digraph Tree {\n" + "edge [style=bold]\n" + resultBuff.toString() + "\n}\n"; return result; } /** * Outputs the decision tree. * * @return a string representation of the classifier */ public String toString() { if (m_zeroR != null) { return "No attributes other than class. Using ZeroR.\n\n" + m_zeroR.toString(); } if ((m_Tree == null)) { return "REPTree: No model built yet."; } return "\nREPTree\n============\n" + m_Tree.toString(0, null) + "\n" + "\nSize of the tree : " + numNodes(); } /** * Main method for this class. * * @param argv the commandline options */ public static void main(String[] argv) { try { System.out.println(Evaluation.evaluateModel(new REPTree(), argv)); } catch (Exception e) { System.err.println(e.getMessage()); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -