📄 tree.java
字号:
/* * YALE - Yet Another Learning Environment * Copyright (C) 2002, 2003 * Simon Fischer, Ralf Klinkenberg, Ingo Mierswa, * Katharina Morik, Oliver Ritthoff * Artificial Intelligence Unit * Computer Science Department * University of Dortmund * 44221 Dortmund, Germany * email: yale@ls8.cs.uni-dortmund.de * web: http://yale.cs.uni-dortmund.de/ * * This program is free software; you can redistribute it and/or * modify it under the terms of the GNU General Public License as * published by the Free Software Foundation; either version 2 of the * License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, but * WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 * USA. */package edu.udo.cs.yale.operator.learner.decisiontree;import edu.udo.cs.yale.MethodNotSupportedException;import edu.udo.cs.yale.example.Attribute;import edu.udo.cs.yale.example.Example;import edu.udo.cs.yale.example.ExampleSet;import edu.udo.cs.yale.operator.learner.Model;import edu.udo.cs.yale.operator.learner.SimpleModel;import edu.udo.cs.yale.tools.Ontology;import edu.udo.cs.yale.tools.LogService;import edu.udo.cs.yale.gui.SwingTools;import java.util.ArrayList;import java.util.Iterator;import java.awt.Graphics2D;import java.awt.Graphics;import java.awt.Dimension;import java.awt.Component;import java.awt.geom.Rectangle2D;import java.awt.geom.RoundRectangle2D;import java.awt.Font;import java.awt.Paint;import java.awt.Color;import java.awt.font.FontRenderContext;import javax.swing.JPanel;/** Tree is a vertex in a decision tree. It contains several simple conditions and other trees * to which positive tests reference. * @author Ingo * @version $Id: Tree.java,v 2.2 2003/04/04 11:59:29 fischer Exp $ */public class Tree extends SimpleModel { private static final int X_SPACING = 15; private static final int Y_SPACING = 40; private static final int PREMISE_SPACING = 20; private static final int TEXT_MARGIN = 4; private static final Font NODE_FONT = new Font("LucidaSans", Font.BOLD, 12); private static final Font EDGE_FONT = new Font("LucidaSans", Font.PLAIN, 12); /** Name of the attribute. */ private Attribute attribute; /** A list of rules. */ private ArrayList rules; /** Indicates if this tree is a leaf with an classification. */ private boolean leaf; /** Class of the tree. */ private int goal; /** Creates an empty decision tree with the given default classification. */ public Tree(Attribute attribute) { this.attribute = attribute; this.leaf = false; this.rules = new ArrayList(); } /** Creates a leaf with the given classification. */ public Tree(Attribute classAttribute, int goal) { this.attribute = classAttribute; this.rules = null; this.leaf = true; this.goal = goal; } public boolean isLeaf() { return leaf; } /** Adds a new child. Checks if the name of the attribute suits the name of the premise. */ public void addChild(Premise premise, Tree goal) { if (!leaf) { rules.add(new Rule(goal, premise)); } else LogService.logMessage("Premise can not be add, tree is leaf node.", LogService.ERROR); } /** Returns the goal of the tree. Iterates through all rules and returns a rule's goal if * it is appliable. Therefore the rules have to be ordered. Returns null if no rule is * applicable. */ public int getGoal(Example example) { if (leaf) { return goal; } else { boolean predicted = false; Iterator i = rules.listIterator(); while (i.hasNext()) { Rule rule = (Rule)i.next(); if (rule.isTrue(example)) { return rule.getGoal().getGoal(example); } } return -1; } } public double predict(Example example) { return getGoal(example); } /** Returns a tree like string representation of this tree. */ public String toString(String prefix) { String result = ""; if (leaf) { result += attribute.getName() + " = " + attribute.getAsString(goal); } else { result = "\n"; Iterator i = rules.listIterator(); while (i.hasNext()) { Rule rule = (Rule)i.next(); result += rule.toString(prefix) + "\n"; } } return result; } public String toResultString() { return toString(""); } public Component getVisualisationComponent() { return new JPanel() { { Dimension d = Tree.this.getSize(); d.setSize(d.getWidth() + TEXT_MARGIN*2+1, d.getHeight() + TEXT_MARGIN*2+1); setPreferredSize(d); } public void printComponent(Graphics g) { Dimension d = Tree.this.getSize(); Graphics2D g2d = (Graphics2D)g.create((int)(getWidth()/2-d.getWidth()/2), (int)(getHeight()/2-d.getHeight()/2), (int)d.getWidth()+1, (int)d.getHeight()+1); render(g2d); g2d.dispose(); } public void paintComponent(Graphics g) { //g.clearRect(0, 0, getWidth(), getHeight()); super.paintComponent(g); printComponent(g); } }; } private String getNodeLabel() { if (isLeaf()) return attribute.getAsString(goal); else return attribute.getName(); } public static Dimension getSize(Font font, String string) { Rectangle2D stringBounds = font.getStringBounds(string, new FontRenderContext(null, false, false)); return new Dimension((int)(stringBounds.getWidth() + 2 * TEXT_MARGIN), (int)(stringBounds.getHeight() + 2 * TEXT_MARGIN)); } public static void drawFramed(Graphics2D g, String string, Font font, int paintIndex, double x, double y) { Rectangle2D stringBounds = font.getStringBounds(string, g.getFontRenderContext()); g.setFont(font); RoundRectangle2D frame = new RoundRectangle2D.Double(x - stringBounds.getWidth()/2 - TEXT_MARGIN, y, stringBounds.getWidth() + 2*TEXT_MARGIN, stringBounds.getHeight() + 2*TEXT_MARGIN, 2*TEXT_MARGIN, 2*TEXT_MARGIN); Paint paint = null; switch (paintIndex) { case 0: paint = SwingTools.makeBluePaint(frame.getWidth(), frame.getHeight()); break; case 1: paint = SwingTools.makeYellowPaint(frame.getWidth(), frame.getHeight()); break; default: case 2: paint = Color.red; break; } g.setPaint(paint); g.fill(frame); g.setColor(Color.black); g.draw(frame); g.setColor(Color.black); g.drawString(string, (int)(x - stringBounds.getWidth()/2), (int)(y - stringBounds.getY() + TEXT_MARGIN)); } public Dimension getSize() { Dimension d = new Dimension(0,0); if (!isLeaf()) { Iterator i = rules.iterator(); boolean first = true; while (i.hasNext()) { Rule rule = (Rule)i.next(); Dimension subtreeSize = rule.getGoal().getSize(); Dimension premiseSize = getSize(EDGE_FONT, ((SimplePremise)rule.getPremise()).getTreeEdgeString()); Dimension totalSize = new Dimension((int)Math.max(premiseSize.getWidth(), subtreeSize.getWidth()), (int)(premiseSize.getHeight()+PREMISE_SPACING+subtreeSize.getHeight())); d.setSize(d.getWidth() + totalSize.getWidth() + (!first ? X_SPACING : 0), Math.max(d.getHeight(), totalSize.getHeight())); first = false; } } Dimension nodeSize = getSize(NODE_FONT, getNodeLabel()); d.setSize(Math.max(d.getWidth(), nodeSize.getWidth()), d.getHeight() + nodeSize.getHeight() + Y_SPACING); return d; } public void render(Graphics2D g) { Dimension d = getSize(); double nodeHeight = getSize(NODE_FONT, getNodeLabel()).getHeight(); double lineX1 = d.getWidth() /2; double lineY1 = nodeHeight/2; // draw children if (!isLeaf()) { double x = 0; int y = (int)nodeHeight + Y_SPACING; Iterator i = rules.iterator(); while (i.hasNext()) { Rule rule = (Rule)i.next(); Dimension premiseSize = getSize(EDGE_FONT, ((SimplePremise)rule.getPremise()).getTreeEdgeString()); Dimension childSize = rule.getGoal().getSize(); double totalWidth = Math.max(premiseSize.getWidth(), childSize.getWidth()); g.setColor(Color.black); g.drawLine((int)lineX1, (int)lineY1, (int)(x + totalWidth/2), y); drawFramed(g, ((SimplePremise)rule.getPremise()).getTreeEdgeString(), EDGE_FONT, 1, x + totalWidth/2, y); g.setColor(Color.black); g.drawLine((int)(x + totalWidth/2), (int)(y+premiseSize.getHeight()), (int)(x + totalWidth/2), (int)(y+premiseSize.getHeight()+PREMISE_SPACING)); Graphics2D childGr = (Graphics2D)g.create((int)(x + totalWidth/2 - childSize.getWidth()/2), (int)(y + premiseSize.getHeight()+PREMISE_SPACING), (int)childSize.getWidth()+1, (int)childSize.getHeight()+1); rule.getGoal().render(childGr); childGr.dispose(); x += totalWidth + X_SPACING; } } // draw the node itself drawFramed(g, getNodeLabel(), NODE_FONT, 0, d.getWidth()/2, 0); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -