⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 ieevaluator.java

📁 这是一个matlab的java实现。里面有许多内容。请大家慢慢捉摸。
💻 JAVA
字号:
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).   http://www.cs.umass.edu/~mccallum/mallet   This software is provided under the terms of the Common Public License,   version 1.0, as published by http://www.opensource.org.  For further   information, see the file `LICENSE' included with this distribution. *//*    Fuchun Peng <a href="mailto:fuchun@cs.umass.edu">fuchun@cs.umass.edu</a>,  */package edu.umass.cs.mallet.projects.seg_plus_coref.ie;import edu.umass.cs.mallet.projects.seg_plus_coref.ie.IEInterface3;import edu.umass.cs.mallet.base.types.*;import edu.umass.cs.mallet.base.fst.*;import edu.umass.cs.mallet.base.pipe.*;import java.io.*;import java.util.logging.*;import java.util.regex.*;public class IEEvaluator extends TransducerEvaluator{	private static Logger logger = Logger.getLogger(IEEvaluator.class.getName());	int numIterationsToWait = 0;	int numIterationsToSkip = 0;	boolean alwaysEvaluateWhenFinished = true;	boolean printCrfAtEnd = true;	boolean checkpointTransducer = false;	String checkpointFilePrefix = null;	int checkpointIterationsToSkip = 30;	String viterbiOutputFilePrefix = null;	int viterbiOutputIterationsToWait = 10;	int viterbiOutputIterationsToSkip = 10;	String viterbiOutputEncoding = null;	private String PUNT = "[,\\.;:?!()*]";	private Pattern puntPattern = Pattern.compile(PUNT);		public IEEvaluator()	{		viterbiOutputFilePrefix = null;		viterbiOutputIterationsToSkip = 10;	}	public IEEvaluator(String viterbiOutputFilePrefix, int viterbiOutputIterationsToSkip)	{		this.viterbiOutputFilePrefix = viterbiOutputFilePrefix;		this.viterbiOutputIterationsToSkip = viterbiOutputIterationsToSkip;	}		public void printFeatures (InstanceList training)	{		Pipe p = training.getInstance(0).getPipe();		Alphabet targets = p.getTargetAlphabet();                assert(targets != null);		//initialize the variables		for (int i = 0; i < training.size(); i++) {			System.out.println ("Viterbi path for "+training+" instance #"+i);			Instance instance = training.getInstance(i);			Sequence input = (Sequence) instance.getData();			Sequence trueOutput = (Sequence) instance.getTarget();			assert (input.size() == trueOutput.size());			TokenSequence sourceTokenSequence = (TokenSequence)instance.getSource();			for (int j = 0; j < trueOutput.size(); j++) {				FeatureVector fv = (FeatureVector) input.get(j);				if (sourceTokenSequence != null){					System.out.print (sourceTokenSequence.getToken(j).getText()+": ");					System.out.println (trueOutput.get(j).toString()+ ": "+ fv.toString(true));				}			}		}	}			public void test(Transducer transducer, InstanceList data,                   String description, PrintStream viterbiOutputStream){	}	public boolean evaluate (Transducer crf, boolean finishedTraining, int iteration,													 boolean converged, double cost,													 InstanceList training, InstanceList validation, InstanceList testing)	{		System.out.println ("Iteration="+iteration+" Cost="+cost);		// Don't evaluate if it is too early in training to matter		if (iteration < numIterationsToWait && !(alwaysEvaluateWhenFinished && finishedTraining))			return true;		// Only evaluate every 5th iteration		if (numIterationsToWait > 0				&& iteration % numIterationsToSkip != 0				&& !(alwaysEvaluateWhenFinished && finishedTraining))			return true;		// Possibly write CRF to a file		if (checkpointTransducer && iteration > 0				&& iteration % checkpointIterationsToSkip == 0) {			String checkFilename = checkpointFilePrefix == null ? "" : checkpointFilePrefix + '.';			checkFilename = checkFilename + "checkpoint"+iteration+".crf";			if(crf instanceof CRF)				((CRF)crf).write (new File (checkFilename));			else if (crf instanceof CRF3)				((CRF3)crf).write (new File (checkFilename));		}		int numCorrectTokens, totalTokens;		int[] numTrueSegments, numPredictedSegments, numCorrectSegments;		int[] numCorrectSegmentsInVocabulary, numCorrectSegmentsOOV;		int[] numIncorrectSegmentsInVocabulary, numIncorrectSegmentsOOV;		int[][] matrixEntry;		InstanceList[] lists = new InstanceList[] {training, validation, testing};		String[] listnames = new String[] {"Training", "Validation", "Testing"};		TokenSequence sourceTokenSequence = null;		if(crf instanceof CRF)			((CRF)crf).write (new File ("tempCRF"));		else if(crf instanceof CRF3)			((CRF3)crf).write (new File ("tempCRF"));//              IEInterface3 ieInterface = new IEInterface3();//	        ieInterface.loadCRF((CRF3)crf);//             	ieInterface.viterbiCRF("wsj_0100.pos");			PrintStream viterbiOutputStream = null;		if ( (iteration >= viterbiOutputIterationsToWait && iteration % viterbiOutputIterationsToSkip == 0) 			|| (alwaysEvaluateWhenFinished && finishedTraining)		)		 {			if (viterbiOutputFilePrefix == null) {				viterbiOutputStream = System.out;			} else {				String viterbiFilename = null;				viterbiFilename = viterbiOutputFilePrefix + ".viterbi_"+iteration;				try {					FileOutputStream fos = new FileOutputStream (viterbiFilename);					if (viterbiOutputEncoding == null)						viterbiOutputStream = new PrintStream (fos);					else						viterbiOutputStream = new PrintStream (fos, true, viterbiOutputEncoding);				} catch (IOException e) {					logger.warning ("Couldn't open Viterbi output file '"+viterbiFilename+"'; continuing without Viterbi output trace.");					viterbiOutputStream = null;				}			}		}		// find out the vocabulary of targets//		Pipe p = lists[0].getInstance(0).getInstancePipe();		Pipe p = lists[0].getInstance(0).getPipe();		Alphabet targets = p.getTargetAlphabet();                assert(targets != null);		numTrueSegments = new int[targets.size()];		numPredictedSegments = new int[targets.size()];		numCorrectSegments = new int[targets.size()];		numCorrectSegmentsInVocabulary = new int[targets.size()];		numCorrectSegmentsOOV = new int[targets.size()];		numIncorrectSegmentsInVocabulary = new int[targets.size()];		numIncorrectSegmentsOOV = new int[targets.size()];		matrixEntry = new int[targets.size()][targets.size()];				int[] numCorrectWholeInstance = new int[lists.length];		for (int k = 0; k < lists.length; k++) {			if (lists[k] == null)				continue;							//initialize the variables			totalTokens = numCorrectTokens = 0;			for(int t=0; t<targets.size(); t++){				numTrueSegments[t] = numPredictedSegments[t] = numCorrectSegments[t] = 0;				numCorrectSegmentsInVocabulary[t] = numCorrectSegmentsOOV[t] = 0;				numIncorrectSegmentsInVocabulary[t] = numIncorrectSegmentsOOV[t] = 0;				for(int tt=0; tt< targets.size(); tt++){					matrixEntry[t][tt] = 0;				}			}			numCorrectWholeInstance[k] = 0;			for (int i = 0; i < lists[k].size(); i++) {				Instance instance = lists[k].getInstance(i);				Sequence input = (Sequence) instance.getData();				//String tokens = null;				//if (instance.getSource() != null)				//tokens = (String) instance.getSource().toString();				Sequence trueOutput = (Sequence) instance.getTarget();				assert (input.size() == trueOutput.size());				Transducer.ViterbiPath viterbiP = crf.viterbiPath(input);				Sequence predOutput = viterbiP.output();				double viterbi_cost = viterbiP.getCost();				double viterbi_p = Math.exp(-viterbi_cost/predOutput.size());				if (viterbiOutputStream != null)					viterbiOutputStream.println ("Viterbi path for "+listnames[k]+" instance #"+i + " p=" + viterbi_p);//				Sequence predOutput = crf.viterbiPath(input).output();				assert (predOutput.size() == trueOutput.size());				sourceTokenSequence = (TokenSequence)instance.getSource();				boolean wholeInstanceCorrect = true;				for (int j = 0; j < trueOutput.size(); j++) {									Object predO = predOutput.get(j);					Object trueO = trueOutput.get(j);					int predIndex = targets.lookupIndex(predO);					int trueIndex = targets.lookupIndex(trueO);									String tokenStr = sourceTokenSequence.getToken(j).getText();					if(puntPattern.matcher(tokenStr).matches()){//ignore punct;						continue;					}					totalTokens ++;					numTrueSegments[trueIndex] ++;					numPredictedSegments[predIndex] ++;					matrixEntry[trueIndex][predIndex] ++;					if(predIndex == trueIndex){						numCorrectTokens ++;						numCorrectSegments[trueIndex] ++;					}					else{				                                                // Here is an incorrect prediction, find out if the word is in the lexicon//						String sb = sourceTokenSequence.getToken(j).getText();//                                              if (HashFile.allLexicons.contains(sb) )//                              	                  numIncorrectSegmentsInVocabulary[trueIndex]++;//              	                        else//                      	                          numIncorrectSegmentsOOV[trueIndex]++;							wholeInstanceCorrect = false;					}						if (viterbiOutputStream != null) {						FeatureVector fv = (FeatureVector) input.get(j);						//viterbiOutputStream.println (tokens.charAt(j)+" "+trueOutput.get(j).toString()+						//'/'+predOutput.get(j).toString()+"  "+ fv.toString(true));						if (sourceTokenSequence != null)							viterbiOutputStream.print (sourceTokenSequence.getToken(j).getText()+": ");//						viterbiOutputStream.println (trueOutput.get(j).toString()+ '/'+predOutput.get(j).toString()+"  "+ fv.toString(true));						viterbiOutputStream.println (trueOutput.get(j).toString()+ '/'+predOutput.get(j).toString());					}				}				if(wholeInstanceCorrect) numCorrectWholeInstance[k] ++;			}			double accuracy = (double)numCorrectTokens/totalTokens;                        System.out.println ("\n" + listnames[k] +" accuracy=" + numCorrectTokens +"/"+ totalTokens + " = " +accuracy);			double wholeInstanceAccuracy = (double)numCorrectWholeInstance[k]/lists[k].size();			System.out.println ("Whole instance accuracy = " + numCorrectWholeInstance[k] + "/" + lists[k].size() + " = " + wholeInstanceAccuracy);			for(int t=0; t<targets.size(); t++){				double precision = numPredictedSegments[t] == 0 ? 1 : ((double)numCorrectSegments[t]) / numPredictedSegments[t];				double recall = numTrueSegments[t] == 0 ? 1 : ((double)numCorrectSegments[t]) / numTrueSegments[t];				double f1 = recall+precision == 0.0 ? 0.0 : (2.0 * recall * precision) / (recall + precision);				double accuracy_individual = (double)(totalTokens-numPredictedSegments[t]-numTrueSegments[t] + 2*numCorrectSegments[t] )/totalTokens;                                System.out.println (targets.lookupObject(t) + " precision="+precision+" recall="+recall+" f1="+f1 + " accuracy=" + accuracy_individual);				System.out.println ("segments true="+numTrueSegments[t]+" pred="+numPredictedSegments[t]+" correct="+numCorrectSegments[t]+" misses="+(numTrueSegments[t]-numCorrectSegments[t])+" alarms="+(numPredictedSegments[t]-numCorrectSegments[t]) + "\n");//				System.out.println ("correct segments OOV="+numCorrectSegmentsOOV[t]+" IV="+numCorrectSegmentsInVocabulary[t]);//				System.out.println ("incorrect segments OOV="+numIncorrectSegmentsOOV[t]+" IV="+numIncorrectSegmentsInVocabulary[t]);			}			System.out.println("\n Confusion Matrix (row: true label, col: predicted label)");			System.out.print("\t");			for(int t=0; t<targets.size(); t++){				System.out.print(targets.lookupObject(t) + "\t");			}			System.out.println();			for(int t=0; t< targets.size(); t++){				System.out.print(targets.lookupObject(t)+"\t");				for(int tt=0; tt<targets.size(); tt++){					System.out.print(matrixEntry[t][tt] + "\t");				}				System.out.println();			}			}		if (viterbiOutputStream != null && viterbiOutputFilePrefix != null && viterbiOutputStream != System.out)			viterbiOutputStream.close();		if (printCrfAtEnd && finishedTraining){//			if(crf instanceof CRF)//				((CRF)crf).print();//			else if(crf instanceof CRF3)//				((CRF3)crf).print();			System.out.println("Finished!");		}				return true;	}	}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -