📄 reptree.java
字号:
* @param newSeed Value to assign to Seed.
*/
public void setSeed(int newSeed) {
m_Seed = newSeed;
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String numFoldsTipText() {
return "Determines the amount of data used for pruning. One fold is used for "
+ "pruning, the rest for growing the rules.";
}
/**
* Get the value of NumFolds.
*
* @return Value of NumFolds.
*/
public int getNumFolds() {
return m_NumFolds;
}
/**
* Set the value of NumFolds.
*
* @param newNumFolds Value to assign to NumFolds.
*/
public void setNumFolds(int newNumFolds) {
m_NumFolds = newNumFolds;
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String maxDepthTipText() {
return "The maximum tree depth (-1 for no restriction).";
}
/**
* Get the value of MaxDepth.
*
* @return Value of MaxDepth.
*/
public int getMaxDepth() {
return m_MaxDepth;
}
/**
* Set the value of MaxDepth.
*
* @param newMaxDepth Value to assign to MaxDepth.
*/
public void setMaxDepth(int newMaxDepth) {
m_MaxDepth = newMaxDepth;
}
/**
* Lists the command-line options for this classifier.
*/
public Enumeration listOptions() {
Vector newVector = new Vector(5);
newVector.
addElement(new Option("\tSet minimum number of instances per leaf " +
"(default 2).",
"M", 1, "-M <minimum number of instances>"));
newVector.
addElement(new Option("\tSet minimum numeric class variance proportion\n" +
"\tof train variance for split (default 1e-3).",
"V", 1, "-V <minimum variance for split>"));
newVector.
addElement(new Option("\tNumber of folds for reduced error pruning " +
"(default 3).",
"N", 1, "-N <number of folds>"));
newVector.
addElement(new Option("\tSeed for random data shuffling (default 1).",
"S", 1, "-S <seed>"));
newVector.
addElement(new Option("\tNo pruning.",
"P", 0, "-P"));
newVector.
addElement(new Option("\tMaximum tree depth (default -1, no maximum)",
"L", 1, "-L"));
return newVector.elements();
}
/**
* Gets options from this classifier.
*/
public String[] getOptions() {
String [] options = new String [12];
int current = 0;
options[current++] = "-M";
options[current++] = "" + (int)getMinNum();
options[current++] = "-V";
options[current++] = "" + getMinVarianceProp();
options[current++] = "-N";
options[current++] = "" + getNumFolds();
options[current++] = "-S";
options[current++] = "" + getSeed();
options[current++] = "-L";
options[current++] = "" + getMaxDepth();
if (getNoPruning()) {
options[current++] = "-P";
}
while (current < options.length) {
options[current++] = "";
}
return options;
}
/**
* Parses a given list of options.
* @param options the list of options as an array of strings
* @exception Exception if an option is not supported
*/
public void setOptions(String[] options) throws Exception {
String minNumString = Utils.getOption('M', options);
if (minNumString.length() != 0) {
m_MinNum = (double)Integer.parseInt(minNumString);
} else {
m_MinNum = 2;
}
String minVarString = Utils.getOption('V', options);
if (minVarString.length() != 0) {
m_MinVarianceProp = Double.parseDouble(minVarString);
} else {
m_MinVarianceProp = 1e-3;
}
String numFoldsString = Utils.getOption('N', options);
if (numFoldsString.length() != 0) {
m_NumFolds = Integer.parseInt(numFoldsString);
} else {
m_NumFolds = 3;
}
String seedString = Utils.getOption('S', options);
if (seedString.length() != 0) {
m_Seed = Integer.parseInt(seedString);
} else {
m_Seed = 1;
}
m_NoPruning = Utils.getFlag('P', options);
String depthString = Utils.getOption('L', options);
if (depthString.length() != 0) {
m_MaxDepth = Integer.parseInt(depthString);
} else {
m_MaxDepth = -1;
}
Utils.checkForRemainingOptions(options);
}
/**
* Computes size of the tree.
*/
public int numNodes() {
return m_Tree.numNodes();
}
/**
* Returns an enumeration of the additional measure names.
*
* @return an enumeration of the measure names
*/
public Enumeration enumerateMeasures() {
Vector newVector = new Vector(1);
newVector.addElement("measureTreeSize");
return newVector.elements();
}
/**
* Returns the value of the named measure.
*
* @param measureName the name of the measure to query for its value
* @return the value of the named measure
* @exception IllegalArgumentException if the named measure is not supported
*/
public double getMeasure(String additionalMeasureName) {
if (additionalMeasureName.equalsIgnoreCase("measureTreeSize")) {
return (double) numNodes();
}
else {throw new IllegalArgumentException(additionalMeasureName
+ " not supported (REPTree)");
}
}
/**
* Builds classifier.
*/
public void buildClassifier(Instances data) throws Exception {
Random random = new Random(m_Seed);
// Check for non-nominal classes
if (!data.classAttribute().isNominal() && !data.classAttribute().isNumeric()) {
throw new UnsupportedClassTypeException("REPTree: nominal or numeric class!");
}
// Delete instances with missing class
data = new Instances(data);
data.deleteWithMissingClass();
// Check for empty datasets
if (data.numInstances() == 0) {
throw new IllegalArgumentException("REPTree: zero training instances or all " +
"instances have missing class!");
}
if (data.checkForStringAttributes()) {
throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");
}
m_zeroR = null;
if (data.numAttributes() == 1) {
m_zeroR = new ZeroR();
m_zeroR.buildClassifier(data);
return;
}
// Randomize and stratify
data.randomize(random);
if (data.classAttribute().isNominal()) {
data.stratify(m_NumFolds);
}
// Split data into training and pruning set
Instances train = null;
Instances prune = null;
if (!m_NoPruning) {
train = data.trainCV(m_NumFolds, 0, random);
prune = data.testCV(m_NumFolds, 0);
} else {
train = data;
}
// Create array of sorted indices and weights
int[][] sortedIndices = new int[train.numAttributes()][0];
double[][] weights = new double[train.numAttributes()][0];
double[] vals = new double[train.numInstances()];
for (int j = 0; j < train.numAttributes(); j++) {
if (j != train.classIndex()) {
weights[j] = new double[train.numInstances()];
if (train.attribute(j).isNominal()) {
// Handling nominal attributes. Putting indices of
// instances with missing values at the end.
sortedIndices[j] = new int[train.numInstances()];
int count = 0;
for (int i = 0; i < train.numInstances(); i++) {
Instance inst = train.instance(i);
if (!inst.isMissing(j)) {
sortedIndices[j][count] = i;
weights[j][count] = inst.weight();
count++;
}
}
for (int i = 0; i < train.numInstances(); i++) {
Instance inst = train.instance(i);
if (inst.isMissing(j)) {
sortedIndices[j][count] = i;
weights[j][count] = inst.weight();
count++;
}
}
} else {
// Sorted indices are computed for numeric attributes
for (int i = 0; i < train.numInstances(); i++) {
Instance inst = train.instance(i);
vals[i] = inst.value(j);
}
sortedIndices[j] = Utils.sort(vals);
for (int i = 0; i < train.numInstances(); i++) {
weights[j][i] = train.instance(sortedIndices[j][i]).weight();
}
}
}
}
// Compute initial class counts
double[] classProbs = new double[train.numClasses()];
double totalWeight = 0, totalSumSquared = 0;
for (int i = 0; i < train.numInstances(); i++) {
Instance inst = train.instance(i);
if (data.classAttribute().isNominal()) {
classProbs[(int)inst.classValue()] += inst.weight();
totalWeight += inst.weight();
} else {
classProbs[0] += inst.classValue() * inst.weight();
totalSumSquared += inst.classValue() * inst.classValue() * inst.weight();
totalWeight += inst.weight();
}
}
m_Tree = new Tree();
double trainVariance = 0;
if (data.classAttribute().isNumeric()) {
trainVariance = m_Tree.
singleVariance(classProbs[0], totalSumSquared, totalWeight) / totalWeight;
classProbs[0] /= totalWeight;
}
// Build tree
m_Tree.buildTree(sortedIndices, weights, train, totalWeight, classProbs,
new Instances(train, 0), m_MinNum, m_MinVarianceProp *
trainVariance, 0, m_MaxDepth);
// Insert pruning data and perform reduced error pruning
if (!m_NoPruning) {
m_Tree.insertHoldOutSet(prune);
m_Tree.reducedErrorPrune();
m_Tree.backfitHoldOutSet(prune);
}
}
/**
* Computes class distribution of an instance using the tree.
*/
public double[] distributionForInstance(Instance instance)
throws Exception {
if (m_zeroR != null) {
return m_zeroR.distributionForInstance(instance);
} else {
return m_Tree.distributionForInstance(instance);
}
}
/**
* For getting a unique ID when outputting the tree source
* (hashcode isn't guaranteed unique)
*/
private static long PRINTED_NODES = 0;
/**
* Gets the next unique node ID.
*
* @return the next unique node ID.
*/
protected static long nextID() {
return PRINTED_NODES ++;
}
protected static void resetID() {
PRINTED_NODES = 0;
}
/**
* Returns the tree as if-then statements.
*
* @return the tree as a Java if-then type statement
* @exception Exception if something goes wrong
*/
public String toSource(String className)
throws Exception {
if (m_Tree == null) {
throw new Exception("REPTree: No model built yet.");
}
StringBuffer [] source = m_Tree.toSource(className, m_Tree);
return
"class " + className + " {\n\n"
+" public static double classify(Object [] i)\n"
+" throws Exception {\n\n"
+" double p = Double.NaN;\n"
+ source[0] // Assignment code
+" return p;\n"
+" }\n"
+ source[1] // Support code
+"}\n";
}
/**
* Returns the type of graph this classifier
* represents.
* @return Drawable.TREE
*/
public int graphType() {
return Drawable.TREE;
}
/**
* Outputs the decision tree as a graph
*/
public String graph() throws Exception {
if (m_Tree == null) {
throw new Exception("REPTree: No model built yet.");
}
StringBuffer resultBuff = new StringBuffer();
m_Tree.toGraph(resultBuff, 0, null);
String result = "digraph Tree {\n" + "edge [style=bold]\n" + resultBuff.toString()
+ "\n}\n";
return result;
}
/**
* Outputs the decision tree.
*/
public String toString() {
if (m_zeroR != null) {
return "No attributes other than class. Using ZeroR.\n\n" + m_zeroR.toString();
}
if ((m_Tree == null)) {
return "REPTree: No model built yet.";
}
return
"\nREPTree\n============\n" + m_Tree.toString(0, null) + "\n" +
"\nSize of the tree : " + numNodes();
}
/**
* Main method for this class.
*/
public static void main(String[] argv) {
try {
System.out.println(Evaluation.evaluateModel(new REPTree(), argv));
} catch (Exception e) {
System.err.println(e.getMessage());
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -