learnabletokenedaffine.java
来自「wekaUT是 university texas austin 开发的基于wek」· Java 代码 · 共 1,105 行 · 第 1/3 页
JAVA
1,105 行
if (stringProb == 0.0) { System.out.println("TROUBLE!!!! s1=" + ts1 + " s2=" + ts2); printMatrices(ts1,ts2); return 0; } m_endAtSubOccs += lambda; m_endAtGapOccs += 2*lambda; for (int i = 1; i < l1; i++) { for (int j = 1; j < l2; j++) { s1_i = s1[i-1]; s2_j = s2[j-1]; if (s1_i == s2_j) { subTokenLogProb = m_matchLogProb; } else { subTokenLogProb = m_nonMatchLogProb; } // substituting or matching occsSubst = Math.exp(fMatrix[i-1][j-1][0] + subTokenLogProb + m_subLogProb + bMatrix[i][j][0] - stringProb); if (s1_i == s2_j) { m_matchOccs += occsSubst; } else { m_nonMatchOccs += occsSubst; } m_subOccs += occsSubst; // starting a gap occsStartGap_1 = Math.exp(fMatrix[i-1][j][0] + m_gapTokenLogProb + m_gapStartLogProb + bMatrix[i][j][1] - stringProb); occsStartGap_2 = Math.exp(fMatrix[i][j-1][0] + m_gapTokenLogProb + m_gapStartLogProb + bMatrix[i][j][2] - stringProb); m_gapStartOccs += occsStartGap_1 + occsStartGap_2; // extendinuing a gap occsExtendGap_1 = Math.exp(fMatrix[i-1][j][1] + m_gapTokenLogProb + m_gapExtendLogProb + bMatrix[i][j][1] - stringProb); occsExtendGap_2 = Math.exp(fMatrix[i][j-1][2] + m_gapTokenLogProb + m_gapExtendLogProb + bMatrix[i][j][2] - stringProb); m_gapExtendOccs += occsExtendGap_1 + occsExtendGap_2; // ending a gap occsEndGap_1 = Math.exp(fMatrix[i-1][j-1][1] + subTokenLogProb + m_gapEndLogProb + bMatrix[i][j][0] - stringProb); if (s1_i == s2_j) { m_matchOccs += occsEndGap_1; // TODO - check!!! , also if's above and below } else { m_nonMatchOccs += occsEndGap_1; } occsEndGap_2 = Math.exp(fMatrix[i-1][j-1][2] + subTokenLogProb + m_gapEndLogProb + bMatrix[i][j][0] - stringProb); if (s1_i == s2_j) { m_matchOccs += occsEndGap_2; } else { m_nonMatchOccs += occsEndGap_2; } m_gapEndOccs += occsEndGap_1 + occsEndGap_2; } } // border rows. We can't end gap, and can start/extend gap only one way for (int i = 1; i < l1; i++) { s1_i = s1[i-1]; s2_j = s2[l2-1]; if (s1_i == s2_j) { subTokenLogProb = m_matchLogProb; } else { subTokenLogProb = m_nonMatchLogProb; } occsSubst = Math.exp(fMatrix[i-1][l2-1][0] + subTokenLogProb + m_subLogProb + bMatrix[i][l2][0] - stringProb); if (s1_i == s2_j) { m_matchOccs += occsSubst; } else { m_nonMatchOccs += occsSubst; } m_subOccs += occsSubst; occsStartGap_1 = Math.exp(fMatrix[i-1][l2][0] + m_gapTokenLogProb + m_gapStartLogProb + bMatrix[i][l2][1] - stringProb); m_gapStartOccs += occsStartGap_1; occsExtendGap_1 = Math.exp(fMatrix[i-1][l2][1] + m_gapTokenLogProb + m_gapExtendLogProb + bMatrix[i][l2][1] - stringProb); m_gapExtendOccs += occsExtendGap_1; // DO WE NEED THIS??? WE HAD NO CHOICE! } for (int j = 1; j < l2; j++) { s1_i = s1[l1-1]; s2_j = s2[j-1]; if (s1_i == s2_j) { subTokenLogProb = m_matchLogProb; } else { subTokenLogProb = m_nonMatchLogProb; } occsSubst = Math.exp(fMatrix[l1-1][j-1][0] + subTokenLogProb + m_subLogProb + bMatrix[l1][j][0] - stringProb); if (s1_i == s2_j) { m_matchOccs += occsSubst; } else { m_nonMatchOccs += occsSubst; } m_subOccs += occsSubst; occsStartGap_2 = Math.exp(fMatrix[l1][j-1][0] + m_gapTokenLogProb + m_gapStartLogProb + bMatrix[l1][j][2] - stringProb); m_gapStartOccs += occsStartGap_2; occsExtendGap_2 = Math.exp(fMatrix[l1][j-1][2] + m_gapTokenLogProb + m_gapExtendLogProb + bMatrix[l1][j][2] - stringProb); m_gapExtendOccs += occsExtendGap_2; // DO WE NEED THIS??? WE HAD NO CHOICE! } return stringProb; } /** * Maximization step of the EM algorithm */ protected void maximizationStep () { double N, N_s, N_id; // TODO: when trying to incorporate discriminative training, see EditDistance.java // in old codebase for an earlier attempt to deal with negative expectations // Sum up expectations for transitions in substitution state N = m_subOccs + 2*m_gapStartOccs + m_endAtSubOccs; m_subProb = m_subOccs / N; m_gapStartProb = m_gapStartOccs / N; m_endAtSubProb = m_endAtSubOccs / N; // Sum up expectations for occurrences in deletion/insertion states N = m_gapExtendOccs + m_gapEndOccs + m_endAtGapOccs; m_gapExtendProb = m_gapExtendOccs / N; m_gapEndProb = m_gapEndOccs / N; m_endAtGapProb = m_endAtGapOccs / N; // regularize if necessary if (m_subProb < m_clampProb) m_subProb = m_clampProb; if (m_gapStartProb < m_clampProb) m_gapStartProb = m_clampProb; if (m_endAtSubProb < m_clampProb) m_endAtSubProb = m_clampProb; if (m_gapExtendProb < m_clampProb) m_gapExtendProb = m_clampProb; if (m_gapEndProb < m_clampProb) m_gapEndProb = m_clampProb; if (m_endAtGapProb < m_clampProb) m_endAtGapProb = m_clampProb; if (m_endAtGapProb < m_clampProb) m_endAtGapProb = m_clampProb; m_matchProb = m_matchOccs / (m_matchOccs + m_nonMatchOccs); if (m_matchProb < m_clampProb) m_matchProb = m_clampProb; if (1.0 - m_matchProb < m_clampProb) m_matchProb = 1.0 - m_clampProb; normalizeTransitionProbs(); normalizeEmissionProbs(); updateLogProbs(); } /** * Normalize the probabilities of emission editops so that they sum to 1 * for each state */ protected void normalizeEmissionProbs() { int numTokens = m_stringTokenStringMap.size(); m_nonMatchProb = (1.0 - m_matchProb) / (numTokens * numTokens - numTokens); } /** * Normalize the probabilities of transitions so that they sum to 1 * for each state*/ protected void normalizeTransitionProbs() { // M-state double P = m_subProb + 2 * m_gapStartProb + m_endAtSubProb; m_subProb /= P; m_gapStartProb /= P; m_endAtSubProb /= P; // I/D states P = m_gapExtendProb + m_gapEndProb + m_endAtGapProb; m_gapExtendProb /= P; m_gapEndProb /= P; m_endAtGapProb /= P; } /** * reset the number of occurrences of all ops in the set */ protected void resetOccurrences () { m_matchOccs = 0; m_nonMatchOccs = 0; m_endAtSubOccs = 0; m_endAtGapOccs = 0; m_gapStartOccs = 0; m_gapExtendOccs = 0; m_gapEndOccs = 0; m_subOccs = 0; } /** * initialize the probabilities to some startup values */ protected void initProbs () { m_endAtSubProb = 0.05; m_endAtGapProb = 0.1; m_gapStartProb = 0.05; m_gapExtendProb = 0.5; m_gapEndProb = 0.4; m_subProb = 0.85; m_matchProb = 0.9; } /** * initialize the costs using current values of the probabilities */ protected void initCosts () { m_gapStartCost = -m_gapStartLogProb; m_gapExtendCost = -m_gapExtendLogProb; m_endAtSubCost = -m_endAtSubLogProb; m_endAtGapCost = -m_endAtGapLogProb; m_gapEndCost = -m_gapEndLogProb; m_subCost = -m_subLogProb; m_matchCost = -m_matchLogProb; m_nonMatchCost = -m_nonMatchLogProb; if (m_verbose) { System.out.println("\nScaled by extend cost:\nGapStrt=" + (m_gapStartCost/m_gapExtendCost) + "\tGapExt=" + (m_gapExtendCost/m_gapExtendCost) + "\tGapEnd=" + (m_gapEndCost/m_gapExtendCost) + "\tSub=" + (m_subCost/m_gapExtendCost) + "\tNoop=" + (m_matchCost/m_gapExtendCost)); System.out.println("\nActual costs:\nGapStrt=" + (m_gapStartCost) + "\tGapExt=" + (m_gapExtendCost) + "\tGapEnd=" + (m_gapEndCost) + "\tSub=" + (m_subCost) + "\tNoop=" + (m_matchCost)); } } /** * store logs of all probabilities in m_editopLogProbs */ protected void updateLogProbs() { m_matchLogProb = Math.log(m_matchProb); m_nonMatchLogProb = Math.log(m_nonMatchProb); m_gapTokenLogProb = Math.log(m_gapTokenProb); m_endAtSubLogProb = Math.log(m_endAtSubProb); m_endAtGapLogProb = Math.log(m_endAtGapProb); m_gapStartLogProb = Math.log(m_gapStartProb); m_gapExtendLogProb = Math.log(m_gapExtendProb); m_gapEndLogProb = Math.log(m_gapEndProb); m_subLogProb = Math.log(m_subProb); DecimalFormat fmt = new DecimalFormat ("0.0000"); if (m_verbose) { System.out.println("After update:\tNOOP=" + fmt.format(m_matchProb) + "=" + fmt.format(m_matchLogProb) + "\tSUB=" + fmt.format(m_subProb) + "=" + fmt.format(m_subLogProb) + "\n\t\tGAPst=" + fmt.format(m_gapStartProb) + "=" + fmt.format(m_gapStartLogProb) + "\tGAPcont=" + fmt.format(m_gapExtendProb) + "=" + fmt.format(m_gapExtendLogProb) + "\tGAPend=" + fmt.format(m_gapEndProb) + "=" + fmt.format(m_gapEndLogProb) + "\n\t\tendAtGap=" + fmt.format(m_endAtGapProb) + "=" + fmt.format(m_endAtGapLogProb) + "\tendAtSub=" + fmt.format(m_endAtSubProb) + "=" + fmt.format(m_endAtSubLogProb)); } } /** * Get the distance between two strings * @param s1 first string * @param s2 second string * @return a value of this distance between these two strings */ public double distance (String s1, String s2) { if (m_useGenerativeModel) { // retrieve the tokenstring's TokenString ts1; if (m_stringTokenStringMap.containsKey(s1)) { ts1 = ((TokenString)m_stringTokenStringMap.get(s1)); } else { ts1 = m_tokenizer.getTokenString(s1); m_stringTokenStringMap.put(s1, ts1); } TokenString ts2; if (m_stringTokenStringMap.containsKey(s2)) { ts2 = ((TokenString)m_stringTokenStringMap.get(s2)); } else { ts2 = m_tokenizer.getTokenString(s2); m_stringTokenStringMap.put(s2, ts2); } double d = backward(ts1,ts2)[0][0][0]; if (m_normalized) { // for (int i = 0; i < (s1.length() + s2.length()); i++) // TODO: fix the posteriors; don't care for now - we always use the additive model // d -= m_noopLogProb + m_subLogProb; // for (int i = 0; i < s1.length(); i++) { // d -= m_editopLogProbs[blank][s1.charAt(i)]; // } // for (int i = 0; i < s2.length(); i++) { // d -= m_editopLogProbs[blank][s2.charAt(i)]; // } } return -d; } else { return costDistance(s1, s2); } } /** Method: recordCosts Record probability matrix for further MatLab use */ void recordCosts(int id) { try { FileOutputStream fstr = new FileOutputStream ("/tmp/probs/ProbAffineCosts.txt", true); DataOutputStream out = new DataOutputStream (fstr); char s, t; DecimalFormat fmt = new DecimalFormat ("0.00"); out.close(); } catch (Exception x) {} } static String MatrixToString (double matrix[][]) { DecimalFormat fmt = new DecimalFormat ("0.00"); String s = ""; for (int i = 0; i < matrix.length; i++) { for (int j = 0; j < matrix[0].length; j++) s = s + fmt.format(matrix[i][j]) + " "; s = s + "\n"; } return s; } static String intMatrixToString (int matrix[][]) { String s = ""; for (int i = 0; i < matrix.length; i++) { for (int j = 0; j < matrix[0].length; j++) s = s + matrix[i][j] + " "; s = s + "\n"; } return s; } static String doubleMatrixToString (double matrix[][]) { String s = ""; java.text.DecimalFormat de = new java.text.DecimalFormat("0.0E000"); for (int i = 0; i < matrix.length; i++) { for (int j = 0; j < matrix[0].length; j++) s = s + de.format(matrix[i][j]) + " "; s = s + "\n"; } return s; } static String doubleMatrixToString0 (double matrix[][][], int k) { String s; s = ""; java.text.DecimalFormat de = new java.text.DecimalFormat("0.0E000"); for (int i = 0; i < matrix.length; i++) { for (int j = 0; j < matrix[0].length; j++) s = s + de.format(matrix[i][j][k]) + " "; s = s + "\n"; } return s; } static String charMatrixToString (char matrix[][]) { String s = ""; for (int i = 0; i < matrix.length; i++) { for (int j = 0; j < matrix[0].length; j++) s = s + matrix[i][j] + " "; s = s + "\n"; } return s; } /** Calculation of log(a+b) with a correction for machine precision * @param _a number log(a)
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?