📄 crossvalidation.java
字号:
import java.text.DecimalFormat;
import java.util.Date;
import java.util.LinkedList;
import java.util.List;
import java.util.ListIterator;
import java.util.Random;
/*
* Created on 11/04/2005
*/
/**
* @author Fernando Meyer
*/
public class CrossValidation implements Runnable {
private GUIAntMinerJFrame caller;
private int folds;
private List dataInstancesList;
private List attributesList;
private int groupSize;
private int [] control;
private MyList trainingSet;
private int minCasesRule;
private int numClasses;
private boolean interrupted;
private Thread cvThread;
public CrossValidation(GUIAntMinerJFrame caller){
this.caller = caller;
interrupted = false;
}
public void start() {
cvThread = new Thread(this);
cvThread.start();
caller.getJProgressBar1().setIndeterminate(true);
}
public void stop(){
cvThread = null;
interrupted = true;
}
public void run(){
Thread currentThread = Thread.currentThread();
try {
caller.getJTextArea1().setLineWrap(true);
Object temp, temp2;
int n, maxUncoveredCases, iteration, numIterations, deltaCount, convergenceTest, numAnts;
boolean goOn;
Random random = new Random();
ListIterator li, li2, liInfo, liInfo2;
List accuracyRatesList, numberOfTermsList, numberOfRulesList;
dataInstancesList = caller.getMyFileReader().getDataInstancesList();
attributesList = caller.getMyFileReader().getAttributesList();
accuracyRatesList = new LinkedList();
numberOfTermsList = new LinkedList();
numberOfRulesList = new LinkedList();
double totalTestAccuracyRate=0.0;
double totalTrainingAccuracyRate=0.0;
Date date = new Date();
//get screen values
folds = Integer.parseInt(caller.getjTextField2Value());
numAnts = Integer.parseInt(caller.getJTextField1().getText());
minCasesRule = Integer.parseInt(caller.getJTextField3().getText());
maxUncoveredCases = Integer.parseInt(caller.getJTextField4().getText());
convergenceTest = Integer.parseInt(caller.getJTextField5().getText());
numIterations = Integer.parseInt(caller.getJTextField6().getText());
printHeader();
//determine size of the data sets for cross-validation
groupSize = dataInstancesList.size() / folds;
if(dataInstancesList.size() % folds == 0)
groupSize--;
control = new int[folds];
for(n=0; n < folds; n++)
control[n] = groupSize;
//see method description below
group();
List testSet = new LinkedList();
trainingSet = new MyList();
//pheromoneList indicates the quantity of pheromone term ij (attribute vs. value) contains
MyList pheromoneList = new MyList(); //list corresponding to the attributes
MyList pheromone; //variable to contain a list that corresponds to the attributes' values
//freqTij indicates the number of cases in the trainingSet that identify class w
List freqTij = new LinkedList(); //list corresponding to the attributes
List attValuesF; //variable to contain a list that corresponds to the attributes' values
//infoTij indicates the quantity of information of term ij
List infoTij = new LinkedList();
List attValuesI;
//hList is to contain the value returned by the heuristic function for term ij
List hList = new LinkedList();
List attValuesh;
//probabilityList indicates the probability that term ij be selected to construct a rule
List probabilityList = new LinkedList();
List attValuesP;
//number of classes
numClasses = ((Attribute)attributesList.get(attributesList.size()-1)).getTypes().size();
li = attributesList.listIterator();
//block to build the list structures for pheromoneList, freqTij and infoTij,..., which have the same length as attributesList
while(li.hasNext()){
pheromone = new MyList();
attValuesF = new LinkedList();
attValuesI = new LinkedList();
attValuesh = new LinkedList();
attValuesP = new LinkedList();
int x2 = ((Attribute) li.next()).getTypes().size();
for(int x=0;x<x2;x++){
pheromone.add(new Object()); //just make list grow
attValuesF.add(new int[numClasses]);
attValuesI.add(new Double(0));
attValuesh.add(new Double(0));
attValuesP.add(new Double(0));
}
pheromoneList.add(pheromone);
freqTij.add(attValuesF);
infoTij.add(attValuesI);
hList.add(attValuesh);
probabilityList.add(attValuesP);
}
int [] freqT = new int[numClasses];
List trainingSetClone;
for(int PS=0; (PS < folds) && (currentThread == cvThread); PS++){
System.out.println("\nCross Validation "+(PS+1)+" of "+folds);
Date date2 = new Date();
li = dataInstancesList.listIterator();
testSet.clear();
trainingSet.clear();
//splits DataSet into testSet and trainingSet
while(li.hasNext()){
temp = li.next();
//try{
if(((DataInstance) temp).getCrossValidationGroup() == PS)
testSet.add(((DataInstance) temp).clone());
else
trainingSet.add(((DataInstance) temp).clone());
//}catch(Exception e){}
}
//clone trainingSet to calculate the classification accuracy rate later
trainingSetClone = (List) trainingSet.clone();
List bestIterationAntsList = new LinkedList();
int bestAntIndex=-1;
List antsFoundRuleList = new LinkedList();
while(trainingSet.size() > maxUncoveredCases && currentThread == cvThread){
System.out.print(".");
bestIterationAntsList.clear();
bestAntIndex=0;
//initialize trails with the same quantity of pheromone
li = pheromoneList.listIterator();
while(li.hasNext()){
List tempList = (List) li.next();
li2 = tempList.listIterator();
while(li2.hasNext()){
li2.next();
li2.set(new Double(log2(numClasses)/caller.getMyFileReader().totalDistinct()));
}
}
//initializes freqT, which contains the number of cases that identify a class in the trainingSet
li = trainingSet.listIterator();
while(li.hasNext()){
List tempList = ((DataInstance) li.next()).getValues();
int classIndex = ((Attribute) attributesList.get(attributesList.size()-1)).getTypes().indexOf(tempList.get(tempList.size()-1));
freqT[classIndex]++;
}
//initializes freqTij, which contains the number of cases that identify a class in the trainingSet
li = trainingSet.listIterator();
while(li.hasNext()){
int attIndex=0,attValueIndex,classIndex;
Object tempTr = li.next();
li2 = ((DataInstance) tempTr).getValues().listIterator();
tempTr = ((DataInstance) tempTr).getValues();
Object currentClass = ((List) tempTr).get(((List) tempTr).size()-1);
while(li2.hasNext()){
temp = li2.next();
attValueIndex = ((Attribute) attributesList.get(attIndex)).getTypes().indexOf(temp);
classIndex = ((Attribute) attributesList.get(attributesList.size()-1)).getTypes().indexOf(currentClass);
if(attValueIndex > -1) //if attribute is not missing, i. e. attribute different from '?'
((int[]) ((List) freqTij.get(attIndex)).get(attValueIndex))[classIndex]++;
attIndex++;
}
}
//initializes infoTij
li = freqTij.listIterator();
liInfo = infoTij.listIterator();
while(li.hasNext()){
int sum;
double hw;
li2 = ((List) li.next()).listIterator();
liInfo2 = ((List) liInfo.next()).listIterator();
while(li2.hasNext()){
temp = li2.next();
sum=0;
hw=0;
for(int x=0;x<numClasses;x++)
sum += ((int[]) temp)[x];
for(int x=0;x<numClasses;x++)
if(((int[]) temp)[x] != 0 && sum !=0)
hw -= (double) ((int[]) temp)[x]/sum * log2((double) ((int[]) temp)[x]/sum);
liInfo2.next();
liInfo2.set(new Double(hw));
}
}
List antsList = new LinkedList();
iteration = deltaCount = 0;
while(iteration < numIterations && deltaCount < convergenceTest){
antsList.clear();
for(int x=0;x<numAnts;x++){
antsList.add(new Ant(attributesList.size()));
}
double bestQuality = 0;
bestAntIndex = 0;
List unusablesList = new LinkedList();
ListIterator liAnt = antsList.listIterator();
int antIndex = 0;
while(liAnt.hasNext()){
unusablesList.clear();
//attDistinct[attribute i] contains the number of distinct values for attribute i
int [] attDistinctLeft = new int[attributesList.size()];
for(int x=0;x<attributesList.size();x++){
attDistinctLeft[x] = ((Attribute) attributesList.get(x)).getTypes().size();
}
attDistinctLeft[attributesList.size()-1] = 0;
List pheroTempList = new LinkedList();
li = pheromoneList.listIterator();
while(li.hasNext()){
pheroTempList.add(((MyList) li.next()).clone());
}
Ant currentAnt = (Ant) liAnt.next();
goOn = true;
while(goOn){
//heuristic function
//hList -> hij = (log2 k - H(W|Ai = Vij)) / (S xm (S log2 k - H(W|Am = Vmn)))
double sum=0.0;
for(int x=0;x<attributesList.size();x++){
li = ((List) infoTij.get(x)).listIterator();
if(currentAnt.getMemory()[x] == 0){ //if attribute hasn't been used...
while(li.hasNext()){
sum += (double) log2(numClasses) - ((Double) li.next()).doubleValue();
}
}
}
li = hList.listIterator();
liInfo = infoTij.listIterator();
int i=0,j;
while(li.hasNext()){
li2 = ((List) li.next()).listIterator();
liInfo2 = ((List) liInfo.next()).listIterator();
j=0;
while(li2.hasNext()){
temp = liInfo2.next();
if(!unusablesList.contains(String.valueOf(i)+String.valueOf(j))){
//if term ij doesn't occur in the trainingSet, then infoTij should have the greatest value of log2 (class w)
double greatestValue = (double) -9999;
boolean termOccurs = false;
ListIterator liTraining = trainingSet.listIterator();
while(liTraining.hasNext() && !termOccurs){
if(((DataInstance) liTraining.next()).getValues().get(i).equals(((List) ((Attribute) attributesList.get(i)).getTypes()).get(j)))
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -