⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 labelbasedevaluation.java

📁 Multi-label classification 和weka集成
💻 JAVA
字号:
package mulan.evaluation;
import weka.core.Utils;





/**
 * Calculate measures separately for each label and  average the results. Macroaveraging.
 * @author  rofr
 */
public class LabelBasedEvaluation extends EvaluationBase
{

	/**
	 * Constant used as array index to retrieve.
	 */
	public final static int MACRO = 0;
	public final static int MICRO = 1;
	

	//TODO: is this the appropriate default?
	protected final static int DEFAULT = MACRO; 
	
	/**
	 * Keep track of which type of measure to return.
	 */
	protected int averagingMethod;
	
	/**
	 * store both micro and macro averaged measures.
	 */
	protected double[] recall    = new double[2];
	protected double[] precision = new double[2];
	protected double[] fmeasure  = new double[2];
	
	//I think we should leave this measure for completeness even
	//if it is equivalent to hammingloss. Perhaps a user will choose
	//either label or example based for a specific scenario.
	protected double[] accuracy  = new double[2];
	
	
	//Per label measures
	protected double[] labelAccuracy;
	protected double[] labelRecall;  
	protected double[] labelPrecision;
	protected double[] labelFmeasure;
	
	public double accuracy(int label)
	{
		return labelAccuracy[label];
	}

	public double recall(int label)
	{
		return labelRecall[label];
	}
	
	public double precision(int label)
	{
		return labelPrecision[label];
	}
	
	public double fmeasure(int label)
	{
		return labelFmeasure[label];
	}
	
	protected LabelBasedEvaluation(BinaryPrediction[][] predictions)
	throws Exception
	{
		this(predictions, DEFAULT);
	}
	
	/**
	 * Used by crossvalidation subclass.
	 *
	 */
	protected LabelBasedEvaluation()
	{
		super(null);
	}
	
	protected LabelBasedEvaluation(BinaryPrediction[][] predictions, int averagingMethod)
	throws Exception
	{
		super(predictions);
		setAveragingMethod(averagingMethod);
		computeMeasures();
	}

	/**
	 * Compute both micro and macro averages.
	 */
	protected void computeMeasures()
	{
            int numInstances = numInstances();
            int numLabels   = numLabels();

            //Counters are doubles to avoid typecasting
            //when performing divisions. It makes the code a
            //little cleaner but:
            //TODO: run performance tests on counting with doubles
            double[] falsePositives    = new double[numLabels];
            double[] truePositives     = new double[numLabels];
            double[] falseNegatives    = new double[numLabels];
            double[] trueNegatives     = new double[numLabels];

            this.labelAccuracy         = new double[numLabels];
            this.labelRecall           = new double[numLabels];
            this.labelPrecision        = new double[numLabels];
            this.labelFmeasure         = new double[numLabels];


            //Count TP, TN, FP, FN
            for(int i = 0; i < numInstances; i++)
            {
                    for(int j = 0; j < numLabels; j++)
                    {	
                            boolean actual = predictions[i][j].actual;
                            boolean predicted = predictions[i][j].predicted;

                            if (actual && predicted) truePositives[j]++;
                            else if (!actual && !predicted) trueNegatives[j]++;
                            else if (predicted) falsePositives[j]++;
                            else falseNegatives[j]++;
                    }
            }

            //Compute macro averaged measures
            for(int i = 0; i < numLabels; i++)
            {
                    labelAccuracy[i] = (truePositives[i] + trueNegatives[i]) / numInstances; 

                    labelRecall[i] = truePositives[i] + falseNegatives[i] == 0 ? 0
                                    :truePositives[i] / (truePositives[i] + falseNegatives[i]);

                    labelPrecision[i] = truePositives[i] + falsePositives[i] == 0 ? 0
                                    :truePositives[i] / (truePositives[i] + falsePositives[i]);

                    labelFmeasure[i] = computeFMeasure(labelPrecision[i], labelRecall[i]);
            }
        this.accuracy[MACRO]  = Utils.mean(labelAccuracy);
	    this.recall[MACRO]    = Utils.mean(labelRecall);
	    this.precision[MACRO] = Utils.mean(labelPrecision);
	    this.fmeasure[MACRO]  = Utils.mean(labelFmeasure);

	    //Compute micro averaged measures
	    double tp = Utils.sum(truePositives);
	    double tn = Utils.sum(trueNegatives);
	    double fp = Utils.sum(falsePositives);
	    double fn = Utils.sum(falseNegatives);
	    
	    this.accuracy[MICRO]  = (tp + tn) / (numInstances * numLabels);
	    this.recall[MICRO]    = tp + fn == 0 ? 0 : tp / (tp + fn);
	    this.precision[MICRO] = tp + fp == 0 ? 0 : tp / (tp + fp);
	    this.fmeasure[MICRO]  = computeFMeasure(this.precision[MICRO], this.recall[MICRO]);
	}

	/**
	 * @param averagingMethod must be one of MICROAVERAGED or MACROAVERAGED
	 *
	 */
	//TODO: Use Enum instead...
	public void setAveragingMethod(int averagingMethod)
	{
		this.averagingMethod = averagingMethod;
	}

	/**
	 * @return the averagingMethod
	 */
	public int getAveragingMethod()
	{
		return averagingMethod;
	}

	@Override
	public double accuracy()
	{
		return accuracy[averagingMethod];
	}

	@Override
	public double fmeasure()
	{
		return fmeasure[averagingMethod];
	}

	@Override
	public double precision()
	{
		return precision[averagingMethod];
	}

	@Override
	public double recall()
	{
		return recall[averagingMethod];
	}
	
	public String toString() {
		String description = "";

		description += "========Label Based Measures========\n";
		setAveragingMethod(LabelBasedEvaluation.MICRO);
		description += "MICRO\n";
		description += "Precision : " + this.precision() + "\n";
		description += "Recall    : " + this.recall() + "\n";
		description += "F1        : " + this.fmeasure() + "\n";
		setAveragingMethod(LabelBasedEvaluation.MACRO);
		description += "MACRO\n";
		description += "Precision : " + this.precision() + "\n";
		description += "Recall    : " + this.recall() + "\n";
		description += "F1        : " + this.fmeasure() + "\n";

		return description;
	}
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -