⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 decisiontree.java

📁 This Decision Tree is developed in java language.
💻 JAVA
字号:
// DecisionTree.java

import java.util.StringTokenizer;
import java.util.Vector;
import java.io.*;

// This class is the controller.  All its methods are static.
// The name is a bit misleading -- it doesn't represent a decision tree.
// An object of class DT is an actual decision tree.
public class DecisionTree
{
    public static String trainingFile = null;
    public static String testFile = null;
    public static String attributeNameFile = null;
    public static int    classIndex = -1;
    public static int    maxDepth = -1;    // -1 means no depth specified or enforced
    public static int[]  skipIndex = new int[100];
    public static int    skipCount = 0;
	public static BufferedReader  trainingIn;
	public static BufferedReader  testIn;
	public static BufferedReader  attrnamesIn;
	public static TrainingData    trainingData;
	public static Vector<String>  attributeName;

    public static void main(String[] arg)
    {
		// Parse and display command line arguments
        checkArguments(arg);
        System.out.println("  training file: " + trainingFile);
        System.out.println("      test file: " + testFile);
        System.out.println("attribute names: " + attributeNameFile);
        System.out.println("class attribute: " + classIndex);
        System.out.println("      max depth: " +
                    (maxDepth == -1 ? "none" : ("" + maxDepth)));
        System.out.print  ("skip attributes: " );
        for (int i=0; i<skipCount; i++)
            System.out.print(skipIndex[i] + " ");
        System.out.println("");

		if (classIndex == -1)
			throw new IllegalArgumentException("Class index not specified on command line");

		// Open and read the .train and .attrnames files.
        openFiles();
        readNames();
        readTrainingData();

        // Create a decision tree.
		DT dt = new DT(trainingData);

		// Print it out on System.out
		dt.printOut();

		// Read in testing data and classify it with dt.
		useTreeToClassify(dt);
    }

	// readNames reads and parses the .attrnames file which contains the
	// names of the attributes.  The number of names read is used to
	// know the number of attributes.  Names are separated by space or
	// new-line.  Don't put a space in an attribute name!  --The system
	// will think you are naming two different attributes.
    private static void readNames()
    {
        attributeName = new java.util.Vector<String>();
        String s;
        initTokenizer();
        while ( (s = getToken(attrnamesIn)) != null)
            attributeName.addElement(s);
        System.out.println("Read " + attributeName.size() + " attribute names.");
    }

    // Initialize trainingData, which is an object of type TrainingData.
	// The .train file is read in and parsed.
    // Each element in trainingData holds one row of information
    // from the training file.  That element is an array of Strings.
    // The class attribute and any ignored attributes are included in
    // trainingData.
    private static void readTrainingData()
    {
        int columns = attributeName.size();
        trainingData = new TrainingData(columns);
        String s;
        initTokenizer();
        while ( (s = getToken(trainingIn)) != null)  // something on next line?
        {
            String[] sa = new String[columns];
            sa[0] = new String(s);
            for (int i=1; i<columns; i++)
            {
                sa[i] = getToken(trainingIn);
            }
            trainingData.add(sa);
        }

        // Display training records
		//System.out.println(trainingData);
    }

    // You may need to add more to this method.
    // This methods reads the testIn file and classifies the
    // examples using the decision tree of which the dt
    // parameter is the root.
	private static void useTreeToClassify(DT dt)
	{
        int columns = attributeName.size();
		int correct=0, incorrect=0, classifiedDefault=0;
        String s;
        initTokenizer();
        // read in attributes from the test dataset
        while ( (s = getToken(testIn)) != null)  // something on next line?
        {
            // Turn each set of n (n=columns) attributes into a String[]
            String[] sa = new String[columns];
            sa[0] = new String(s);
            for (int i=1; i<columns; i++)
            {
                sa[i] = getToken(testIn);
            }
			// now sa is an array of the input record's attributes
			String decisionTreesClassification = dt.classify(sa);

			// See if the values in sa weren't found in the tree,
			// in which case a default is guessed.
			if (decisionTreesClassification.startsWith("*"))
			{
			    ++classifiedDefault;
			    // strip off the "*" at the beginning
			    decisionTreesClassification =
			        decisionTreesClassification.substring(1);
			}

			if (decisionTreesClassification.equals(sa[classIndex]))
				++correct;
			else
				++incorrect;
        }
		System.out.println("\n% correct: " +
		        ((double)(correct * 100) / (double)(correct+incorrect)) +
		        "   (" + correct + " classified correctly and " +
				incorrect + " incorrectly)");
        System.out.println(classifiedDefault + " records received the " +
                "default classification");
	}

	// This function opens the three input files.
    private static void openFiles()
    {
		try
		{
			trainingIn = new BufferedReader(new FileReader(trainingFile));
		}
        catch (FileNotFoundException e)
        {
            throw new IllegalArgumentException(
                            "File " + trainingFile + " not found");
        }
		try
		{
			testIn = new BufferedReader(new FileReader(testFile));
		}
        catch (FileNotFoundException e)
        {
            throw new IllegalArgumentException(
                            "File " + testFile + " not found");
        }
		try
		{
			attrnamesIn = new BufferedReader(new FileReader(attributeNameFile));
		}
        catch (FileNotFoundException e)
        {
            throw new IllegalArgumentException(
                            "File " + attributeNameFile + " not found");
        }

    }

    // checkArguments parses the command line.
    // Flags begin with a dash and are usually followed by an argument.
    // Use -file to specify the three input files, which have the same
    // name plus .train, .attrnames, and .test.
    // Use -class to specify which attribute is the class to learn
    // Use -skip to specify attributes to skip when building the decision tree.
    // Use -depth to specify maximum tree depth
    private static void checkArguments(String[] arg)
    {
        for (int i=0; i<arg.length; i++)
        {
            String a = arg[i];
            if (a.startsWith("-f"))    // next arg is name file prefix
            {
                if (i+1 >= arg.length)
                    throw new IllegalArgumentException
                            ("-file not followed by file name");
                if (arg[i+1].startsWith("-"))
                    throw new IllegalArgumentException
                            ("-file not followed by file name");
                trainingFile = arg[i+1] + ".train";
                testFile = arg[i+1] + ".test";
                attributeNameFile = arg[i+1] + ".attrnames";
            }
            else if (a.startsWith("-c"))    // next arg is the class attribute index
            {
                if (i+1 >= arg.length)
                    throw new IllegalArgumentException
                            ("-class not followed by integer");
                if (arg[i+1].startsWith("-"))
                    throw new IllegalArgumentException
                            ("-class not followed by non-negative integer");
                try
                {
                    classIndex = Integer.decode(arg[i+1]).intValue();
                }
                catch (NumberFormatException e)
                {
                    throw new IllegalArgumentException
                            ("-class not followed by integer");
                }
            }
            else if (a.startsWith("-d"))    // next arg is the max tree depth
            {
                if (i+1 >= arg.length)
                    throw new IllegalArgumentException
                            ("-depth not followed by integer");
                if (arg[i+1].startsWith("-"))
                    throw new IllegalArgumentException
                            ("-depth not followed by non-negative integer");
                try
                {
                    maxDepth = Integer.decode(arg[i+1]).intValue();
                }
                catch (NumberFormatException e)
                {
                    throw new IllegalArgumentException
                            ("-depth not followed by integer");
                }
            }
            else if (a.startsWith("-s"))    // next args are entries to skip
            {
                if (i+1 >= arg.length)
                    throw new IllegalArgumentException
                            ("-skip not followed by integer");
                if (arg[i+1].startsWith("-"))
                    throw new IllegalArgumentException
                            ("-skip not followed by non-negative integer");
                int j=0;
                do
                {
                    j++;
                    try
                    {
                        skipIndex[skipCount++] = Integer.decode(arg[i+j]).intValue();
                    }
                    catch (NumberFormatException e)
                    {
                        throw new IllegalArgumentException
                                ("-sip not followed by integer");
                    }
                }
                while ( (i+j+1 < arg.length) &&
                        (!arg[i+j+1].startsWith("-")) );
            }
            else if (a.startsWith("-"))    // some unknown switch
            {
                throw new IllegalArgumentException
                        ("unrecognized switch: " + a);
            }
            else // not a switch
                ; // ignore
        }
    }

    public static boolean shouldBeSkipped(int c)
    {
        if (c == classIndex)
            return true;
        for (int i=0; i<skipCount; i++)
            if (c == skipIndex[i])
                return true;
        return false;
    }


    // line and st are member variables so that their values are
    // retained from one call of getToken() to the next.
	private static final String TOKENS = " ,\n\r\t";
	private static String          line;
	private static StringTokenizer st;

	private static void initTokenizer()
	{
	    st = null;
	}

	// getToken reads lines from the input file, as necessary, and uses
	// A StringTokenizer to break the characters into tokens.
	// White space (space, tab, newline, carriage return) and commas are ignored.
	private static String getToken(BufferedReader in)
	{
		while (st == null  ||         // no current StringTokenizer
		       !st.hasMoreTokens())   // no more tokens on this line
		{
		    try
		    {
			    line = in.readLine();
			    if (line == null)
			    {
				    return null;    // end of file
				}
		    }
		    catch (IOException e)
		    {
			    return null;
		    }
		    st = new StringTokenizer(line, TOKENS, false);
		}
		return st.nextToken();
	}

}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -