⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 dsstree.java

📁 ID3决策树算法的JAVA实现:ID3算法是机器学习中的一种分类方法
💻 JAVA
字号:
package dsstree;import java.util.Vector;import javax.swing.*;import javax.swing.tree.*;import javax.swing.event.*;import java.awt.*;public class dssTree{  private TreeNode rootNode = null;  private int[][] data;  private int row, col;  private int[] att_No;  private String[] att_Name;  TreeNode tempNode=null;    DefaultTreeModel treeModel=null;    DefaultMutableTreeNode treeroot=null,treenode=null,layerChildNode=null;    DefaultTreeCellRenderer cellRenderer=null;    JTree tree=null;  public dssTree(int[][] data1, int row1, int col1, int[] att_No1, String[] att_Name1) {    data = data1;    row = row1;    col = col1;    att_No = att_No1;    att_Name = att_Name1;    genTree(); //generate the dssTree  }  private void genTree() {    int p = 0, n = 0;    double e[] = {        0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; //存放各属性的期望    double g[] = {        0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; //存放各属性的增益    for (int i = 0; i < row; i++) { //计算出 p n      if (data[i][5] == 0)        n++;      if (data[i][5] == 1)        p++;    }    for (int i = 1; i < col; i++) { //计算每个属性的信息增益      for (int j = 1; j <= att_No[i]; j++) { //每个属性所有的属性值        int p1 = 0, n1 = 0; //在该属性值下的 p n        for (int k = 0; k < row; k++) { //统计在该属性下的 p n 的值          if ( (data[k][i - 1] == j) && (data[k][5] == 0))            n1++;          if ( (data[k][i - 1] == j) && (data[k][5] == 1))            p1++;        }        e[i] += 1.0 * (p1 + n1) * calI(p1, n1) / (p + n);      }      g[i] = calI(p, n) - e[i];    }    int m = 1; //找信息增益最大的属性号    for (int i = 2; i < col; i++) {      if (g[i] > g[m])        m = i;    }    mainApp.text.append("\nMax:g[" + m + "]=" + g[m]);    rootNode = new TreeNode(m, 0); //生成根结点    rootNode.setParent(null);    genChildTree(rootNode); //调用递归函数,产生决策树  }  private void genChildTree(TreeNode root) {    //产生根结点的各子结点    int m = root.getattr(); //m 为取得待处理结点的属性号    if (root.getParent() != null) {      mainApp.text.append("\nroot: " + att_Name[m] + " Parent->" +                          att_Name[root.getParent().getattr()]);    }    else {      mainApp.text.append("\nroot: " + att_Name[m] + " Parent->NULL");    }    for (int h = 1; h <= att_No[m]; h++) { //h:属性号为 m 的属性值个数      int p = 0, n = 0;      double e[] = {          0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; //存放各属性的期望      double g[] = {          0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; //存放各属性的增益      mainApp.text.append("\n" + att_Name[m] + "" + h);      for (int i = 0; i < row; i++) { //计算出 p n        boolean bool = (data[i][m - 1] == h); //用于判断第 i 行第 m 各属性的属性值是否为 h        TreeNode tempNode = root; //用于找所有父结点中已经得到的各个属性        while (tempNode.getParent() != null) { //在满足各父结点都存在的情况下,n p          bool = bool &&              (data[i][tempNode.getParent().getattr() - 1] ==               tempNode.getParent_attr_value());          tempNode = tempNode.getParent();        }        if (bool && data[i][5] == 0)          n++;        if (bool && data[i][5] == 1)          p++;      }      if (p + n == 0)        continue;      if (n == 0) { //the leafNode is positive        TreeNode theNode = new TreeNode( -1, h);        theNode.setParent(root);        root.addChild(theNode);        mainApp.text.append("\nleafNode:" + att_Name[root.getattr()] + "[" + h +                            "]" + "--Positive");        continue;      }      if (p == 0) { //the leafNode is negative        TreeNode theNode = new TreeNode( -2, h);        theNode.setParent(root);        root.addChild(theNode);        mainApp.text.append("\nleafNode:" + att_Name[root.getattr()] + "[" + h +                            "]" + "--Negative");        continue;      }      for (int i = 1; i < col; i++) { //计算每个属性的信息增益,其中 i 为属性号        boolean bool = (i != m); //用于判断和父结点属性是否相等        TreeNode tempNode = root.getParent(); //用于找所有父结点        while (tempNode != null) { //除去父结点中具有相同属性号的属性          bool = bool && (i != tempNode.getattr());          tempNode = tempNode.getParent();        }        if (bool) {          for (int j = 1; j <= att_No[i]; j++) { //对属性 i 的每个属性值            int p1 = 0, n1 = 0; //在该属性值下的 p n            //int count=0;            for (int k = 0; k < row; k++) { //统计在该属性下的 p n 的值              boolean boolx = (data[k][m - 1] == h); //用于判断和父结点属性是否相等              TreeNode tempNodex = root; //用于找所有父结点              //text.append("nullPointer test,the root is="+att_Name[root.getattr()]+"");              while (tempNodex != null && tempNodex.getParent() != null) {                //text.append("nullPointer test,the root is="+att_Name[tempNodex.getattr()]+"");                //text.append(""+tempNodex.getattr()+" "+count);                //count++;                boolx = boolx &&                    (data[k][tempNodex.getParent().getattr() - 1] ==                     tempNodex.getParent_attr_value());                //text.append("boolx="+bool);                tempNodex = tempNodex.getParent();              }              if (boolx && data[k][i - 1] == j && data[k][5] == 0)                n1++;              if (boolx && data[k][i - 1] == j && data[k][5] == 1)                p1++;            }            e[i] += 1.0 * (p1 + n1) * calI(p1, n1) / (p + n);          }          g[i] = calI(p, n) - e[i];          mainApp.text.append("\ng[" + att_Name[i] + "]=" + g[i]);        }      }      int k = 1;      for (int m1 = 2; m1 < col; m1++) {        if (g[m1] > g[k])          k = m1;      }      if (g[k] > 0) {        TreeNode theNode = new TreeNode(k, h); //生成子结点        theNode.setParent(root);        root.addChild(theNode);        mainApp.text.append("\nMax:g[" + att_Name[k] + "]=" + g[k]);      }    }    //对root结点的每个子结点,递归创建子结点的子结点    for (int child_num = 0; child_num < root.getchildrenNum(); child_num++) {      TreeNode tNode = (TreeNode) root.getChildren().get(child_num);      if (tNode.getattr() > 0) {        genChildTree(tNode);      }    }  }  private double calI(int p, int n) {    double x;    if (p * n != 0) {      x = ( -1.0 * p * Math.log( (p * 1.0 / (p + n))) / Math.log(2) -           1.0 * n * Math.log( (n * 1.0 / (p + n))) / Math.log(2)) / (p + n);    }    else {      x = 0.0;    }    return x;  }  public TreeNode getRoot() {    return rootNode;  }  public void preOrderRules(TreeNode root) {    if (root != null) {      //Output the rules of each leafNode      if (root.getchildrenNum() == 0) {        mainApp.text.append("\n " + root.getattr() + " " +                            root.getParent_attr_value() + " " +                            root.getchildrenNum());        String theRule = "";        if (root.getattr() == -1) {          theRule = " then is T";        }        else {          theRule = " then is F";        }        TreeNode tempNode = root;        while (tempNode.getParent() != null) {          theRule = " and " + att_Name[tempNode.getParent().getattr()] + "=" +              att_Name[tempNode.getParent().getattr()] +              tempNode.getParent_attr_value() + theRule;          tempNode = tempNode.getParent();        }        mainApp.text.append(" :    if" + theRule.substring(4));      }      int num = root.getchildrenNum();      if (num > 0) {        Vector childList = root.getChildren();        for (int i = 0; i < num; i++) {          TreeNode childNode = (TreeNode) childList.get(i);          if (childNode != null) {            preOrderRules(childNode);          }        }      }    }  }  public void preOrderAllNodes(TreeNode root) {    if (root != null) {      //Output the information of each TreeNode in the dssTree      int att = root.getattr();      String str = "zhufubao";      if (att == -1)        str = "Positive";      else if (att == -2)        str = "Negative";      else        str = att_Name[att];      mainApp.text.append("\n " + root.getattr() + " " +                          root.getParent_attr_value() + " " +                          root.getchildrenNum() + "  " + str);      int num = root.getchildrenNum();      if (num > 0) {        Vector childList = root.getChildren();        for (int i = 0; i < num; i++) {          TreeNode childNode = (TreeNode) childList.get(i);          if (childNode != null) {            preOrderAllNodes(childNode);          }        }      }    }  }  public void layerOrderToGenTree(TreeNode root)  {    /*TreeNode tempNode=null;    DefaultTreeModel treeModel=null;    DefaultMutableTreeNode treeroot=null,treenode=null,layerChildNode=null;    DefaultTreeCellRenderer cellRenderer=null;    JTree tree=null;*/    Vector queue=new Vector(0,1);//定义队列    Vector vnode=new Vector(0,1);//暂存结点    DefaultMutableTreeNode layerNode=null;    if (root != null)    {      String stemp=att_Name[root.getattr()];      treeroot=new DefaultMutableTreeNode(stemp);      treenode=treeroot;      treeModel = new DefaultTreeModel(treeroot);      queue.add(root);      vnode.add(treenode);    }    int i = 0;    while(i < queue.size())//队列不空时    {      tempNode=(TreeNode)queue.elementAt(i);      treenode=(DefaultMutableTreeNode)vnode.elementAt(i);      i++;      int num = tempNode.getchildrenNum();      if (num > 0)      {        Vector childList = tempNode.getChildren();        for (int j = 0; j < num; j++)        {          TreeNode childNode = (TreeNode) childList.get(j);          queue.add(childNode);          if(childNode.getattr()>0)            layerChildNode = new DefaultMutableTreeNode(att_Name[childNode.getattr()]+":"+att_Name[childNode.getParent().getattr()]+"="+att_Name[childNode.getParent().getattr()]+childNode.getParent_attr_value());          if(childNode.getattr()==-1)            layerChildNode = new DefaultMutableTreeNode("Pos"+":"+att_Name[childNode.getParent().getattr()]+"="+att_Name[childNode.getParent().getattr()]+childNode.getParent_attr_value());          if(childNode.getattr()==-2)            layerChildNode = new DefaultMutableTreeNode("Neg"+":"+att_Name[childNode.getParent().getattr()]+"="+att_Name[childNode.getParent().getattr()]+childNode.getParent_attr_value());          vnode.add(layerChildNode);          treeModel.insertNodeInto(layerChildNode, treenode, treenode.getChildCount());        }      }    }    //treeroot.add(layerChildNode);    tree = new JTree(treeModel);    tree.setRowHeight(20);    cellRenderer = (DefaultTreeCellRenderer) tree.getCellRenderer();    cellRenderer.setFont(new Font("宋体", Font.PLAIN, 12));    cellRenderer.setBorderSelectionColor(Color.blue);    cellRenderer.setBackgroundSelectionColor(Color.blue);    cellRenderer.setTextSelectionColor(Color.white);    cellRenderer.setTextNonSelectionColor(Color.red);    mainApp.scrollPane.setViewportView(tree);    //mainApp.scrollPane.setVerticalScrollBarPolicy(JScrollPane.VERTICAL_SCROLLBAR_ALWAYS);    mainApp.scrollPane.setViewportBorder(BorderFactory.createEtchedBorder());    tree.addTreeSelectionListener(new javax.swing.event.TreeSelectionListener(){      public void valueChanged(TreeSelectionEvent e)       {        tree_valueChanged(e);      }       });  }  void tree_valueChanged(TreeSelectionEvent e) {    DefaultMutableTreeNode node = (DefaultMutableTreeNode)tree.getLastSelectedPathComponent();    TreePath pathnode = (TreePath)tree.getLeadSelectionPath();    if (node.isRoot()) {      mainApp.result.setText("Hello!");    }    else if(node.isLeaf())    {      String theRule = " and " + node.toString().substring(node.toString().indexOf(":")+1);      if (node.toString().indexOf("Pos")>-1) {        theRule = theRule + " then is T";      }      else {        theRule = theRule + " then is F";      }      DefaultMutableTreeNode tnode=(DefaultMutableTreeNode)node.getParent();      while (!tnode.isRoot()) {        theRule = " and " + tnode.toString().substring(tnode.toString().indexOf(":")+1) + theRule;        tnode=(DefaultMutableTreeNode)tnode.getParent();      }      mainApp.result.setText("if" + theRule.substring(4));    }    else    {      mainApp.result.setText("");    }  }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -