⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 hmm.java

📁 mallet是自然语言处理、机器学习领域的一个开源项目。
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).   http://www.cs.umass.edu/~mccallum/mallet   This software is provided under the terms of the Common Public License,   version 1.0, as published by http://www.opensource.org.  For further   information, see the file `LICENSE' included with this distribution. *//** 		@author Aron Culotta <a href="mailto:culotta@cs.umass.edu">culotta@cs.umass.edu</a>		@author Andrew McCallum <a href="mailto:mccallum@cs.umass.edu">mccallum@cs.umass.edu</a> */package edu.umass.cs.mallet.base.fst;import edu.umass.cs.mallet.base.types.*;import edu.umass.cs.mallet.base.pipe.Pipe;import edu.umass.cs.mallet.base.maximize.*;import edu.umass.cs.mallet.base.maximize.tests.*;import edu.umass.cs.mallet.base.util.Maths;import edu.umass.cs.mallet.base.util.MalletLogger;import java.util.ArrayList;import java.util.HashMap;import java.util.Iterator;import java.util.Arrays;import java.util.BitSet;import java.util.Random;import java.util.regex.*;import java.util.logging.*;import java.io.*;import java.lang.reflect.Constructor;/** Hidden Markov Model */public class HMM extends Transducer implements Serializable{	private static Logger logger = MalletLogger.getLogger(HMM.class.getName());  static final String LABEL_SEPARATOR = ",";	Alphabet inputAlphabet;	Alphabet outputAlphabet;	ArrayList states = new ArrayList ();	ArrayList initialStates = new ArrayList ();	HashMap name2state = new HashMap ();	Multinomial.Estimator [] transitionEstimator;	Multinomial.Estimator [] emissionEstimator;	Multinomial.Estimator initialEstimator;	Multinomial [] transitionMultinomial;	Multinomial [] emissionMultinomial;	Multinomial initialMultinomial;		boolean trainable = false;	public HMM (Pipe inputPipe, Pipe outputPipe)	{		this.inputPipe = inputPipe;		this.outputPipe = outputPipe;		this.inputAlphabet = inputPipe.getDataAlphabet();		this.outputAlphabet = inputPipe.getTargetAlphabet();	}		public HMM (Alphabet inputAlphabet,							 Alphabet outputAlphabet)	{		inputAlphabet.stopGrowth();		logger.info ("HMM input dictionary size = "+inputAlphabet.size());		this.inputAlphabet = inputAlphabet;		this.outputAlphabet = outputAlphabet;	}	public Alphabet getInputAlphabet () { return inputAlphabet; }	public Alphabet getOutputAlphabet () { return outputAlphabet; }	public void print () {		StringBuffer sb = new StringBuffer();		for (int i = 0; i < numStates(); i++) {			State s = (State) getState (i);			sb.append ("STATE NAME=\"");			sb.append (s.name);	sb.append ("\" ("); sb.append (s.destinations.length); sb.append (" outgoing transitions)\n");			sb.append ("  "); sb.append ("initialCost = "); sb.append (s.initialCost); sb.append ('\n');			sb.append ("  "); sb.append ("finalCost = "); sb.append (s.finalCost); sb.append ('\n');			sb.append ("Emission distribution:\n" + emissionMultinomial[i] + "\n\n"); 			sb.append ("Transition distribution:\n" + transitionMultinomial[i].toString());		}				System.out.println (sb.toString());			}	 	public void addState (String name, double initialCost, double finalCost,												String[] destinationNames,												String[] labelNames)	{		assert (labelNames.length == destinationNames.length);		setTrainable (false);		if (name2state.get(name) != null)			throw new IllegalArgumentException ("State with name `"+name+"' already exists.");		State s = new State (name, states.size(), initialCost, finalCost,												 destinationNames, labelNames, this);		s.print ();		states.add (s);		if (initialCost < INFINITE_COST)			initialStates.add (s);		name2state.put (name, s);	}	// Add a state with parameters equal zero, and labels on out-going arcs	// the same name as their destination state names.	public void addState (String name, String[] destinationNames)	{		this.addState (name, 0, 0,									 destinationNames, destinationNames);	}	// Add a group of states that are fully connected with each other,	// with parameters equal zero, and labels on their out-going arcs	// the same name as their destination state names.	public void addFullyConnectedStates (String[] stateNames)	{		for (int i = 0; i < stateNames.length; i++)			addState (stateNames[i], stateNames);	}	public void addFullyConnectedStatesForLabels ()	{		String[] labels = new String[outputAlphabet.size()];		// This is assuming the the entries in the outputAlphabet are Strings!		for (int i = 0; i < outputAlphabet.size(); i++) {			labels[i] = (String) outputAlphabet.lookupObject(i);		}		addFullyConnectedStates (labels);	}	private boolean[][] labelConnectionsIn (InstanceList trainingSet)	{		int numLabels = outputAlphabet.size();		boolean[][] connections = new boolean[numLabels][numLabels];		for (int i = 0; i < trainingSet.size(); i++) {			Instance instance = trainingSet.getInstance(i);			FeatureSequence output = (FeatureSequence) instance.getTarget();			for (int j = 1; j < output.size(); j++) {				int sourceIndex = outputAlphabet.lookupIndex (output.get(j-1));				int destIndex = outputAlphabet.lookupIndex (output.get(j));				assert (sourceIndex >= 0 && destIndex >= 0);				connections[sourceIndex][destIndex] = true;			}		}		return connections;	}	/** Add states to create a first-order Markov model on labels,			adding only those transitions the occur in the given			trainingSet. */	public void addStatesForLabelsConnectedAsIn (InstanceList trainingSet)	{		int numLabels = outputAlphabet.size();		boolean[][] connections = labelConnectionsIn (trainingSet);		for (int i = 0; i < numLabels; i++) {			int numDestinations = 0;			for (int j = 0; j < numLabels; j++)				if (connections[i][j]) numDestinations++;			String[] destinationNames = new String[numDestinations];			int destinationIndex = 0;			for (int j = 0; j < numLabels; j++)				if (connections[i][j])					destinationNames[destinationIndex++] = (String)outputAlphabet.lookupObject(j);			addState ((String)outputAlphabet.lookupObject(i), destinationNames);		}	}	/** Add as many states as there are labels, but don't create separate weights			for each source-destination pair of states.  Instead have all the incoming			transitions to a state share the same weights. */	public void addStatesForHalfLabelsConnectedAsIn (InstanceList trainingSet)	{		int numLabels = outputAlphabet.size();		boolean[][] connections = labelConnectionsIn (trainingSet);		for (int i = 0; i < numLabels; i++) {			int numDestinations = 0;			for (int j = 0; j < numLabels; j++)				if (connections[i][j]) numDestinations++;			String[] destinationNames = new String[numDestinations];			int destinationIndex = 0;			for (int j = 0; j < numLabels; j++)				if (connections[i][j])					destinationNames[destinationIndex++] = (String)outputAlphabet.lookupObject(j);			addState ((String)outputAlphabet.lookupObject(i), 0.0, 0.0,								destinationNames, destinationNames);		}	}	/** Add as many states as there are labels, but don't create			separate observational-test-weights for each source-destination			pair of states---instead have all the incoming transitions to a			state share the same observational-feature-test weights.			However, do create separate default feature for each transition,			(which acts as an HMM-style transition probability). */	public void addStatesForThreeQuarterLabelsConnectedAsIn (InstanceList trainingSet)	{		int numLabels = outputAlphabet.size();		boolean[][] connections = labelConnectionsIn (trainingSet);		for (int i = 0; i < numLabels; i++) {			int numDestinations = 0;			for (int j = 0; j < numLabels; j++)				if (connections[i][j]) numDestinations++;			String[] destinationNames = new String[numDestinations];			int destinationIndex = 0;			for (int j = 0; j < numLabels; j++)				if (connections[i][j]) {					String labelName = (String)outputAlphabet.lookupObject(j);					destinationNames[destinationIndex] = labelName;					// The "transition" weights will include only the default feature					String wn = (String)outputAlphabet.lookupObject(i) + "->" + (String)outputAlphabet.lookupObject(j);					destinationIndex++;				}			addState ((String)outputAlphabet.lookupObject(i), 0.0, 0.0,								destinationNames, destinationNames);		}	}	public void addFullyConnectedStatesForThreeQuarterLabels (InstanceList trainingSet)	{		int numLabels = outputAlphabet.size();		for (int i = 0; i < numLabels; i++) {			String[] destinationNames = new String[numLabels];			for (int j = 0; j < numLabels; j++) {				String labelName = (String)outputAlphabet.lookupObject(j);				destinationNames[j] = labelName;			}			addState ((String)outputAlphabet.lookupObject(i), 0.0, 0.0,								destinationNames, destinationNames);		}	}		public void addFullyConnectedStatesForBiLabels ()	{		String[] labels = new String[outputAlphabet.size()];		// This is assuming the the entries in the outputAlphabet are Strings!		for (int i = 0; i < outputAlphabet.size(); i++) {			labels[i] =  outputAlphabet.lookupObject(i).toString();		}		for (int i = 0; i < labels.length; i++) {			for (int j = 0; j < labels.length; j++) {				String[] destinationNames = new String[labels.length];				for (int k = 0; k < labels.length; k++)					destinationNames[k] = labels[j]+LABEL_SEPARATOR+labels[k];				addState (labels[i]+LABEL_SEPARATOR+labels[j], 0.0, 0.0,									destinationNames, labels);			}		}	}	/** Add states to create a second-order Markov model on labels,			adding only those transitions the occur in the given			trainingSet. */	public void addStatesForBiLabelsConnectedAsIn (InstanceList trainingSet)	{		int numLabels = outputAlphabet.size();		boolean[][] connections = labelConnectionsIn (trainingSet);		for (int i = 0; i < numLabels; i++) {			for (int j = 0; j < numLabels; j++) {				if (!connections[i][j])					continue;				int numDestinations = 0;				for (int k = 0; k < numLabels; k++)					if (connections[j][k]) numDestinations++;				String[] destinationNames = new String[numDestinations];				String[] labels = new String[numDestinations];				int destinationIndex = 0;				for (int k = 0; k < numLabels; k++)					if (connections[j][k]) {						destinationNames[destinationIndex] =							(String)outputAlphabet.lookupObject(j)+LABEL_SEPARATOR+(String)outputAlphabet.lookupObject(k);						labels[destinationIndex] = (String)outputAlphabet.lookupObject(k);						destinationIndex++;					}				addState ((String)outputAlphabet.lookupObject(i)+LABEL_SEPARATOR+									(String)outputAlphabet.lookupObject(j), 0.0, 0.0,									destinationNames, labels);			}		}	}		public void addFullyConnectedStatesForTriLabels ()	{		String[] labels = new String[outputAlphabet.size()];		// This is assuming the the entries in the outputAlphabet are Strings!		for (int i = 0; i < outputAlphabet.size(); i++) {			logger.info ("HMM: outputAlphabet.lookup class = "+									 outputAlphabet.lookupObject(i).getClass().getName());			labels[i] =  outputAlphabet.lookupObject(i).toString();		}		for (int i = 0; i < labels.length; i++) {			for (int j = 0; j < labels.length; j++) {				for (int k = 0; k < labels.length; k++) {					String[] destinationNames = new String[labels.length];					for (int l = 0; l < labels.length; l++)						destinationNames[l] = labels[j]+LABEL_SEPARATOR+labels[k]+LABEL_SEPARATOR+labels[l];					addState (labels[i]+LABEL_SEPARATOR+labels[j]+LABEL_SEPARATOR+labels[k], 0.0, 0.0,										destinationNames, labels);				}			}		}	}		public void addSelfTransitioningStateForAllLabels (String name)	{		String[] labels = new String[outputAlphabet.size()];		String[] destinationNames  = new String[outputAlphabet.size()];		for (int i = 0; i < outputAlphabet.size(); i++) {			labels[i] =  outputAlphabet.lookupObject(i).toString();			destinationNames[i] = name;		}		addState (name, 0.0, 0.0, destinationNames, labels);	}  private String concatLabels(String[] labels)  {

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -