⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 hmm.java

📁 mallet是自然语言处理、机器学习领域的一个开源项目。
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
    String sep = "";    StringBuffer buf = new StringBuffer();    for (int i = 0; i < labels.length; i++)    {      buf.append(sep).append(labels[i]);      sep = LABEL_SEPARATOR;    }    return buf.toString();  }    private String nextKGram(String[] history, int k, String next)  {    String sep = "";    StringBuffer buf = new StringBuffer();    int start = history.length + 1 - k;    for (int i = start; i < history.length; i++)    {      buf.append(sep).append(history[i]);      sep = LABEL_SEPARATOR;    }    buf.append(sep).append(next);    return buf.toString();  }    private boolean allowedTransition(String prev, String curr,                                    Pattern no, Pattern yes)  {    String pair = concatLabels(new String[]{prev, curr});    if (no != null && no.matcher(pair).matches())      return false;    if (yes != null && !yes.matcher(pair).matches())      return false;    return true;  }      private boolean allowedHistory(String[] history, Pattern no, Pattern yes) {    for (int i = 1; i < history.length; i++)      if (!allowedTransition(history[i-1], history[i], no, yes))        return false;    return true;  }  /**   * Assumes that the HMM's output alphabet contains   * <code>String</code>s. Creates an order-<em>n</em> HMM with input   * predicates and output labels given by <code>trainingSet</code>   * and order, connectivity, and weights given by the remaining   * arguments.   *   * @param trainingSet the training instances   * @param orders an array of increasing non-negative numbers giving   * the orders of the features for this HMM. The largest number   * <em>n</em> is the Markov order of the HMM. States are   * <em>n</em>-tuples of output labels. Each of the other numbers   * <em>k</em> in <code>orders</code> represents a weight set shared   * by all destination states whose last (most recent) <em>k</em>   * labels agree. If <code>orders</code> is <code>null</code>, an   * order-0 HMM is built.   * @param defaults If non-null, it must be the same length as   * <code>orders</code>, with <code>true</code> positions indicating   * that the weight set for the corresponding order contains only the   * weight for a default feature; otherwise, the weight set has   * weights for all features built from input predicates.   * @param start The label that represents the context of the start of   * a sequence. It may be also used for sequence labels.   * @param forbidden If non-null, specifies what pairs of successive   * labels are not allowed, both for constructing <em>n</em>order   * states or for transitions. A label pair (<em>u</em>,<em>v</em>)   * is not allowed if <em>u</em> + "," + <em>v</em> matches   * <code>forbidden</code>.   * @param allowed If non-null, specifies what pairs of successive   * labels are allowed, both for constructing <em>n</em>order   * states or for transitions. A label pair (<em>u</em>,<em>v</em>)   * is allowed only if <em>u</em> + "," + <em>v</em> matches   * <code>allowed</code>.   * @param fullyConnected Whether to include all allowed transitions,   * even those not occurring in <code>trainingSet</code>,   * @returns The name of the start state.   *    */  public String addOrderNStates(InstanceList trainingSet, int[] orders,                                boolean[] defaults, String start,                                Pattern forbidden, Pattern allowed,                                boolean fullyConnected)  {    boolean[][] connections = null;    if (!fullyConnected)      connections = labelConnectionsIn (trainingSet);    int order = -1;    if (defaults != null && defaults.length != orders.length)      throw new IllegalArgumentException("Defaults must be null or match orders");    if (orders == null)      order = 0;    else    {      for (int i = 0; i < orders.length; i++)        if (orders[i] <= order)          throw new IllegalArgumentException("Orders must be non-negative and in ascending order");        else           order = orders[i];      if (order < 0) order = 0;    }    if (order > 0)    {      int[] historyIndexes = new int[order];      String[] history = new String[order];      String label0 = (String)outputAlphabet.lookupObject(0);      for (int i = 0; i < order; i++)        history[i] = label0;      int numLabels = outputAlphabet.size();      while (historyIndexes[0] < numLabels)      {        logger.info("Preparing " + concatLabels(history));        if (allowedHistory(history, forbidden, allowed))        {          String stateName = concatLabels(history);          int nt = 0;          String[] destNames = new String[numLabels];          String[] labelNames = new String[numLabels];          for (int nextIndex = 0; nextIndex < numLabels; nextIndex++)          {            String next = (String)outputAlphabet.lookupObject(nextIndex);            if (allowedTransition(history[order-1], next, forbidden, allowed)                && (fullyConnected ||                    connections[historyIndexes[order-1]][nextIndex]))            {              destNames[nt] = nextKGram(history, order, next);              labelNames[nt] = next;              nt++;            }          }          if (nt < numLabels)          {            String[] newDestNames = new String[nt];            String[] newLabelNames = new String[nt];            for (int t = 0; t < nt; t++)            {              newDestNames[t] = destNames[t];              newLabelNames[t] = labelNames[t];            }            destNames = newDestNames;            labelNames = newLabelNames;          }          addState (stateName, 0.0, 0.0, destNames, labelNames);        }        for (int o = order-1; o >= 0; o--)           if (++historyIndexes[o] < numLabels)          {            history[o] = (String)outputAlphabet.lookupObject(historyIndexes[o]);            break;          } else if (o > 0)          {            historyIndexes[o] = 0;            history[o] = label0;          }      }      for (int i = 0; i < order; i++)        history[i] = start;      return concatLabels(history);    }    else    {      String[] stateNames = new String[outputAlphabet.size()];      for (int s = 0; s < outputAlphabet.size(); s++)        stateNames[s] = (String)outputAlphabet.lookupObject(s);      for (int s = 0; s < outputAlphabet.size(); s++)        addState(stateNames[s], 0.0, 0.0, stateNames, stateNames);      return start;    }  }	public State getState (String name)	{		return (State) name2state.get(name);	}		public int numStates () { return states.size(); }	public Transducer.State getState (int index) {		return (Transducer.State) states.get(index); }		public Iterator initialStateIterator () {		return initialStates.iterator (); }	public boolean isTrainable () { return trainable; }	public void reset ()	{		throw new UnsupportedOperationException ("Not used in HMMs");	}	public void estimate ()	{		if (!trainable)			throw new IllegalStateException ("This transducer not currently trainable.");		// xxx Put stuff in here. EM training.		throw new UnsupportedOperationException ("Not yet implemented.  Never?");	}	public boolean train (InstanceList ilist)	{		return train (ilist, (InstanceList)null, (InstanceList)null);	}	public boolean train (InstanceList ilist, InstanceList validation, InstanceList testing)	{		return train (ilist, validation, testing, (TransducerEvaluator)null);	}		public boolean train (InstanceList ilist, InstanceList validation, InstanceList testing,												TransducerEvaluator eval)	{		assert (ilist.size() > 0);		if (emissionEstimator == null) {			emissionEstimator = new Multinomial.LaplaceEstimator[numStates()];			transitionEstimator = new Multinomial.LaplaceEstimator[numStates()];			emissionMultinomial = new Multinomial[numStates()];			transitionMultinomial = new Multinomial[numStates()];			Alphabet transitionAlphabet = new Alphabet ();			for (int i=0; i < numStates(); i++) 				transitionAlphabet.lookupIndex (((State)states.get(i)).getName(), true);			for (int i=0; i < numStates(); i++) {				emissionEstimator[i] = new Multinomial.LaplaceEstimator(inputAlphabet);								transitionEstimator[i] = new Multinomial.LaplaceEstimator(transitionAlphabet);				emissionMultinomial[i] = new Multinomial (getUniformArray (inputAlphabet.size()), inputAlphabet);				transitionMultinomial[i] = new Multinomial (getUniformArray (transitionAlphabet.size()), transitionAlphabet);			}			initialEstimator = new Multinomial.LaplaceEstimator (transitionAlphabet);		}		for (int i=0; i < ilist.size(); i++) {			Instance instance = ilist.getInstance(i);			FeatureSequence input = (FeatureSequence) instance.getData();			FeatureSequence output = (FeatureSequence) instance.getTarget();			forwardBackward (input, output, true);					}		initialMultinomial = initialEstimator.estimate();		for (int i=0; i < numStates(); i++) {			emissionMultinomial[i] = emissionEstimator[i].estimate();			transitionMultinomial[i] = transitionEstimator[i].estimate();			getState (i).setInitialCost (-initialMultinomial.logProbability (getState(i).getName()));		}					return true;	}	public void write (File f) {		try {			ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(f));			oos.writeObject(this);			oos.close();		}		catch (IOException e) {			System.err.println("Exception writing file " + f + ": " + e);		}	}	private double[] getUniformArray (int size) {		double[] ret = new double[size];		for (int i=0; i < size; i++)			ret[i] = 1.0 / (double)size;		return ret;	}	// Serialization	// For HMM class	private static final long serialVersionUID = 1;	private static final int CURRENT_SERIAL_VERSION = 1;	static final int NULL_INTEGER = -1;	/* Need to check for null pointers. */	private void writeObject (ObjectOutputStream out) throws IOException {		int i, size;		out.writeInt (CURRENT_SERIAL_VERSION);		out.writeObject(inputPipe);		out.writeObject(outputPipe);		out.writeObject (inputAlphabet);		out.writeObject (outputAlphabet);		size = states.size();		out.writeInt(size);		for (i = 0; i<size; i++)			out.writeObject(states.get(i));		size = initialStates.size();		out.writeInt(size);		for (i = 0; i <size; i++)			out.writeObject(initialStates.get(i));		out.writeObject(name2state);		if (emissionEstimator != null) {			size = emissionEstimator.length;			for (i=0; i<size; i++)				out.writeObject(emissionEstimator[i]);		} else			out.writeInt(NULL_INTEGER);				if (transitionEstimator != null) {			size = transitionEstimator.length;			for (i=0; i<size; i++)				out.writeObject(transitionEstimator[i]);		} else			out.writeInt(NULL_INTEGER);				out.writeBoolean(trainable);	}		private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {		int size, i;		int version = in.readInt ();		inputPipe = (Pipe) in.readObject();		outputPipe = (Pipe) in.readObject();		inputAlphabet = (Alphabet) in.readObject();		outputAlphabet = (Alphabet) in.readObject();		size = in.readInt();		states = new ArrayList();		for (i=0; i<size; i++) {			State s = (HMM.State) in.readObject();			states.add(s);		}		size = in.readInt();		initialStates = new ArrayList();		for (i=0; i<size; i++) {

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -