📄 apriori.java
字号:
package datamining;
import java.io.*;
import java.util.*;
/**
* A bare bone clean implementation of the Apriori
* algorithm for finding frequent itemsets. Good for educational
* purposes and as a root class for experimenting on
* optimizations.
*
* In the latest version the use of DataHandler is added for reading
* the database.
*
* @author Michael Holler
* @version 0.8, 16.03.2004
*/
public class Apriori {
int pass; // number of passes
int total; // total number of frequent itemsets
int minsup; // minimal support of itemset
String filename; // the filename of the database
Item root; // the root item of the Trie
BufferedWriter writer;// the buffer to write the output to
DataHandler dh; // the handler for the database
/**
* Default constructur for creating a Apriori object.
*/
public Apriori() {
this.pass = 0;
this.minsup = 4;
this.dh = new DataHandler("test.dat");
this.root = new Item(0);
}
/**
* Constructur for creating a Apriori object with parameters.
*
* @param filename the name of the database file
* @param minsup the minimal support threshold
* @param outfile the name of the output file
*/
public Apriori(String filename, int minsup, String outfile) {
this.pass = 0;
this.minsup = minsup;
this.dh = new DataHandler(filename);
this.root = new Item(0);
try {
if (!outfile.equals("")) {
writer = new BufferedWriter(new FileWriter(outfile));
}
} catch (Exception e) {}
}
/**
* Constructur for creating a Apriori object with parameters.
* This one is used with other mining algorithms.
*
* @param minsup the minimal support threshold
* @param datahandler the handler for the database
*/
public Apriori(int minsup, DataHandler datahandler) {
this.pass = 0;
this.minsup = minsup;
this.dh = datahandler;
this.root = new Item(0);
}
/**
* The workhorse method for the basic implementation of
* the Apriori algorithm.
*/
public void findFrequentSets() {
boolean running = true;
int candidates = 0, transactions= 0, pruned = 0, itemsets;
while (running) {
this.pass++;
candidates = this.generateCandidates(this.root, new Vector(), 1);
transactions = this.countSupport();
pruned = this.pruneCandidates(this.root);
itemsets = candidates - pruned;
// correct the candidate count on first pass for printing
if (this.pass == 1)
candidates = total;
total += itemsets;
if (itemsets <= this.pass && this.pass > 1) {
running = false;
}
System.out.println("pass: " + this.pass +
", total: " + total +
", candidates: " + candidates +
", pruned: " + pruned);
}
}
/**
* Method for generating new candidates.
* Copies the siblings of an item to its children.
*
* @param item the item to which generated items are added
* @param depth the depth of recursion
* @return the number of new candidates generated
*/
public int generateCandidates(Item item, Vector current, int depth) {
Vector v = item.getChildren();
Item child = item;
int generated = 0;
for (Enumeration e = v.elements(); e.hasMoreElements(); ) {
child = (Item)e.nextElement();
current.add(child);
if (depth == this.pass-1) {
generated += this.copySiblings(child, v, current);
} else {
generated += this.generateCandidates(child, current, depth+1);
}
current.remove(child);
}
return generated;
}
/**
* Method for copying the siblings of an Item to its children.
*
* @param item the item to which the siblings are copied
* @param siblings the siblings to be copied
* @param current the current itemset to be generated
* @return the number of siblings copied
*/
public int copySiblings(Item item, Vector siblings, Vector current) {
Enumeration e = siblings.elements();
Item parent = item;
Item sibling = new Item();
int copied = 0;
while (sibling.getLabel() < parent.getLabel() && e.hasMoreElements()) {
sibling = (Item)e.nextElement();
}
while (e.hasMoreElements()) {
sibling = (Item)e.nextElement();
current.add(sibling);
if (this.pass <= 2 || this.checkSubsets(current, this.root.getChildren(), 0, 1)) {
parent.addChild(new Item(sibling.getLabel()));
copied++;
}
current.remove(sibling);
}
return copied;
}
/**
* Checks if the subsets of the itemset to be generated are all frequent.
*
* @param current the current itemset to be generated
* @param children the children in the trie on this depth
* @param mark the mark in the current itemset
* @param depth depth of recursion
* @return true if the subsets are frequent, else false
*/
public boolean checkSubsets(Vector current, Vector children, int mark, int depth) {
boolean ok = true;
Item child;
int index;
int i = depth;
if (children == null) return false;
while (ok && (mark <= i)) {
index = children.indexOf(current.elementAt(i));
if (index >= 0) {
if (depth < this.pass-1) {
child = (Item)children.elementAt(index);
ok = checkSubsets(current, child.getChildren(), i+1, depth+1);
}
} else {
ok = false;
}
i--;
}
return ok;
}
/**
* Method for counting the supports of the candidates
* generated on this pass.
*
* @return the number of transactions from which
* the support was counted
*/
public int countSupport() {
int rowcount = 0;
int[] items;
this.dh.open();
for (items = this.dh.read(); items.length > 0; items = this.dh.read()) {
rowcount++;
if (this.pass == 1) {
this.root.incSupport();
this.total += generateFirstCandidates(items);
} else {
countSupport(root, items, 0, 1);
}
}
return rowcount;
}
/**
* Method generates the first candidates by adding each item
* found in the database to the children of the root item. Also
* counts the supports of the items found in the database.
*
* @param items the array of integer items from the database
* @return the number of candidates generated
*/
public int generateFirstCandidates(int[] items) {
Vector v = root.getChildren();
Enumeration e = v.elements();
Item item = new Item();
int generated = 0;
for (int i = 0; i < items.length; i++) {
while (e.hasMoreElements() && item.getLabel() < items[i]) {
item = (Item)e.nextElement();
}
if (item.getLabel() == items[i]) {
item.incSupport();
if (e.hasMoreElements())
item = (Item)e.nextElement();
} else if (item.getLabel() > items[i]) {
int index = v.indexOf(item);
Item child = new Item(items[i]);
child.incSupport();
this.root.addChild(child, index);
generated++;
} else {
Item child = new Item(items[i]);
child.incSupport();
this.root.addChild(child);
generated++;
}
}
return generated;
}
/**
* Adds the cover of the Item given as paramater and all the
* Items in Trie below it.
*
* @param item the item the cover of which is to be counted
* @param items the array of integer items from the database
* @param i the position in the array
* @param depth the depth of recursion
*/
public void countSupport(Item item, int[] items, int i, int depth) {
Vector v = item.getChildren();
Item child;
int tmp;
Enumeration e = v.elements();
// loop through the children to check
while (e.hasMoreElements()) {
child = (Item)e.nextElement();
// break, if the whole transaction is checked
if (i == items.length) { break; }
// do a linear search for the child in the transaction starting from i
tmp = i;
while (tmp < items.length && items[tmp] < child.getLabel()) tmp++;
// if the same item exists, increase support or go deaper
if (tmp < items.length && child.getLabel() == items[tmp]) {
if (depth == this.pass) {
child.incSupport();
} else {
countSupport(child, items, tmp+1, depth+1);
}
i = tmp+1;
}
}
}
/**
* Method for pruning the candidates. Removes items that are
* not frequent from the Trie.
*
* @param item the item the children of which will be pruned
* @return the number of items pruned from the candidates
*/
public int pruneCandidates(Item item) {
Vector v = item.getChildren();
Item child = item;
int pruned = 0;
for (Enumeration e = new Vector(v).elements(); e.hasMoreElements(); ) {
child = (Item)e.nextElement();
// check infrequency, existence and that it is fully counted
if (child.getSupport() < this.minsup) {
v.remove(child);
pruned++;
} else {
pruned += pruneCandidates(child);
}
}
return pruned;
}
/**
* Method gets and returns the root of the
* candidate trie.
*
* @return the root of the candidate trie
*/
public Item getTrie() {
return this.root;
}
/**
* Method prints the itemsets to the system output and to a file
* if the name of an output file exists.
*/
public void printFrequentSets() {
if (this.writer != null) {
print(root, "");
}
System.out.println("\nnumber of frequent itemsets found: " + this.total);
}
/**
* Loops through the Trie recursively adding
* paths and subpaths to the output string along the way.
*
* @param item the item where the recursion is
* @param str the string of the gatherd itemset
*/
public void print(Item item, String str) {
Vector v = item.getChildren();
for (Enumeration e = v.elements(); e.hasMoreElements(); ) {
item = (Item)e.nextElement();
try {
this.writer.write(str + item.getLabel()
+ " (" + item.getSupport() + ")\n");
this.writer.flush();
} catch (Exception x) {
System.out.println("no output file");
}
if (item.hasChildren()) {
print(item, str + item.getLabel() + " ");
}
}
}
/**
* Main method for testing the algorithm.
*
* @param args the arguments can contain the filename
* of the testfile and the minimal support
* threshold and a filename for output
*/
public static void main(String args[]) {
String testfile = "test.dat";
String outfile = "";
int support = 5;
try {
testfile = args[0];
} catch (Exception e) {
System.out.println("Didn't get filename. Using '" + testfile + "'.");
}
try {
support = new Integer(args[1]).intValue();
} catch (Exception e) {
System.out.println("Didn't get support threshold. Using '" + support + "'.");
}
try {
outfile = args[2];
} catch (Exception e) {
System.out.println("Didn't get output filename. Not printing.");
}
StopWatch sw = new StopWatch();
sw.start();
Apriori apriori = new Apriori(testfile, support, outfile);
apriori.findFrequentSets();
apriori.printFrequentSets();
sw.stop();
sw.print();
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -