📄 hmmpoolmanager.java
字号:
return; } int indexState = state.getState(); SenoneHMM hmm = (SenoneHMM) state.getHMM(); float[][] matrix = hmm.getTransitionMatrix(); // Find the index for current matrix in the transition matrix pool // int indexMatrix = matrixPool.indexOf(matrix); Integer indexInMap = (Integer) indexMap.get(matrix); int indexMatrix = indexInMap.intValue(); // Find the corresponding buffer Buffer[] bufferArray = (Buffer []) matrixBufferPool.get(indexMatrix); // Let's concentrate on the transitions *from* the current state float[] vector = matrix[indexState]; for (int i = 0; i < vector.length; i++) { // Make sure this is a valid transition if (vector[i] != LogMath.getLogZero()) { // We're assuming that if the states have position "a" // and "b" in the HMM, they'll have positions "k+a" // and "k+b" in the graph, that is, their relative // position is the same. // Distance between current state and "to" state in // the HMM int dist = i - indexState; // "to" state in the graph int indexNextScore = indexScore + dist; // Make sure the next state is non-emitting (the last // in the HMM), or in the same HMM. assert ((nextScore[indexNextScore].getState() == null) || (nextScore[indexNextScore].getState().getHMM() == hmm)); float alpha = score[indexScore].getAlpha(); float beta = nextScore[indexNextScore].getBeta(); float transitionProb = vector[i]; float outputProb = nextScore[indexNextScore].getScore(); float prob = alpha + beta + transitionProb + outputProb; prob -= currentLogLikelihood; // i is the index into the next state. bufferArray[indexState].logAccumulate(prob, i, logMath); /* if ((indexMatrix == 0) && (i == 2)) { // System.out.println("Out: " + outputProb); // bufferArray[indexState].dump(); } */ } } } /** * Accumulate transitions from a given state. * * @param indexState the state index * @param hmm the HMM * @param value the value to accumulate */ private void accumulateStateTransition(int indexState, SenoneHMM hmm, float value) { // Find the transition matrix in this hmm float[][] matrix = hmm.getTransitionMatrix(); // Find the vector with transitions from the current state to // other states. float[] stateVector = matrix[indexState]; // Find the index of the current transition matrix in the // transition matrix pool. // int indexMatrix = matrixPool.indexOf(matrix); Integer indexInMap = (Integer) indexMap.get(matrix); int indexMatrix = indexInMap.intValue(); // Find the buffer for the transition matrix. Buffer[] bufferArray = (Buffer []) matrixBufferPool.get(indexMatrix); // Accumulate for the transitions from current state for (int i = 0; i < stateVector.length; i++) { // Make sure we're not trying to accumulate in an invalid // transition. if (stateVector[i] != LogMath.getLogZero()) { bufferArray[indexState].logAccumulate(value, i, logMath); } } } /** * Accumulate the transition probabilities. */ private void accumulateTransition(int indexHmm, int indexScore, TrainerScore[] score, TrainerScore[] nextScore) { if (indexHmm == TrainerAcousticModel.ALL_MODELS) { // Well, special case... we want to add an amount to all // the states in all models for (Iterator i = hmmManager.getIterator(); i.hasNext(); ) { SenoneHMM hmm = (SenoneHMM) i.next(); for (int j = 0; j < hmm.getOrder(); j++) { accumulateStateTransition(j, hmm, score[indexScore].getScore()); } } } else { // For transition accumulation, we don't consider the last // time frame, since there's no transition from there to // anywhere... if (nextScore != null) { accumulateStateTransition(indexScore, score, nextScore); } } } /** * Update the log likelihood. This method should be called for * every utterance. */ protected void updateLogLikelihood() { // logLikelihood += currentLogLikelihood; } /** * Normalize the buffers. * * @return the log likelihood associated with the current training set */ protected float normalize() { normalizePool(meansBufferPool); normalizePool(varianceBufferPool); logNormalizePool(mixtureWeightsBufferPool); logNormalize2DPool(matrixBufferPool, matrixPool); return logLikelihood; } /** * Normalize a single buffer pool. * * @param pool the buffer pool to normalize */ private void normalizePool(Pool pool) { assert pool != null; for (int i = 0; i < pool.size(); i++) { Buffer buffer = (Buffer)pool.get(i); if (buffer.wasUsed()) { buffer.normalize(); } } } /** * Normalize a single buffer pool in log scale. * * @param pool the buffer pool to normalize */ private void logNormalizePool(Pool pool) { assert pool != null; for (int i = 0; i < pool.size(); i++) { Buffer buffer = (Buffer)pool.get(i); if (buffer.wasUsed()) { buffer.logNormalize(); } } } /** * Normalize a 2D buffer pool in log scale. Typically, this is the * case with the transition matrix, which also needs a mask for * values that are allowed, and therefor have to be updated, or * not allowed, and should be ignored. * * @param pool the buffer pool to normalize * @param maskPool pool containing a mask with zero/non-zero values. */ private void logNormalize2DPool(Pool pool, Pool maskPool) { assert pool != null; for (int i = 0; i < pool.size(); i++) { Buffer[] bufferArray = (Buffer []) pool.get(i); float[][] mask = (float[][]) maskPool.get(i); for (int j = 0; j < bufferArray.length; j++) { if (bufferArray[j].wasUsed()) { bufferArray[j].logNormalizeNonZero(mask[j]); } } } } /** * Update the models. */ protected void update() { updateMeans(); updateVariances(); recomputeMixtureComponents(); updateMixtureWeights(); updateTransitionMatrices(); } /** * Copy one vector onto another. * * @param in the source vector * @param out the destination vector */ private void copyVector(float[] in, float[] out) { assert in.length == out.length; for (int i = 0; i < in.length; i++) { out[i] = in[i]; } } /** * Update the means. */ private void updateMeans() { assert meansPool.size() == meansBufferPool.size(); for (int i = 0; i < meansPool.size(); i++) { float[] means = (float [])meansPool.get(i); Buffer buffer = (Buffer) meansBufferPool.get(i); if (buffer.wasUsed()) { float[] meansBuffer = buffer.getValues(); copyVector(meansBuffer, means); } else { logger.info("Senone " + i + " not used."); } } } /** * Update the variances. */ private void updateVariances() { assert variancePool.size() == varianceBufferPool.size(); for (int i = 0; i < variancePool.size(); i++) { float[] means = (float [])meansPool.get(i); float[] variance = (float [])variancePool.get(i); Buffer buffer = (Buffer) varianceBufferPool.get(i); if (buffer.wasUsed()) { float[] varianceBuffer = (float [])buffer.getValues(); assert means.length == varianceBuffer.length; for (int j = 0; j < means.length; j++) { varianceBuffer[j] -= means[j] * means[j]; if (varianceBuffer[j] < varianceFloor) { varianceBuffer[j] = varianceFloor; } } copyVector(varianceBuffer, variance); } } } /** * Recompute the precomputed values in all mixture components. */ private void recomputeMixtureComponents() { for (int i = 0; i < senonePool.size(); i++) { GaussianMixture gMix = (GaussianMixture) senonePool.get(i); MixtureComponent[] mixComponent = gMix.getMixtureComponents(); for (int j = 0; j < mixComponent.length; j++) { mixComponent[j].precomputeDistance(); } } } /** * Update the mixture weights. */ private void updateMixtureWeights() { assert mixtureWeightsPool.size() == mixtureWeightsBufferPool.size(); for (int i = 0; i < mixtureWeightsPool.size(); i++) { float[] mixtureWeights = (float [])mixtureWeightsPool.get(i); Buffer buffer = (Buffer) mixtureWeightsBufferPool.get(i); if (buffer.wasUsed()) { if (buffer.logFloor(logMixtureWeightFloor)) { buffer.logNormalizeToSum(logMath); } float[] mixtureWeightsBuffer = (float [])buffer.getValues(); copyVector(mixtureWeightsBuffer, mixtureWeights); } } } /** * Update the transition matrices. */ private void updateTransitionMatrices() { assert matrixPool.size() == matrixBufferPool.size(); for (int i = 0; i < matrixPool.size(); i++) { float[][] matrix = (float [][])matrixPool.get(i); Buffer[] bufferArray = (Buffer []) matrixBufferPool.get(i); for (int j = 0; j < matrix.length; j++) { Buffer buffer = bufferArray[j]; if (buffer.wasUsed()) { for (int k = 0; k < matrix[j].length; k++) { float bufferValue = buffer.getValue(k); if (bufferValue != LogMath.getLogZero()) { assert matrix[j][k] != LogMath.getLogZero(); if (bufferValue < logTransitionProbabilityFloor) { buffer.setValue(k, logTransitionProbabilityFloor); } } } buffer.logNormalizeToSum(logMath); copyVector(buffer.getValues(), matrix[j]); } } } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -