📄 crossvalidation.java
字号:
trainingSet = tempTrainingSet;
bestIterationAntsList.clear();
}
caller.getJTextArea1().append("\n------------------ Cross Validation #"+(crossValidation+1)+"------------------\n\n");
caller.getJTextArea1().append("Cases in the training set: "+trainingSetClone.length+"\n");
if(caller.getJCheckBox3IsSelected()){
caller.getJTextArea1().append("\n");
for(int x=0; x < trainingSetClone.length; x++){
caller.getJTextArea1().append(getInstanceString(trainingSetClone[x].getValues())+"\n");
}
}
caller.getJTextArea1().append("\nCases in the test set: "+testSet.length+"\n");
if(caller.getJCheckBox2IsSelected()){
caller.getJTextArea1().append("\n");
for(int x=0; x < testSet.length; x++){
caller.getJTextArea1().append(getInstanceString(testSet[x].getValues())+"\n");
}
}
numberOfRulesList.add(new Double((double)(antsFoundRuleList.size()+1)));
int sum=0;
ListIterator li = antsFoundRuleList.listIterator();
while(li.hasNext()){
sum += ruleSize(((Ant) li.next()).getRulesArray());
}
numberOfTermsList.add(new Double((double)sum));
caller.getJTextArea1().append("\nRules: "+(antsFoundRuleList.size()+1)+"\n\n");
//initializes freqT, which contains the number of cases that identify a class in the trainingSet
for(int n=0; n < numClasses; n++){
freqT[n] = 0;
}
int classIndex;
int greatest=0, defaultClassIndex=0;
for(int n=0; n < trainingSet.length; n++){
classIndex = trainingSet[n].getValues()[trainingSet[n].getValues().length-1];
freqT[classIndex]++;
if(freqT[classIndex] > greatest){
greatest = freqT[classIndex];
defaultClassIndex = classIndex;
}
}
double trainingAccuracyRate = calculateAccuracyRate(trainingSetClone, antsFoundRuleList, defaultClassIndex);
totalTrainingAccuracyRate += trainingAccuracyRate;
double testAccuracyRate = calculateAccuracyRate(testSet, antsFoundRuleList, defaultClassIndex);
totalTestAccuracyRate += testAccuracyRate;
accuracyRatesList.add(new Double(testAccuracyRate));
for(ListIterator i=antsFoundRuleList.listIterator(); i.hasNext();){
Object antObj = i.next();
int [] rule = ((Ant)antObj).getRulesArray();
caller.getJTextArea1().append(getRuleString(rule, ((Ant)antObj).getRuleConsequent()) + "\n");
}
caller.getJTextArea1().append("Default rule: "+attributesArray[attributesArray.length-1].getTypes()[defaultClassIndex]+"\n");
System.out.println("\nAccuracy rate on the training set: "+trainingAccuracyRate+" %");
System.out.println("Accuracy rate on the test set: "+testAccuracyRate+" %");
caller.getJTextArea1().append("\nAccuracy rate on the training set: "+trainingAccuracyRate+" %\n");
caller.getJTextArea1().append("Accuracy rate on the test set: "+testAccuracyRate+" %\n\n");
caller.getJTextArea1().append("Time taken: "+((new Date().getTime() - date2.getTime())/1000.0)+" s.\n");
System.out.println("Time taken: "+((new Date().getTime() - date2.getTime())/1000.0)+" s.\n");
}
if(!interrupted){
DecimalFormat myFormatter = new DecimalFormat("###.##");
caller.getJTextArea1().append("\n-------------------------------------------------------------------\n");
caller.getJTextArea1().append(" "+folds+"-Fold Cross Validation Results\n");
caller.getJTextArea1().append("-------------------------------------------------------------------\n");
caller.getJTextArea1().append("Accuracy Rate on Test Set | Rules Number | Conditions Number \n");
caller.getJTextArea1().append("-------------------------------------------------------------------\n");
caller.getJTextArea1().append(" "+myFormatter.format(totalTestAccuracyRate/folds)+"% +/- "+myFormatter.format(calculateVariance(accuracyRatesList,(totalTestAccuracyRate/folds),folds))+"%");
double total=0.0;
ListIterator li = numberOfRulesList.listIterator();
while(li.hasNext()){
total += ((Double) li.next()).doubleValue();
}
caller.getJTextArea1().append(" | "+myFormatter.format(total/folds)+" +/- "+myFormatter.format(calculateVariance(numberOfRulesList,(total/folds),folds)));
total=0.0;
li = numberOfTermsList.listIterator();
while(li.hasNext()){
total += ((Double) li.next()).doubleValue();
}
caller.getJTextArea1().append(" | "+myFormatter.format(total/folds)+" +/- "+myFormatter.format(calculateVariance(numberOfTermsList,(total/folds),folds)));
caller.getJTextArea1().append("\n\nTotal elapsed time: "+((new Date().getTime() - date.getTime())/1000)+" s.\n");
}else
caller.getJTextArea1().append("\nCLASSIFICATION HAS BEEN CANCELED!");
caller.getJTextArea1().setCaretPosition(caller.getJTextArea1().getText().length());
caller.getJProgressBar1().setIndeterminate(false);
caller.setIsClassifying(false);
}
/**
*
*/
private void printHeader(){
if(caller.getJCheckBox1IsSelected())
caller.getJTextArea1().setText(null);
caller.getJTextArea1().append("=== Run Information ===\n\n");
caller.getJTextArea1().append("Relation: " + caller.getJLabel2().getText() + "\n");
caller.getJTextArea1().append("Instances: " + dataInstancesArray.length + "\n");
caller.getJTextArea1().append("Attributes: " + attributesArray.length + "\n");
for(int x=0; x < attributesArray.length; x++){
caller.getJTextArea1().append(" " + attributesArray[x].getAttributeName() + "\n");
}
caller.getJTextArea1().append("\nUser-defined Parameters\n\n");
caller.getJTextArea1().append("Folds: "+folds+"\n");
caller.getJTextArea1().append("Number of Ants: "+numAnts+"\n");
caller.getJTextArea1().append("Min. Cases per Rule: "+minCasesRule+"\n");
caller.getJTextArea1().append("Max. uncovered Cases: "+maxUncoveredCases+"\n");
caller.getJTextArea1().append("Rules for Convergence: "+convergenceTest+"\n");
caller.getJTextArea1().append("Number of Iterations: "+numIterations+"\n");
}
/**
* @param instancesArray
* @param antsList
* @param defaultClassIndex
* @return
*/
private double calculateAccuracyRate(DataInstance [] instancesArray, List antsList, int defaultClassIndex){
int correctlyCovered = 0;
ListIterator liAnt;
boolean covering, classesCompared;
for(int x=0; x < instancesArray.length; x++){
liAnt = antsList.listIterator();
classesCompared = false;
while(liAnt.hasNext() && !classesCompared){
Object antObj = liAnt.next();
int [] rulesArray = ((Ant) antObj).getRulesArray();
covering = true;
for(int x2=0; x2 < rulesArray.length && covering; x2++){
if(rulesArray[x2] != -1)
if(rulesArray[x2] == instancesArray[x].getValues()[x2])
covering = true;
else
covering = false;
}
//if the rule covered the case, check if the rule consequent matches the class of the case
if(covering){
if(instancesArray[x].getValues()[rulesArray.length] == ((Ant)antObj).getRuleConsequent())
correctlyCovered++;
classesCompared = true;
//if the case was not covered by any rule so far and there is only the default rule left,
//check if the case class matches the default rule consequent
}else if(!liAnt.hasNext()){
if(instancesArray[x].getValues()[rulesArray.length] == attributesArray[attributesArray.length-1].getIntTypesArray()[defaultClassIndex])
correctlyCovered++;
classesCompared = true;
}
}
}
Double result = new Double(((double)correctlyCovered)/((double)instancesArray.length));
if(Double.isNaN(result.doubleValue())){
result = new Double(0);
}
return result.doubleValue()*100;
}
/**
* @param valuesList
* @param average
* @param folds
* @return
*/
private double calculateVariance(List valuesList, double average, int folds){
double calc = 0.0;
ListIterator li = valuesList.listIterator();
while(li.hasNext()){
calc += Math.pow(((Double) li.next()).doubleValue() - average, 2.0);
}
calc /= folds - 1;
calc /= folds;
calc = Math.sqrt(calc);
return calc;
}
/**
* Assigns each case a number with a value between 0 and the number of cross-validation folds -1.
*/
private void group(){
Random random = new Random();
int randomNumber;
loosenGroups();
for(int n=0; n < dataInstancesArray.length; n++)
while(dataInstancesArray[n].getCrossValidationGroup() == -1){
randomNumber = (random.nextInt() << 1 >>> 1) % folds;
if(control[randomNumber] >= 0){
control[randomNumber]--;
dataInstancesArray[n].setCrossValidationGroup(randomNumber);
}
}
}
/**
* Calculates the number of instances in a certain cross validation group.
* @param group
* @return
*/
private int noOfInstancesInGroup(int group){
int count=0;
for(int n=0; n < dataInstancesArray.length; n++){
if(dataInstancesArray[n].getCrossValidationGroup() == group)
count++;
}
return count;
}
/**
* Splits dataInstancesArray into testSet and trainingSet.
* @param crossValidation
*/
private void splitDataSet(int crossValidation){
int testSetIndex=0,trainingSetIndex=0;
testSet = new DataInstance[noOfInstancesInGroup(crossValidation)];
trainingSet = new DataInstance[dataInstancesArray.length - noOfInstancesInGroup(crossValidation)];
for(int n=0; n < dataInstancesArray.length; n++){
try {
if(dataInstancesArray[n].getCrossValidationGroup() == crossValidation)
testSet[testSetIndex++] = (DataInstance)dataInstancesArray[n].clone();
else
trainingSet[trainingSetIndex++] = (DataInstance)dataInstancesArray[n].clone();
} catch (CloneNotSupportedException e) {
e.printStackTrace();
}
}
}
/**
* Unsets previously formed groups by applying -1 to the value of each case group
*/
private void loosenGroups(){
for(int n=0; n < dataInstancesArray.length; n++)
dataInstancesArray[n].setCrossValidationGroup(-1);
}
/**
* Initializes trails with the same quantity of pheromone
*/
private void initializePheromoneTrails(){
int totalDistinct = totalDistinct();
for(int n=0; n < pheromoneArray.length; n++){
for(int n2=0; n2 < attributesArray[n].getTypes().length; n2++)
pheromoneArray[n][n2] = log2(numClasses)/totalDistinct;
}
}
/**
* Initializes freqTij, which contains the number of cases that identify a class in the trainingSet.
*/
private void calculateFreqTij(){
for(int n=0; n < trainingSet.length; n++){
int attIndex=0,attValueIndex,classIndex;
for(int n2=0; n2 < trainingSet[n].getValues().length-1; n2++){
attValueIndex = trainingSet[n].getValues()[n2];
classIndex = trainingSet[n].getClassValue();
if(attValueIndex > -1)
freqTij[attIndex][attValueIndex][classIndex]++;
attIndex++;
}
}
}
/**
* Initializes infoTij
*/
private void calculateInfoTij(){
for(int n=0; n < freqTij.length; n++){
for(int n2=0; n2 < freqTij[n].length; n2++){
int sum=0;
double hw=0;
for(int x=0; x < numClasses; x++)
sum += freqTij[n][n2][x];
for(int x=0; x < numClasses; x++)
if(freqTij[n][n2][x] != 0 && sum !=0)
hw -= (double) freqTij[n][n2][x]/sum * log2((double) freqTij[n][n2][x]/sum);
infoTij[n][n2] = hw;
}
}
}
/**
* Calculates the heuristic function, given by:
* hArray -> hij = (log2 k - H(W|Ai = Vij)) / (S xm (S log2 k - H(W|Am = Vmn)))
* @param ant
*/
private void calculateHeuristicFunction(Ant ant){
double sum=0.0;
boolean termOccurs;
int instanceClass;
for(int c=0; c < attributesArray.length-1; c++){
if(ant.getMemory()[c] == 0) //if the attribute hasn't been used...
for(int d=0; d < infoTij[c].length; d++)
sum += log2(numClasses) - infoTij[c][d];
}
for(int i=0; i < hArray.length; i++){
for(int j=0; j < hArray[i].length; j++){
if(!unusableAttributeVsValueArray[i][j]){
termOccurs = false;
//if all cases with term ij belong to the same class, then infoTij should be zero
instanceClass = trainingSet[0].getClassValue();
boolean isEqual = true;
for(int c=0; c < trainingSet.length && isEqual; c++){
if(trainingSet[c].getValues()[i] == attributesArray[i].getIntTypesArray()[j]){
termOccurs = true;
//compare the last instance class with the current instance class
if(instanceClass == trainingSet[c].getClassValue())
instanceClass = trainingSet[c].getClassValue();
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -