perdocumentf1evaluator.java

来自「mallet是自然语言处理、机器学习领域的一个开源项目。」· Java 代码 · 共 155 行

JAVA
155
字号
/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).   http://www.cs.umass.edu/~mccallum/mallet   This software is provided under the terms of the Common Public License,   version 1.0, as published by http://www.opensource.org.  For further   information, see the file `LICENSE' included with this distribution. */package edu.umass.cs.mallet.base.extract;import edu.umass.cs.mallet.base.types.Label;import edu.umass.cs.mallet.base.types.LabelAlphabet;import edu.umass.cs.mallet.base.types.MatrixOps;import java.io.PrintStream;import java.io.OutputStream;import java.io.PrintWriter;import java.io.OutputStreamWriter;import java.text.DecimalFormat;import java.util.Iterator;/** * Created: Oct 8, 2004 * * @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A> * @version $Id: PerDocumentF1Evaluator.java,v 1.9 2005/03/06 19:41:52 casutton Exp $ */public class PerDocumentF1Evaluator implements ExtractionEvaluator {  private FieldComparator comparator = new ExactMatchComparator ();  private PrintStream errorOutputStream = null;  public FieldComparator getComparator ()  {    return comparator;  }  public void setComparator (FieldComparator comparator)  {    this.comparator = comparator;  }  public PrintStream getErrorOutputStream ()  {    return errorOutputStream;  }  public void setErrorOutputStream (OutputStream errorOutputStream)  {    // Work around java bug when wrapping System.out    if (errorOutputStream instanceof PrintStream) {      this.errorOutputStream = (PrintStream) errorOutputStream;    } else {      this.errorOutputStream = new PrintStream (errorOutputStream);    }  }  public void evaluate (Extraction extraction)  {    evaluate (extraction, System.out);  }  public void evaluate (Extraction extraction, PrintStream out)  {    evaluate ("", extraction, new PrintWriter (new OutputStreamWriter (out), true));  }  public void evaluate (Extraction extraction, PrintWriter out)  {    evaluate ("", extraction, out);  }  // Assumes that there are as many records as documents, indexed by docs.  // Assumes that extractor returns at most one value  public void evaluate (String description, Extraction extraction, PrintWriter out)  {    int numDocs = extraction.getNumDocuments ();    assert numDocs == extraction.getNumRecords ();    LabelAlphabet dict = extraction.getLabelAlphabet();    int numLabels = dict.size();    int[] numCorr = new int [numLabels];    int[] numPred = new int [numLabels];    int[] numTrue = new int [numLabels];    for (int docnum = 0; docnum < numDocs; docnum++) {      Record extracted = extraction.getRecord (docnum);      Record target = extraction.getTargetRecord (docnum);      // Calc precision      Iterator it = extracted.fieldsIterator ();      while (it.hasNext ()) {        Field predField = (Field) it.next ();        Label name = predField.getName ();        Field trueField = target.getField (name);        int idx = name.getIndex ();        numPred [idx]++;        if (predField.numValues() > 1)          System.err.println ("Warning: Field "+predField+" has more than one extracted value. Picking arbitrarily...");        if (trueField != null && trueField.isValue (predField.value (0), comparator)) {          numCorr [idx]++;        } else {          // We have an error, report if necessary          if (errorOutputStream != null) {            //xxx TODO: Display name of supporting document            errorOutputStream.println ("Error in extraction! Document "+extraction.getDocumentExtraction (docnum).getName ());            errorOutputStream.println ("Predicted "+predField);            errorOutputStream.println ("True "+trueField);            errorOutputStream.println ();          }        }      }      // Calc true      it = target.fieldsIterator ();      while (it.hasNext ()) {        Field trueField = (Field) it.next ();        Label name = trueField.getName ();        numTrue [name.getIndex ()]++;      }    }    DecimalFormat f = new DecimalFormat ("0.####");    double totalF1 = 0;    int totalFields = 0;    out.println (description+" per-document F1");    out.println ("Name\tP\tR\tF1");    for (int i = 0; i < numLabels; i++) {      double P = (numPred[i] == 0) ? 0 : ((double)numCorr[i]) / numPred [i];      double R = (numTrue[i] == 0) ? 1 : ((double)numCorr[i]) / numTrue [i];      double F1 = (P + R == 0) ? 0 : (2 * P * R) / (P + R);      if ((numPred[i] > 0) || (numTrue[i] > 0)) {        totalF1 += F1;        totalFields++;      }      Label name = dict.lookupLabel (i);      out.println (name+"\t"+f.format(P)+"\t"+f.format(R)+"\t"+f.format(F1));    }    int totalCorr = MatrixOps.sum (numCorr);    int totalPred = MatrixOps.sum (numPred);    int totalTrue = MatrixOps.sum (numTrue);    double P = ((double)totalCorr) / totalPred;    double R = ((double)totalCorr) / totalTrue;    double F1 = (2 * P * R) / (P + R);    out.println ("OVERALL (micro-averaged) P="+f.format(P)+" R="+f.format(R)+" F1="+f.format(F1));    out.println ("OVERALL (macro-averaged) F1="+f.format(totalF1/totalFields));    out.println ();  }}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?