📄 id3.java
字号:
/* * To change this template, choose Tools | Templates * and open the template in the editor. */package id3;import java.io.BufferedReader;import java.io.File;import java.io.FileInputStream;import java.io.FileOutputStream;import java.io.IOException;import java.io.InputStreamReader;import java.io.PrintWriter;import java.util.ArrayList;/** * * @author Lam *//** * ID3判定树类 * @author Lam * */class ID3Tree{ ID3Tree() { root=null; attributeList=new AttributesList(); dataSet=new Matrix(); testData=new Matrix(); ruleList=new ArrayList<String>(); } /** * 多数表决 * @param matrix 数据集 * @return 数据集中最普通的类 */ private String getMaxLabel(Matrix matrix) { if(matrix.matrix.size()==0) return ""; ArrayList<Target> tarList=new ArrayList<Target>(); ArrayList<RowData> rows=matrix.matrix; for(int i=0;i<rows.size();i++) { ArrayList<String> row=rows.get(i).rowdata; String cLabel=row.get(row.size()-1); int j=0; for( ;j<tarList.size();j++) { if(tarList.get(j).label.equals(cLabel)) { tarList.get(j).counts++; break; } } if(j==tarList.size()) { Target tar=new Target(); tar.label=cLabel; tar.counts++; tarList.add(tar); } } int maxTargetIndex=0; int maxCounts=0; for(int i=0;i<tarList.size();i++) { int count=tarList.get(i).counts; if(count>maxCounts) { maxCounts=count; maxTargetIndex=i; } } String maxLabel=tarList.get(maxTargetIndex).label; return maxLabel; } /** * 产生判定树 * @param attList 属性列表 * @param matrix 数据表 * @param path 节点路径(最终转化为规则) * @return 判定树节点 */ private Node Generate_decision_tree(AttributesList attList,Matrix matrix,String path) { Node t; String label=this.theSameClass(matrix); //在同一类中 if(label!=null) { t=new Node("Leave-Class"+label,path+"->"+label,1,0); } //attribute_list为空 else if(attList.attributes.size()==0) { String maxLabel=this.getMaxLabel(matrix); t=new Node("Leave-Class"+maxLabel,path+"->"+maxLabel,1,0); } else { int maxColIndex=0; double maxEntropy=0.0; ArrayList<Column> cols=attList.attributes; for(int i=0;i<cols.size();i++) { double gain=attList.inentropy-entropy(cols.get(i)); if(gain>maxEntropy) { maxEntropy=gain; maxColIndex=i; } } Column col=cols.get(maxColIndex); int childCounts=col.attributes.size(); t=new Node("Node-"+col.label,path+col.label,-1,childCounts); Node[] childs=t.childs; for(int i=0;i<childs.length;i++) { String subAttr=this.getAttrLabel(attList, maxColIndex, i); Matrix matr=buildMatrix(matrix,attList,maxColIndex,i); if(matr.matrix.size()==0) { String maxLabel=this.getMaxLabel(matrix); childs[i]=new Node("Leave-"+maxLabel,t.path+":"+subAttr+"->"+maxLabel,1,0); } else { AttributesList newAttrList=this.buildAttributesList(matr,attList,maxColIndex); childs[i]=this.Generate_decision_tree(newAttrList, matr,t.path+":"+subAttr); } } } return t; } /** * 返回属性列表中指定的某列的某个子属性 * @param attList 属性列表 * @param colIndex 属性列位置 * @param attIndex 子属性位置 * @return */ private String getAttrLabel(AttributesList attList,int colIndex,int attIndex) { ArrayList<Column> cols=attList.attributes; Column col=cols.get(colIndex); Attribute attr=col.attributes.get(attIndex); return attr.label; } /** * 从指定节点产生判定树规则 * @param t */ private void genRule(Node t) { if(t!=null) { if(t.decision==1) //属于叶节点,记下它的路径 { ruleList.add(t.path); return; } else //递归查找叶节点 { Node[] childs=t.childs; for(int i=0;i<childs.length;i++) { genRule(childs[i]); } } } } /** * 将规则转换为特定形式(类似于0:1->1) * */ private void checkRules() { genRule(root); for(int i=0;i<ruleList.size();i++) { String s=ruleList.get(i); s=s.substring(9); s=s.replaceAll("Attribute", "+"); ruleList.set(i, s); } } /** * 输出判定树规则的特定形式 * */ public void showRules() throws IOException { checkRules(); System.out.println("规则列表:"); File ruleFile=new File("rule.txt"); if(ruleFile.exists()) ruleFile.delete(); FileOutputStream fout=new FileOutputStream(ruleFile); PrintWriter pWriter=new PrintWriter(fout); pWriter.println("ID3分类规则列表:"); for(String s:ruleList) { pWriter.println(s); } pWriter.close(); for(String s:ruleList) { System.out.println(s); } System.out.println("结果同时保存到文件\"rule.txt\"中."); } /** * 判断所有数据是否属于同一类 * @param matrix 数据表 * @return 此类的标识 */ private String theSameClass(Matrix matrix) { int labelIndex=matrix.width-1; ArrayList<RowData> rows=matrix.matrix; String cLabel=rows.get(0).rowdata.get(labelIndex); for(int i=0;i<rows.size();i++) { ArrayList<String> row=rows.get(i).rowdata; if(!row.get(labelIndex).equals(cLabel)) return null; } return cLabel; } /** * 计算某个主属性的属性增益 * @param col 主属性 * @return 属性增益值 */ private double entropy(Column col) { //TotalTarget totalTargets=new TotalTarget(); //ArrayList<Target> totalList=totalTargets.totalTargets; ArrayList<Attribute> colAttr=col.attributes; int totalCount=0; //属性增益 double EAttr=0.0; for(int i=0;i<colAttr.size();i++) { Attribute attr=colAttr.get(i); //数据总项数 totalCount+=attr.counts; } for(int i=0;i<colAttr.size();i++) { //对每个子属性计算属性增益 Attribute attr=colAttr.get(i); int sC=attr.counts; ArrayList<Target> tarList=attr.targets; double inentropy=0.0; for(int j=0;j<tarList.size();j++) { int pC=tarList.get(j).counts; double p=(double)pC/sC; inentropy+=p*Math.log(p)/LN2; } inentropy=-1*inentropy; //sc/totalCount属性增益权重 EAttr+=inentropy*((double)sC/totalCount); } return EAttr; } /** * 由指定的主属性,筛选某个子属性,构造新的数据Matrix * @param matrix 旧的数据集 * @param attList 属性列表 * @param col_index 主属性下标 * @param attr_index 主属性下的子属性下标 * @return 满足条件的新的数据集 */ private Matrix buildMatrix(Matrix matrix,AttributesList attList,int col_index,int attr_index) { //实际上是实现算法中的在一个test_attribute中取出一个分支的数据 ArrayList<Column> cols=attList.attributes; ArrayList<Attribute> colAttr=cols.get(col_index).attributes; Attribute attr=colAttr.get(attr_index); String label=attr.label; ArrayList<RowData> oldData=matrix.matrix; Matrix newMatrix=new Matrix(); ArrayList<RowData> newRowData=newMatrix.matrix; newMatrix.setWidth(matrix.width-1); for(int i=0;i<oldData.size();i++) { ArrayList<String> row=oldData.get(i).rowdata; RowData newrowdata=new RowData(); ArrayList<String> newrow=newrowdata.rowdata; //只选择具有这个属性的行 if(row.get(col_index).equals(label)) { for(int j=0;j<row.size();j++) { if(j!=col_index) newrow.add(row.get(j)); } newRowData.add(newrowdata); } } return newMatrix; } /** * 根据删除了主属性col_index的数据集产生新的属性列表 * @param theMatrix 数据集 * @param oldattList 旧的属性列表 * @param col_index 主属性下标 * @return 新的属性列表 */ private AttributesList buildAttributesList(Matrix theMatrix,AttributesList oldattList,int col_index) { //实际上是构造一个除去test_attribute的attribute_list,即attribute_list-test_attribute //由于上一步的buildMatrix在生成满足test_attribute的数据集theMatrix时,同时将test_attribute这一列数据删除, //而保持了列之间顺序的不变,在这里只需使用旧的列标识从theMatrix中直接生成AttrubutesList AttributesList attList=new AttributesList(); TotalTarget totalTargets=new TotalTarget(); ArrayList<Target> totalList=totalTargets.totalTargets; //属性列集合 ArrayList<Column> newCols=attList.attributes; ArrayList<Column> oldCols=oldattList.attributes; for(int i=0;i<oldCols.size();i++) { //构造新的列,保留除了要删除的列之外的旧列的标识 if(i!=col_index) { Column col=new Column(); col.label=oldCols.get(i).label; newCols.add(col); } } //行数 int height=theMatrix.matrix.size(); //列数 int width=theMatrix.width; //行数据集合 ArrayList<RowData> rows=theMatrix.matrix; for(int i=0;i<height;i++) { totalTargets.totalCount++; //行数据 ArrayList<String> row=rows.get(i).rowdata; //类标识,链表最后一个元素 String cLabel=row.get(width-1);// ***统计总的类标识计数 int t=0; for(;t<totalList.size();t++) { Target tar=totalList.get(t); if(tar.label.equals(cLabel)) { tar.counts++; break; } } if(t==totalList.size()) { Target tar=new Target(); tar.label=cLabel; tar.counts++; totalList.add(tar); } for(int j=0;j<width-1;j++) { //每一个属性列 Column col=newCols.get(j); ArrayList<Attribute> attCol=col.attributes; int k=0; for(;k<attCol.size();k++) { Attribute attr=attCol.get(k); if(attr.label.equals(row.get(j))) { attr.counts++; break; } } //此属性尚未构造 if(k==attCol.size()) { Attribute attrib=new Attribute(); ArrayList<Target> tarList=attrib.targets; attrib.label=row.get(j); attrib.counts++; Target targ=new Target(); targ.label=cLabel; targ.counts++; tarList.add(targ); attCol.add(attrib); } else { Attribute attr=attCol.get(k); ArrayList<Target> tarList=attr.targets; int s=0; for(;s<tarList.size();s++) { Target target=tarList.get(s); if(target.label.equals(cLabel)) { target.counts++; break; } } //此类标识尚未构造 if(s==tarList.size()) { Target targ=new Target(); targ.label=cLabel; targ.counts++; tarList.add(targ); } } } } //计算区分整个样本的熵 double inentropy=0.0; int totalCounts=totalTargets.totalCount; for(int i=0;i<totalList.size();i++) { Target tar=totalList.get(i); int pCount=tar.counts; //每一类含有的行数 double p=(double)pCount/totalCounts; inentropy+=p*Math.log(p)/LN2; } inentropy=-1*inentropy; attList.inentropy=inentropy; attList.rowcounts=theMatrix.matrix.size(); return attList; } /** * 输出原始数据表(测试用) * */ public void showMatrix() { Matrix matrix=this.dataSet; ArrayList<RowData> rowdata=matrix.matrix; for(int i=0;i<rowdata.size();i++) { ArrayList<String> row=rowdata.get(i).rowdata; for(int j=0;j<row.size();j++) { System.out.print(row.get(j)+" "); } System.out.println(); } } /** * 开始构建ID3判定树 * @param dataFile 训练数据文件 * @param classFile 训练数据类文件 * @throws IOException */ public void start(File dataFile,File classFile) throws IOException { dataSet=ID3.buildMatrix(dataFile,attributeList); ID3.buildAttribute(classFile, dataSet,attributeList); root=this.Generate_decision_tree(this.attributeList, this.dataSet, ""); } /** * 从指定节点打印判定树结构(使用前序遍历) * */ public void printTree() { printTree(root); } /** * 打印判定树结构(使用前序遍历) * */ private void printTree(Node t) { if(t!=null) { System.out.print(t.label+" "); Node[] childs=t.childs; if(childs!=null) { for(int i=0;i<childs.length;i++) { printTree(childs[i]); } } System.out.println(); } } /** * 由测试数据文件构造测试数据集 * @param testFile 测试数据文件 * @param testData 测试数据集 * @return 测试数据集
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -