📄 id3.java
字号:
* @throws IOException */ private Matrix buildTestMatrix(File testFile,Matrix testData) throws IOException { FileInputStream fin=new FileInputStream(testFile); BufferedReader reader=new BufferedReader(new InputStreamReader(fin)); String line=reader.readLine(); ArrayList<RowData> rows=testData.matrix; while(line!=null) { RowData rowdata=new RowData(); String[] items=line.split(" "); ArrayList<String> row=rowdata.rowdata; for(int i=0;i<items.length;i++) row.add(items[i]); rows.add(rowdata); line=reader.readLine(); } reader.close(); return testData; } /** * 由规则列表对测试数据集进行分类 * @param testData 测试数据集 * @param ruleList 规则列表 * @return 带有类标识的数据集 */ private Matrix genTestMatrix(Matrix testData,ArrayList<String> ruleList) { ArrayList<RowData> rows=testData.matrix; for(int i=0;i<rows.size();i++) { ArrayList<String> row=rows.get(i).rowdata; String target=""; //对每个规则,检验是否匹配 for(int j=0;j<ruleList.size();j++) { String rule=ruleList.get(j); int labelIndex=rule.indexOf(":"); while(labelIndex!=-1) { //冒号前的字符表示主属性的下标 char carrayIndex=rule.charAt(labelIndex-1); //冒号后的字符表示主属性下的子属性标识 String label=""+rule.charAt(labelIndex+1); int arrayIndex=Integer.parseInt(""+carrayIndex); if(row.get(arrayIndex).equals(label)) { rule=rule.substring(labelIndex+1); labelIndex=rule.indexOf(":"); } else { break; } } //成功的匹配 if(labelIndex==-1) { target+=rule.charAt(rule.length()-1); break; } } //end for(int j=0;i<ruleList //将类标识添加到测试数据集末尾 row.add(target); } return testData; } /** * 测试分类的正确性 * @param testData 测试数据集 * @param testClassFile 目标类测试文件 * @throws IOException */ private void checkClass(Matrix testData,File testClassFile) throws IOException { FileInputStream fin=new FileInputStream(testClassFile); BufferedReader reader=new BufferedReader(new InputStreamReader(fin)); String line=reader.readLine(); ArrayList<String> classList=new ArrayList<String>(); int totalCounts=0; //数据行总数 int correctCounts=0; //正确标识的行数 while(line!=null) { totalCounts++; classList.add(line); line=reader.readLine(); } reader.close(); ArrayList<RowData> rows=testData.matrix; for(int i=0;i<rows.size();i++) { ArrayList<String> row=rows.get(i).rowdata; String cLabel=row.get(row.size()-1); if(cLabel.equals(classList.get(i))) correctCounts++; } float correctPercent=(float)correctCounts/totalCounts; System.out.println("分类正确率:"+correctPercent*100+"%"); } /** * 测试分类正确性(外部接口) * @param testFile * @param testClassFile * @throws IOException */ public void checkClassification(File testFile,File testClassFile) throws IOException { testData=buildTestMatrix(testFile, testData); testData=genTestMatrix(testData, ruleList); this.checkClass(testData, testClassFile); } private AttributesList attributeList; //属性列表 private Matrix dataSet; //训练数据集 private Matrix testData; //测试数据集 private ArrayList<String> ruleList; //规则列表 private final double LN2=Math.log(2.0); //ln(2) Node root; //根节点}/** * 判定树节点类 * @author Lam * */class Node{ Node() { this("",-1,0); } /** * * @param theLabel 节点标识 * @param theDecision 是否叶节点 * @param childCounts 分支数目 */ Node(String theLabel,int theDecision,int childCounts) { label=theLabel; decision=theDecision; if(childCounts>0) { childs=new Node[childCounts]; } } /** * * @param theLabel 节点标识 * @param thePath 节点路径 * @param theDecision 是否叶节点 * @param childCounts 分支数目 */ Node(String theLabel,String thePath,int theDecision,int childCounts) { this.label=theLabel; this.path=thePath; this.decision=theDecision; if(childCounts>0) { childs=new Node[childCounts]; } } String label; //节点标识 String path; //节点路径 int decision; //区分节点标志,-1是普通节点,1是叶节点 Node[] childs; //分支}/** * 数据集类 * @author Lam * */class Matrix{ Matrix() { matrix=new ArrayList<RowData>(); } void setWidth(int theWidth) { width=theWidth; } int width; ArrayList<RowData> matrix;}/** * 行数据 * @author Lam * */class RowData{ RowData() { rowdata=new ArrayList<String>(); } ArrayList<String> rowdata;}/** * 属性列表 * @author Lam * */class AttributesList{ AttributesList() { attributes=new ArrayList<Column>(); rowcounts=0; inentropy=0.0; } int rowcounts; //数据集行数 double inentropy; //类区分整个数据集的熵 ArrayList<Column> attributes; }/** * 属性列 * @author Lam * */class Column{ Column() { attributes=new ArrayList<Attribute>(); label=""; } ArrayList<Attribute> attributes; //列下的属性列表 String label; //列标识}/** * 主属性 * @author Lam * */class Attribute{ Attribute() { targets=new ArrayList<Target>(); } int counts; //属性包含元素个数 String label; //属性标识 ArrayList<Target> targets; //属性下的目标分类}/** * 目标类 * @author Lam * */class Target{ String label; //目标类标识 int counts; //目标类包含的元素个数}/** * 基于某个属性列表的目标类列表 * 用于计算对于一个样本分类所需的信息I * @author Lam * */class TotalTarget{ TotalTarget() { totalTargets=new ArrayList<Target>(); totalCount=0; } ArrayList<Target> totalTargets; //目标类列表 int totalCount; //总的元素个数}public class ID3 { public ID3() { //dataSet=new Matrix(); //attrList=new AttributesList(); //totalTarget=new ArrayList<Target>(); } /** * 从数据文件中构造数据表Matrix * @param dataFile 数据文件 * @param attrList 属性列表 * @return * @throws IOException */ static Matrix buildMatrix(File dataFile,AttributesList attrList) throws IOException { FileInputStream fin=new FileInputStream(dataFile); BufferedReader reader=new BufferedReader(new InputStreamReader(fin)); String line=reader.readLine(); String[] items; Matrix matrixD=new Matrix(); RowData rowD; items=line.split(" "); //属性列表初始化 for(int i=0;i<items.length;i++) { Column col=new Column(); col.label="Attribute"+i; attrList.attributes.add(col); } matrixD.setWidth(items.length); while(line!=null) { items=line.split(" "); rowD=new RowData(); for(int i=0;i<items.length;i++) { rowD.rowdata.add(items[i]); } matrixD.matrix.add(rowD); line=reader.readLine(); } reader.close(); return matrixD; } /** * 读取类文件以及Matrix数据表 * 构造属性表 * @param classFile 类文件 * @param theData 数据表 * @throws IOException */ static void buildAttribute(File classFile,Matrix theData,AttributesList attrList) throws IOException { FileInputStream fin=new FileInputStream(classFile); BufferedReader reader=new BufferedReader(new InputStreamReader(fin)); String line=reader.readLine(); //AttributesList attrList=new AttributesList(); //类 Target ta; TotalTarget totalTargets=new TotalTarget(); ArrayList<Target> totalList=totalTargets.totalTargets; //列集(每列是一个主属性,包含若干子属性) ArrayList<Column> cols=attrList.attributes; //行数 int lncount=0; while(line!=null) { //*******目标各属性的总个数 int t=0; for(;t<totalList.size();t++) { Target tag=totalList.get(t); if(tag.label.equals(line)) { totalTargets.totalCount++; tag.counts++; break; } } //类标识尚未建立 if(t==totalList.size()) { ta=new Target(); ta.label=line; ta.counts++; totalList.add(ta); totalTargets.totalCount++; } //*******// //*******读取Matrix的每一行,构造Attribute ArrayList<RowData> rows=theData.matrix; // ArrayList<String> rowdata=rows.get(lncount++).rowdata; //row的长度和cols的长度一样 for(int i=0;i<cols.size();i++) { int j=0; //取得每一列的引用 Column col=cols.get(i); //列的子属性列表 ArrayList<Attribute> attCol=col.attributes; for( ;j<attCol.size();j++) { //列的第j个属性 Attribute attr=attCol.get(j); // if(attr.label.equals(rowdata.get(i))) { attr.counts++; break; } } //end for( ;j<attCol.size //这个属性尚未构建 if(j==attCol.size()) { Attribute attr=new Attribute(); attr.label=rowdata.get(i); attr.counts++; //设置子属性的目标类个数 ArrayList<Target> tarList=attr.targets; Target target=new Target(); target.label=line; target.counts++; tarList.add(target); attCol.add(attr); } else { //找到这个属性,对属性下的子属性的类标识计数 int k=0; Attribute attr=attCol.get(j); ArrayList<Target> tarList=attr.targets; for(;k<tarList.size();k++) { Target tar=tarList.get(k); if(tar.label.equals(line)) { tar.counts++; break; } } //在Attribute表中找不到这个目标属性 //Attribute表为空时 if(k==tarList.size()) { Target newtar=new Target(); newtar.label=line; newtar.counts++; tarList.add(newtar); } } }// 将类标识加入到Matrix中 rowdata.add(line); //读取下一行 line=reader.readLine(); } reader.close(); //加入类标识,列数加1 theData.setWidth(theData.width+1); //计算区分整个数据集的熵 int totalCounts=totalTargets.totalCount; double inentropy=0.0; for(int i=0;i<totalList.size();i++) { Target tar=totalList.get(i); int pCount=tar.counts; double p=(double)pCount/totalCounts; inentropy+=p*Math.log(p)/LN2; } inentropy=-1*inentropy; attrList.inentropy=inentropy; attrList.rowcounts=theData.matrix.size(); //return attrList; } private static final double LN2=Math.log(2.0); //ln(2) /** * 数据集测试 * @param args */ public static void main(String[] args)throws IOException{ // TODO 自动生成方法存根 String dataS="Train_Attr.txt"; File dataFile=new File(dataS); if(!dataFile.exists()) { System.err.println("在当前目录下找不到文件Train_Attr.txt!"); System.exit(-1); } String classS="Train_Class.txt"; File classFile=new File(classS); if(!classFile.exists()) { System.err.println("在当前目录下找不到文件Train_Class.txt!"); System.exit(-1); } String testDataS="Test_Attr.txt"; File testDataFile=new File(testDataS); if(!testDataFile.exists()) { System.err.println("在当前目录下找不到文件Test_Attr.txt!"); System.exit(-1); } String testClassS="Test_Class.txt"; File testClassFile=new File(testClassS); if(!testClassFile.exists()) { System.err.println("在当前目录下找不到文件Test_Class.txt!"); System.exit(-1); } ID3Tree id3tree=new ID3Tree(); System.out.println("开始构建判定树..."); id3tree.start(dataFile, classFile); //id3tree.start(testDataFile, testClassFile); System.out.println("由训练数据集构建判定树完成.以下是导出的自定义规则:"); //id3tree.printTree(); id3tree.showRules(); System.out.println("测试分类正确性..."); id3tree.checkClassification(testDataFile, testClassFile); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -