📄 baumwelchlearner.java
字号:
/* * Copyright 1999-2002 Carnegie Mellon University. * Portions Copyright 2002 Sun Microsystems, Inc. * Portions Copyright 2002 Mitsubishi Electric Research Laboratories. * All Rights Reserved. Use is subject to license terms. * * See the file "license.terms" for information on usage and * redistribution of this file, and for a DISCLAIMER OF ALL * WARRANTIES. * */package edu.cmu.sphinx.trainer;import java.io.FileInputStream;import java.io.IOException;import java.io.InputStream;import java.util.ArrayList;import java.util.Collection;import java.util.Iterator;import java.util.List;import java.util.logging.Logger;import edu.cmu.sphinx.frontend.Data;import edu.cmu.sphinx.frontend.DataEndSignal;import edu.cmu.sphinx.frontend.DataProcessingException;import edu.cmu.sphinx.frontend.DataProcessor;import edu.cmu.sphinx.frontend.DataStartSignal;import edu.cmu.sphinx.frontend.FrontEnd;import edu.cmu.sphinx.frontend.FrontEndFactory;import edu.cmu.sphinx.frontend.Signal;import edu.cmu.sphinx.frontend.util.StreamCepstrumSource;import edu.cmu.sphinx.frontend.util.StreamDataSource;import edu.cmu.sphinx.linguist.acoustic.HMM;import edu.cmu.sphinx.linguist.acoustic.HMMState;import edu.cmu.sphinx.linguist.acoustic.tiedstate.SenoneHMM;import edu.cmu.sphinx.linguist.acoustic.tiedstate.SenoneHMMState;import edu.cmu.sphinx.linguist.acoustic.tiedstate.trainer.TrainerScore;import edu.cmu.sphinx.util.LogMath;import edu.cmu.sphinx.util.SphinxProperties;import edu.cmu.sphinx.util.Utilities;/** * Provides mechanisms for computing statistics given a set of states * and input data. */public class BaumWelchLearner implements Learner { private final static String PROP_PREFIX = "edu.cmu.sphinx.trainer."; /** * The SphinxProperty name for the input data type. */ public final static String PROP_INPUT_TYPE = PROP_PREFIX+"inputDataType"; /** * The default value for the property PROP_INPUT_TYPE. */ public final static String PROP_INPUT_TYPE_DEFAULT = "cepstrum"; /** * The sphinx property for the front end class. */ public final static String PROP_FRONT_END = PROP_PREFIX + "frontend"; /** * The default value of PROP_FRONT_END. */ public final static String PROP_FRONT_END_DEFAULT = "edu.cmu.sphinx.frontend.SimpleFrontEnd"; /* * The logger for this class */ private static Logger logger = Logger.getLogger("edu.cmu.sphinx.trainer.BaumWelch"); private FrontEnd frontEnd; private DataProcessor dataSource; private String context; private String inputDataType; private SphinxProperties props; private LogMath logMath; private Data curFeature; private UtteranceGraph graph; private Object[] scoreArray; private int lastFeatureIndex; private int currentFeatureIndex; private float[] alphas; private float[] betas; private float[] outputProbs; private float[] componentScores; private float[] probCurrentFrame; private float totalLogScore; /** * Constructor for this learner. */ public BaumWelchLearner(SphinxProperties props) throws IOException { this.props = props; context = props.getContext(); logMath = LogMath.getLogMath(context); initialize(); } /** * Initializes the Learner with the proper context and frontend. * * @throws IOException */ private void initialize() throws IOException { inputDataType = props.getString(PROP_INPUT_TYPE, PROP_INPUT_TYPE_DEFAULT); if (inputDataType.equals("audio")) { dataSource = new StreamDataSource(); dataSource.initialize("batchAudioSource", null, props, null); } else if (inputDataType.equals("cepstrum")) { dataSource = new StreamCepstrumSource(); dataSource.initialize("batchCepstrumSource", null, props, null); } else { throw new Error("Unsupported data type: " + inputDataType + "\n" + "Only audio and cepstrum are supported\n"); } frontEnd = getFrontEnd(); } // Cut and paste from e.c.s.d.Recognizer.java /** * Initialize and return the frontend based on the given sphinx * properties. */ protected FrontEnd getFrontEnd() { String path = null; try { FrontEnd fe = null; Collection names = FrontEndFactory.getNames(props); assert names.size() == 1; for (Iterator i = names.iterator(); i.hasNext();) { String feName = (String) i.next(); fe = FrontEndFactory.getFrontEnd(props, feName); } return fe; } catch (InstantiationException ie) { throw new Error("IE: Can't create front end " + path, ie); } } /** * Sets the learner to use a utterance. * * @param utterance the utterance * * @throws IOException */ public void setUtterance(Utterance utterance) throws IOException { String file = utterance.toString(); InputStream is = new FileInputStream(file); inputDataType = props.getString(PROP_INPUT_TYPE, PROP_INPUT_TYPE_DEFAULT); if (inputDataType.equals("audio")) { ((StreamDataSource) dataSource).setInputStream(is, file); } else if (inputDataType.equals("cepstrum")) { boolean bigEndian = Utilities.isCepstraFileBigEndian(file); ((StreamCepstrumSource) dataSource).setInputStream(is, bigEndian); } } /** * Returns a single frame of speech. * * @return a feature frame * * @throws IOException */ private boolean getFeature() { try { curFeature = frontEnd.getData(); if (curFeature == null) { return false; } if (curFeature instanceof DataStartSignal) { curFeature = frontEnd.getData(); if (curFeature == null) { return false; } } if (curFeature instanceof DataEndSignal) { return false; } if (curFeature instanceof Signal) { throw new Error("Can't score non-content feature"); } } catch (DataProcessingException dpe) { System.out.println("DataProcessingException " + dpe); dpe.printStackTrace(); return false; } return true; } /** * Starts the Learner. */ public void start(){ } /** * Stops the Learner. */ public void stop(){ } /** * Initializes computation for current utterance and utterance graph. * * @param utterance the current utterance * @param graph the current utterance graph * * @throws IOException */ public void initializeComputation(Utterance utterance, UtteranceGraph graph) throws IOException { setUtterance(utterance); setGraph(graph); } /** * Implements the setGraph method. * * @param graph the graph */ public void setGraph(UtteranceGraph graph) { this.graph = graph; } /** * Prepares the learner for returning scores, one at a time. To do * so, it performs the full forward pass, but returns the scores * for the backward pass one feature frame at a time. */ private Object[] prepareScore() { // scoreList will contain a list of score, which in turn are a // vector of TrainerScore elements. List scoreList = new ArrayList(); int numStates = graph.size(); TrainerScore[] score = new TrainerScore[numStates]; alphas = new float[numStates]; betas = new float[numStates]; outputProbs = new float[numStates]; // First we do the forward pass. We need this before we can // return any probability. When we're doing the backward pass, // we can finally return a score for each call of this method. probCurrentFrame = new float[numStates]; // Initialization of probCurrentFrame for the alpha computation Node initialNode = graph.getInitialNode(); int indexInitialNode = graph.indexOf(initialNode); for (int i = 0; i < numStates; i++) { probCurrentFrame[i] = LogMath.getLogZero(); } // Overwrite in the right position probCurrentFrame[indexInitialNode] = 0.0f; for (initialNode.startOutgoingEdgeIterator(); initialNode.hasMoreOutgoingEdges(); ) { Edge edge = initialNode.nextOutgoingEdge(); Node node = edge.getDestination(); int index = graph.indexOf(node); if (!node.isType("STATE")) { // Certainly non-emitting, if it's not in an HMM. probCurrentFrame[index] = 0.0f; } else { // See if it's the last state in the HMM, i.e., if // it's non-emitting. HMMState state = (HMMState) node.getObject(); HMM hmm = state.getHMM(); if (!state.isEmitting()) { probCurrentFrame[index] = 0.0f; } assert false; } } // If getFeature() is true, curFeature contains a valid // Feature. If not, a problem or EOF was encountered. lastFeatureIndex = 0; while (getFeature()) { forwardPass(score); scoreList.add(score); lastFeatureIndex++; } logger.info("Feature frames read: " + lastFeatureIndex); // Prepare for beta computation for (int i = 0; i < probCurrentFrame.length; i++) { probCurrentFrame[i] = LogMath.getLogZero(); } Node finalNode = graph.getFinalNode(); int indexFinalNode = graph.indexOf(finalNode); // Overwrite in the right position probCurrentFrame[indexFinalNode] = 0.0f; for (finalNode.startIncomingEdgeIterator(); finalNode.hasMoreIncomingEdges(); ) { Edge edge = finalNode.nextIncomingEdge(); Node node = edge.getSource(); int index = graph.indexOf(node); if (!node.isType("STATE")) { // Certainly non-emitting, if it's not in an HMM. probCurrentFrame[index] = 0.0f; assert false; } else { // See if it's the last state in the HMM, i.e., if // it's non-emitting. HMMState state = (HMMState) node.getObject(); HMM hmm = state.getHMM(); if (!state.isEmitting()) { probCurrentFrame[index] = 0.0f; } } } return scoreList.toArray(); } /** * Gets the TrainerScore for the next frame * * @return the TrainerScore, or null if EOF was found */ public TrainerScore[] getScore() { TrainerScore[] score; if (scoreArray == null) { // Do the forward pass, and create the necessary arrays scoreArray = prepareScore(); currentFeatureIndex = lastFeatureIndex; } currentFeatureIndex--; if (currentFeatureIndex >= 0) { float logScore = LogMath.getLogZero(); score = (TrainerScore []) scoreArray[currentFeatureIndex]; assert score.length == betas.length; backwardPass(score); for (int i = 0; i < betas.length; i++) { score[i].setGamma(); logScore = logMath.addAsLinear(logScore, score[i].getGamma()); } if (currentFeatureIndex == lastFeatureIndex - 1) { TrainerScore.setLogLikelihood(logScore); totalLogScore = logScore; } else { if (Math.abs(totalLogScore - logScore) > Math.abs(totalLogScore)) { System.out.println("WARNING: log probabilities differ: " + totalLogScore + " and " + logScore); } } return score; } else { // We need to clear this, so we start the next iteration // on a clean plate. scoreArray = null; return null; } } /** * Computes the acoustic scores using the current Feature and a * given node in the graph. * * @param index the graph index * * @return the overall acoustic score */ private float calculateScores(int index) { float logScore; // Find the HMM state for this node SenoneHMMState state = (SenoneHMMState) graph.getNode(index).getObject(); if ((state != null) && (state.isEmitting())) { // Compute the scores for each mixture component in this state componentScores = state.calculateComponentScore(curFeature); // Compute the overall score for this state logScore = state.getScore(curFeature); // For CI models, for now, we only try to use mixtures // with one component assert componentScores.length == 1; } else { componentScores = null; logScore = 0.0f; } return logScore; } /** * Does the forward pass, one frame at a time. * * @param score the objects transferring info to the buffers */ private void forwardPass(TrainerScore[] score) { // Let's precompute the acoustic probabilities and create the // score object, one for each state for (int i = 0; i < graph.size(); i++) { outputProbs[i] = calculateScores(i); score[i] = new TrainerScore(curFeature, outputProbs[i], (HMMState) graph.getNode(i).getObject(), componentScores); score[i].setAlpha(probCurrentFrame[i]); } // Now, the forward pass. float[] probPreviousFrame = probCurrentFrame; probCurrentFrame = new float[graph.size()];
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -