⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 decisiontreetest.java

📁 学习数据挖掘时写的决策树算法
💻 JAVA
字号:
import java.io.*;
import java.sql.*;
import java.util.*;
import java.awt.*;
import java.awt.event.*;
import javax.swing.*;
import javax.swing.table.*;

public class DecisionTreeTest extends JFrame{
	private File file=null;
	private DataModel dataModel=new DataModel();
	
	private DefaultTableModel tableModel=new DefaultTableModel();
	private JTable table=new JTable(tableModel);
	private Canvas treeView=new TreeCanvas(dataModel);
	private JTextArea rulesView=new JTextArea();			
	
	public DecisionTreeTest(){				
		initGUI();				
	}	
	
	private void initGUI(){
		this.setBounds(0,0,1024,768);
		this.setTitle("刑法分则决策");
		Container panel=this.getContentPane();
		panel.setLayout(new BorderLayout());
		JSplitPane viewSplitPane=new JSplitPane(JSplitPane.VERTICAL_SPLIT,true,new JScrollPane(table,JScrollPane.VERTICAL_SCROLLBAR_ALWAYS,JScrollPane.HORIZONTAL_SCROLLBAR_ALWAYS),new JScrollPane(rulesView,JScrollPane.VERTICAL_SCROLLBAR_ALWAYS,JScrollPane.HORIZONTAL_SCROLLBAR_ALWAYS));
		viewSplitPane.setDividerLocation(400);
		viewSplitPane.setLastDividerLocation(0);
		viewSplitPane.setOneTouchExpandable(true);
		JSplitPane splitPane=new JSplitPane(JSplitPane.HORIZONTAL_SPLIT,true,viewSplitPane,new JScrollPane(treeView,JScrollPane.VERTICAL_SCROLLBAR_ALWAYS,JScrollPane.HORIZONTAL_SCROLLBAR_ALWAYS));
		splitPane.setDividerLocation(500);
		splitPane.setLastDividerLocation(0);
		splitPane.setOneTouchExpandable(true);
		panel.add(splitPane,BorderLayout.CENTER);
		JMenuBar menuBar=new JMenuBar();
		JMenu menu=new JMenu("文件");
		JMenuItem menuItem=new JMenuItem("打开文件");
		menuItem.addActionListener(new ActionListener(){
			public void actionPerformed(ActionEvent e){
				JFileChooser fileChooser=new JFileChooser();
				fileChooser.setCurrentDirectory(new File("./data"));//设置当前目录                
				fileChooser.showOpenDialog(DecisionTreeTest.this);
				file=fileChooser.getSelectedFile();
				dataModel.clear();							
				dataModel.setFile(file);					
				dataModel.loadData();								
				dataModel.generateDecisionTree();			
				DecisionTreeTest.this.rulesView.setText(dataModel.getRules());
				DecisionTreeTest.this.treeView.repaint();
				DecisionTreeTest.this.loadData();				
			}
		});
		menuBar.add(menu);
		menu.add(menuItem);
		this.setJMenuBar(menuBar);		
		this.setVisible(true);
		this.addWindowListener(new WindowAdapter(){
			public void windowClosing(WindowEvent e){
				System.exit(0);
			}
		});
	}
	
	private void loadData(){
		tableModel=new DefaultTableModel();							
		try{			
			BufferedReader reader=new BufferedReader(new FileReader(dataModel.getFile()));
			String headerline=reader.readLine();
			
			StringTokenizer fieldName=new StringTokenizer(headerline,",");
			String name="";
			while(fieldName.hasMoreTokens()){
				name=fieldName.nextToken();
				tableModel.addColumn("Name:"+name);
			}			
			String recordline=reader.readLine();						
			while(recordline!=null){			
				Vector record=new Vector();
				StringTokenizer field=new StringTokenizer(recordline,",");				
				while(field.hasMoreTokens()){
					record.add(field.nextToken());					
				}
				tableModel.addRow(record);				
				recordline=reader.readLine();
			}	
		}
		catch(IOException e){
		}
		table.setModel(tableModel);		
	}
	
	public static void main(String args[]){
		DecisionTreeTest treeTest=new DecisionTreeTest();				
	}
}

class TreeCanvas extends Canvas{
	private DataModel dataModel=null;
	public TreeCanvas(DataModel dataModel){
		this.dataModel=dataModel;
	}
	public void paint(Graphics g){
		super.paint(g);		
		dataModel.paintTree(g,350,50);		
	}
}

class DataModel{
	private Vector attributes=new Vector();
	private Vector samples=new Vector();	
	private String decisionAttribute="";
	private Node decisionTree=null;
	
	private File file;
	public void clear(){
		attributes.removeAllElements();
		samples.removeAllElements();
		decisionAttribute="";
		decisionTree=null;
	}

	public void setFile(File file){
		this.file=file;
	}
	public File getFile(){
		return this.file;
	}
	public void loadData(){		
		try{				    
			BufferedReader reader=new BufferedReader(new FileReader(file));
			String headerline=reader.readLine();			
			StringTokenizer fieldName=new StringTokenizer(headerline,",");
			String name="";
			while(fieldName.hasMoreTokens()){
				name=fieldName.nextToken();
				attributes.add(name);
			}
			this.setDecisionAttribute(name);			
			String recordline=reader.readLine();						
			while(recordline!=null){				
				Vector tempRecord=new Vector();
				StringTokenizer field=new StringTokenizer(recordline,",");				
				while(field.hasMoreTokens()){
					String value=field.nextToken();
					tempRecord.add(value);
				}
				Hashtable record=new Hashtable();
				for(int i=0;i<attributes.size();i++){
					record.put(attributes.get(i),tempRecord.get(i));
				}
				samples.add(record);				
				recordline=reader.readLine();
			}			
		}
		catch(IOException e){
		}		
	}
	
	public void setDecisionAttribute(String attribute){
		this.decisionAttribute=attribute;
	}	
	
	public void generateDecisionTree(){
		if(decisionTree==null) decisionTree=new Node();
		decisionTree.setDecisionAttribute(this.decisionAttribute);
		decisionTree.initattributeSet(attributes);
		decisionTree.generateDecisionTree(samples,attributes);
	}
	
	public String getRules(){
		decisionTree.generateRules();
		return decisionTree.getRules();
	}
	
	public void paintTree(Graphics g,int x,int y){
		if(decisionTree!=null)decisionTree.paintNode(g,x,y);
	}
}

class Node{
	private static double PI=3.1415926;
	private static int LEVEL_HIGHT=50;
	
	private String testAttribute="";	
	private Hashtable attributeValues=new Hashtable();
	private Hashtable childNodes=new Hashtable();
	private String decisionAttribute="";
	private static Vector rules=new Vector();
	private static Stack ruleStack=new Stack();
	private static Vector attributeSet=new Vector();
	
	public void setDecisionAttribute(String attribute){
		this.decisionAttribute=attribute;
	}
	public String getTestAttribute(){
		return this.testAttribute;
	}
	public String getDecisionAttribute(){
		return this.decisionAttribute;
	}
	public String selectTestAttribute(Vector sample,Vector attribute){
		if (!attribute.isEmpty())return (String)attribute.firstElement();	
		else return "";
	}
	public void initattributeSet(Vector attribute){
		for(int i=0;i<attribute.size();i++){
			attributeSet.add(attribute.get(i));
		}		
		attributeSet.remove(this.decisionAttribute);					
	}
	public void generateDecisionTree(Vector sample,Vector attribute){		
        //if Samples 都在同一类then结束递归调用,返回
		Vector tempSampleClass=new Vector();
		Iterator iter=sample.iterator();
		while(iter.hasNext()){
			Hashtable record=(Hashtable)iter.next();						
			String decisionAttributeValue=(String)record.get(this.decisionAttribute);
			if(!tempSampleClass.contains(decisionAttributeValue))
				tempSampleClass.add(decisionAttributeValue);
		}
		if(tempSampleClass.size()==1){		
			this.testAttribute=decisionAttribute;
			Hashtable record=(Hashtable)sample.get(0);			
			String testAttributeValue="";			
			testAttributeValue=(String)record.get(decisionAttribute);			
			this.attributeValues.put(testAttributeValue,new Vector());			
			return;//都在同一类then结束递归调用,返回			
		}				
		//选择attibute中具有最高信息的属性为testAttribute
		this.testAttribute=selectTestAttribute(sample,attribute);		
		//生成新的属性集
		String tempAttribute="";
		Vector newAttribute=new Vector();
		if(!this.attributeSet.isEmpty())
		tempAttribute=(String)this.attributeSet.remove(attributeSet.indexOf(this.testAttribute));						
		for(int i=0;i<attributeSet.size();i++) newAttribute.add(attributeSet.get(i));		
		//根据testAttribute将样本进行划分
	    iter=sample.iterator();
		while(iter.hasNext()){
			Hashtable record=(Hashtable)iter.next();
			String testAttributeValue=(String)record.get(this.testAttribute);
			if(attributeValues.containsKey(testAttributeValue)){
				((Vector)attributeValues.get(testAttributeValue)).add(record);					
			}
			else{
				Vector newRecordSet=new Vector();
				newRecordSet.add(record);
				attributeValues.put(testAttributeValue,newRecordSet);
			}
		}		
		//根据得到的每一个划分为其建立对应结点,并递归地建立其子决策树
		iter=attributeValues.keySet().iterator();
		while(iter.hasNext()){			
			//构造新结点
			Node childNode=new Node();
			childNode.setDecisionAttribute(this.decisionAttribute);
			//生成新的样本集
			String testAttributeValue=(String)iter.next();				
			Vector newSample=new Vector();				
			Vector v=(Vector)attributeValues.get(testAttributeValue);
			for(int i=0;i<v.size();i++) newSample.add(v.get(i));
			childNode.generateDecisionTree(newSample,newAttribute);			
			childNodes.put(testAttributeValue,childNode);				
		}		
		this.attributeSet.add(tempAttribute);					
	}
		
	public void generateRules(){		
		if(this.testAttribute.equals(this.decisionAttribute)){			
			String decisionValue="";
			Iterator iter=this.attributeValues.keySet().iterator();			
			if(iter.hasNext()) decisionValue=(String)iter.next();
			String newRule="IF ";
			for(int i=0;i<ruleStack.size()-1;i++){
				newRule+=(String)ruleStack.get(i)+" AND ";				
			}
			if(ruleStack.size()>=1)newRule+=ruleStack.get(ruleStack.size()-1);
			newRule+=" THEN "+this.decisionAttribute+":"+decisionValue;
			rules.add(newRule);
			return;
		}
		else{			
			Iterator iter=childNodes.keySet().iterator();			
			while(iter.hasNext()){				
				String testAttributeValue=(String)iter.next();				
				Node childNode=(Node)childNodes.get(testAttributeValue);				
				ruleStack.push(this.testAttribute+":"+testAttributeValue);				
				childNode.generateRules();
				ruleStack.pop();
			}
		}
	}
	
	public String getRules(){
		String decisions="\nRules:\n";
		for(int i=0;i<rules.size();i++){
			decisions=decisions+(String)rules.get(i)+";\n";
		}
		return decisions;
	}
	
	public void paintNode(Graphics g,int x,int y){
		g.setFont(new Font("黑体",Font.BOLD,18));
		g.setColor(Color.DARK_GRAY);
		g.fillOval(x+5,y+5,30,30);
		g.setColor(Color.YELLOW);
		g.fillOval(x,y,30,30);
		g.setColor(Color.BLACK);
		g.drawOval(x,y,30,30);		
		g.setColor(Color.RED);
		g.drawString(this.testAttribute,x,y);
		if(this.testAttribute.equals(this.decisionAttribute)){			
			String decisionValue="";
			Iterator iter=this.attributeValues.keySet().iterator();			
			if(iter.hasNext()) decisionValue=(String)iter.next();
			g.setColor(Color.BLUE);			
			g.drawString(decisionValue,x,y+LEVEL_HIGHT/2);
			return;
		}
		else{
			int childNum=this.childNodes.size();
			double perArc=PI*(120.0/180)/(childNum-1);
			double arc=PI*(30.0/180);		
			int newX=0,newY=0;				
			Iterator iter=this.childNodes.keySet().iterator();
			while(iter.hasNext()){
				String testAttributeValue=(String)iter.next();
				newX=x-(int)(LEVEL_HIGHT/Math.tan(arc));
				newY=y+(int)LEVEL_HIGHT;
				arc+=perArc;			
				g.setColor(Color.GREEN);				
				g.drawLine(x+15,y+15,newX+15,newY+15);
				g.setColor(Color.BLACK);
				g.drawString(testAttributeValue,(newX+x)/2,(newY+y)/2);
				Node childNode=(Node)this.childNodes.get(testAttributeValue);
				childNode.paintNode(g,newX,newY);
			}
		}		
	}
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -