📄 sphinx3saver.java
字号:
} outputStream.close(); } /** * Saves the sphinx3 densityfile, a set of density arrays are * created and placed in the given pool. * * @param pool the pool to be saved * @param path the name of the data * @param append is true, the file will be appended, useful if * saving to a ZIP or JAR file * * @throws FileNotFoundException if a file cannot be found * @throws IOException if an error occurs while saving the data */ private void saveDensityFileBinary(Pool pool, String path, boolean append) throws FileNotFoundException, IOException { int token_type; int numStates; int numStreams; int numGaussiansPerState; Properties props = new Properties(); int checkSum = 0; logger.info("Saving density file to: " ); logger.info(path); props.setProperty("version", DENSITY_FILE_VERSION); props.setProperty("chksum0", checksum); DataOutputStream dos = writeS3BinaryHeader(location, path, props, append); numStates = pool.getFeature(NUM_SENONES, -1); numStreams = pool.getFeature(NUM_STREAMS, -1); numGaussiansPerState = pool.getFeature(NUM_GAUSSIANS_PER_STATE, -1); writeInt(dos, numStates); writeInt(dos, numStreams); writeInt(dos, numGaussiansPerState); int rawLength = 0; int[] vectorLength = new int[numStreams]; for (int i = 0; i < numStreams; i++) { vectorLength[i] = this.vectorLength; writeInt(dos, vectorLength[i]); rawLength += numGaussiansPerState * numStates * vectorLength[i]; } assert numStreams == 1; assert rawLength == numGaussiansPerState * numStates * this.vectorLength; writeInt(dos, rawLength); //System.out.println("Nstates " + numStates); //System.out.println("Nstreams " + numStreams); //System.out.println("NgaussiansPerState " + numGaussiansPerState); //System.out.println("vectorLength " + vectorLength.length); //System.out.println("rawLength " + rawLength); int r = 0; for (int i = 0; i < numStates; i++) { for (int j = 0; j < numStreams; j++) { for (int k = 0; k < numGaussiansPerState; k++) { int id = i * numStreams * numGaussiansPerState + j * numGaussiansPerState + k; float[] density = (float [])pool.get(id); // Do checksum here? writeFloatArray(dos, density); } } } if (doCheckSum) { assert doCheckSum = false: "Checksum not supported"; } // S3 requires some number here.... writeInt(dos, checkSum); // BUG: not checking the check sum yet. dos.close(); } /** * Writes the S3 binary header to the given location+path. * * @param location the location of the file * @param path the name of the file * @param props the properties * @param append is true, the file will be appended, useful if * saving to a ZIP or JAR file * * @return the output stream positioned after the header * * @throws IOException on error */ protected DataOutputStream writeS3BinaryHeader(String location, String path, Properties props, boolean append) throws IOException { OutputStream outputStream = StreamFactory.getOutputStream(location, path, append); if (doCheckSum) { assert false: "Checksum not supported"; } DataOutputStream dos = new DataOutputStream(new BufferedOutputStream(outputStream)); writeWord(dos, "s3\n"); for (Enumeration e = props.keys(); e.hasMoreElements(); ) { String name = (String) e.nextElement(); String value = props.getProperty(name); writeWord(dos, name + " " + value + "\n"); } writeWord(dos, "endhdr\n"); writeInt(dos, BYTE_ORDER_MAGIC); return dos; } /** * Writes the next word (without surrounding white spaces) to the * given stream. * * @param dos the output stream * @param word the next word * * @throws IOException on error */ void writeWord(DataOutputStream dos, String word) throws IOException { dos.writeBytes(word); } /** * Writes a single char to the stream * * @param dos the stream to read * @param character the next character on the stream * * @throws IOException if an error occurs */ private void writeChar(DataOutputStream dos, char character) throws IOException { dos.writeByte(character); } /** * swap a 32 bit word * * @param val the value to swap * * @return the swapped value */ private int byteSwap(int val) { return ((0xff & (val >>24)) | (0xff00 & (val >>8)) | (0xff0000 & (val <<8)) | (0xff000000 & (val <<24))); } /** * Writes an integer to the output stream, byte-swapping as * necessary * * @param dos the outputstream * @param val an integer value * * @throws IOException on error */ protected void writeInt(DataOutputStream dos, int val) throws IOException { if (swap) { dos.writeInt(Utilities.swapInteger(val)); } else { dos.writeInt(val); } } /** * Writes a float to the output stream, byte-swapping as * necessary * * @param dos the inputstream * @param val a float value * * @throws IOException on error */ protected void writeFloat(DataOutputStream dos, float val) throws IOException { if (swap) { dos.writeFloat(Utilities.swapFloat(val)); } else { dos.writeFloat(val); } } // Do we need the method nonZeroFloor?? /** * If a data point is non-zero and below 'floor' make * it equal to floor (don't floor zero values though). * * @param data the data to floor * @param floor the floored value */ private void nonZeroFloor(float[] data, float floor) { for (int i = 0; i < data.length; i++) { if (data[i] != 0.0 && data[i] < floor) { data[i] = floor; } } } /** * If a data point is below 'floor' make * it equal to floor. * * @param data the data to floor * @param floor the floored value */ private void floorData(float[] data, float floor) { for (int i = 0; i < data.length; i++) { if (data[i] < floor) { data[i] = floor; } } } /** * Normalize the given data * * @param data the data to normalize */ private void normalize(float[] data) { float sum = 0; for (int i = 0; i < data.length; i++) { sum += data[i]; } if (sum != 0.0f) { for (int i = 0; i < data.length; i++) { data[i] = data[i] / sum ; } } } /** * Dump the data * * @param name the name of the data * @param data the data itself * */ private void dumpData(String name, float[] data) { System.out.println(" ----- " + name + " -----------"); for (int i = 0; i < data.length; i++) { System.out.println(name + " " + i + ": " + data[i]); } } /** * Convert to log math * * @param data the data to normalize */ // linearToLog returns a float, so zero values in linear scale // should return -Float.MAX_VALUE. private void convertToLogMath(float[] data) { for (int i = 0; i < data.length; i++) { data[i] = logMath.linearToLog(data[i]); } } /** * Convert from log math * * @param in the data in log scale * @param out the data in linear scale */ protected void convertFromLogMath(float[] in, float[] out) { assert in.length == out.length; for (int i = 0; i < in.length; i++) { out[i] = (float)logMath.logToLinear(in[i]); } } /** * Writes the given number of floats from an array of floats to a * stream. * * @param dos the stream to write the data to * @param data the array of floats to write to the stream * * @throws IOException if an exception occurs */ protected void writeFloatArray(DataOutputStream dos, float[] data) throws IOException{ for (int i = 0; i < data.length; i++) { writeFloat(dos, data[i]); } } /** * Saves the sphinx3 densityfile, a set of density arrays are * created and placed in the given pool. * * @param useCDUnits if true, uses context dependent units * @param outputStream the open output stream to use * @param path the path to a density file * * @throws FileNotFoundException if a file cannot be found * @throws IOException if an error occurs while saving the data */ private void saveHMMPool(boolean useCDUnits, OutputStream outputStream, String path) throws FileNotFoundException, IOException { int token_type; int numBase; int numTri; int numStateMap; int numTiedState; int numStatePerHMM; int numContextIndependentTiedState; int numTiedTransitionMatrices; logger.info("Saving HMM file to: "); logger.info(path); if (outputStream == null) { throw new IOException("Error trying to write file " + location + path); } PrintWriter pw = new PrintWriter(outputStream, true); /* ExtendedStreamTokenizer est = new ExtendedStreamTokenizer (outputStream, '#', false); Pool pool = new Pool(path); */ // First, count the HMMs numBase = 0; numTri = 0; numContextIndependentTiedState = 0; numStateMap = 0; for (Iterator i = hmmManager.getIterator(); i.hasNext(); ) { SenoneHMM hmm = (SenoneHMM)i.next(); numStateMap += hmm.getOrder() + 1; if (hmm.isContextDependent()) { numTri++; } else { numBase++; numContextIndependentTiedState += hmm.getOrder(); } } pw.println(MODEL_VERSION); pw.println(numBase + " n_base"); pw.println(numTri + " n_tri"); pw.println(numStateMap + " n_state_map"); numTiedState = mixtureWeightsPool.getFeature(NUM_SENONES, 0); pw.println(numTiedState + " n_tied_state"); pw.println(numContextIndependentTiedState + " n_tied_ci_state"); numTiedTransitionMatrices = numBase; assert numTiedTransitionMatrices == matrixPool.size(); pw.println(numTiedTransitionMatrices + " n_tied_tmat"); pw.println("#"); pw.println("# Columns definitions"); pw.println("#base lft rt p attrib tmat ... state id's ..."); numStatePerHMM = numStateMap/(numTri+numBase); // Save the base phones for (Iterator i = hmmManager.getIterator(); i.hasNext(); ) { SenoneHMM hmm = (SenoneHMM)i.next(); if (hmm.isContextDependent()) { continue; } Unit unit = hmm.getUnit(); String name = unit.getName(); pw.print(name + "\t"); String left = "-"; pw.print(left + " "); String right = "-"; pw.print(right + " "); String position = hmm.getPosition().toString(); pw.print(position + "\t"); String attribute; if (unit.isFiller()) { attribute = FILLER; } else { attribute = "n/a"; } pw.print(attribute + "\t"); int tmat = matrixPool.indexOf(hmm.getTransitionMatrix()); assert tmat < numTiedTransitionMatrices; pw.print(tmat + "\t"); SenoneSequence ss = hmm.getSenoneSequence(); Senone[] senones = ss.getSenones(); for (int j = 0; j < senones.length; j++) { int index = senonePool.indexOf(senones[j]); assert index >= 0 && index < numContextIndependentTiedState; pw.print(index + "\t"); } pw.println("N"); if (logger.isLoggable(Level.FINE)) { logger.fine("Saved " + unit); } } // Save the context dependent phones. for (Iterator i = hmmManager.getIterator(); i.hasNext(); ) { SenoneHMM hmm = (SenoneHMM)i.next(); if (!hmm.isContextDependent()) { continue; } Unit unit = hmm.getUnit(); LeftRightContext context = (LeftRightContext)unit.getContext(); Unit[] leftContext = context.getLeftContext(); Unit[] rightContext = context.getRightContext(); assert leftContext.length == 1 && rightContext.length == 1; String name = unit.getName(); pw.print(name + "\t"); String left = leftContext[0].getName(); pw.print(left + " "); String right = rightContext[0].getName(); pw.print(right + " "); String position = hmm.getPosition().toString(); pw.print(position + "\t"); String attribute; if (unit.isFiller()) { attribute = FILLER; } else { attribute = "n/a"; } assert attribute.equals("n/a"); pw.print(attribute + "\t"); int tmat = matrixPool.indexOf(hmm.getTransitionMatrix()); assert tmat < numTiedTransitionMatrices;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -