⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 crossvalidation.java

📁 Short description: GUI Ant-Miner is a tool for extracting classification rules from data. It is an u
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
				trainingSet = tempTrainingSet;
				
				bestIterationAntsList.clear();
			}
			
			
			
			caller.getJTextArea1().append("\n------------------ Cross Validation #"+(crossValidation+1)+"------------------\n\n");
			
			caller.getJTextArea1().append("Cases in the training set: "+trainingSetClone.length+"\n");
			if(caller.getJCheckBox3IsSelected()){
				caller.getJTextArea1().append("\n");
				for(int x=0; x < trainingSetClone.length; x++){
					caller.getJTextArea1().append(getInstanceString(trainingSetClone[x].getValues())+"\n");
				}
			}
			
			caller.getJTextArea1().append("\nCases in the test set:     "+testSet.length+"\n");
			if(caller.getJCheckBox2IsSelected()){
				caller.getJTextArea1().append("\n");
				for(int x=0; x < testSet.length; x++){
					caller.getJTextArea1().append(getInstanceString(testSet[x].getValues())+"\n");
				}
			}
			
			numberOfRulesList.add(new Double((double)(antsFoundRuleList.size()+1)));
			
			int sum=0;
			ListIterator li = antsFoundRuleList.listIterator();
			while(li.hasNext()){
				sum += ruleSize(((Ant) li.next()).getRulesArray());
			}
			numberOfTermsList.add(new Double((double)sum));
			
			caller.getJTextArea1().append("\nRules: "+(antsFoundRuleList.size()+1)+"\n\n");
				
			
			//initializes freqT, which contains the number of cases that identify a class in the trainingSet
			for(int n=0; n < numClasses; n++){
				freqT[n] = 0;
			}
			int classIndex;
			int greatest=0, defaultClassIndex=0;
			for(int n=0; n < trainingSet.length; n++){
				classIndex = trainingSet[n].getValues()[trainingSet[n].getValues().length-1];
				freqT[classIndex]++;
				if(freqT[classIndex] > greatest){
					greatest = freqT[classIndex];
					defaultClassIndex = classIndex;
				}
			}
			
			double trainingAccuracyRate = calculateAccuracyRate(trainingSetClone, antsFoundRuleList, defaultClassIndex);
			totalTrainingAccuracyRate += trainingAccuracyRate;
			
			double testAccuracyRate = calculateAccuracyRate(testSet, antsFoundRuleList, defaultClassIndex);
			totalTestAccuracyRate += testAccuracyRate;
			
			accuracyRatesList.add(new Double(testAccuracyRate));
			
			for(ListIterator i=antsFoundRuleList.listIterator(); i.hasNext();){
				Object antObj = i.next();
				int [] rule = ((Ant)antObj).getRulesArray();
				caller.getJTextArea1().append(getRuleString(rule, ((Ant)antObj).getRuleConsequent()) + "\n");
			}
			caller.getJTextArea1().append("Default rule: "+attributesArray[attributesArray.length-1].getTypes()[defaultClassIndex]+"\n");
			
			System.out.println("\nAccuracy rate on the training set: "+trainingAccuracyRate+" %");
			System.out.println("Accuracy rate on the test set:     "+testAccuracyRate+" %");
			
			caller.getJTextArea1().append("\nAccuracy rate on the training set: "+trainingAccuracyRate+" %\n");
			caller.getJTextArea1().append("Accuracy rate on the test set:     "+testAccuracyRate+" %\n\n");
			caller.getJTextArea1().append("Time taken:                        "+((new Date().getTime() - date2.getTime())/1000.0)+" s.\n");

			System.out.println("Time taken: "+((new Date().getTime() - date2.getTime())/1000.0)+" s.\n");
			
		}		
		
		if(!interrupted){
			DecimalFormat myFormatter = new DecimalFormat("###.##");		
			
			caller.getJTextArea1().append("\n-------------------------------------------------------------------\n");
			caller.getJTextArea1().append("                 "+folds+"-Fold Cross Validation Results\n");
			caller.getJTextArea1().append("-------------------------------------------------------------------\n");
			caller.getJTextArea1().append("Accuracy Rate on Test Set |   Rules Number   | Conditions Number   \n");
			caller.getJTextArea1().append("-------------------------------------------------------------------\n");
			caller.getJTextArea1().append("    "+myFormatter.format(totalTestAccuracyRate/folds)+"%  +/- "+myFormatter.format(calculateVariance(accuracyRatesList,(totalTestAccuracyRate/folds),folds))+"%");		
			
			double total=0.0;
			ListIterator li = numberOfRulesList.listIterator();
			while(li.hasNext()){
				total += ((Double) li.next()).doubleValue();
			}
			caller.getJTextArea1().append("     |  "+myFormatter.format(total/folds)+"  +/- "+myFormatter.format(calculateVariance(numberOfRulesList,(total/folds),folds)));
			
			total=0.0;
			li = numberOfTermsList.listIterator();
			while(li.hasNext()){
				total += ((Double) li.next()).doubleValue();
			}
			caller.getJTextArea1().append("  |   "+myFormatter.format(total/folds)+"  +/- "+myFormatter.format(calculateVariance(numberOfTermsList,(total/folds),folds)));
			
			caller.getJTextArea1().append("\n\nTotal elapsed time: "+((new Date().getTime() - date.getTime())/1000)+" s.\n");
		}else
			caller.getJTextArea1().append("\nCLASSIFICATION HAS BEEN CANCELED!");
		
		caller.getJTextArea1().setCaretPosition(caller.getJTextArea1().getText().length());
		caller.getJProgressBar1().setIndeterminate(false);
		caller.setIsClassifying(false);
		
	}
	
	/**
	 * 
	 */
	private void printHeader(){
		if(caller.getJCheckBox1IsSelected())
			caller.getJTextArea1().setText(null);
		caller.getJTextArea1().append("=== Run Information ===\n\n");
		caller.getJTextArea1().append("Relation:   " + caller.getJLabel2().getText() + "\n");
		caller.getJTextArea1().append("Instances:  " + dataInstancesArray.length + "\n");
		caller.getJTextArea1().append("Attributes: " + attributesArray.length + "\n");
		for(int x=0; x < attributesArray.length; x++){
			caller.getJTextArea1().append("            " + attributesArray[x].getAttributeName() + "\n");
		}
		caller.getJTextArea1().append("\nUser-defined Parameters\n\n");
		caller.getJTextArea1().append("Folds:                 "+folds+"\n");
		caller.getJTextArea1().append("Number of Ants:        "+numAnts+"\n");
		caller.getJTextArea1().append("Min. Cases per Rule:   "+minCasesRule+"\n");
		caller.getJTextArea1().append("Max. uncovered Cases:  "+maxUncoveredCases+"\n");
		caller.getJTextArea1().append("Rules for Convergence: "+convergenceTest+"\n");
		caller.getJTextArea1().append("Number of Iterations:  "+numIterations+"\n");
	}
	
	/**
	 * @param instancesArray
	 * @param antsList
	 * @param defaultClassIndex
	 * @return
	 */
	private double calculateAccuracyRate(DataInstance [] instancesArray, List antsList, int defaultClassIndex){
		int correctlyCovered = 0;
		ListIterator liAnt;
		boolean covering, classesCompared;
		
		for(int x=0; x < instancesArray.length; x++){
			liAnt = antsList.listIterator();
			classesCompared = false;
			while(liAnt.hasNext() && !classesCompared){
				Object antObj = liAnt.next();
				int [] rulesArray = ((Ant) antObj).getRulesArray();
				covering = true;
				for(int x2=0; x2 < rulesArray.length && covering; x2++){
					if(rulesArray[x2] != -1)
						if(rulesArray[x2] == instancesArray[x].getValues()[x2])
							covering = true;
						else
							covering = false;
				}
				//if the rule covered the case, check if the rule consequent matches the class of the case 
				if(covering){
					if(instancesArray[x].getValues()[rulesArray.length] == ((Ant)antObj).getRuleConsequent())
						correctlyCovered++;
					classesCompared = true;
				//if the case was not covered by any rule so far and there is only the default rule left,
				//check if the case class matches the default rule consequent
				}else if(!liAnt.hasNext()){
					if(instancesArray[x].getValues()[rulesArray.length] == attributesArray[attributesArray.length-1].getIntTypesArray()[defaultClassIndex])
						correctlyCovered++;
					classesCompared = true;
				}
			}
		}
		Double result = new Double(((double)correctlyCovered)/((double)instancesArray.length));		
		if(Double.isNaN(result.doubleValue())){
			result = new Double(0);
		}
		return result.doubleValue()*100;
	}
	
	/**
	 * @param valuesList
	 * @param average
	 * @param folds
	 * @return
	 */
	private double calculateVariance(List valuesList, double average, int folds){
		double calc = 0.0;
		ListIterator li = valuesList.listIterator();
		while(li.hasNext()){
			calc += Math.pow(((Double) li.next()).doubleValue() - average, 2.0);
		}
		calc /= folds - 1;
		calc /= folds;
		calc = Math.sqrt(calc);
		return calc;
	}
	
	/**
	 * Assigns each case a number with a value between 0 and the number of cross-validation folds -1.
	 */
	private void group(){
		Random random = new Random();
		int randomNumber;
		
		loosenGroups();
		
		for(int n=0; n < dataInstancesArray.length; n++)
			while(dataInstancesArray[n].getCrossValidationGroup() == -1){
				randomNumber = (random.nextInt() << 1 >>> 1) % folds;
				if(control[randomNumber] >= 0){
					control[randomNumber]--;
					dataInstancesArray[n].setCrossValidationGroup(randomNumber);
				}
			}
	}
	
	/**
	 * Calculates the number of instances in a certain cross validation group.
	 * @param group
	 * @return
	 */
	private int noOfInstancesInGroup(int group){
		int count=0;
		for(int n=0; n < dataInstancesArray.length; n++){
			if(dataInstancesArray[n].getCrossValidationGroup() == group)
				count++;
		}
		return count;
	}
	
	/**
	 * Splits dataInstancesArray into testSet and trainingSet.
	 * @param crossValidation
	 */
	private void splitDataSet(int crossValidation){
		int testSetIndex=0,trainingSetIndex=0;
		testSet = new DataInstance[noOfInstancesInGroup(crossValidation)];
		trainingSet = new DataInstance[dataInstancesArray.length - noOfInstancesInGroup(crossValidation)];
		for(int n=0; n < dataInstancesArray.length; n++){
			try {
				if(dataInstancesArray[n].getCrossValidationGroup() == crossValidation)
					testSet[testSetIndex++] = (DataInstance)dataInstancesArray[n].clone();
				else
					trainingSet[trainingSetIndex++] = (DataInstance)dataInstancesArray[n].clone();
			} catch (CloneNotSupportedException e) {
				e.printStackTrace();
			}
		}
	}	
	
	/**
	 * Unsets previously formed groups by applying -1 to the value of each case group 
	 */
	private void loosenGroups(){
		for(int n=0; n < dataInstancesArray.length; n++)
			dataInstancesArray[n].setCrossValidationGroup(-1);
	}
	
	/**
	 * Initializes trails with the same quantity of pheromone
	 */
	private void initializePheromoneTrails(){
		int totalDistinct = totalDistinct();
		for(int n=0; n < pheromoneArray.length; n++){
			for(int n2=0; n2 < attributesArray[n].getTypes().length; n2++)
				pheromoneArray[n][n2] = log2(numClasses)/totalDistinct;
		}
	}
	
	/**
	 * Initializes freqTij, which contains the number of cases that identify a class in the trainingSet.
	 */
	private void calculateFreqTij(){
		for(int n=0; n < trainingSet.length; n++){
			int attIndex=0,attValueIndex,classIndex;
			for(int n2=0; n2 < trainingSet[n].getValues().length-1; n2++){
				attValueIndex = trainingSet[n].getValues()[n2];
				classIndex = trainingSet[n].getClassValue();
				if(attValueIndex > -1)
					freqTij[attIndex][attValueIndex][classIndex]++;
				attIndex++;
			}
		}
	}
	
	/**
	 * Initializes infoTij
	 */
	private void calculateInfoTij(){
		for(int n=0; n < freqTij.length; n++){
			for(int n2=0; n2 < freqTij[n].length; n2++){
				int sum=0;
				double hw=0;
				for(int x=0; x < numClasses; x++)
					sum += freqTij[n][n2][x];
				for(int x=0; x < numClasses; x++)
					if(freqTij[n][n2][x] != 0 && sum !=0)
						hw -= (double) freqTij[n][n2][x]/sum * log2((double) freqTij[n][n2][x]/sum);
				infoTij[n][n2] = hw;
			}
		}
	}
	
	/**
	 * Calculates the heuristic function, given by:
	 * hArray ->  hij = (log2 k - H(W|Ai = Vij)) / (S xm (S log2 k - H(W|Am = Vmn)))
	 * @param ant
	 */
	private void calculateHeuristicFunction(Ant ant){
		double sum=0.0;
		boolean termOccurs;
		int instanceClass;
		for(int c=0; c < attributesArray.length-1; c++){
			if(ant.getMemory()[c] == 0)  //if the attribute hasn't been used...
				for(int d=0; d < infoTij[c].length; d++)
					sum += log2(numClasses) - infoTij[c][d];
		}
		for(int i=0; i < hArray.length; i++){
			for(int j=0; j < hArray[i].length; j++){
				if(!unusableAttributeVsValueArray[i][j]){
					termOccurs = false;
					//if all cases with term ij belong to the same class, then infoTij should be zero									
					instanceClass = trainingSet[0].getClassValue();
					boolean isEqual = true;
					for(int c=0; c < trainingSet.length && isEqual; c++){
						if(trainingSet[c].getValues()[i] == attributesArray[i].getIntTypesArray()[j]){
							termOccurs = true;
							//compare the last instance class with the current instance class
							if(instanceClass == trainingSet[c].getClassValue())
								instanceClass = trainingSet[c].getClassValue();

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -