📄 apriori.java
字号:
/**
* Copyright (c) 2003 HCYCOM, Inc. All Rights Reserved.
*
* A Java application of Apriori algorithm.
*
* Make sure that your DBMS Service is running before the OS loads the program.
* Either SQLServer2000 or MySQL4.0 is your good start.
*
* @author Zengli, China Agricultural University.
* @version 0.01 03/07/09
*
* mailto:zengli@hcycom.com
*/
import java.sql.*;
import java.util.*;
import java.math.*;
public class Apriori
{
Connection conn;
int rowCount;
private String url,user,passwd;
String[] itemList = null;
int[] itemCount = null;
final double MIN_SUP = 0.4;//最小支持度s
final double MIN_CONF = 0.5;//最小置信度c
PreparedStatement pstmt = null;
ResultSet rs = null; //定义结果集
PreparedStatement ckPstmt = null;
ResultSet ckRs = null;
PreparedStatement lkPstmt = null;
//构造函数
Apriori()
{
url = "jdbc:odbc:AprioriMySQL";//定义JDBC-ODBC桥的URL
user = "root";
passwd = "";
conn = null;
rowCount = 0;
//预设项列表(商品集)={啤酒,尿布,擦身粉,面包,雨伞,牛奶,洗衣粉,可乐}
itemList = new String[]{"beer","diaper","toiletpowder","bread","umbrella","milk","washingpowder","coke"};
itemCount = new int[itemList.length];
}
public static void main(String[] args)
{
Apriori a = new Apriori();
a.execute();
}
//加载指定的数据库驱动程序
void loadJdbcOdbcDriver(String driverName)
{
try{
System.out.println("加载数据库驱动程序 ...OK!");
Class.forName(driverName);
}
catch(java.lang.ClassNotFoundException e){
System.err.print("ClassNotFoundException:");
System.err.println(e.getMessage());
}
}
//Apriori算法实现
void execute()
{
//加载数据库驱动程序
loadJdbcOdbcDriver("sun.jdbc.odbc.JdbcOdbcDriver");
System.out.println("初次扫描数据库commodity ...OK!\n");
try{
//建立连接ODBC数据源
conn = DriverManager.getConnection(url,user,passwd);
//提交连接请求
conn.setAutoCommit(false);
//------- 初次扫描数据库commodity的transaction表,对每个候选(即商品项)以及事务(即交易)计数 ------
pstmt = conn.prepareStatement("SELECT * FROM transaction",ResultSet.TYPE_SCROLL_SENSITIVE,ResultSet.CONCUR_UPDATABLE);
rs = pstmt.executeQuery();
int id;
String commodity_list;
Statement stmt5 = null;
Statement stmt8 = null;
Statement stmt9 = null;
stmt5 = conn.createStatement();
stmt5.execute("DELETE FROM lk");
for (int i = 0; i < itemList.length; i++){
rowCount = 0;
while (rs.next()){
id = rs.getInt(1);
commodity_list = rs.getString("commodity_list");
if (commodity_list.indexOf(itemList[i]) != -1) itemCount[i]++;//对每个候选计数
rowCount++;
}
System.out.println(itemList[i] + "\t\t\t(" + itemCount[i] + ")");
//比较候选支持度计数与最小支持度计数,得出L1表
if (itemCount[i] >= (int)(MIN_SUP * rowCount)){
pstmt = conn.prepareStatement("INSERT INTO Lk(itemsets,sup_count) VALUES(?,?)",ResultSet.TYPE_SCROLL_SENSITIVE,ResultSet.CONCUR_UPDATABLE);
pstmt.setString(1,itemList[i]);
pstmt.setInt(2,itemCount[i]);
pstmt.executeUpdate();
}
rs.beforeFirst();
}
//------- 算法核心 : 使用候选项集Ck找频繁项集Lk -------
for (int k = 2; ; k++){
if (!aproiri_gen(k - 1, MIN_SUP)) break;//连接-剪枝生成候选集Ck(即Ck表)
//读取Ck表的每一个候选项集
pstmt = conn.prepareStatement("SELECT * FROM ck",ResultSet.TYPE_SCROLL_SENSITIVE,ResultSet.CONCUR_UPDATABLE);
rs = pstmt.executeQuery();
stmt8 = conn.createStatement();
stmt8.execute("DELETE FROM lk");//清空旧的Lk表(即Lk-1表)
while (rs.next()){
String ck = rs.getString("itemsets");//读出Ck中每个候选项集(即Ck表的第一个字段)
//将Ck当前候选项集中各元素分离并复制到Vector向量的各分量中
Vector ckItemList = new Vector(k);
ckItemList = splitItemList(k, ck);
int ckSupCnt = 0;//Ck项集支持度
//扫描数据库commodity的transaction表,为Ck中的每个候选项集计数
ckPstmt = conn.prepareStatement("SELECT * FROM transaction",ResultSet.TYPE_SCROLL_SENSITIVE,ResultSet.CONCUR_UPDATABLE);
ckRs = ckPstmt.executeQuery();
while (ckRs.next()){//循环读入transaction表的记录
String s = ckRs.getString("commodity_list");//得到commodity_list字段内容
System.out.println("commodity_field = "+s);
int j = 0;
for (; j < k; j++){
if (s.indexOf(ckItemList.elementAt(j).toString()) == -1) break;//扫描当前记录,判断是否同时存在Ck当前候选项集中的所有元素
}
if (j >= k) ckSupCnt++;//Ck候选项集支持度计数
}
//比较候选支持度计数与最小支持度计数,生成新的Lk表
System.out.println("ckSupCnt = " + ckSupCnt +" (int)(MIN_SUP * rowCount)=" +(int)(MIN_SUP * rowCount));
if (ckSupCnt >= (int)(MIN_SUP * rowCount)){
stmt9 = conn.createStatement();
stmt9.execute("INSERT INTO lk(itemsets,sup_count) VALUES('" + ck + "'," + ckSupCnt + ")");
}
}
}
//操作语句提交
conn.commit();
conn.setAutoCommit(true);
if (ckPstmt != null) ckPstmt.close();
if (lkPstmt != null) lkPstmt.close();
if (pstmt != null) pstmt.close();
if (conn != null) conn.close();
System.out.println("-- END --");
}
catch(BatchUpdateException b){
System.err.println("--BatchUpdateException--");
System.err.println("SQLState: " + b.getSQLState());
System.err.println("Message: " + b.getMessage());
System.err.println("Vendor: " + b.getErrorCode());
System.err.print("Update counts: ");
int [] updateCounts = b.getUpdateCounts();
for(int i = 0; i<updateCounts.length;i++){
System.err.print(updateCounts[i] + " ");
}
System.err.println("");
}
catch(SQLException ex){
System.err.println("--SQLException--");
System.err.println("SQLState: " + ex.getSQLState());
System.err.println("Message: " + ex.getMessage());
System.err.println("Vendor: " + ex.getErrorCode());
}
}
//连接-剪枝动作,产生候选k-项集(即得到Ck表)
boolean aproiri_gen(int k, double min_sup)
{
ResultSet lkRs1 = null;
ResultSet lkRs2 = null;
int ckCnt = 0;
Statement stmt = null;
Statement stmt2 = null;
Statement stmt7 = null;
try{
//读取旧Lk表(即Lk-1表)中的所有项集
stmt = conn.createStatement();
lkRs1 = stmt.executeQuery("SELECT * FROM lk");
stmt7 = conn.createStatement();
stmt7.execute("DELETE FROM ck");
ResultSet tRs1 = lkRs1;
while (lkRs1.next()){
String lkItemsets1 = lkRs1.getString("itemsets");//读取当前lkRs1结果集记录的第一个字段
stmt2 = conn.createStatement();
lkRs2 = stmt2.executeQuery("SELECT * FROM lk");
while (lkRs2.next()){
String lkItemsets2 = lkRs2.getString("itemsets");//读取当前lkRs2结果集记录的第一个字段
if (isLinkable(k, lkItemsets1, lkItemsets2)){//判断当前lkRs1项集与当前lkRs2项集是否存在可连接性
System.out.println("k="+ k +" lkItemsets1 = "+ lkItemsets1 + " lkItemsets2= "+ lkItemsets2);
if (!pruneAction(joinAction(k, lkItemsets1, lkItemsets2), tRs1)) ckCnt++;//连接-剪枝;统计剪枝后剩余的Ck项集数
}
else break;//不存在可连接性则中断本次lkRs2结果集循环
}
}
}
catch(Exception e){
e.printStackTrace();
}
System.out.println("ckCnt=" + ckCnt);
if (ckCnt == 0) return false;//如剪枝后Ck项集的计数值为0则算法终止
else return true;
}
//分离字符串并返回到向量中
Vector splitItemList(int k, String cl)
{
String[] strArray = cl.split(",");
Vector v = new Vector(Arrays.asList(strArray));//将字符串数组转换为向量
System.out.println("splitItem = "+v);
return v;
}
//判断是否具有可连接性
boolean isLinkable(int k, String s1,String s2)
{
Vector v1 = splitItemList(k, s1);
Vector v2 = splitItemList(k, s2);
int i = 0;
int cnt = 0;
for (; i < v1.size(); i++){
if (v1.elementAt(i) == v2.elementAt(i)) cnt++;//break;//未满足可连接条件
}
if ((cnt == v1.size() -1) && !v1.equals(v2)){
return true;//可连接
}
else return false;
}
//连接动作
Vector joinAction(int k, String s1,String s2)
{
Vector v1 = splitItemList(k, s1);
Vector v2 = splitItemList(k, s2);
v1.add(v2.elementAt(v2.size() -1));//取v2最后一个分量插入到v1的尾部
return v1;
}
//判断当前频繁项集Ck中的所有(k-1)子集是否频繁,否则进行剪枝
boolean pruneAction(Vector v, ResultSet rs)
{
//----------- 该处对原Apriori算法中描述的剪枝动作有所改动 ------------
//用频繁项集L(k-1)与剪枝前的候选项集Ck相匹配,然后由 ?(计数值 > k) 判断是否应当剪去该候选项集
Statement stmt3 = null;
Statement stmt4 = null;
ResultSet rs2 = null;
ResultSet rs3 = null;
int containCnt = 0;
try{
stmt3 = conn.createStatement();
rs2 = stmt3.executeQuery("SELECT * FROM lk");
while (rs2.next()){
String lkItemsets = rs2.getString("itemsets");//读入L(k-1)的各项集
Vector lkItemList = splitItemList(v.size() - 1, lkItemsets);//(v.size - 1)即 k - 1
System.out.println("v_list = " +v +" lkItemList = "+lkItemList);
if (v.containsAll(lkItemList)) containCnt++;
}
System.out.println("v.size = " +v.size() +" containCnt = "+containCnt);
if (containCnt >= v.size()){//匹配数 >= k说明该候选项集需保留(非剪枝)
//------ 得到候选集(Ck表) --------
String s = new String(v.elementAt(0).toString());
for (int i = 1; i < v.size(); i++){
s = s + "," + v.elementAt(i);
}
System.out.println("s = " +s);
stmt4 = conn.createStatement();
stmt4.execute("INSERT INTO ck(itemsets,sup_count) VALUES('" + s + "',0)");
stmt4.close();
return false;//非剪枝,保留
}
}
catch(Exception e){
e.printStackTrace();
}
return true;//剪枝,返回true
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -