📄 sphinx3saver.java
字号:
/* * Copyright 1999-2002 Carnegie Mellon University. * Portions Copyright 2002 Sun Microsystems, Inc. * Portions Copyright 2002 Mitsubishi Electric Research Laboratories. * All Rights Reserved. Use is subject to license terms. * * See the file "license.terms" for information on usage and * redistribution of this file, and for a DISCLAIMER OF ALL * WARRANTIES. * */package edu.cmu.sphinx.linguist.acoustic.tiedstate.trainer;import java.io.BufferedOutputStream;import java.io.DataOutputStream;import java.io.File;import java.io.FileNotFoundException;import java.io.IOException;import java.io.OutputStream;import java.io.PrintStream;import java.io.PrintWriter;import java.net.URL;import java.util.Enumeration;import java.util.Iterator;import java.util.LinkedHashMap;import java.util.Map;import java.util.Properties;import java.util.logging.Level;import java.util.logging.Logger;import java.util.zip.ZipException;import edu.cmu.sphinx.linguist.acoustic.AcousticModel;import edu.cmu.sphinx.linguist.acoustic.LeftRightContext;import edu.cmu.sphinx.linguist.acoustic.Unit;import edu.cmu.sphinx.linguist.acoustic.tiedstate.GaussianMixture;import edu.cmu.sphinx.linguist.acoustic.tiedstate.HMMManager;import edu.cmu.sphinx.linguist.acoustic.tiedstate.Loader;import edu.cmu.sphinx.linguist.acoustic.tiedstate.MixtureComponent;import edu.cmu.sphinx.linguist.acoustic.tiedstate.Pool;import edu.cmu.sphinx.linguist.acoustic.tiedstate.Saver;import edu.cmu.sphinx.linguist.acoustic.tiedstate.Senone;import edu.cmu.sphinx.linguist.acoustic.tiedstate.SenoneHMM;import edu.cmu.sphinx.linguist.acoustic.tiedstate.SenoneSequence;import edu.cmu.sphinx.linguist.acoustic.tiedstate.TiedStateAcousticModel;import edu.cmu.sphinx.util.LogMath;import edu.cmu.sphinx.util.SphinxProperties;import edu.cmu.sphinx.util.StreamFactory;import edu.cmu.sphinx.util.Utilities;/** * An acoustic model saver that saves sphinx3 ascii data. * * Mixture weights and transition probabilities are saved in linear scale. */class Sphinx3Saver implements Saver { /** * The logger for this class */ private static Logger logger = Logger.getLogger(AcousticModel.PROP_PREFIX + "AcousticModel"); protected final static String NUM_SENONES = "num_senones"; protected final static String NUM_GAUSSIANS_PER_STATE = "num_gaussians"; protected final static String NUM_STREAMS = "num_streams"; protected final static String FILLER = "filler"; protected final static String SILENCE_CIPHONE = "SIL"; protected final static int BYTE_ORDER_MAGIC = 0x11223344; public final static String MODEL_VERSION = "0.3"; protected final static int CONTEXT_SIZE = 1; private String checksum; private boolean doCheckSum; private Pool meansPool; private Pool variancePool; private Pool matrixPool; private Pool meanTransformationMatrixPool; private Pool meanTransformationVectorPool; private Pool varianceTransformationMatrixPool; private Pool varianceTransformationVectorPool; private Pool mixtureWeightsPool; private Pool senonePool; private int vectorLength; private Map contextIndependentUnits; private HMMManager hmmManager; private LogMath logMath; private SphinxProperties acousticProperties; private boolean binary = false; private String location; private boolean swap; protected final static String DENSITY_FILE_VERSION = "1.0"; protected final static String MIXW_FILE_VERSION = "1.0"; protected final static String TMAT_FILE_VERSION = "1.0"; /** * Saves the sphinx3 models. * * @param modelName the name of the model as specified in the * props file. * @param props the SphinxProperties object * @param binary if <code>true</code> the file is saved in binary * format * @param loader this acoustic model's loader */ public Sphinx3Saver(String modelName, SphinxProperties props, boolean binary, Loader loader) throws FileNotFoundException, IOException, ZipException { this.binary = binary; logMath = LogMath.getLogMath(props.getContext()); // extract the feature vector length String vectorLengthProp = TiedStateAcousticModel.PROP_VECTOR_LENGTH; if (modelName != null) { vectorLengthProp = AcousticModel.PROP_PREFIX + modelName + ".FeatureVectorLength"; } vectorLength = props.getInt (vectorLengthProp, TiedStateAcousticModel.PROP_VECTOR_LENGTH_DEFAULT); hmmManager = loader.getHMMManager(); meansPool = loader.getMeansPool(); variancePool = loader.getVariancePool(); mixtureWeightsPool = loader.getMixtureWeightPool(); matrixPool = loader.getTransitionMatrixPool(); senonePool = loader.getSenonePool(); contextIndependentUnits = new LinkedHashMap(); acousticProperties = loader.getModelProperties(); // TODO: read checksum from props; checksum = "no"; doCheckSum = (checksum != null && checksum.equals("yes")); swap = false; // do the actual acoustic model loading saveModelFiles(modelName, props); } /** * Return the checksum string. * * @return the checksum */ protected String getCheckSum() { return checksum; } /** * Return whether to do the dochecksum. If true, checksum is * performed. * * @return the dochecksum */ protected boolean getDoCheckSum() { return doCheckSum; } /** * Return the LogMath. * * @return the logMath */ protected LogMath getLogMath() { return logMath; } /** * Return the acousticProperties. * * @return the acousticProperties */ protected SphinxProperties getAcousticProperties() { return acousticProperties; } /** * Return the location. * * @return the location */ protected String getLocation() { return location; } /** * Saves the AcousticModel from a directory in the file system. * * @param modelName the name of the acoustic model; if null we just * save from the default location * @param props the SphinxProperties object to use */ private void saveModelFiles(String modelName, SphinxProperties props) throws FileNotFoundException, IOException, ZipException { String prefix, model, dataDir, propsFile; if (modelName == null) { prefix = TrainerAcousticModel.PROP_PREFIX; } else { prefix = TrainerAcousticModel.PROP_PREFIX + modelName + "."; } // System.out.println("Using prefix: " + prefix); location = props.getString (prefix + "location.save", TrainerAcousticModel.PROP_LOCATION_SAVE_DEFAULT); model = props.getString (prefix + "definition_file", TiedStateAcousticModel.PROP_MODEL_DEFAULT); dataDir = props.getString (prefix + "data_location", TiedStateAcousticModel.PROP_DATA_LOCATION_DEFAULT) + "/"; propsFile = props.getString (prefix + "properties_file", TiedStateAcousticModel.PROP_PROPERTIES_FILE_DEFAULT); float distFloor = props.getFloat(TiedStateAcousticModel.PROP_MC_FLOOR, TiedStateAcousticModel.PROP_MC_FLOOR_DEFAULT); float mixtureWeightFloor = props.getFloat(TiedStateAcousticModel.PROP_MW_FLOOR, TiedStateAcousticModel.PROP_MW_FLOOR_DEFAULT); float transitionProbabilityFloor = props.getFloat(TiedStateAcousticModel.PROP_TP_FLOOR, TiedStateAcousticModel.PROP_TP_FLOOR_DEFAULT); float varianceFloor = props.getFloat(TiedStateAcousticModel.PROP_VARIANCE_FLOOR, TiedStateAcousticModel.PROP_VARIANCE_FLOOR_DEFAULT); logger.info("Saving acoustic model: " + modelName); logger.info(" Path : " + location); logger.info(" modellName: " + model); logger.info(" dataDir : " + dataDir); // save the acoustic properties file (am.props), // create a different URL depending on the data format URL url = null; String format = StreamFactory.resolve(location); if (format.equals(StreamFactory.ZIP_FILE)) { url = new URL("jar:" + location + "!/" + propsFile); } else { File file = new File(location, propsFile); url = file.toURI().toURL(); } if (modelName == null) { prefix = props.getContext() + ".acoustic"; } else { prefix = props.getContext() + ".acoustic." + modelName; } saveAcousticPropertiesFile(acousticProperties, propsFile, false); if (binary) { // First, overwrite the previous file saveDensityFileBinary(meansPool, dataDir + "means", true); // From now on, append to previous file saveDensityFileBinary(variancePool, dataDir + "variances", true); saveMixtureWeightsBinary(mixtureWeightsPool, dataDir + "mixture_weights", true); saveTransitionMatricesBinary(matrixPool, dataDir + "transition_matrices", true); } else { saveDensityFileAscii(meansPool, dataDir + "means.ascii", true); saveDensityFileAscii(variancePool, dataDir + "variances.ascii", true); saveMixtureWeightsAscii(mixtureWeightsPool, dataDir + "mixture_weights.ascii", true); saveTransitionMatricesAscii(matrixPool, dataDir + "transition_matrices.ascii", true); } // senonePool = createSenonePool(distFloor); // save the HMM model file boolean useCDUnits = props.getBoolean (TiedStateAcousticModel.PROP_USE_CD_UNITS, TiedStateAcousticModel.PROP_USE_CD_UNITS_DEFAULT); saveHMMPool(useCDUnits, StreamFactory.getOutputStream(location, model, true), location + File.separator + model); } /** * Returns the map of context indepent units. The map can be * accessed by unit name. * * @return the map of context independent units. */ public Map getContextIndependentUnits() { return contextIndependentUnits; } /** * Creates the senone pool from the rest of the pools. * * @param distFloor the lowest allowed score * * @return the senone pool */ private Pool createSenonePool(float distFloor, float varianceFloor) { Pool pool = new Pool("senones"); int numMixtureWeights = mixtureWeightsPool.size(); int numMeans = meansPool.size(); int numVariances = variancePool.size(); int numGaussiansPerSenone = mixtureWeightsPool.getFeature(NUM_GAUSSIANS_PER_STATE, 0); int numSenones = mixtureWeightsPool.getFeature(NUM_SENONES, 0); int whichGaussian = 0; logger.fine("NG " + numGaussiansPerSenone); logger.fine("NS " + numSenones); logger.fine("NMIX " + numMixtureWeights); logger.fine("NMNS " + numMeans); logger.fine("NMNS " + numVariances); assert numGaussiansPerSenone > 0; assert numMixtureWeights == numSenones; assert numVariances == numSenones * numGaussiansPerSenone; assert numMeans == numSenones * numGaussiansPerSenone; for (int i = 0; i < numSenones; i++) { MixtureComponent[] mixtureComponents = new MixtureComponent[numGaussiansPerSenone]; for (int j = 0; j < numGaussiansPerSenone; j++) { mixtureComponents[j] = new MixtureComponent( logMath, (float[]) meansPool.get(whichGaussian), (float[][]) meanTransformationMatrixPool.get(0), (float[]) meanTransformationVectorPool.get(0), (float[]) variancePool.get(whichGaussian), (float[][]) varianceTransformationMatrixPool.get(0), (float[]) varianceTransformationVectorPool.get(0), distFloor, varianceFloor); whichGaussian++; } Senone senone = new GaussianMixture( logMath, (float[]) mixtureWeightsPool.get(i), mixtureComponents, i); pool.put(i, senone); } return pool; } /** * Loads the Sphinx 3 acoustic model properties file, which * is basically a normal system properties file. * * @param property the SphinxProperties object * @param path the path to the acoustic properties file * @param append if true, append to the current file, if ZIP file * * @throws FileNotFoundException if the file cannot be found * @throws IOException if an error occurs while saving the data */ private void saveAcousticPropertiesFile(SphinxProperties property, String path, boolean append) throws FileNotFoundException, IOException { logger.info("Saving acoustic properties file to:"); logger.info(path); OutputStream outputStream = StreamFactory.getOutputStream(location, path, append); if (outputStream == null) { throw new IOException("Error trying to write file " + path); } PrintStream ps = new PrintStream(outputStream, true); property.list(ps); outputStream.close(); } /** * Saves the sphinx3 densityfile, a set of density arrays are * created and placed in the given pool. * * @param pool the pool to be saved * @param path the name of the data * @param append is true, the file will be appended, useful if * saving to a ZIP or JAR file * * @throws FileNotFoundException if a file cannot be found * @throws IOException if an error occurs while saving the data */ private void saveDensityFileAscii(Pool pool, String path, boolean append) throws FileNotFoundException, IOException { int token_type; int numStates; int numStreams; int numGaussiansPerState; logger.info("Saving density file to: "); logger.info(path); OutputStream outputStream = StreamFactory.getOutputStream(location, path, append); if (outputStream == null) { throw new IOException("Error trying to write file " + location + path); } PrintWriter pw = new PrintWriter(outputStream, true); pw.print("param "); numStates = pool.getFeature(NUM_SENONES, -1); pw.print(numStates + " "); numStreams = pool.getFeature(NUM_STREAMS, -1); pw.print(numStreams + " "); numGaussiansPerState = pool.getFeature(NUM_GAUSSIANS_PER_STATE, -1); pw.println(numGaussiansPerState); for (int i = 0; i < numStates; i++) { pw.println("mgau " + i); pw.println("feat " + 0); for (int j = 0; j < numGaussiansPerState; j++) { pw.print("density" + " \t" + j); int id = i * numGaussiansPerState + j; float[] density = (float [])pool.get(id); for (int k = 0; k < vectorLength; k++) { pw.print(" " + density[k]); // System.out.println(" " + i + " " + j + " " + k + // " " + density[k]); } pw.println(); }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -