⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 constrainedviterbitransducercorrector.java

📁 常用机器学习算法,java编写源代码,内含常用分类算法,包括说明文档
💻 JAVA
字号:
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).   http://www.cs.umass.edu/~mccallum/mallet   This software is provided under the terms of the Common Public License,   version 1.0, as published by http://www.opensource.org.  For further   information, see the file `LICENSE' included with this distribution. *//** 		@author Aron Culotta <a href="mailto:culotta@cs.umass.edu">culotta@cs.umass.edu</a>*/package edu.umass.cs.mallet.base.fst.confidence;import edu.umass.cs.mallet.base.types.*;import edu.umass.cs.mallet.base.fst.*;import edu.umass.cs.mallet.base.util.MalletLogger;import java.util.logging.*;import java.util.ArrayList;/** * Corrects a subset of the {@link Segment}s produced by a {@link * Transducer}. It's most useful to find the {@link Segment}s that the * {@link Transducer} is least confident in and correct those using * the true {@link Labeling} * (<code>correctLeastConfidenceSegments</code>).  The corrected * segment then propagates to other labelings in the sequence using * "constrained viterbi" -- a viterbi calculation that requires the * path to pass through the corrected segment states. */public class ConstrainedViterbiTransducerCorrector implements TransducerCorrector{	private static Logger logger = MalletLogger.getLogger(ConstrainedViterbiTransducerCorrector.class.getName());	TransducerConfidenceEstimator confidenceEstimator;	Transducer model;	ArrayList leastConfidentSegments;		public ConstrainedViterbiTransducerCorrector (TransducerConfidenceEstimator confidenceEstimator,															Transducer model) {		this.confidenceEstimator = confidenceEstimator;		this.model = model;	}	public ConstrainedViterbiTransducerCorrector (Transducer model) {		this (new ConstrainedForwardBackwardConfidenceEstimator (model), model);	}		public java.util.Vector getSegmentConfidences () {return confidenceEstimator.getSegmentConfidences();}	/**		 Returns the least confident segments from each sequence in the		 previous call to <code>correctLeastConfidentSegments</code>	 */	public ArrayList getLeastConfidentSegments () {		return this.leastConfidentSegments;	}	/**		 Returns the least confident segments in <code>ilist</code>		 @param ilist test instances		 @param startTags indicate the beginning of segments		 @param continueTages indicate "inside" of segments		 @return list of {@link Segment}s, one for each instance, that is least confident	*/	public ArrayList getLeastConfidentSegments (InstanceList ilist, Object[] startTags, Object[] continueTags) {		ArrayList ret = new ArrayList ();		for (int i=0; i < ilist.size(); i++) {			Segment[] orderedSegments = confidenceEstimator.rankSegmentsByConfidence (				ilist.getInstance (i), startTags, continueTags);			ret.add (orderedSegments[0]);		}		return ret;	}			public ArrayList correctLeastConfidentSegments (InstanceList ilist,																									Object[] startTags,																									Object[] continueTags) {		return correctLeastConfidentSegments (ilist, startTags, continueTags, false);	}			/**		 Returns an ArrayList of corrected Sequences.  Also stores		 leastConfidentSegments, an ArrayList of the segments to correct,		 where null entries mean no segment was corrected for that		 sequence.		 @param ilist test instances		 @param startTags indicate the beginning of segments		 @param continueTages indicate "inside" of segments		 @param findIncorrect true if we should cycle through least		 confident segments until find an incorrect one		 @return list of {@link Sequence}s corresponding to the corrected		 tagging of each instance in <code>ilist</code>	*/	public ArrayList correctLeastConfidentSegments (InstanceList ilist, Object[] startTags,																										Object[] continueTags, boolean findIncorrect) {		ArrayList correctedPredictionList = new ArrayList ();		this.leastConfidentSegments = new ArrayList ();		logger.info (this.getClass().getName() + " ranking confidence using " +								 confidenceEstimator.getClass().getName());		for (int i=0; i < ilist.size(); i++) {			logger.info ("correcting instance# " + i + " / " + ilist.size());			Instance instance = ilist.getInstance (i);			Segment[] orderedSegments = new Segment[1];			Sequence input = (Sequence) instance.getData ();			Sequence truth = (Sequence) instance.getTarget ();			Sequence predicted = model.viterbiPath (input).output ();			int numIncorrect = 0;			for (int j=0; j < predicted.size(); j++) 				numIncorrect += (!predicted.get(j).equals (truth.get(j))) ? 1 : 0;			if (numIncorrect == 0) { // nothing to correct				this.leastConfidentSegments.add (null);				correctedPredictionList.add (predicted);				continue;							}			// rank segments by confidence			orderedSegments = confidenceEstimator.rankSegmentsByConfidence (				instance, startTags, continueTags);			logger.fine ("Ordered Segments:\n");			for (int j=0; j < orderedSegments.length; j++) {				logger.fine (orderedSegments[j].toString());			}			logger.fine ("Correcting Segment: True Sequence:");						for (int j=0; j < truth.size(); j++)				logger.fine ((String)truth.get (j) + "\t");			logger.fine ("");			logger.fine ("Ordered Segments:\n");			for (int j=0; j < orderedSegments.length; j++) {				logger.fine (orderedSegments[j].toString());			}			// if <code>findIncorrect</code>, find the least confident			// segment that is incorrectly labeled			// else, use least confident segment			Segment leastConfidentSegment = orderedSegments[0];			if (findIncorrect) {				for (int j=0; j < orderedSegments.length; j++) {					if (!orderedSegments[j].correct()) {						leastConfidentSegment = orderedSegments[j];						break;					}				}			}			if (findIncorrect && leastConfidentSegment.correct()) {				logger.warning ("cannot find incorrect segment, probably because error is in background state\n");				this.leastConfidentSegments.add (null);				correctedPredictionList.add (predicted);				continue;							}			this.leastConfidentSegments.add (leastConfidentSegment);			if (leastConfidentSegment == null) { // nothing extracted				correctedPredictionList.add (predicted);				continue;			}							// create segmentCorrectedOutput, which has the true labels for			// the leastConfidentSegment and null for other positions			String[] sequence = new String[truth.size()];			int numCorrectedTokens = 0;			for (int j=0; j < sequence.length; j++)				sequence[j] = null;			for (int j=0; j < truth.size(); j++) {				// if in segment				if (leastConfidentSegment.indexInSegment (j)) {					sequence[j] = (String)truth.get (j);					numCorrectedTokens++;				}			}			if (leastConfidentSegment.endsPrematurely ()) {				sequence[leastConfidentSegment.getEnd()+1] =					(String)truth.get (leastConfidentSegment.getEnd()+1);				numCorrectedTokens++;			}			logger.fine ("Constrained Segment Sequence\n");			for (int j=0; j < sequence.length; j++) { 				logger.fine (sequence[j]); 			}			ArraySequence segmentCorrectedOutput = new ArraySequence (sequence);			// run constrained viterbi on this sequence with the			// constraint that this segment is tagged correctly			Sequence correctedPrediction = model.viterbiPath (				orderedSegments[0].getInput (),	segmentCorrectedOutput).output();			int numIncorrectAfterCorrection = 0;			for (int j=0; j < truth.size(); j++) 				numIncorrectAfterCorrection += (!correctedPrediction.get(j).equals (truth.get(j))) ? 1 : 0;			logger.fine ("Num incorrect tokens in original prediction: " + numIncorrect);			logger.fine ("Num corrected tokens: " + numCorrectedTokens);			logger.fine ("Num incorrect tokens after correction-propagation: " + numIncorrectAfterCorrection);			// print sequence info			logger.fine ("Correcting Segment: True Sequence:");						for (int j=0; j < truth.size(); j++)				logger.fine ((String)truth.get (j) + "\t");			logger.fine ("\nOriginal prediction: ");						for (int j=0; j < predicted.size(); j++)				logger.fine ((String)predicted.get (j) + "\t");			logger.fine ("\nCorrected prediction: ");						for (int j=0; j < correctedPrediction.size(); j++)				logger.fine ((String)correctedPrediction.get (j) + "\t");			logger.fine ("");			correctedPredictionList.add (correctedPrediction);		}		return correctedPredictionList;	}}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -