📄 tree.java
字号:
/*
* YALE - Yet Another Learning Environment
* Copyright (C) 2001-2004
* Simon Fischer, Ralf Klinkenberg, Ingo Mierswa,
* Katharina Morik, Oliver Ritthoff
* Artificial Intelligence Unit
* Computer Science Department
* University of Dortmund
* 44221 Dortmund, Germany
* email: yale-team@lists.sourceforge.net
* web: http://yale.cs.uni-dortmund.de/
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License as
* published by the Free Software Foundation; either version 2 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
* USA.
*/
package edu.udo.cs.yale.operator.learner.decisiontree;
import edu.udo.cs.yale.example.Attribute;
import edu.udo.cs.yale.example.Example;
import edu.udo.cs.yale.operator.learner.SimpleModel;
import edu.udo.cs.yale.tools.LogService;
import edu.udo.cs.yale.gui.SwingTools;
import java.util.ArrayList;
import java.util.Iterator;
import java.awt.Graphics2D;
import java.awt.Graphics;
import java.awt.Dimension;
import java.awt.Component;
import java.awt.geom.Rectangle2D;
import java.awt.geom.RoundRectangle2D;
import java.awt.Font;
import java.awt.Paint;
import java.awt.Color;
import java.awt.font.FontRenderContext;
import javax.swing.JPanel;
/** Tree is a vertex in a decision tree. It contains several simple conditions and other trees
* to which positive tests reference.
*
* @author Ingo, Dave Meppelink
* @version $Id: Tree.java,v 2.7 2004/08/27 11:57:38 ingomierswa Exp $
*/
public class Tree extends SimpleModel {
private static final int X_SPACING = 15;
private static final int Y_SPACING = 40;
private static final int PREMISE_SPACING = 20;
private static final int TEXT_MARGIN = 4;
private static final Font NODE_FONT = new Font("LucidaSans", Font.BOLD, 12);
private static final Font EDGE_FONT = new Font("LucidaSans", Font.PLAIN, 12);
/** Name of the attribute.
*/
private Attribute attribute;
/** A list of rules.
*/
private ArrayList rules;
/** Indicates if this tree is a leaf with an classification.
*/
private boolean leaf;
/** Class of the tree.
*/
private int goal;
protected Tree() {
super();
}
/** Creates an empty decision tree with the given default classification.
*/
public Tree(Attribute label, Attribute attribute) {
super(label);
this.attribute = attribute;
this.leaf = false;
this.rules = new ArrayList();
}
/** Creates a leaf with the given classification.
*/
public Tree(Attribute classAttribute, int goal) {
super(classAttribute);
this.rules = null;
this.leaf = true;
this.goal = goal;
}
public boolean isLeaf() {
return leaf;
}
/** Adds a new child. Checks if the name of the attribute suits the name of the premise.
*/
public void addChild(Premise premise, Tree goal) {
if (!leaf) {
rules.add(new Rule(goal, premise));
} else LogService.logMessage("Premise can not be add, tree is leaf node.", LogService.ERROR);
}
/** Returns the goal of the tree. Iterates through all rules and returns a rule's goal if
* it is appliable. Therefore the rules have to be ordered. Returns null if no rule is
* applicable.
*/
public int getGoal(Example example) {
if (leaf) {
return goal;
} else {
Iterator i = rules.listIterator();
while (i.hasNext()) {
Rule rule = (Rule)i.next();
if (rule.isTrue(example)) {
return rule.getGoal().getGoal(example);
}
}
return -1;
}
}
public double predict(Example example) {
return getGoal(example);
}
/** Returns a tree like string representation of this tree. */
public String toString(String prefix) {
StringBuffer result = new StringBuffer();
if (leaf) {
result.append(getLabel().getName());
result.append(" = ");
result.append(getLabel().getAsString(goal));
} else {
result.append("\n");
Iterator i = rules.listIterator();
while (i.hasNext()) {
Rule rule = (Rule)i.next();
result.append(rule.toString(prefix));
result.append("\n");
}
}
return result.toString();
}
public String toResultString() {
return toString("");
}
public Component getVisualisationComponent() {
return new JPanel() {
{
Dimension d = Tree.this.getSize();
d.setSize(d.getWidth() + TEXT_MARGIN*2+1, d.getHeight() + TEXT_MARGIN*2+1);
setPreferredSize(d);
}
public void printComponent(Graphics g) {
Dimension d = Tree.this.getSize();
Graphics2D g2d = (Graphics2D)g.create((int)(getWidth()/2-d.getWidth()/2),
(int)(getHeight()/2-d.getHeight()/2),
(int)d.getWidth()+1,
(int)d.getHeight()+1);
render(g2d);
g2d.dispose();
}
public void paintComponent(Graphics g) {
//g.clearRect(0, 0, getWidth(), getHeight());
super.paintComponent(g);
printComponent(g);
}
};
}
private String getNodeLabel() {
if (isLeaf())
return getLabel().getAsString(goal);
else
return attribute.getName();
}
public static Dimension getSize(Font font, String string) {
Rectangle2D stringBounds = font.getStringBounds(string,
new FontRenderContext(null, false, false));
return new Dimension((int)(stringBounds.getWidth() + 2 * TEXT_MARGIN),
(int)(stringBounds.getHeight() + 2 * TEXT_MARGIN));
}
public static void drawFramed(Graphics2D g, String string, Font font, int paintIndex, double x, double y) {
Rectangle2D stringBounds = font.getStringBounds(string, g.getFontRenderContext());
g.setFont(font);
RoundRectangle2D frame =
new RoundRectangle2D.Double(x - stringBounds.getWidth()/2 - TEXT_MARGIN,
y,
stringBounds.getWidth() + 2*TEXT_MARGIN,
stringBounds.getHeight() + 2*TEXT_MARGIN,
2*TEXT_MARGIN, 2*TEXT_MARGIN);
Paint paint = null;
switch (paintIndex) {
case 0: paint = SwingTools.makeBluePaint(frame.getWidth(), frame.getHeight()); break;
case 1: paint = SwingTools.makeYellowPaint(frame.getWidth(), frame.getHeight()); break;
default:
case 2: paint = Color.red; break;
}
g.setPaint(paint);
g.fill(frame);
g.setColor(Color.black);
g.draw(frame);
g.setColor(Color.black);
g.drawString(string,
(int)(x - stringBounds.getWidth()/2),
(int)(y - stringBounds.getY() + TEXT_MARGIN));
}
public Dimension getSize() {
Dimension d = new Dimension(0,0);
if (!isLeaf()) {
Iterator i = rules.iterator();
boolean first = true;
while (i.hasNext()) {
Rule rule = (Rule)i.next();
Dimension subtreeSize = rule.getGoal().getSize();
Dimension premiseSize = getSize(EDGE_FONT,
((SimplePremise)rule.getPremise()).getTreeEdgeString());
Dimension totalSize =
new Dimension((int)Math.max(premiseSize.getWidth(), subtreeSize.getWidth()),
(int)(premiseSize.getHeight()+PREMISE_SPACING+subtreeSize.getHeight()));
d.setSize(d.getWidth() + totalSize.getWidth() + (!first ? X_SPACING : 0),
Math.max(d.getHeight(), totalSize.getHeight()));
first = false;
}
}
Dimension nodeSize = getSize(NODE_FONT, getNodeLabel());
d.setSize(Math.max(d.getWidth(), nodeSize.getWidth()),
d.getHeight() + nodeSize.getHeight() + Y_SPACING);
return d;
}
public void render(Graphics2D g) {
Dimension d = getSize();
double nodeHeight = getSize(NODE_FONT, getNodeLabel()).getHeight();
double lineX1 = d.getWidth() /2;
double lineY1 = nodeHeight/2;
// draw children
if (!isLeaf()) {
double x = 0;
int y = (int)nodeHeight + Y_SPACING;
Iterator i = rules.iterator();
while (i.hasNext()) {
Rule rule = (Rule)i.next();
Dimension premiseSize = getSize(EDGE_FONT,
((SimplePremise)rule.getPremise()).getTreeEdgeString());
Dimension childSize = rule.getGoal().getSize();
double totalWidth = Math.max(premiseSize.getWidth(), childSize.getWidth());
g.setColor(Color.black);
g.drawLine((int)lineX1, (int)lineY1,
(int)(x + totalWidth/2), y);
drawFramed(g, ((SimplePremise)rule.getPremise()).getTreeEdgeString(),
EDGE_FONT, 1, x + totalWidth/2, y);
g.setColor(Color.black);
g.drawLine((int)(x + totalWidth/2), (int)(y+premiseSize.getHeight()),
(int)(x + totalWidth/2), (int)(y+premiseSize.getHeight()+PREMISE_SPACING));
Graphics2D childGr = (Graphics2D)g.create((int)(x + totalWidth/2 - childSize.getWidth()/2),
(int)(y + premiseSize.getHeight()+PREMISE_SPACING),
(int)childSize.getWidth()+1,
(int)childSize.getHeight()+1);
rule.getGoal().render(childGr);
childGr.dispose();
x += totalWidth + X_SPACING;
}
}
// draw the node itself
drawFramed(g, getNodeLabel(), NODE_FONT, 0, d.getWidth()/2, 0);
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -