⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 id3.java

📁 id3算法代码,用C 语言实现的.小弟急需JAVA语言写的ID3算法代码,麻烦有的大哥大姐共享下,
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
	 * @throws IOException	 */	private Matrix buildTestMatrix(File testFile,Matrix testData) throws IOException	{		FileInputStream fin=new FileInputStream(testFile);		BufferedReader reader=new BufferedReader(new InputStreamReader(fin));		String line=reader.readLine();		ArrayList<RowData> rows=testData.matrix;		while(line!=null)		{			RowData rowdata=new RowData();			String[] items=line.split(" ");			ArrayList<String> row=rowdata.rowdata;			for(int i=0;i<items.length;i++)				row.add(items[i]);			rows.add(rowdata);			line=reader.readLine();				}		reader.close();		return testData;			}			/**	 * 由规则列表对测试数据集进行分类	 * @param testData 测试数据集	 * @param ruleList 规则列表	 * @return 带有类标识的数据集	 */	private Matrix genTestMatrix(Matrix testData,ArrayList<String> ruleList)	{		ArrayList<RowData> rows=testData.matrix;		for(int i=0;i<rows.size();i++)		{			ArrayList<String> row=rows.get(i).rowdata;			String target="";			//对每个规则,检验是否匹配			for(int j=0;j<ruleList.size();j++)			{				String rule=ruleList.get(j);				int labelIndex=rule.indexOf(":");				while(labelIndex!=-1)				{					//冒号前的字符表示主属性的下标					char carrayIndex=rule.charAt(labelIndex-1);					//冒号后的字符表示主属性下的子属性标识					String label=""+rule.charAt(labelIndex+1);					int arrayIndex=Integer.parseInt(""+carrayIndex);					if(row.get(arrayIndex).equals(label))					{						rule=rule.substring(labelIndex+1);						labelIndex=rule.indexOf(":");					}					else					{						break;					}				}				//成功的匹配				if(labelIndex==-1)				{					target+=rule.charAt(rule.length()-1);					break;				}			} //end for(int j=0;i<ruleList			//将类标识添加到测试数据集末尾			row.add(target);		}		return testData;	}			/**	 * 测试分类的正确性	 * @param testData 测试数据集	 * @param testClassFile 目标类测试文件	 * @throws IOException	 */	private void checkClass(Matrix testData,File testClassFile) throws IOException	{		FileInputStream fin=new FileInputStream(testClassFile);		BufferedReader reader=new BufferedReader(new InputStreamReader(fin));		String line=reader.readLine();		ArrayList<String> classList=new ArrayList<String>();		int totalCounts=0; //数据行总数		int correctCounts=0; //正确标识的行数		while(line!=null)		{			totalCounts++;			classList.add(line);			line=reader.readLine();		}		reader.close();		ArrayList<RowData> rows=testData.matrix;		for(int i=0;i<rows.size();i++)		{			ArrayList<String> row=rows.get(i).rowdata;			String cLabel=row.get(row.size()-1);			if(cLabel.equals(classList.get(i)))				correctCounts++;		}		float correctPercent=(float)correctCounts/totalCounts;		System.out.println("分类正确率:"+correctPercent*100+"%");	}			/**	 * 测试分类正确性(外部接口)	 * @param testFile	 * @param testClassFile	 * @throws IOException	 */	public void checkClassification(File testFile,File testClassFile) throws IOException	{		testData=buildTestMatrix(testFile, testData);		testData=genTestMatrix(testData, ruleList);		this.checkClass(testData, testClassFile);	}		private AttributesList attributeList;      //属性列表	private Matrix dataSet;                    //训练数据集	private Matrix testData;                   //测试数据集	private ArrayList<String> ruleList;        //规则列表	private final double LN2=Math.log(2.0);   //ln(2)	Node root;                                  //根节点}/** * 判定树节点类 * @author Lam * */class Node{	Node()	{		this("",-1,0);	}		/**	 * 	 * @param theLabel 节点标识	 * @param theDecision 是否叶节点	 * @param childCounts 分支数目	 */	Node(String theLabel,int theDecision,int childCounts)	{		label=theLabel;		decision=theDecision;		if(childCounts>0)		{			childs=new Node[childCounts];		}	}		/**	 * 	 * @param theLabel 节点标识	 * @param thePath  节点路径	 * @param theDecision 是否叶节点	 * @param childCounts 分支数目	 */	Node(String theLabel,String thePath,int theDecision,int childCounts)	{		this.label=theLabel;		this.path=thePath;		this.decision=theDecision;		if(childCounts>0)		{			childs=new Node[childCounts];		}	}		String label; //节点标识	String path; //节点路径	int decision; //区分节点标志,-1是普通节点,1是叶节点	Node[] childs; //分支}/** * 数据集类 * @author Lam * */class Matrix{	Matrix()	{		matrix=new ArrayList<RowData>();	}		void setWidth(int theWidth)	{		width=theWidth;	}		int width;	ArrayList<RowData> matrix;}/** * 行数据 * @author Lam * */class RowData{	RowData()	{		rowdata=new ArrayList<String>();	}	ArrayList<String> rowdata;}/** * 属性列表 * @author Lam * */class AttributesList{	AttributesList()	{		attributes=new ArrayList<Column>();		rowcounts=0;		inentropy=0.0;	}	int rowcounts;  //数据集行数	double inentropy; //类区分整个数据集的熵	ArrayList<Column> attributes;	}/** * 属性列 * @author Lam * */class Column{	Column()	{		attributes=new ArrayList<Attribute>();		label="";	}	ArrayList<Attribute> attributes; //列下的属性列表	String label; //列标识}/** * 主属性 * @author Lam * */class Attribute{	Attribute()	{		targets=new ArrayList<Target>();	}		int counts; //属性包含元素个数	String label; //属性标识	ArrayList<Target> targets; //属性下的目标分类}/** * 目标类 * @author Lam * */class Target{	String label; //目标类标识	int counts;   //目标类包含的元素个数}/** * 基于某个属性列表的目标类列表 * 用于计算对于一个样本分类所需的信息I * @author Lam * */class TotalTarget{	TotalTarget()	{		totalTargets=new ArrayList<Target>();		totalCount=0;	}	ArrayList<Target> totalTargets; //目标类列表	int totalCount;  //总的元素个数}public class ID3 {		public ID3()	{		//dataSet=new Matrix();		//attrList=new AttributesList();		//totalTarget=new ArrayList<Target>();	}	/**	 * 从数据文件中构造数据表Matrix	 * @param dataFile 数据文件	 * @param attrList 属性列表	 * @return	 * @throws IOException	 */	static Matrix buildMatrix(File dataFile,AttributesList attrList) throws IOException	{		FileInputStream fin=new FileInputStream(dataFile);		BufferedReader reader=new BufferedReader(new InputStreamReader(fin));		String line=reader.readLine();		String[] items;		Matrix matrixD=new Matrix();		RowData rowD;		items=line.split(" ");		//属性列表初始化		for(int i=0;i<items.length;i++)		{			Column col=new Column();			col.label="Attribute"+i;			attrList.attributes.add(col);					}		matrixD.setWidth(items.length);		while(line!=null)		{			items=line.split(" ");			rowD=new RowData();			for(int i=0;i<items.length;i++)			{								rowD.rowdata.add(items[i]);			}			matrixD.matrix.add(rowD);			line=reader.readLine();					}		reader.close();		return matrixD;	}			/**	 * 读取类文件以及Matrix数据表	 * 构造属性表	 * @param classFile 类文件	 * @param theData   数据表	 * @throws IOException	 */	static void buildAttribute(File classFile,Matrix theData,AttributesList attrList) throws IOException	{		FileInputStream fin=new FileInputStream(classFile);		BufferedReader reader=new BufferedReader(new InputStreamReader(fin));		String line=reader.readLine();		//AttributesList attrList=new AttributesList();		//类		Target ta;		TotalTarget totalTargets=new TotalTarget();		ArrayList<Target> totalList=totalTargets.totalTargets;		//列集(每列是一个主属性,包含若干子属性)		ArrayList<Column> cols=attrList.attributes;		//行数		int lncount=0;		while(line!=null)		{			//*******目标各属性的总个数			int t=0;			for(;t<totalList.size();t++)			{				Target tag=totalList.get(t);				if(tag.label.equals(line))				{					totalTargets.totalCount++;					tag.counts++;					break;				}			}			//类标识尚未建立			if(t==totalList.size())			{				ta=new Target();				ta.label=line;				ta.counts++;				totalList.add(ta);				totalTargets.totalCount++;			}			//*******//			//*******读取Matrix的每一行,构造Attribute			ArrayList<RowData> rows=theData.matrix;			//			ArrayList<String> rowdata=rows.get(lncount++).rowdata;			//row的长度和cols的长度一样			for(int i=0;i<cols.size();i++)			{				int j=0;				//取得每一列的引用				Column col=cols.get(i);				//列的子属性列表				ArrayList<Attribute> attCol=col.attributes;												for( ;j<attCol.size();j++)				{										//列的第j个属性					Attribute attr=attCol.get(j);										//					if(attr.label.equals(rowdata.get(i)))					{						attr.counts++;						break;					}				} //end for( ;j<attCol.size				//这个属性尚未构建				if(j==attCol.size())				{					Attribute attr=new Attribute();					attr.label=rowdata.get(i);					attr.counts++;					//设置子属性的目标类个数					ArrayList<Target> tarList=attr.targets;					Target target=new Target();					target.label=line;					target.counts++;					tarList.add(target);										attCol.add(attr);				}				else				{					//找到这个属性,对属性下的子属性的类标识计数					int k=0;					Attribute attr=attCol.get(j);					ArrayList<Target> tarList=attr.targets;					for(;k<tarList.size();k++)					{						Target tar=tarList.get(k);						if(tar.label.equals(line))						{							tar.counts++;							break;						}					}					//在Attribute表中找不到这个目标属性					//Attribute表为空时					if(k==tarList.size())					{						Target newtar=new Target();						newtar.label=line;						newtar.counts++;						tarList.add(newtar);					}				}							}//			将类标识加入到Matrix中			rowdata.add(line);			//读取下一行			line=reader.readLine();					}		reader.close();		//加入类标识,列数加1		theData.setWidth(theData.width+1);		//计算区分整个数据集的熵		int totalCounts=totalTargets.totalCount;		double inentropy=0.0;		for(int i=0;i<totalList.size();i++)		{			Target tar=totalList.get(i);			int pCount=tar.counts;			double p=(double)pCount/totalCounts;			inentropy+=p*Math.log(p)/LN2;		}		inentropy=-1*inentropy;		attrList.inentropy=inentropy;		attrList.rowcounts=theData.matrix.size();		//return attrList;	}		private static final double LN2=Math.log(2.0); //ln(2)			/**	 * 数据集测试	 * @param args	 */	public static void main(String[] args)throws IOException{		// TODO 自动生成方法存根		String dataS="Train_Attr.txt";		File dataFile=new File(dataS);		if(!dataFile.exists())		{			System.err.println("在当前目录下找不到文件Train_Attr.txt!");			System.exit(-1);		}		String classS="Train_Class.txt";		File classFile=new File(classS);		if(!classFile.exists())		{			System.err.println("在当前目录下找不到文件Train_Class.txt!");			System.exit(-1);		}		String testDataS="Test_Attr.txt";		File testDataFile=new File(testDataS);		if(!testDataFile.exists())		{			System.err.println("在当前目录下找不到文件Test_Attr.txt!");			System.exit(-1);		}		String testClassS="Test_Class.txt";		File testClassFile=new File(testClassS);		if(!testClassFile.exists())		{			System.err.println("在当前目录下找不到文件Test_Class.txt!");			System.exit(-1);		}		ID3Tree id3tree=new ID3Tree();		System.out.println("开始构建判定树...");		id3tree.start(dataFile, classFile);		//id3tree.start(testDataFile, testClassFile);		System.out.println("由训练数据集构建判定树完成.以下是导出的自定义规则:");		//id3tree.printTree();		id3tree.showRules();		System.out.println("测试分类正确性...");		id3tree.checkClassification(testDataFile, testClassFile);			}}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -