📄 reptree.java
字号:
if (inst.isMissing(m_Attribute)) { // Distribute instance for (int i = 0; i < m_Successors.length; i++) { if (m_Prop[i] > 0) { m_Successors[i].insertHoldOutInstance(inst, weight * m_Prop[i], this); } } } else { if (m_Info.attribute(m_Attribute).isNominal()) { // Treat nominal attributes m_Successors[(int)inst.value(m_Attribute)]. insertHoldOutInstance(inst, weight, this); } else { // Treat numeric attributes if (inst.value(m_Attribute) < m_SplitPoint) { m_Successors[0].insertHoldOutInstance(inst, weight, this); } else { m_Successors[1].insertHoldOutInstance(inst, weight, this); } } } } } } /** The Tree object */ protected Tree m_Tree = null; /** Number of folds for reduced error pruning. */ protected int m_NumFolds = 3; /** Seed for random data shuffling. */ protected int m_Seed = 1; /** Don't prune */ protected boolean m_NoPruning = false; /** The minimum number of instances per leaf. */ protected double m_MinNum = 2; /** The minimum proportion of the total variance (over all the data) required for split. */ protected double m_MinVarianceProp = 1e-3; /** Upper bound on the tree depth */ protected int m_MaxDepth = -1; /** * Get the value of NoPruning. * * @return Value of NoPruning. */ public boolean getNoPruning() { return m_NoPruning; } /** * Set the value of NoPruning. * * @param newNoPruning Value to assign to NoPruning. */ public void setNoPruning(boolean newNoPruning) { m_NoPruning = newNoPruning; } /** * Get the value of MinNum. * * @return Value of MinNum. */ public double getMinNum() { return m_MinNum; } /** * Set the value of MinNum. * * @param newMinNum Value to assign to MinNum. */ public void setMinNum(double newMinNum) { m_MinNum = newMinNum; } /** * Get the value of MinVarianceProp. * * @return Value of MinVarianceProp. */ public double getMinVarianceProp() { return m_MinVarianceProp; } /** * Set the value of MinVarianceProp. * * @param newMinVarianceProp Value to assign to MinVarianceProp. */ public void setMinVarianceProp(double newMinVarianceProp) { m_MinVarianceProp = newMinVarianceProp; } /** * Get the value of Seed. * * @return Value of Seed. */ public int getSeed() { return m_Seed; } /** * Set the value of Seed. * * @param newSeed Value to assign to Seed. */ public void setSeed(int newSeed) { m_Seed = newSeed; } /** * Get the value of NumFolds. * * @return Value of NumFolds. */ public int getNumFolds() { return m_NumFolds; } /** * Set the value of NumFolds. * * @param newNumFolds Value to assign to NumFolds. */ public void setNumFolds(int newNumFolds) { m_NumFolds = newNumFolds; } /** * Get the value of MaxDepth. * * @return Value of MaxDepth. */ public int getMaxDepth() { return m_MaxDepth; } /** * Set the value of MaxDepth. * * @param newMaxDepth Value to assign to MaxDepth. */ public void setMaxDepth(int newMaxDepth) { m_MaxDepth = newMaxDepth; } /** * Lists the command-line options for this classifier. */ public Enumeration listOptions() { Vector newVector = new Vector(5); newVector. addElement(new Option("\tSet minimum number of instances per leaf " + "(default 2).", "M", 1, "-M <minimum number of instances>")); newVector. addElement(new Option("\tSet minimum numeric class variance proportion " + "of train variance for split (default 1e-3).", "V", 1, "-V <minimum variance for split>")); newVector. addElement(new Option("\tNumber of folds for reduced error pruning " + "(default 3).", "N", 1, "-N <number of folds>")); newVector. addElement(new Option("\tSeed for random data shuffling (default 1).", "S", 1, "-S <seed>")); newVector. addElement(new Option("\tNo pruning.", "P", 0, "-P")); newVector. addElement(new Option("\tMaximum tree depth (default -1, no maximum)", "D", 1, "-D")); return newVector.elements(); } /** * Gets options from this classifier. */ public String[] getOptions() { String [] options = new String [12]; int current = 0; options[current++] = "-M"; options[current++] = "" + getMinNum(); options[current++] = "-V"; options[current++] = "" + getMinVarianceProp(); options[current++] = "-N"; options[current++] = "" + getNumFolds(); options[current++] = "-S"; options[current++] = "" + getSeed(); options[current++] = "-D"; options[current++] = "" + getMaxDepth(); if (getNoPruning()) { options[current++] = "-P"; } while (current < options.length) { options[current++] = ""; } return options; } /** * Parses a given list of options. * @param options the list of options as an array of strings * @exception Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { String minNumString = Utils.getOption('M', options); if (minNumString.length() != 0) { m_MinNum = (double)Integer.parseInt(minNumString); } else { m_MinNum = 2; } String minVarString = Utils.getOption('V', options); if (minVarString.length() != 0) { m_MinVarianceProp = Double.parseDouble(minVarString); } else { m_MinVarianceProp = 1e-3; } String numFoldsString = Utils.getOption('N', options); if (numFoldsString.length() != 0) { m_NumFolds = Integer.parseInt(numFoldsString); } else { m_NumFolds = 3; } String seedString = Utils.getOption('S', options); if (seedString.length() != 0) { m_Seed = Integer.parseInt(seedString); } else { m_Seed = 1; } m_NoPruning = Utils.getFlag('P', options); String depthString = Utils.getOption('D', options); if (depthString.length() != 0) { m_MaxDepth = Integer.parseInt(depthString); } else { m_MaxDepth = -1; } Utils.checkForRemainingOptions(options); } /** * Computes size of the tree. */ public int numNodes() { return m_Tree.numNodes(); } /** * Returns an enumeration of the additional measure names. * * @return an enumeration of the measure names */ public Enumeration enumerateMeasures() { Vector newVector = new Vector(1); newVector.addElement("measureTreeSize"); return newVector.elements(); } /** * Returns the value of the named measure. * * @param measureName the name of the measure to query for its value * @return the value of the named measure * @exception IllegalArgumentException if the named measure is not supported */ public double getMeasure(String additionalMeasureName) { if (additionalMeasureName.equals("measureTreeSize")) { return (double) numNodes(); } else {throw new IllegalArgumentException(additionalMeasureName + " not supported (REPTree)"); } } /** * Builds classifier. */ public void buildClassifier(Instances data) throws Exception { Random random = new Random(m_Seed); // Check for non-nominal classes if (!data.classAttribute().isNominal() && !data.classAttribute().isNumeric()) { throw new UnsupportedClassTypeException("REPTree: nominal or numeric class, please."); } // Delete instances with missing class data = new Instances(data); data.deleteWithMissingClass(); // Check for empty datasets if (data.numInstances() == 0) { throw new Exception("REPTree: zero training instances or all instances " + "have missing class!"); } // Randomize and stratify data.randomize(random); if (data.classAttribute().isNominal()) { data.stratify(m_NumFolds); } // Split data into training and pruning set Instances train = null; Instances prune = null; if (!m_NoPruning) { train = data.trainCV(m_NumFolds, 0); prune = data.testCV(m_NumFolds, 0); } else { train = data; } // Create array of sorted indices and weights int[][] sortedIndices = new int[train.numAttributes()][0]; double[][] weights = new double[train.numAttributes()][0]; double[] vals = new double[train.numInstances()]; for (int j = 0; j < train.numAttributes(); j++) { if (j != train.classIndex()) { weights[j] = new double[train.numInstances()]; if (train.attribute(j).isNominal()) { // Handling nominal attributes. Putting indices of // instances with missing values at the end. sortedIndices[j] = new int[train.numInstances()]; int count = 0; for (int i = 0; i < train.numInstances(); i++) { Instance inst = train.instance(i); if (!inst.isMissing(j)) { sortedIndices[j][count] = i; weights[j][count] = inst.weight(); count++; } } for (int i = 0; i < train.numInstances(); i++) { Instance inst = train.instance(i); if (inst.isMissing(j)) { sortedIndices[j][count] = i; weights[j][count] = inst.weight(); count++; } } } else { // Sorted indices are computed for numeric attributes for (int i = 0; i < train.numInstances(); i++) { Instance inst = train.instance(i); vals[i] = inst.value(j); } sortedIndices[j] = Utils.sort(vals); for (int i = 0; i < train.numInstances(); i++) { weights[j][i] = train.instance(sortedIndices[j][i]).weight(); } } } } // Compute initial class counts double[] classProbs = new double[train.numClasses()]; double totalWeight = 0, totalSumSquared = 0; for (int i = 0; i < train.numInstances(); i++) { Instance inst = train.instance(i); if (data.classAttribute().isNominal()) { classProbs[(int)inst.classValue()] += inst.weight(); totalWeight += inst.weight(); } else { classProbs[0] += inst.classValue() * inst.weight(); totalSumSquared += inst.classValue() * inst.classValue() * inst.weight(); totalWeight += inst.weight(); } } m_Tree = new Tree(); double trainVariance = 0; if (data.classAttribute().isNumeric()) { trainVariance = m_Tree. singleVariance(classProbs[0], totalSumSquared, totalWeight) / totalWeight; classProbs[0] /= totalWeight; } // Build tree m_Tree.buildTree(sortedIndices, weights, train, totalWeight, classProbs, new Instances(train, 0), m_MinNum, m_MinVarianceProp * trainVariance, 0, m_MaxDepth); // Insert pruning data and perform reduced error pruning if (!m_NoPruning) { m_Tree.insertHoldOutSet(prune); m_Tree.reducedErrorPrune(); } } /** * Computes class distribution of an instance using the tree. */ public double[] distributionForInstance(Instance instance) throws Exception { return m_Tree.distributionForInstance(instance); } /** * Outputs the decision tree as a graph */ public String graph() throws Exception { StringBuffer resultBuff = new StringBuffer(); m_Tree.toGraph(resultBuff, 0, null); String result = "digraph Tree {\n" + "edge [style=bold]\n" + resultBuff.toString() + "\n}\n"; return result; } /** * Outputs the decision tree. */ public String toString() { if ((m_Tree == null)) { return "REPTree: No model built yet."; } return "\nREPTree\n============\n" + m_Tree.toString(0, null) + "\n" + "\nSize of the tree : " + numNodes(); } /** * Main method for this class. */ public static void main(String[] argv) { try { System.out.println(Evaluation.evaluateModel(new REPTree(), argv)); } catch (Exception e) { System.err.println(e.getMessage()); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -