⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 evaluator.java

📁 Multi-label classification 和weka集成
💻 JAVA
字号:
package mulan.evaluation;
import java.util.Random;

import mulan.classifier.AbstractMultiLabelClassifier;
import mulan.classifier.MultiLabelClassifier;
import mulan.classifier.Prediction;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;


/**
 * Evaluator - responsible for generating evaluation data
 * @author rofr
 *
 */
public class Evaluator
{
	public static final int DEFAULTFOLDS = 10;
	
	/**
	 * Seed to random number generator. Needed to reproduce crossvalidation randomization.
	 * Default is 1 
	 */
	protected int seed; 
	
	public Evaluator()
	{
		this(1);
	}
	
	public Evaluator(int seed)
	{
		this.seed = seed;
	}
	
	public Evaluation crossValidate(MultiLabelClassifier classifier, Instances dataset)
	throws Exception
	{
		return crossValidate(classifier, dataset, DEFAULTFOLDS);
	}
	
	
	public Evaluation crossValidate(MultiLabelClassifier classifier, Instances dataset, int numFolds)
	throws Exception
	{
		if (numFolds == -1) numFolds = dataset.numInstances();
		LabelBasedEvaluation[]  labelBased = new LabelBasedEvaluation[numFolds]; 
		ExampleBasedEvaluation[]  exampleBased = new ExampleBasedEvaluation[numFolds];
		LabelRankingBasedEvaluation[] rankingBased=new LabelRankingBasedEvaluation[numFolds];
		Random random = new Random(seed);
		
		Instances workingSet = new Instances(dataset);
		workingSet.randomize(random);
		for(int i = 0; i < numFolds; i++)
		{
			Instances train = workingSet.trainCV(numFolds, i, random);  
			Instances test  = workingSet.testCV(numFolds, i);
			AbstractMultiLabelClassifier clone = 
				(AbstractMultiLabelClassifier) Classifier.makeCopy((Classifier) classifier);
			clone.buildClassifier(train);
			Evaluation evaluation = evaluate(clone, test);
			labelBased[i] = evaluation.getLabelBased();
			exampleBased[i] = evaluation.getExampleBased();
			rankingBased[i] = evaluation.getRankingBased();
		}
		
		return new CrossValidation(
				new LabelBasedCrossValidation(labelBased),
				new ExampleBasedCrossValidation(exampleBased),
				new LabelRankingBasedCrossValidation(rankingBased),
				numFolds); 

	}
	
	public IntegratedCrossvalidation crossValidateAll(MultiLabelClassifier classifier, Instances dataset, int numFolds)
	throws Exception
	{
		if (numFolds == -1) numFolds = dataset.numInstances();
		IntegratedEvaluation[] integrated=new IntegratedEvaluation[numFolds];
		Random random = new Random(seed);
		
		Instances workingSet = new Instances(dataset);
		workingSet.randomize(random);
		for(int i = 0; i < numFolds; i++)
		{
			Instances train = workingSet.trainCV(numFolds, i, random);  
			Instances test  = workingSet.testCV(numFolds, i);
			AbstractMultiLabelClassifier clone = 
				(AbstractMultiLabelClassifier) Classifier.makeCopy((Classifier) classifier);
			//long start = System.currentTimeMillis();
			clone.buildClassifier(train);
			//long end = System.currentTimeMillis();
			//System.out.print(i + "Buildclassifier Time: " + (end - start) + "\n");
			//start = System.currentTimeMillis();
			integrated[i] = evaluateAll(clone, test);
			//end = System.currentTimeMillis();
			//System.out.print(i + "Evaluation Time: " + (end - start) + "\n");
		}
		return new IntegratedCrossvalidation(integrated); 
	}
	
	public IntegratedCrossvalidation[] crossvalidateOverThreshold(
			BinaryPrediction[][][] predictions, Instances dataset, double start, double increment,
			int steps, int numFolds) throws Exception {
		IntegratedCrossvalidation[] crossvalidations = new IntegratedCrossvalidation[steps];

		double threshold = start;
		for (int i = 0; i < steps; i++) { //for every step
			crossvalidations[i] = new IntegratedCrossvalidation(numFolds);
			for (int l = 0; l < numFolds; l++) { //for every fold that has been evaluated
				//calculate the predictions based on threshold
				for (int j = 0; j < predictions[l].length; j++) {

					boolean flag = false;
					double[] confidences = new double[predictions[l][0].length];

					for (int k = 0; k < predictions[l][0].length; k++) {
						confidences[k] = predictions[l][j][k].confidenceTrue;
						if (predictions[l][j][k].confidenceTrue >= threshold) {
							predictions[l][j][k].predicted = true;
							flag = true;
						} else {
							predictions[l][j][k].predicted = false;
						}
					}
					//assign the class with the greater confidence
					if (flag == false) {
						int index = Utils.maxIndex(confidences);
						predictions[l][j][index].predicted = true;
					}
				}
				//assign the prediction to the l th fold of this step's crossvalidation
				crossvalidations[i].folds[l] = new IntegratedEvaluation(predictions[l]);
			}
			crossvalidations[i].computeMeasures();
			threshold += increment; //increase threshold for the next step
		}

		return crossvalidations;

	}

	public IntegratedCrossvalidation[] crossvalidateOverThreshold(MultiLabelClassifier classifier,
			Instances dataset, double start, double increment, int steps, int numFolds)
			throws Exception {
		//create a crossvalidation of the classifier in order to get predictions
		IntegratedCrossvalidation cv = crossValidateAll(classifier, dataset, numFolds);
		BinaryPrediction[][][] predictions2 = new BinaryPrediction[numFolds][][];
		for (int i = 0; i < numFolds; i++) {
			predictions2[i] = cv.folds[i].predictions;
		}

return crossvalidateOverThreshold(predictions2, dataset, start, increment, steps,numFolds);
}
	
	protected BinaryPrediction[][] getPredictions(MultiLabelClassifier classifier, Instances dataset)
	throws Exception
	{
		BinaryPrediction[][] predictions = 
			new BinaryPrediction[dataset.numInstances()][classifier.getNumLabels()];
		
		for(int i = 0; i < dataset.numInstances(); i++)
		{
			Instance instance = dataset.instance(i);
			Prediction result = classifier.predict(instance);
			//System.out.println(java.util.Arrays.toString(result.getConfidences()));
			for(int j = 0; j < classifier.getNumLabels(); j++)
			{
				int classIdx = dataset.numAttributes() - classifier.getNumLabels() + j;
				String classValue = dataset.attribute(classIdx).value((int) instance.value(classIdx));
                                boolean actual = classValue.equals("1");
				predictions[i][j] = new BinaryPrediction(
							result.getPrediction(j), 
							actual, 
							result.getConfidence(j));
			}
		}
		return predictions;
	}
	

	public IntegratedEvaluation[] evaluateOverThreshold(BinaryPrediction[][] predictions,
											  Instances dataset,
											  double start,
											  double increment,
											  int steps)
	throws Exception
	{
		IntegratedEvaluation[] evaluations = new IntegratedEvaluation[steps];
		
		double threshold = start;
		for(int i = 0; i < steps; i++)
		{
			for(int j = 0; j < predictions.length; j++)
				for(int k = 0; k < predictions[0].length; k++)
					predictions[j][k].predicted = predictions[j][k].confidenceTrue >= threshold;
			threshold += increment;
			evaluations[i] = new IntegratedEvaluation(predictions);
		}
		
		return evaluations;
		
	}
	
	public IntegratedEvaluation[] evaluateOverThreshold(MultiLabelClassifier classifier, 
											  Instances dataset, 
											  double start, 
											  double increment, 
											  int steps)
	throws Exception
	{
		BinaryPrediction[][] predictions = getPredictions(classifier, dataset);
		return evaluateOverThreshold(predictions, dataset, start, increment, steps);
	}
	
	public Evaluation evaluate(BinaryPrediction[][] predictions)
	throws Exception
	{
		return new Evaluation(
				new LabelBasedEvaluation(predictions),
				new ExampleBasedEvaluation(predictions),
				new LabelRankingBasedEvaluation(predictions));
	}
	
	
	public Evaluation evaluate(MultiLabelClassifier classifier, Instances dataset)
	throws Exception
	{
		BinaryPrediction[][] predictions = getPredictions(classifier, dataset);
		return evaluate(predictions);
	}
	
	public IntegratedEvaluation evaluateAll(MultiLabelClassifier classifier, Instances dataset)
	throws Exception
	{
		BinaryPrediction[][] predictions = getPredictions(classifier, dataset);
		return new IntegratedEvaluation(predictions);
	}
	
	public ExampleBasedEvaluation evaluateExample(MultiLabelClassifier classifier, Instances dataset)
	throws Exception
	{
		BinaryPrediction[][] predictions = getPredictions(classifier, dataset);
		return new ExampleBasedEvaluation(predictions);
	}
	
	public LabelRankingBasedEvaluation evaluateRanking(MultiLabelClassifier classifier, Instances dataset)
	throws Exception
	{
		BinaryPrediction[][] predictions = getPredictions(classifier, dataset);
		return new LabelRankingBasedEvaluation(predictions);
	}
	
	public LabelBasedEvaluation evaluateLabel(MultiLabelClassifier classifier, Instances dataset)
	throws Exception
	{
		BinaryPrediction[][] predictions = getPredictions(classifier, dataset);
		return new LabelBasedEvaluation(predictions);
	}
	
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -