📄 simpletagger.java
字号:
/* Copyright (C) 2003 University of Pennsylvania. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. *//** @author Fernando Pereira <a href="mailto:pereira@cis.upenn.edu">pereira@cis.upenn.edu</a> */package edu.umass.cs.mallet.base.fst;import edu.umass.cs.mallet.base.types.*;import edu.umass.cs.mallet.base.fst.*;import edu.umass.cs.mallet.base.minimize.*;import edu.umass.cs.mallet.base.minimize.tests.*;import edu.umass.cs.mallet.base.pipe.*;import edu.umass.cs.mallet.base.pipe.iterator.*;import edu.umass.cs.mallet.base.pipe.tsf.*;import edu.umass.cs.mallet.base.util.*;import junit.framework.*;import java.util.Iterator;import java.util.Random;import java.util.regex.*;import java.util.logging.*;import java.io.*;/** * This class's main method trains, tests, or runs a generic CRF-based * sequence tagger. The tagger expects instances in the form required * by {@link SimpleTaggerSentence2FeatureSequence}. A variety of * command line options control the operation of the main program, as * described in the comments for {@link #main main}. * * @author <a href="mailto:pereira@cis.upenn.edu">Fernando Pereira</a> * @version 1.0 */public class SimpleTagger{ private static Logger logger = MalletLogger.getLogger(SimpleTagger.class.getName()); /** * No <code>SimpleTagger</code> objects allowed. * */ private SimpleTagger() { }/** * Converts an external encoding of a sequence of elements with binary * features to a {@link FeatureVectorSequence}. If target processing * is on, extracts element labels from the external encoding to create * a target {@link LabelSequence}. Two external encodings are * supported: * * 1) A {@link String} containing lines of whitespace-separated tokens. * 2) a {@link String}[][] * * Both represent rows of tokens, If target processing, the last token * in each row is the label of the sequence element represented by * this row. All other tokens in the row, or all tokens in the row if * not target processing, are the names of features that are on for * the row's element. * */ public static class SimpleTaggerSentence2FeatureVectorSequence extends Pipe { /** * Creates a new * <code>SimpleTaggerSentence2FeatureVectorSequence</code> instance. */ public SimpleTaggerSentence2FeatureVectorSequence () { super (Alphabet.class, LabelAlphabet.class); } /** * Parses a string representing a sequence of rows of tokens into an * array of arrays of tokens. * * @param sentence a <code>String</code> * @return the corresponding array of arrays of tokens. */ private String[][] parseSentence(String sentence) { String[] lines = sentence.split("\n"); String[][] tokens = new String[lines.length][]; for (int i = 0; i < lines.length; i++) tokens[i] = lines[i].split(" "); return tokens; } public Instance pipe (Instance carrier) { Object inputData = carrier.getData(); Alphabet features = getDataAlphabet(); LabelAlphabet labels; LabelSequence target = null; String [][] tokens; if (inputData instanceof String) tokens = parseSentence((String)inputData); else if (inputData instanceof String[][]) tokens = (String[][])inputData; else throw new IllegalArgumentException("Not a String or String[][]"); FeatureVector[] fvs = new FeatureVector[tokens.length]; if (isTargetProcessing()) { labels = (LabelAlphabet)getTargetAlphabet(); target = new LabelSequence (labels, tokens.length); } for (int l = 0; l < tokens.length; l++) { int nFeatures; if (isTargetProcessing()) { if (tokens[l].length < 1) throw new IllegalStateException ("Missing label at line " + l); nFeatures = tokens[l].length - 1; target.add(tokens[l][nFeatures]); } else nFeatures = tokens[l].length; int featureIndices[] = new int[nFeatures]; for (int f = 0; f < nFeatures; f++) featureIndices[f] = features.lookupIndex(tokens[l][f]); fvs[l] = new FeatureVector(features, featureIndices); } carrier.setData(new FeatureVectorSequence(fvs)); if (isTargetProcessing()) carrier.setTarget(target); return carrier; } } private static final CommandOption.Double gaussianVarianceOption = new CommandOption.Double (SimpleTagger.class, "gaussian-variance", "DECIMAL", true, 10.0, "The gaussian prior variance used for training.", null); private static final CommandOption.Boolean trainOption = new CommandOption.Boolean (SimpleTagger.class, "train", "true|false", true, false, "Whether to train", null); private static final CommandOption.String testOption = new CommandOption.String (SimpleTagger.class, "test", "lab or seg=start-1.continue-1,...,start-n.continue-n", true, null, "Test measuring labeling or segmentation (start-i, continue-i) accuracy", null); private static final CommandOption.File modelOption = new CommandOption.File (SimpleTagger.class, "model-file", "FILENAME", true, null, "The filename for reading (train/run) or saving (train) the model.", null); private static final CommandOption.Double trainingFractionOption = new CommandOption.Double (SimpleTagger.class, "training-proportion", "DECIMAL", true, 0.5, "Fraction of data to use for training in a random split.", null); private static final CommandOption.Integer randomSeedOption = new CommandOption.Integer (SimpleTagger.class, "random-seed", "INTEGER", true, 0, "The random seed for randomly selecting a proportion of the instance list for training", null); private static final CommandOption.IntegerArray ordersOption = new CommandOption.IntegerArray (SimpleTagger.class, "orders", "COMMA-SEP-DECIMALS", true, new int[]{1}, "List of label Markov orders (main and backoff) ", null); private static final CommandOption.String forbiddenOption = new CommandOption.String( SimpleTagger.class, "forbidden", "REGEXP", true, "\\s", "label1,label2 transition forbidden if it matches this", null); private static final CommandOption.String allowedOption = new CommandOption.String( SimpleTagger.class, "allowed", "REGEXP", true, ".*", "label1,label2 transition allowed only if it matches this", null); private static final CommandOption.String defaultOption = new CommandOption.String( SimpleTagger.class, "default-label", "STRING", true, "O", "Label for initial context and uninteresting tokens", null); private static final CommandOption.Integer iterationsOption = new CommandOption.Integer( SimpleTagger.class, "iterations", "INTEGER", true, 500, "Number of training iterations", null); private static final CommandOption.Boolean viterbiOutputOption = new CommandOption.Boolean( SimpleTagger.class, "viterbi-output", "true|false", true, false, "Print Viterbi periodically during training", null); private static final CommandOption.Boolean connectedOption = new CommandOption.Boolean( SimpleTagger.class, "fully-connected", "true|false", true, true, "Include all allowed transitions, even those not in training data", null); private static final CommandOption.Boolean continueTrainingOption = new CommandOption.Boolean( SimpleTagger.class, "continue-training", "true|false", false, false, "Continue training from model specified by --model-file", null); private static final CommandOption.List commandOptions = new CommandOption.List ( "Training, testing and running a generic tagger.", new CommandOption[] { gaussianVarianceOption, trainOption, iterationsOption, testOption, trainingFractionOption, modelOption, randomSeedOption, ordersOption, forbiddenOption, allowedOption, defaultOption, viterbiOutputOption, connectedOption, continueTrainingOption, }); /** * Create and train a CRF model from the given training data, * optionally testing it on the given test data. * * @param training training data * @param testing test data (possibly <code>null</code>) * @param eval accuracy evaluator (possibly <code>null</code>) * @param orders label Markov orders (main and backoff) * @param defaultLabel default label * @param forbidden regular expression specifying impossible label * transitions <em>current</em><code>,</code><em>next</em> * (<code>null</code> indicates no forbidden transitions) * @param allowed regular expression specifying allowed label transitions * (<code>null</code> indicates everything is allowed that is not forbidden) * @param connected whether to include even transitions not * occurring in the training data. * @param iterations number of traning iterations * @param var Gaussian prior variance * @return the trained model */ public static CRF4 train(InstanceList training, InstanceList testing, TransducerEvaluator eval, int[] orders, String defaultLabel, String forbidden, String allowed, boolean connected, int iterations, double var, CRF4 crf) { Pattern forbiddenPat = Pattern.compile(forbidden); Pattern allowedPat = Pattern.compile(allowed); if (crf == null) { crf = new CRF4(training.getPipe(), null); String startName = crf.addOrderNStates(training, orders, null, defaultLabel, forbiddenPat, allowedPat, connected); crf.setGaussianPriorVariance (var); for (int i = 0; i < crf.numStates(); i++) crf.getState(i).setInitialCost (Double.POSITIVE_INFINITY); crf.getState(startName).setInitialCost (0.0); } logger.info("Training on " + training.size() + " instances"); if (testing != null) logger.info("Testing on " + testing.size() + " instances"); crf.train (training, null, testing, eval, iterations); return crf; }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -