⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 apriori.java

📁 数据挖掘关联规则Apiori算法的java实现,关于算法可参考韩家伟的数据挖掘教程
💻 JAVA
字号:
/**
 * Copyright (c) 2003 HCYCOM, Inc. All Rights Reserved.
 *
 * A Java application of Apriori algorithm.
 *
 * Make sure that your DBMS Service is running before the OS loads the program.
 * Either SQLServer2000 or MySQL4.0 is your good start.
 *
 * @author Zengli, China Agricultural University.
 * @version 0.01 03/07/09
 * 
 * mailto:zengli@hcycom.com
 */
 
import java.sql.*;
import java.util.*;
import java.math.*;

public class Apriori
{
	Connection conn;
	int rowCount;
	private String url,user,passwd;
	String[] itemList = null;
	int[] itemCount = null;
	final double MIN_SUP = 0.4;//最小支持度s
	final double MIN_CONF = 0.5;//最小置信度c
	
	PreparedStatement pstmt = null;
	ResultSet rs = null;			//定义结果集	
	
	PreparedStatement ckPstmt = null;
	ResultSet ckRs = null;	
	PreparedStatement lkPstmt = null;
	
	//构造函数
	Apriori()
	{		
		url = "jdbc:odbc:AprioriMySQL";//定义JDBC-ODBC桥的URL
		user = "root";	
		passwd = "";	
		
		conn = null;					
		rowCount = 0;
		
		//预设项列表(商品集)={啤酒,尿布,擦身粉,面包,雨伞,牛奶,洗衣粉,可乐}
		itemList = new String[]{"beer","diaper","toiletpowder","bread","umbrella","milk","washingpowder","coke"};
		itemCount = new int[itemList.length];
	}
		
	public static void main(String[] args)
	{				
		Apriori a = new Apriori();
		a.execute();
	}
	
	//加载指定的数据库驱动程序
	void loadJdbcOdbcDriver(String driverName)
	{
		try{
			System.out.println("加载数据库驱动程序 ...OK!");
			Class.forName(driverName);
		}
		catch(java.lang.ClassNotFoundException e){
			System.err.print("ClassNotFoundException:");
			System.err.println(e.getMessage());
		}	
	}
	
	//Apriori算法实现
	void execute()
	{	
		
		//加载数据库驱动程序				
		loadJdbcOdbcDriver("sun.jdbc.odbc.JdbcOdbcDriver");
		
		System.out.println("初次扫描数据库commodity ...OK!\n");
		try{
			//建立连接ODBC数据源
			conn = DriverManager.getConnection(url,user,passwd);
			
			//提交连接请求
			conn.setAutoCommit(false);	
			
			//------- 初次扫描数据库commodity的transaction表,对每个候选(即商品项)以及事务(即交易)计数 ------
			
			pstmt = conn.prepareStatement("SELECT * FROM transaction",ResultSet.TYPE_SCROLL_SENSITIVE,ResultSet.CONCUR_UPDATABLE);				
			rs = pstmt.executeQuery();
			
			int id;
			String commodity_list;			
			Statement stmt5 = null;
			Statement stmt8 = null;
			Statement stmt9 = null;
			stmt5 = conn.createStatement();
			stmt5.execute("DELETE FROM lk");			
			
			for (int i = 0; i < itemList.length; i++){
				rowCount = 0;
				while (rs.next()){
					id = rs.getInt(1);	
					commodity_list = rs.getString("commodity_list");					
					if (commodity_list.indexOf(itemList[i]) != -1) itemCount[i]++;//对每个候选计数
					rowCount++;	
				}
				System.out.println(itemList[i] + "\t\t\t(" + itemCount[i] + ")");
				
				//比较候选支持度计数与最小支持度计数,得出L1表				
				if (itemCount[i] >= (int)(MIN_SUP * rowCount)){
					pstmt = conn.prepareStatement("INSERT INTO Lk(itemsets,sup_count) VALUES(?,?)",ResultSet.TYPE_SCROLL_SENSITIVE,ResultSet.CONCUR_UPDATABLE);						
					pstmt.setString(1,itemList[i]);
					pstmt.setInt(2,itemCount[i]);
					pstmt.executeUpdate();	
				}  
				rs.beforeFirst();
			}
			
			//------- 算法核心 : 使用候选项集Ck找频繁项集Lk ------- 

			for (int k = 2; ; k++){

				if (!aproiri_gen(k - 1, MIN_SUP)) break;//连接-剪枝生成候选集Ck(即Ck表)
				
				//读取Ck表的每一个候选项集
				pstmt = conn.prepareStatement("SELECT * FROM ck",ResultSet.TYPE_SCROLL_SENSITIVE,ResultSet.CONCUR_UPDATABLE);
				rs = pstmt.executeQuery();	
				
				stmt8 = conn.createStatement();
				stmt8.execute("DELETE FROM lk");//清空旧的Lk表(即Lk-1表)			
				while (rs.next()){
					String ck = rs.getString("itemsets");//读出Ck中每个候选项集(即Ck表的第一个字段)
					
					//将Ck当前候选项集中各元素分离并复制到Vector向量的各分量中
					Vector ckItemList = new Vector(k);
					ckItemList = splitItemList(k, ck);
					
					int ckSupCnt = 0;//Ck项集支持度
					
					//扫描数据库commodity的transaction表,为Ck中的每个候选项集计数
					ckPstmt = conn.prepareStatement("SELECT * FROM transaction",ResultSet.TYPE_SCROLL_SENSITIVE,ResultSet.CONCUR_UPDATABLE);
					ckRs = ckPstmt.executeQuery();				
					while (ckRs.next()){//循环读入transaction表的记录
						String s = ckRs.getString("commodity_list");//得到commodity_list字段内容
						System.out.println("commodity_field = "+s);
						int j = 0;
						for (; j < k; j++){
							if (s.indexOf(ckItemList.elementAt(j).toString()) == -1) break;//扫描当前记录,判断是否同时存在Ck当前候选项集中的所有元素							
						}
						
						if (j >= k) ckSupCnt++;//Ck候选项集支持度计数
					}					
					
					//比较候选支持度计数与最小支持度计数,生成新的Lk表
					System.out.println("ckSupCnt = " + ckSupCnt +" (int)(MIN_SUP * rowCount)=" +(int)(MIN_SUP * rowCount));
					if (ckSupCnt >= (int)(MIN_SUP * rowCount)){
						stmt9 = conn.createStatement();
						stmt9.execute("INSERT INTO lk(itemsets,sup_count) VALUES('" + ck + "'," + ckSupCnt + ")");
					}					
				}				
			}
				
			//操作语句提交
			conn.commit();
			conn.setAutoCommit(true);
			
			if (ckPstmt != null) ckPstmt.close();			
			if (lkPstmt != null) lkPstmt.close();
			if (pstmt != null) pstmt.close();			
			if (conn != null) conn.close();
			
			System.out.println("-- END --");
		}
		catch(BatchUpdateException b){
			System.err.println("--BatchUpdateException--");	
			System.err.println("SQLState: " + b.getSQLState());
			System.err.println("Message: " + b.getMessage());
			System.err.println("Vendor: " + b.getErrorCode());
			System.err.print("Update counts: ");
			int [] updateCounts = b.getUpdateCounts();
			for(int i = 0; i<updateCounts.length;i++){
				System.err.print(updateCounts[i] + " ");
			}
			System.err.println("");			
		}
		catch(SQLException ex){
			System.err.println("--SQLException--");
			System.err.println("SQLState: " + ex.getSQLState());
			System.err.println("Message: " + ex.getMessage());
			System.err.println("Vendor: " + ex.getErrorCode());
		}

	}
	
	
	//连接-剪枝动作,产生候选k-项集(即得到Ck表)
	boolean aproiri_gen(int k, double min_sup)
	{
		ResultSet lkRs1 = null;
		ResultSet lkRs2 = null;
		
		int ckCnt = 0;
		
		Statement stmt = null;
		Statement stmt2 = null;
		Statement stmt7 = null;
		try{
			//读取旧Lk表(即Lk-1表)中的所有项集			
			stmt = conn.createStatement();
			lkRs1 = stmt.executeQuery("SELECT * FROM lk");			
			stmt7 = conn.createStatement();
			stmt7.execute("DELETE FROM ck");
		
			ResultSet tRs1 = lkRs1;							
			while (lkRs1.next()){
				String lkItemsets1 = lkRs1.getString("itemsets");//读取当前lkRs1结果集记录的第一个字段
				stmt2 = conn.createStatement();
				lkRs2 = stmt2.executeQuery("SELECT * FROM lk");
				while (lkRs2.next()){
					String lkItemsets2 = lkRs2.getString("itemsets");//读取当前lkRs2结果集记录的第一个字段
					if (isLinkable(k, lkItemsets1, lkItemsets2)){//判断当前lkRs1项集与当前lkRs2项集是否存在可连接性	
						System.out.println("k="+ k +" lkItemsets1 = "+ lkItemsets1 + "  lkItemsets2= "+ lkItemsets2);				
						if (!pruneAction(joinAction(k, lkItemsets1, lkItemsets2), tRs1)) ckCnt++;//连接-剪枝;统计剪枝后剩余的Ck项集数										
					}
					else break;//不存在可连接性则中断本次lkRs2结果集循环
				}
			}
		}
		catch(Exception e){
			e.printStackTrace();	
		}		
		System.out.println("ckCnt=" + ckCnt);
		if (ckCnt == 0)	return false;//如剪枝后Ck项集的计数值为0则算法终止
		else return true;			
	}
	
	//分离字符串并返回到向量中
	Vector splitItemList(int k, String cl)
	{
		String[] strArray = cl.split(",");
		Vector v = new Vector(Arrays.asList(strArray));//将字符串数组转换为向量
		
		System.out.println("splitItem = "+v);		
		return v;	
	}
	
	//判断是否具有可连接性
	boolean isLinkable(int k, String s1,String s2)
	{		
		Vector v1 = splitItemList(k, s1);
		Vector v2 = splitItemList(k, s2);
		
		int i = 0;
		int cnt = 0;
		for (; i < v1.size(); i++){
				if (v1.elementAt(i) == v2.elementAt(i)) cnt++;//break;//未满足可连接条件			
		}
		
		if ((cnt == v1.size() -1) && !v1.equals(v2)){
			return true;//可连接
		}
		else return false;
	}
	
	//连接动作
	Vector joinAction(int k, String s1,String s2)
	{
		Vector v1 = splitItemList(k, s1);
		Vector v2 = splitItemList(k, s2);
		
		v1.add(v2.elementAt(v2.size() -1));//取v2最后一个分量插入到v1的尾部
		
		return v1;
	}
	
	//判断当前频繁项集Ck中的所有(k-1)子集是否频繁,否则进行剪枝
	boolean pruneAction(Vector v, ResultSet rs)
	{	
		//----------- 该处对原Apriori算法中描述的剪枝动作有所改动 ------------
				
		//用频繁项集L(k-1)与剪枝前的候选项集Ck相匹配,然后由 ?(计数值 > k) 判断是否应当剪去该候选项集		
		Statement stmt3 = null;
		Statement stmt4 = null;
		ResultSet rs2 = null;
		ResultSet rs3 = null;
		
		int containCnt = 0;
		
		try{
			stmt3 = conn.createStatement();
			rs2 = stmt3.executeQuery("SELECT * FROM lk");
			while (rs2.next()){
				String lkItemsets = rs2.getString("itemsets");//读入L(k-1)的各项集		
				Vector lkItemList = splitItemList(v.size() - 1, lkItemsets);//(v.size - 1)即 k - 1				
				System.out.println("v_list = " +v +"  lkItemList = "+lkItemList);
				if (v.containsAll(lkItemList)) containCnt++;		
			}			
		
		System.out.println("v.size = " +v.size() +"  containCnt = "+containCnt);
			if (containCnt >= v.size()){//匹配数 >= k说明该候选项集需保留(非剪枝)
			
				//------ 得到候选集(Ck表) --------
				String s = new String(v.elementAt(0).toString());			
				for (int i = 1; i < v.size(); i++){
					s = s + "," + v.elementAt(i);	
				}			
				System.out.println("s = " +s);
				
				stmt4 = conn.createStatement();
				stmt4.execute("INSERT INTO ck(itemsets,sup_count) VALUES('" + s + "',0)");				
				stmt4.close();
				return	false;//非剪枝,保留			
			}
		}
		catch(Exception e){
			e.printStackTrace();
		}		
		return true;//剪枝,返回true	
	}		
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -