📄 decisiontree.java
字号:
// DecisionTree.java
import java.util.StringTokenizer;
import java.util.Vector;
import java.io.*;
// This class is the controller. All its methods are static.
// The name is a bit misleading -- it doesn't represent a decision tree.
// An object of class DT is an actual decision tree.
public class DecisionTree
{
public static String trainingFile = null;
public static String testFile = null;
public static String attributeNameFile = null;
public static int classIndex = -1;
public static int maxDepth = -1; // -1 means no depth specified or enforced
public static int[] skipIndex = new int[100];
public static int skipCount = 0;
public static BufferedReader trainingIn;
public static BufferedReader testIn;
public static BufferedReader attrnamesIn;
public static TrainingData trainingData;
public static Vector<String> attributeName;
public static void main(String[] arg)
{
// Parse and display command line arguments
checkArguments(arg);
System.out.println(" training file: " + trainingFile);
System.out.println(" test file: " + testFile);
System.out.println("attribute names: " + attributeNameFile);
System.out.println("class attribute: " + classIndex);
System.out.println(" max depth: " +
(maxDepth == -1 ? "none" : ("" + maxDepth)));
System.out.print ("skip attributes: " );
for (int i=0; i<skipCount; i++)
System.out.print(skipIndex[i] + " ");
System.out.println("");
if (classIndex == -1)
throw new IllegalArgumentException("Class index not specified on command line");
// Open and read the .train and .attrnames files.
openFiles();
readNames();
readTrainingData();
// Create a decision tree.
DT dt = new DT(trainingData);
// Print it out on System.out
dt.printOut();
// Read in testing data and classify it with dt.
useTreeToClassify(dt);
}
// readNames reads and parses the .attrnames file which contains the
// names of the attributes. The number of names read is used to
// know the number of attributes. Names are separated by space or
// new-line. Don't put a space in an attribute name! --The system
// will think you are naming two different attributes.
private static void readNames()
{
attributeName = new java.util.Vector<String>();
String s;
initTokenizer();
while ( (s = getToken(attrnamesIn)) != null)
attributeName.addElement(s);
System.out.println("Read " + attributeName.size() + " attribute names.");
}
// Initialize trainingData, which is an object of type TrainingData.
// The .train file is read in and parsed.
// Each element in trainingData holds one row of information
// from the training file. That element is an array of Strings.
// The class attribute and any ignored attributes are included in
// trainingData.
private static void readTrainingData()
{
int columns = attributeName.size();
trainingData = new TrainingData(columns);
String s;
initTokenizer();
while ( (s = getToken(trainingIn)) != null) // something on next line?
{
String[] sa = new String[columns];
sa[0] = new String(s);
for (int i=1; i<columns; i++)
{
sa[i] = getToken(trainingIn);
}
trainingData.add(sa);
}
// Display training records
//System.out.println(trainingData);
}
// You may need to add more to this method.
// This methods reads the testIn file and classifies the
// examples using the decision tree of which the dt
// parameter is the root.
private static void useTreeToClassify(DT dt)
{
int columns = attributeName.size();
int correct=0, incorrect=0, classifiedDefault=0;
String s;
initTokenizer();
// read in attributes from the test dataset
while ( (s = getToken(testIn)) != null) // something on next line?
{
// Turn each set of n (n=columns) attributes into a String[]
String[] sa = new String[columns];
sa[0] = new String(s);
for (int i=1; i<columns; i++)
{
sa[i] = getToken(testIn);
}
// now sa is an array of the input record's attributes
String decisionTreesClassification = dt.classify(sa);
// See if the values in sa weren't found in the tree,
// in which case a default is guessed.
if (decisionTreesClassification.startsWith("*"))
{
++classifiedDefault;
// strip off the "*" at the beginning
decisionTreesClassification =
decisionTreesClassification.substring(1);
}
if (decisionTreesClassification.equals(sa[classIndex]))
++correct;
else
++incorrect;
}
System.out.println("\n% correct: " +
((double)(correct * 100) / (double)(correct+incorrect)) +
" (" + correct + " classified correctly and " +
incorrect + " incorrectly)");
System.out.println(classifiedDefault + " records received the " +
"default classification");
}
// This function opens the three input files.
private static void openFiles()
{
try
{
trainingIn = new BufferedReader(new FileReader(trainingFile));
}
catch (FileNotFoundException e)
{
throw new IllegalArgumentException(
"File " + trainingFile + " not found");
}
try
{
testIn = new BufferedReader(new FileReader(testFile));
}
catch (FileNotFoundException e)
{
throw new IllegalArgumentException(
"File " + testFile + " not found");
}
try
{
attrnamesIn = new BufferedReader(new FileReader(attributeNameFile));
}
catch (FileNotFoundException e)
{
throw new IllegalArgumentException(
"File " + attributeNameFile + " not found");
}
}
// checkArguments parses the command line.
// Flags begin with a dash and are usually followed by an argument.
// Use -file to specify the three input files, which have the same
// name plus .train, .attrnames, and .test.
// Use -class to specify which attribute is the class to learn
// Use -skip to specify attributes to skip when building the decision tree.
// Use -depth to specify maximum tree depth
private static void checkArguments(String[] arg)
{
for (int i=0; i<arg.length; i++)
{
String a = arg[i];
if (a.startsWith("-f")) // next arg is name file prefix
{
if (i+1 >= arg.length)
throw new IllegalArgumentException
("-file not followed by file name");
if (arg[i+1].startsWith("-"))
throw new IllegalArgumentException
("-file not followed by file name");
trainingFile = arg[i+1] + ".train";
testFile = arg[i+1] + ".test";
attributeNameFile = arg[i+1] + ".attrnames";
}
else if (a.startsWith("-c")) // next arg is the class attribute index
{
if (i+1 >= arg.length)
throw new IllegalArgumentException
("-class not followed by integer");
if (arg[i+1].startsWith("-"))
throw new IllegalArgumentException
("-class not followed by non-negative integer");
try
{
classIndex = Integer.decode(arg[i+1]).intValue();
}
catch (NumberFormatException e)
{
throw new IllegalArgumentException
("-class not followed by integer");
}
}
else if (a.startsWith("-d")) // next arg is the max tree depth
{
if (i+1 >= arg.length)
throw new IllegalArgumentException
("-depth not followed by integer");
if (arg[i+1].startsWith("-"))
throw new IllegalArgumentException
("-depth not followed by non-negative integer");
try
{
maxDepth = Integer.decode(arg[i+1]).intValue();
}
catch (NumberFormatException e)
{
throw new IllegalArgumentException
("-depth not followed by integer");
}
}
else if (a.startsWith("-s")) // next args are entries to skip
{
if (i+1 >= arg.length)
throw new IllegalArgumentException
("-skip not followed by integer");
if (arg[i+1].startsWith("-"))
throw new IllegalArgumentException
("-skip not followed by non-negative integer");
int j=0;
do
{
j++;
try
{
skipIndex[skipCount++] = Integer.decode(arg[i+j]).intValue();
}
catch (NumberFormatException e)
{
throw new IllegalArgumentException
("-sip not followed by integer");
}
}
while ( (i+j+1 < arg.length) &&
(!arg[i+j+1].startsWith("-")) );
}
else if (a.startsWith("-")) // some unknown switch
{
throw new IllegalArgumentException
("unrecognized switch: " + a);
}
else // not a switch
; // ignore
}
}
public static boolean shouldBeSkipped(int c)
{
if (c == classIndex)
return true;
for (int i=0; i<skipCount; i++)
if (c == skipIndex[i])
return true;
return false;
}
// line and st are member variables so that their values are
// retained from one call of getToken() to the next.
private static final String TOKENS = " ,\n\r\t";
private static String line;
private static StringTokenizer st;
private static void initTokenizer()
{
st = null;
}
// getToken reads lines from the input file, as necessary, and uses
// A StringTokenizer to break the characters into tokens.
// White space (space, tab, newline, carriage return) and commas are ignored.
private static String getToken(BufferedReader in)
{
while (st == null || // no current StringTokenizer
!st.hasMoreTokens()) // no more tokens on this line
{
try
{
line = in.readLine();
if (line == null)
{
return null; // end of file
}
}
catch (IOException e)
{
return null;
}
st = new StringTokenizer(line, TOKENS, false);
}
return st.nextToken();
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -