📄 crossvalidation.java
字号:
package guiAntMiner;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.LinkedList;
import java.util.List;
import java.util.ListIterator;
import java.util.Random;
import javax.swing.JOptionPane;
/**
* Copyright (C) 2007 Fernando Meyer
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License
* as published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* A full copy of the license is available in gpl.txt and online at
* http://www.gnu.org/licenses/gpl.txt
*/
public class CrossValidation implements Runnable {
private Attribute [] attributesArray;
private DataInstance [] dataInstancesArray;
private int folds;
private DataInstance [] testSet;
private DataInstance [] trainingSet;
private int numAnts;
private double [][] pheromoneArray;
private int numClasses;
private int [][][] freqTij;
private double [][] infoTij;
private int [] freqT;
private int [] control;
private double [][] hArray;
private double [][] probabilitiesArray;
private boolean [][] unusableAttributeVsValueArray;
private int minCasesRule;
private int convergenceTest;
private int numIterations;
private int maxUncoveredCases;
private GUIAntMinerJFrame caller;
private boolean interrupted;
private Thread cvThread;
public CrossValidation(GUIAntMinerJFrame caller){
this.caller = caller;
interrupted = false;
}
public void setAttributesArray(Attribute [] attributesArray){
this.attributesArray = attributesArray;
}
public void setDataInstancesArray(DataInstance [] dataInstancesArray){
this.dataInstancesArray = dataInstancesArray;
}
public void setNumAnts(int numAnts){
this.numAnts = numAnts;
}
public void setFolds(int folds){
this.folds = folds;
}
public void setMinCasesRule(int minCasesRule){
this.minCasesRule = minCasesRule;
}
public void setConvergenceTest(int convergenceTest){
this.convergenceTest = convergenceTest;
}
public void setNumIterations(int numIterations){
this.numIterations = numIterations;
}
public void setMaxUncoveredCases(int maxUncoveredCases) {
this.maxUncoveredCases = maxUncoveredCases;
}
public void start() {
cvThread = new Thread(this);
try {
initialize();
cvThread.start();
caller.getJProgressBar1().setIndeterminate(true);
} catch (Exception e) {
e.printStackTrace();
}
}
public void stop(){
cvThread = null;
interrupted = true;
}
private void initialize() throws Exception{
int n=0;
int arraysSize;
if(numAnts==0 || folds<2 || minCasesRule==0 ||
convergenceTest==0 || numIterations==0 || maxUncoveredCases==0){
caller.getJProgressBar1().setIndeterminate(false);
caller.setIsClassifying(false);
JOptionPane.showMessageDialog(null, "<html><font face=Dialog size=3>At least one of the parameters is invalid.</font></html>", "Error", JOptionPane.ERROR_MESSAGE);
throw new InvalidArgumentException();
}
arraysSize = attributesArray.length - 1;
numClasses = attributesArray[arraysSize].getTypes().length;
freqT = new int[numClasses];
//determine the size of the data sets for cross-validation
int groupSize = dataInstancesArray.length / folds;
if(dataInstancesArray.length % folds == 0)
groupSize--;
control = new int[folds];
for(n=0; n < folds; n++)
control[n] = groupSize;
pheromoneArray = new double[arraysSize][];
freqTij = new int[arraysSize][][]; //freqTij[noOfAttributes][noOfValues][noOfClasses]
infoTij = new double[arraysSize][];
hArray = new double[arraysSize][];
unusableAttributeVsValueArray = new boolean[arraysSize][];
probabilitiesArray = new double[arraysSize][];
for(n=0; n < arraysSize; n++){
pheromoneArray[n] = new double[attributesArray[n].getTypes().length];
freqTij[n] = new int[attributesArray[n].getTypes().length][numClasses];
infoTij[n] = new double[attributesArray[n].getTypes().length];
hArray[n] = new double[attributesArray[n].getTypes().length];
unusableAttributeVsValueArray[n] = new boolean[attributesArray[n].getTypes().length];
probabilitiesArray[n] = new double[attributesArray[n].getTypes().length];
}
}
public void run(){
Thread currentThread = Thread.currentThread();
Random random = new Random();
boolean goOn;
double [][] pheromoneTempArray = new double[attributesArray.length-1][];
List<Double> accuracyRatesList = new LinkedList<Double>();
List<Double> numberOfTermsList = new LinkedList<Double>();
List<Double> numberOfRulesList = new LinkedList<Double>();
double totalTestAccuracyRate=0.0;
double totalTrainingAccuracyRate=0.0;
Date date = new Date();
caller.getJTextArea1().setLineWrap(true);
printHeader();
group();
for(int crossValidation=0; crossValidation < folds && currentThread == cvThread; crossValidation++){
System.out.println("Cross Validation "+(crossValidation+1)+" of "+folds);
Date date2 = new Date();
splitDataSet(crossValidation);
DataInstance [] trainingSetClone = (DataInstance[])trainingSet.clone();
int bestAntIndex=-1;
List<Object> bestIterationAntsList = new ArrayList<Object>();
List antsFoundRuleList = new ArrayList();
while(trainingSet.length > maxUncoveredCases && currentThread == cvThread){
System.out.print(".");
bestAntIndex=0;
initializePheromoneTrails();
calculateFreqTij();
calculateInfoTij();
int iteration, deltaCount;
iteration=deltaCount=0;
while(iteration < numIterations && deltaCount < convergenceTest) {
double bestQuality = 0;
bestAntIndex = 0;
Ant [] antsArray = new Ant[numAnts];
for(int x=0; x < numAnts; x++){
antsArray[x] = new Ant(attributesArray.length-1);
}
for(int antIndex=0; antIndex < numAnts; antIndex++){
Ant currentAnt = antsArray[antIndex];
for(int n=0; n < pheromoneTempArray.length; n++)
pheromoneTempArray[n] = (double[]) pheromoneArray[n].clone();
//attDistinct[attribute i] contains the number of distinct values for attribute i
int [] attDistinctLeft = new int[attributesArray.length-1];
for(int i=0; i <attDistinctLeft.length; i++)
attDistinctLeft[i] = attributesArray[i].getTypes().length-1;
for(int i=0; i < unusableAttributeVsValueArray.length; i++)
for(int j=0; j < unusableAttributeVsValueArray[i].length; j++)
unusableAttributeVsValueArray[i][j] = false;
goOn = true;
while(goOn && currentThread == cvThread){
calculateHeuristicFunction(currentAnt);
calculateProbabilities(currentAnt, pheromoneTempArray);
float randomNumber = (random.nextInt() << 1 >>> 1) % 101;
randomNumber /= 100;
boolean found = false;
double sum = 0.0;
for(int i=0; i < probabilitiesArray.length; i++){
for(int j=0; j < probabilitiesArray[i].length && !found; j++){
sum += probabilitiesArray[i][j];
if(sum >= randomNumber){
if(!ruleConstructor(currentAnt, i, j)){
//set to true so that this ant does not try to use term ij again
unusableAttributeVsValueArray[i][j] = true;
attDistinctLeft[i]--;
pheromoneTempArray[i][j] = 0.0;
}else{
attDistinctLeft[i] = -1;
}
found = true;
}
}
}
//determine if the ant already tried to use all the possible values of attribute a
for(int a=0; a < attDistinctLeft.length; a++){
if(attDistinctLeft[a] <= 0){
currentAnt.getMemory()[a] = 2;
}
}
goOn = false;
int a=0;
do{
if(currentAnt.getMemory()[a] == 0)
goOn = true;
a++;
}while(!goOn && a < currentAnt.getMemory().length);
}
determineRuleConsequent(currentAnt);
calculateRuleQuality(currentAnt);
try {
currentAnt = pruneRule(currentAnt);
} catch (CloneNotSupportedException e) {
e.printStackTrace();
}
antsArray[antIndex] = currentAnt;
if(currentAnt.getRuleQuality() >= bestQuality){
bestQuality = currentAnt.getRuleQuality();
bestAntIndex = antIndex;
}
}
try {
bestIterationAntsList.add(antsArray[bestAntIndex].clone());
} catch (CloneNotSupportedException e) {
e.printStackTrace();
}
//check if rule quality has stagnated by comparing the last best quality with the previous one
if(bestIterationAntsList.size() > 1){
if(((Ant) bestIterationAntsList.get(bestIterationAntsList.size()-1)).getRuleQuality() == ((Ant) bestIterationAntsList.get(bestIterationAntsList.size()-2)).getRuleQuality())
deltaCount++;
else
deltaCount = 0;
}else
deltaCount++;
updatePheromone(antsArray[bestAntIndex]);
iteration++;
}
//determine which ant has the best quality
ListIterator li = bestIterationAntsList.listIterator();
int index=0;
bestAntIndex=0;
double bestQuality = 0.0;
while(li.hasNext()){
Object temp = li.next();
if(((Ant) temp).getRuleQuality() >= bestQuality){
bestQuality = ((Ant) temp).getRuleQuality();
bestAntIndex = index;
}
index++;
}
try {
antsFoundRuleList.add(((Ant) bestIterationAntsList.get(bestAntIndex)).clone());
} catch (CloneNotSupportedException e) {
e.printStackTrace();
}
//remove covered cases from the trainingSet
int count=0;
if(bestAntIndex != -1){
li = ((Ant) bestIterationAntsList.get(bestAntIndex)).getInstancesIndexList().listIterator(((Ant) bestIterationAntsList.get(bestAntIndex)).getInstancesIndexList().size());
while(li.hasPrevious()){
Object temp = li.previous();
trainingSet[((Integer) temp).intValue()] = null;
count++;
}
}
DataInstance [] tempTrainingSet = new DataInstance[trainingSet.length-count];
count=0;
for(int x=0; x < trainingSet.length; x++){
if(trainingSet[x] != null)
tempTrainingSet[count++] = trainingSet[x];
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -