⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 id3.java

📁 id3算法代码,用C 语言实现的.小弟急需JAVA语言写的ID3算法代码,麻烦有的大哥大姐共享下,
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
/* * To change this template, choose Tools | Templates * and open the template in the editor. */package id3;import java.io.BufferedReader;import java.io.File;import java.io.FileInputStream;import java.io.FileOutputStream;import java.io.IOException;import java.io.InputStreamReader;import java.io.PrintWriter;import java.util.ArrayList;/** * * @author Lam *//** * ID3判定树类 * @author Lam * */class ID3Tree{	ID3Tree()	{		root=null;		attributeList=new AttributesList();		dataSet=new Matrix();		testData=new Matrix();		ruleList=new ArrayList<String>();	}		/**	 * 多数表决 	 * @param matrix 数据集	 * @return 数据集中最普通的类	 */	private String getMaxLabel(Matrix matrix)	{		if(matrix.matrix.size()==0)			return "";		ArrayList<Target> tarList=new ArrayList<Target>();		ArrayList<RowData> rows=matrix.matrix;		for(int i=0;i<rows.size();i++)		{			ArrayList<String> row=rows.get(i).rowdata;			String cLabel=row.get(row.size()-1);			int j=0;			for( ;j<tarList.size();j++)			{				if(tarList.get(j).label.equals(cLabel))				{					tarList.get(j).counts++;					break;				}			}			if(j==tarList.size())			{				Target tar=new Target();				tar.label=cLabel;				tar.counts++;				tarList.add(tar);			}		}		int maxTargetIndex=0;		int maxCounts=0;		for(int i=0;i<tarList.size();i++)		{			int count=tarList.get(i).counts;			if(count>maxCounts)			{				maxCounts=count;				maxTargetIndex=i;			}		}		String maxLabel=tarList.get(maxTargetIndex).label;		return maxLabel;	}		/**	 * 产生判定树	 * @param attList 属性列表	 * @param matrix  数据表	 * @param path 节点路径(最终转化为规则)	 * @return 判定树节点	 */	private Node Generate_decision_tree(AttributesList attList,Matrix matrix,String path)	{		   Node t;					String label=this.theSameClass(matrix);			//在同一类中			if(label!=null)			{				t=new Node("Leave-Class"+label,path+"->"+label,1,0);							}			//attribute_list为空			else if(attList.attributes.size()==0)			{				String maxLabel=this.getMaxLabel(matrix);				t=new Node("Leave-Class"+maxLabel,path+"->"+maxLabel,1,0);			}			else			{				int maxColIndex=0;				double maxEntropy=0.0;				ArrayList<Column> cols=attList.attributes;				for(int i=0;i<cols.size();i++)				{					double gain=attList.inentropy-entropy(cols.get(i));					if(gain>maxEntropy)					{						maxEntropy=gain;						maxColIndex=i;					}				}				Column col=cols.get(maxColIndex);				int childCounts=col.attributes.size();				t=new Node("Node-"+col.label,path+col.label,-1,childCounts);				Node[] childs=t.childs;				for(int i=0;i<childs.length;i++)				{					String subAttr=this.getAttrLabel(attList, maxColIndex, i);					Matrix matr=buildMatrix(matrix,attList,maxColIndex,i);					if(matr.matrix.size()==0)					{						String maxLabel=this.getMaxLabel(matrix);												childs[i]=new Node("Leave-"+maxLabel,t.path+":"+subAttr+"->"+maxLabel,1,0);					}					else					{						AttributesList newAttrList=this.buildAttributesList(matr,attList,maxColIndex);						childs[i]=this.Generate_decision_tree(newAttrList, matr,t.path+":"+subAttr);					}				}			}				return t;			}		/**	 * 返回属性列表中指定的某列的某个子属性	 * @param attList 属性列表	 * @param colIndex 属性列位置	 * @param attIndex  子属性位置	 * @return	 */	private String getAttrLabel(AttributesList attList,int colIndex,int attIndex)	{		ArrayList<Column> cols=attList.attributes;		Column col=cols.get(colIndex);		Attribute attr=col.attributes.get(attIndex);		return attr.label;	}		/**	 * 从指定节点产生判定树规则	 * @param t	 */	private void genRule(Node t)	{		if(t!=null)		{			if(t.decision==1)  //属于叶节点,记下它的路径			{				ruleList.add(t.path);				return;			}			else //递归查找叶节点			{				Node[] childs=t.childs;				for(int i=0;i<childs.length;i++)				{					genRule(childs[i]);				}			}		}	}		/**	 * 将规则转换为特定形式(类似于0:1->1)	 *	 */	private void checkRules()	{		genRule(root);		for(int i=0;i<ruleList.size();i++)		{			String s=ruleList.get(i);			s=s.substring(9);			s=s.replaceAll("Attribute", "+");			ruleList.set(i, s);		}	}		/**	 * 输出判定树规则的特定形式	 *	 */	public void showRules() throws IOException	{		checkRules();		System.out.println("规则列表:");		File ruleFile=new File("rule.txt");		if(ruleFile.exists())			ruleFile.delete();		FileOutputStream fout=new FileOutputStream(ruleFile);		PrintWriter pWriter=new PrintWriter(fout);		pWriter.println("ID3分类规则列表:");		for(String s:ruleList)		{			pWriter.println(s);		}		pWriter.close();		for(String s:ruleList)		{			System.out.println(s);		}                System.out.println("结果同时保存到文件\"rule.txt\"中.");	}			/**	 * 判断所有数据是否属于同一类	 * @param matrix 数据表	 * @return  此类的标识	 */	private String theSameClass(Matrix matrix)	{		int labelIndex=matrix.width-1;		ArrayList<RowData> rows=matrix.matrix;		String cLabel=rows.get(0).rowdata.get(labelIndex);		for(int i=0;i<rows.size();i++)		{			ArrayList<String> row=rows.get(i).rowdata;			if(!row.get(labelIndex).equals(cLabel))					return null;		}		return cLabel;	}			/**	 * 计算某个主属性的属性增益	 * @param col 主属性	 * @return  属性增益值	 */	private double entropy(Column col)	{		//TotalTarget totalTargets=new TotalTarget();		//ArrayList<Target> totalList=totalTargets.totalTargets;		ArrayList<Attribute> colAttr=col.attributes;		int totalCount=0;		//属性增益		double EAttr=0.0;		for(int i=0;i<colAttr.size();i++)		{			Attribute attr=colAttr.get(i);			//数据总项数			totalCount+=attr.counts;		}		for(int i=0;i<colAttr.size();i++)		{			//对每个子属性计算属性增益			Attribute attr=colAttr.get(i);			int sC=attr.counts;			ArrayList<Target> tarList=attr.targets;			double inentropy=0.0;			for(int j=0;j<tarList.size();j++)			{				int pC=tarList.get(j).counts;				double p=(double)pC/sC;				inentropy+=p*Math.log(p)/LN2;			}			inentropy=-1*inentropy;			//sc/totalCount属性增益权重			EAttr+=inentropy*((double)sC/totalCount);		}		return EAttr;	}			/**	 * 由指定的主属性,筛选某个子属性,构造新的数据Matrix	 * @param matrix 旧的数据集	 * @param attList 属性列表	 * @param col_index 主属性下标	 * @param attr_index 主属性下的子属性下标	 * @return 满足条件的新的数据集	 */	private Matrix buildMatrix(Matrix matrix,AttributesList attList,int col_index,int attr_index)	{		//实际上是实现算法中的在一个test_attribute中取出一个分支的数据		ArrayList<Column> cols=attList.attributes;		ArrayList<Attribute> colAttr=cols.get(col_index).attributes;		Attribute attr=colAttr.get(attr_index);		String label=attr.label;		ArrayList<RowData> oldData=matrix.matrix;		Matrix newMatrix=new Matrix();		ArrayList<RowData> newRowData=newMatrix.matrix;		newMatrix.setWidth(matrix.width-1);		for(int i=0;i<oldData.size();i++)		{			ArrayList<String> row=oldData.get(i).rowdata;			RowData newrowdata=new RowData();			ArrayList<String> newrow=newrowdata.rowdata;			//只选择具有这个属性的行			if(row.get(col_index).equals(label))			{				for(int j=0;j<row.size();j++)				{					if(j!=col_index)						newrow.add(row.get(j));				}				newRowData.add(newrowdata);			}		}		return newMatrix;	}			/**         * 根据删除了主属性col_index的数据集产生新的属性列表         * @param theMatrix 数据集         * @param oldattList 旧的属性列表         * @param col_index 主属性下标         * @return 新的属性列表         */	private AttributesList buildAttributesList(Matrix theMatrix,AttributesList oldattList,int col_index)	{            //实际上是构造一个除去test_attribute的attribute_list,即attribute_list-test_attribute            //由于上一步的buildMatrix在生成满足test_attribute的数据集theMatrix时,同时将test_attribute这一列数据删除,            //而保持了列之间顺序的不变,在这里只需使用旧的列标识从theMatrix中直接生成AttrubutesList		AttributesList attList=new AttributesList();		TotalTarget totalTargets=new TotalTarget();		ArrayList<Target> totalList=totalTargets.totalTargets;		//属性列集合		ArrayList<Column> newCols=attList.attributes;				ArrayList<Column> oldCols=oldattList.attributes;		for(int i=0;i<oldCols.size();i++)		{			//构造新的列,保留除了要删除的列之外的旧列的标识			if(i!=col_index)				{					Column col=new Column();					col.label=oldCols.get(i).label;					newCols.add(col);				}		}		//行数		int height=theMatrix.matrix.size();		//列数		int width=theMatrix.width;		//行数据集合		ArrayList<RowData> rows=theMatrix.matrix;		for(int i=0;i<height;i++)		{			totalTargets.totalCount++;			//行数据			ArrayList<String> row=rows.get(i).rowdata;			//类标识,链表最后一个元素			String cLabel=row.get(width-1);//			***统计总的类标识计数			int t=0;			for(;t<totalList.size();t++)			{				Target tar=totalList.get(t);				if(tar.label.equals(cLabel))				{					tar.counts++;					break;				}			}			if(t==totalList.size())			{				Target tar=new Target();				tar.label=cLabel;				tar.counts++;				totalList.add(tar);			}						for(int j=0;j<width-1;j++)			{								//每一个属性列				Column col=newCols.get(j);				ArrayList<Attribute> attCol=col.attributes;				int k=0;				for(;k<attCol.size();k++)				{					Attribute attr=attCol.get(k);					if(attr.label.equals(row.get(j)))					{						attr.counts++;						break;					}				}				//此属性尚未构造				if(k==attCol.size())				{					Attribute attrib=new Attribute();					ArrayList<Target> tarList=attrib.targets;					attrib.label=row.get(j);					attrib.counts++;					Target targ=new Target();					targ.label=cLabel;					targ.counts++;					tarList.add(targ);					attCol.add(attrib);									}				else				{					Attribute attr=attCol.get(k);					ArrayList<Target> tarList=attr.targets;					int s=0;					for(;s<tarList.size();s++)					{						Target target=tarList.get(s);						if(target.label.equals(cLabel))						{							target.counts++;							break;						}					}					//此类标识尚未构造					if(s==tarList.size())					{						Target targ=new Target();						targ.label=cLabel;						targ.counts++;						tarList.add(targ);					}				}			}		}                //计算区分整个样本的熵		double inentropy=0.0;		int totalCounts=totalTargets.totalCount;		for(int i=0;i<totalList.size();i++)		{			Target tar=totalList.get(i);			int pCount=tar.counts; //每一类含有的行数			double p=(double)pCount/totalCounts;			inentropy+=p*Math.log(p)/LN2;		}		inentropy=-1*inentropy;		attList.inentropy=inentropy;		attList.rowcounts=theMatrix.matrix.size();				return attList;	}			/**	 * 输出原始数据表(测试用)	 *	 */	public void showMatrix()	{				Matrix matrix=this.dataSet;		ArrayList<RowData> rowdata=matrix.matrix;						for(int i=0;i<rowdata.size();i++)		{			ArrayList<String> row=rowdata.get(i).rowdata;									for(int j=0;j<row.size();j++)			{				System.out.print(row.get(j)+" ");			}			System.out.println();					}					}			/**	 * 开始构建ID3判定树	 * @param dataFile 训练数据文件	 * @param classFile 训练数据类文件	 * @throws IOException	 */	public void start(File dataFile,File classFile) throws IOException	{		dataSet=ID3.buildMatrix(dataFile,attributeList);		ID3.buildAttribute(classFile, dataSet,attributeList);		root=this.Generate_decision_tree(this.attributeList, this.dataSet, "");	}			/**	 * 从指定节点打印判定树结构(使用前序遍历)	 *	 */	public void printTree()	{		printTree(root);	}			/**	 * 打印判定树结构(使用前序遍历)	 *	 */	private void printTree(Node t)	{		if(t!=null)		{			System.out.print(t.label+" ");			Node[] childs=t.childs;			if(childs!=null)			{				for(int i=0;i<childs.length;i++)				{					printTree(childs[i]);				}			}			System.out.println();		}	}			/**	 * 由测试数据文件构造测试数据集	 * @param testFile 测试数据文件	 * @param testData 测试数据集	 * @return 测试数据集

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -