⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 crossvalidation.java

📁 Short description: GUI Ant-Miner is a tool for extracting classification rules from data. It is an u
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
package guiAntMiner;

import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.LinkedList;
import java.util.List;
import java.util.ListIterator;
import java.util.Random;

import javax.swing.JOptionPane;


/**
 * Copyright (C) 2007  Fernando Meyer
 * 
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * A full copy of the license is available in gpl.txt and online at
 * http://www.gnu.org/licenses/gpl.txt
 */

public class CrossValidation implements Runnable {
	private Attribute [] attributesArray;
	private DataInstance [] dataInstancesArray;
	
	private int folds;
	
	private DataInstance [] testSet;
	private DataInstance [] trainingSet;
	private int numAnts;
	private double [][] pheromoneArray;
	private int numClasses;
	private int [][][] freqTij;
	private double [][] infoTij;
	private int [] freqT;
	private int [] control;
	private double [][] hArray;
	private double [][] probabilitiesArray;
	private boolean [][] unusableAttributeVsValueArray;
	private int minCasesRule;
	private int convergenceTest;
	private int numIterations;
	private int maxUncoveredCases;
	private GUIAntMinerJFrame caller;
	private boolean interrupted;
	private Thread cvThread;
	
	public CrossValidation(GUIAntMinerJFrame caller){
		this.caller = caller;
		interrupted = false;
	}
	
	public void setAttributesArray(Attribute [] attributesArray){
		this.attributesArray = attributesArray;
	}
	public void setDataInstancesArray(DataInstance [] dataInstancesArray){
		this.dataInstancesArray = dataInstancesArray;
	}
	public void setNumAnts(int numAnts){
		this.numAnts = numAnts;
	}
	public void setFolds(int folds){
		this.folds = folds;
	}
	public void setMinCasesRule(int minCasesRule){
		this.minCasesRule = minCasesRule;
	}
	public void setConvergenceTest(int convergenceTest){
		this.convergenceTest = convergenceTest;
	}
	public void setNumIterations(int numIterations){
		this.numIterations = numIterations;
	}
	public void setMaxUncoveredCases(int maxUncoveredCases) {
		this.maxUncoveredCases = maxUncoveredCases;
	}
	
	public void start() {
	  	cvThread = new Thread(this);
	  	try {
			initialize();
		  	cvThread.start();
		  	caller.getJProgressBar1().setIndeterminate(true);
		} catch (Exception e) {
			e.printStackTrace();
		}
	}
	
	public void stop(){
		cvThread = null;
		interrupted = true;
	}
	
	private void initialize() throws Exception{
		int n=0;
		int arraysSize;
		
		if(numAnts==0 || folds<2 || minCasesRule==0 || 
				convergenceTest==0 || numIterations==0 || maxUncoveredCases==0){
			caller.getJProgressBar1().setIndeterminate(false);
			caller.setIsClassifying(false);
			JOptionPane.showMessageDialog(null, "<html><font face=Dialog size=3>At least one of the parameters is invalid.</font></html>", "Error", JOptionPane.ERROR_MESSAGE);
			throw new InvalidArgumentException();
		}
		
		arraysSize = attributesArray.length - 1;
		
		numClasses = attributesArray[arraysSize].getTypes().length;

		freqT = new int[numClasses];
		
		//determine the size of the data sets for cross-validation
		int groupSize = dataInstancesArray.length / folds;
		if(dataInstancesArray.length % folds == 0)
			groupSize--;
		
		control = new int[folds];
		for(n=0; n < folds; n++)
			control[n] = groupSize;
		
		pheromoneArray = new double[arraysSize][];
		freqTij = new int[arraysSize][][]; //freqTij[noOfAttributes][noOfValues][noOfClasses]
		infoTij = new double[arraysSize][];
		hArray = new double[arraysSize][];
		unusableAttributeVsValueArray = new boolean[arraysSize][];
		probabilitiesArray = new double[arraysSize][];
		
		for(n=0; n < arraysSize; n++){
			pheromoneArray[n] = new double[attributesArray[n].getTypes().length];
			freqTij[n] = new int[attributesArray[n].getTypes().length][numClasses];
			infoTij[n] = new double[attributesArray[n].getTypes().length];
			hArray[n] = new double[attributesArray[n].getTypes().length];
			unusableAttributeVsValueArray[n] = new boolean[attributesArray[n].getTypes().length];
			probabilitiesArray[n] = new double[attributesArray[n].getTypes().length];
		}
	}
		
	public void run(){
		Thread currentThread = Thread.currentThread();
		
		Random random = new Random();
		boolean goOn;
		
		double [][] pheromoneTempArray = new double[attributesArray.length-1][];
		
		List<Double> accuracyRatesList = new LinkedList<Double>();
		List<Double> numberOfTermsList = new LinkedList<Double>();
		List<Double> numberOfRulesList = new LinkedList<Double>();
		
		double totalTestAccuracyRate=0.0;
		double totalTrainingAccuracyRate=0.0;
		
		Date date = new Date();
		
		caller.getJTextArea1().setLineWrap(true);
		printHeader();
		
		group();
		
		for(int crossValidation=0; crossValidation < folds && currentThread == cvThread; crossValidation++){
			System.out.println("Cross Validation "+(crossValidation+1)+" of "+folds);
			
			Date date2 = new Date();
			
			splitDataSet(crossValidation);
			DataInstance [] trainingSetClone = (DataInstance[])trainingSet.clone();
			
			int bestAntIndex=-1;
			List<Object> bestIterationAntsList = new ArrayList<Object>();
			List antsFoundRuleList = new ArrayList();
			
			while(trainingSet.length > maxUncoveredCases && currentThread == cvThread){
				System.out.print(".");
				bestAntIndex=0;
				
				initializePheromoneTrails();
				
				calculateFreqTij();
				calculateInfoTij();

				int iteration, deltaCount;
				iteration=deltaCount=0;
				
				while(iteration < numIterations && deltaCount < convergenceTest) {
					
					double bestQuality = 0;
					bestAntIndex = 0;
					
					Ant [] antsArray = new Ant[numAnts];
					for(int x=0; x < numAnts; x++){
						antsArray[x] = new Ant(attributesArray.length-1);
					}
					
					for(int antIndex=0; antIndex < numAnts; antIndex++){
						Ant currentAnt = antsArray[antIndex];
						
						for(int n=0; n < pheromoneTempArray.length; n++)
							pheromoneTempArray[n] = (double[]) pheromoneArray[n].clone();
						
						//attDistinct[attribute i] contains the number of distinct values for attribute i
						int [] attDistinctLeft = new int[attributesArray.length-1];
						for(int i=0; i <attDistinctLeft.length; i++)
							attDistinctLeft[i] = attributesArray[i].getTypes().length-1;
						
						for(int i=0; i < unusableAttributeVsValueArray.length; i++)
							for(int j=0; j < unusableAttributeVsValueArray[i].length; j++)
								unusableAttributeVsValueArray[i][j] = false;
						
						goOn = true;
						while(goOn && currentThread == cvThread){
							calculateHeuristicFunction(currentAnt);
							calculateProbabilities(currentAnt, pheromoneTempArray);
							
							float randomNumber = (random.nextInt() << 1 >>> 1) % 101;
							randomNumber /= 100;
							boolean found = false;
							double sum = 0.0;
							for(int i=0; i < probabilitiesArray.length; i++){
								for(int j=0; j < probabilitiesArray[i].length && !found; j++){
									sum += probabilitiesArray[i][j];
									if(sum >= randomNumber){
										if(!ruleConstructor(currentAnt, i, j)){
											//set to true so that this ant does not try to use term ij again
											unusableAttributeVsValueArray[i][j] = true;
											attDistinctLeft[i]--;
											pheromoneTempArray[i][j] = 0.0;
										}else{
											attDistinctLeft[i] = -1;
										}
										found = true;
									}
								}
							}
							
							//determine if the ant already tried to use all the possible values of attribute a
							for(int a=0; a < attDistinctLeft.length; a++){
								if(attDistinctLeft[a] <= 0){
									currentAnt.getMemory()[a] = 2;
								}
							}
							goOn = false;
							int a=0;
							do{
								if(currentAnt.getMemory()[a] == 0)
									goOn = true;
								a++;
							}while(!goOn && a < currentAnt.getMemory().length);
						}
						
						determineRuleConsequent(currentAnt);
						calculateRuleQuality(currentAnt);
						
						try {
							currentAnt = pruneRule(currentAnt);
						} catch (CloneNotSupportedException e) {
							e.printStackTrace();
						}
						antsArray[antIndex] = currentAnt;
						
						if(currentAnt.getRuleQuality() >= bestQuality){
							bestQuality = currentAnt.getRuleQuality();
							bestAntIndex = antIndex;
						}
					}
					
					try {
						bestIterationAntsList.add(antsArray[bestAntIndex].clone());
					} catch (CloneNotSupportedException e) {
						e.printStackTrace();
					}
					
					//check if rule quality has stagnated by comparing the last best quality with the previous one
					if(bestIterationAntsList.size() > 1){
						if(((Ant) bestIterationAntsList.get(bestIterationAntsList.size()-1)).getRuleQuality() == ((Ant) bestIterationAntsList.get(bestIterationAntsList.size()-2)).getRuleQuality())
							deltaCount++;
						else
							deltaCount = 0;
					}else
						deltaCount++;
					
					updatePheromone(antsArray[bestAntIndex]);
					
					iteration++;
				}
				
				
				//determine which ant has the best quality
				ListIterator li = bestIterationAntsList.listIterator();
				int index=0;
				bestAntIndex=0;
				double bestQuality = 0.0;
				while(li.hasNext()){
					Object temp = li.next();
					if(((Ant) temp).getRuleQuality() >= bestQuality){
						bestQuality = ((Ant) temp).getRuleQuality();
						bestAntIndex = index;
					}
					index++;
				}
				
				try {
					antsFoundRuleList.add(((Ant) bestIterationAntsList.get(bestAntIndex)).clone());
				} catch (CloneNotSupportedException e) {
					e.printStackTrace();
				}

				//remove covered cases from the trainingSet
				int count=0;
				if(bestAntIndex != -1){
					li = ((Ant) bestIterationAntsList.get(bestAntIndex)).getInstancesIndexList().listIterator(((Ant) bestIterationAntsList.get(bestAntIndex)).getInstancesIndexList().size());
					while(li.hasPrevious()){
						Object temp = li.previous();
						trainingSet[((Integer) temp).intValue()] = null;
						count++;
					}
				}
				DataInstance [] tempTrainingSet = new DataInstance[trainingSet.length-count];
				count=0;
				for(int x=0; x < trainingSet.length; x++){
					if(trainingSet[x] != null)
						tempTrainingSet[count++] = trainingSet[x];
				}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -