⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 addclassifiertokenpredictions.java

📁 mallet是自然语言处理、机器学习领域的一个开源项目。
💻 JAVA
字号:
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).   http://www.cs.umass.edu/~mccallum/mallet   This software is provided under the terms of the Common Public License,   version 1.0, as published by http://www.opensource.org.  For further   information, see the file `LICENSE' included with this distribution. */package edu.umass.cs.mallet.base.pipe;import java.io.IOException;import java.io.ObjectInputStream;import java.io.ObjectOutputStream;import java.io.Serializable;import java.util.HashMap;import java.util.logging.Logger;import edu.umass.cs.mallet.base.classify.BalancedWinnowTrainer;import edu.umass.cs.mallet.base.classify.Classification;import edu.umass.cs.mallet.base.classify.Classifier;import edu.umass.cs.mallet.base.classify.ClassifierTrainer;import edu.umass.cs.mallet.base.classify.Trial;import edu.umass.cs.mallet.base.types.Alphabet;import edu.umass.cs.mallet.base.types.AugmentableFeatureVector;import edu.umass.cs.mallet.base.types.FeatureVector;import edu.umass.cs.mallet.base.types.FeatureVectorSequence;import edu.umass.cs.mallet.base.types.Instance;import edu.umass.cs.mallet.base.types.InstanceList;import edu.umass.cs.mallet.base.types.Label;import edu.umass.cs.mallet.base.types.LabelSequence;import edu.umass.cs.mallet.base.types.LabelVector;import edu.umass.cs.mallet.base.types.Labeling;import edu.umass.cs.mallet.base.util.MalletLogger;/** * This pipe uses a Classifier to label each token (i.e., using 0-th order Markov assumption),  * then adds the predictions as features to each token. *  * This pipe assumes the input Instance's data is of type FeatureVectorSequence  * (each an augmentable feature vector).  *  * Example usage:<pre> * 		1) Create and serialize a featurePipe that converts raw input to FeatureVectorSequences * 		2) Pipe input data through featurePipe, train a TokenClassifiers via cross validation, then serialize the classifiers * 		2) Pipe input data through featurePipe and this pipe (using the saved classifiers), and train a Transducer  * 		4) Serialize the trained Transducer  * </pre> * @author ghuang */public class AddClassifierTokenPredictions extends Pipe implements Serializable {	private static Logger logger = MalletLogger.getLogger(AddClassifierTokenPredictions.class.getName());		// Specify which predictions are to be added as features.  	// E.g., { 1, 2 } = add labels of the top 2 highest-scoring predictions as features.	int[] m_predRanks2add;		// The trained token classifier 	TokenClassifiers m_tokenClassifiers; 	// Whether to treat each instance's feature values as binary 	boolean m_binary;	// Whether the pipe is currently being used at production time 	// (i.e., not being used as pipeline for training a transducer)   	boolean m_inProduction;	// Augmented data alphabet that includes the class predictions	Alphabet m_dataAlphabet;			public AddClassifierTokenPredictions(InstanceList trainList)	{		this(trainList, null);	}		public AddClassifierTokenPredictions(InstanceList trainList, InstanceList testList)	{		this(new TokenClassifiers(convert(trainList, (Noop) trainList.getPipe())), new int[] { 1 }, true, 					convert(testList, (Noop) trainList.getPipe()));	}			public AddClassifierTokenPredictions(TokenClassifiers tokenClassifiers, int[] predRanks2add, 			boolean binary, InstanceList testList)	{		m_predRanks2add = predRanks2add;		m_binary = binary;		m_tokenClassifiers = tokenClassifiers;		m_inProduction = false;		m_dataAlphabet = (Alphabet) tokenClassifiers.getAlphabet().clone();		Alphabet labelAlphabet = tokenClassifiers.getLabelAlphabet();				// add the token prediction features to the alphabet		for (int i = 0; i < m_predRanks2add.length; i++) {			for (int j = 0; j < labelAlphabet.size(); j++) {				String featName = "TOK_PRED=" + labelAlphabet.lookupObject(j).toString() + "_@_RANK_" + m_predRanks2add[i];				m_dataAlphabet.lookupIndex(featName, true);			}		}				// evaluate token classifier  		if (testList != null) {			Trial trial = new Trial(m_tokenClassifiers, testList);			logger.info("Token classifier accuracy on test set = " + trial.accuracy());		}	}		public void setInProduction(boolean inProduction) { m_inProduction = inProduction; }	public boolean getInProduction() { return m_inProduction; }	public static void setInProduction(Pipe p, boolean value)	{		if (p instanceof AddClassifierTokenPredictions) 			((AddClassifierTokenPredictions) p).setInProduction(value);		else if (p instanceof SerialPipes) {			SerialPipes sp = (SerialPipes) p;			for (int i = 0; i < sp.size(); i++)				setInProduction(sp.getPipe(i), value);		}	}		public Alphabet getDataAlphabet() { return m_dataAlphabet; }		/**	 * Add the token classifier's predictions as features to the instance.	 * This method assumes the input instance contains FeatureVectorSequence as data  	 */	public Instance pipe(Instance carrier) 	{		FeatureVectorSequence fvs = (FeatureVectorSequence) carrier.getData();		InstanceList ilist = convert(carrier, (Noop) m_tokenClassifiers.getInstancePipe());		assert (fvs.size() == ilist.size());		// For passing instances to the token classifier, each instance's data alphabet needs to 		// match that used by the token classifier at training time.  For the resulting piped 		// instance, each instance's data alphabet needs to contain token classifier's prediction 		// as features 		FeatureVector[] fva = new FeatureVector[fvs.size()];		for (int i = 0; i < ilist.size(); i++) {			Instance inst = ilist.getInstance(i);			Classification c = m_tokenClassifiers.classify(inst, ! m_inProduction);			LabelVector lv = c.getLabelVector();			AugmentableFeatureVector afv1 = (AugmentableFeatureVector) inst.getData();			int[] indices = afv1.getIndices();			AugmentableFeatureVector afv2 = new AugmentableFeatureVector(m_dataAlphabet, 					indices, afv1.getValues(), indices.length + m_predRanks2add.length);			for (int j = 0; j < m_predRanks2add.length; j++) {				Label label = lv.getLabelAtRank(m_predRanks2add[j]);				int idx = m_dataAlphabet.lookupIndex("TOK_PRED=" + label.toString() + "_@_RANK_" + m_predRanks2add[j]);				assert(idx >= 0);				afv2.add(idx, 1);			}			fva[i] = afv2; 		}		carrier.setData(new FeatureVectorSequence(fva));		return carrier;	}	/**	 * Converts each instance containing a FeatureVectorSequence to multiple instances, 	 * each containing an AugmentableFeatureVector as data.  	 *  	 * @param ilist Instances with FeatureVectorSequence as data field	 * @param alphabetsPipe a Noop pipe containing the data and target alphabets for the resulting InstanceList 	 * @return an InstanceList where each Instance contains one Token's AugmentableFeatureVector as data 	 */	public static InstanceList convert(InstanceList ilist, Noop alphabetsPipe)	{		if (ilist == null) return null;				// This monstrosity is necessary b/c Classifiers obtain the data/target alphabets via pipes		InstanceList ret = new InstanceList(alphabetsPipe);		for (int i = 0; i < ilist.size(); i++) {			ret.add(convert(ilist.getInstance(i), alphabetsPipe));		}		return ret;	}	/**	 * 	 * @param inst input instance, with FeatureVectorSequence as data.	 * @param alphabetsPipe a Noop pipe containing the data and target alphabets for 	 * the resulting InstanceList and AugmentableFeatureVectors	 * @return list of instances, each with one AugmentableFeatureVector as data	 */	public static InstanceList convert(Instance inst, Noop alphabetsPipe)	{		InstanceList ret = new InstanceList(alphabetsPipe);		Object obj = inst.getData();		assert(obj instanceof FeatureVectorSequence);		FeatureVectorSequence fvs = (FeatureVectorSequence) obj;		LabelSequence ls = (LabelSequence) inst.getTarget();		assert(fvs.size() == ls.size());		Object instName = (inst.getName() == null ? "NONAME" : inst.getName());				for (int j = 0; j < fvs.size(); j++) {			FeatureVector fv = fvs.getFeatureVector(j);			int[] indices = fv.getIndices();			FeatureVector data = new AugmentableFeatureVector (alphabetsPipe.getDataAlphabet(),					indices, fv.getValues(), indices.length); 			Labeling target = ls.getLabelAtPosition(j);			String name = instName.toString() + "_@_POS_" + (j + 1);			Object source = inst.getSource();			Instance toAdd = new Instance(data, target, name, source, alphabetsPipe);			ret.add(toAdd);		}		return ret;	}	// Serialization 	private static final long serialVersionUID = 1;	/**	 * This inner class represents the trained token classifiers.	 * @author ghuang	 */	public static class TokenClassifiers extends Classifier implements Serializable	{		// number of folds in cross-validation training 		int m_numCV;		// random seed to split training data for cross-validation		int m_randSeed;				// trainer for token classifier		ClassifierTrainer m_trainer;				// token classifier trained on the entirety of the training set		Classifier m_tokenClassifier;				// table storing instance name -->  out-of-fold classifier 		// Used to prevent overfitting to the token classifier's predictions		HashMap m_table;				/**		 * Train a token classifier using the given Instances with 5-fold cross validation		 * @param trainList training instances		 */		public TokenClassifiers(InstanceList trainList)		{			this(trainList, 0, 5);		}						public TokenClassifiers(InstanceList trainList, int randSeed, int numCV)		{//			this(new AdaBoostM2Trainer(new DecisionTreeTrainer(2), 10), trainList, randSeed, numCV);//			this(new NaiveBayesTrainer(), trainList, randSeed, numCV);			this(new BalancedWinnowTrainer(), trainList, randSeed, numCV);//			this(new SVMTrainer(), trainList, randSeed, numCV);		}						public TokenClassifiers(ClassifierTrainer trainer, InstanceList trainList, int randSeed, int numCV)		{			super(trainList.getPipe());			m_trainer = trainer;			m_randSeed = randSeed;			m_numCV = numCV;			m_table = new HashMap();			doTraining(trainList);		}		// train the token classifier		private void doTraining(InstanceList trainList)		{			// train a classifier on the entire training set			logger.info("Training token classifier on entire data set (size=" + trainList.size() + ")...");			m_tokenClassifier = m_trainer.train(trainList);			Trial t = new Trial(m_tokenClassifier, trainList);			logger.info("Training set accuracy = " + t.accuracy());						if (m_numCV == 0)				return;			// train classifiers using cross validation			InstanceList.CrossValidationIterator cvIter = trainList.new CrossValidationIterator(m_numCV, m_randSeed);			int f = 1;			while (cvIter.hasNext()) {				f++;				InstanceList[] fold = cvIter.nextSplit();				logger.info("Training token classifier on cv fold " + f + " / " + m_numCV + " (size=" + fold[0].size() + ")...");								Classifier foldClassifier = m_trainer.train(fold[0]);				Trial t1 = new Trial(foldClassifier, fold[0]);				Trial t2 = new Trial(foldClassifier, fold[1]);				logger.info("Within-fold accuracy = " + t1.accuracy());				logger.info("Out-of-fold accuracy = " + t2.accuracy());				/*for (int x = 0; x < t2.size(); x++) {					logger.info("xxx pred:" + t2.getClassification(x).getLabeling().getBestLabel() + " true:" + t2.getClassification(x).getInstance().getLabeling());				}*/								for (int i = 0; i < fold[1].size(); i++) {					Instance inst = fold[1].getInstance(i);					m_table.put(inst.getName(), foldClassifier);				}			}		}		public Classification classify(Instance instance)		{			return classify(instance, false);		}		/**		 * 		 * @param instance the instance to classify		 * @param useOutOfFold whether to check the instance name and use the out-of-fold classifier		 * if the instance name matches one in the training data		 * @return the token classifier's output		 */		public Classification classify(Instance instance, boolean useOutOfFold)		{			Object instName = instance.getName();						if (! useOutOfFold || ! m_table.containsKey(instName))				return m_tokenClassifier.classify(instance);						Classifier classifier = (Classifier) m_table.get(instName);			return classifier.classify(instance);		}		// serialization		private static final long serialVersionUID = 1;		private static final int CURRENT_SERIAL_VERSION = 1;				private void writeObject(ObjectOutputStream out) throws IOException		{			out.writeInt(CURRENT_SERIAL_VERSION);			out.writeObject(getInstancePipe());			out.writeInt(m_numCV);			out.writeInt(m_randSeed);			out.writeObject(m_table);			out.writeObject(m_tokenClassifier);			out.writeObject(m_trainer);		}				private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {			int version = in.readInt();			if (version != CURRENT_SERIAL_VERSION)				throw new ClassNotFoundException("Mismatched TokenClassifiers versions: wanted " +						CURRENT_SERIAL_VERSION + ", got " +						version);			instancePipe = (Pipe) in.readObject();			m_numCV = in.readInt();			m_randSeed = in.readInt();			m_table = (HashMap) in.readObject();			m_tokenClassifier = (Classifier) in.readObject();			m_trainer = (ClassifierTrainer) in.readObject();		}	}}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -