profilehmmparameterestimator.java
来自「马尔科夫模型的c语言实现」· Java 代码 · 共 580 行
JAVA
580 行
import java.io.*;import java.util.*;class ProfileHMMParameterEstimator{ /* this class implements the setting of the parameters for a profile hmm * built by modhmm using the profile7 architecture */ final double MATCH_PRIOR_SCALER = 10.0; final double INSERT_PRIOR_SCALER = 1000.0; LinkedList priors; public ProfileHMMParameterEstimator() { } public void setProfileHMMParameters(ModelMaker modelmaker, MSA msa, String[] prifiles, LinkedList alphabets, String weightingScheme, String alignmentScheme) { readPrifiles(prifiles); boolean done = false; while(!done) { setEmissionParameters(modelmaker, msa, alphabets); if(alignmentScheme.startsWith("G")) { setTransitionParametersGlobal(modelmaker, msa); } else { //setTransitionParametersLocal(modelmaker, msa); } done = adjustSequenceWeights(); } } public void dumpPriors() { for(int i = 0; i < priors.size();i++) { DirichletComponent[] dc = ((DirichletComponent[])priors.get(i)); for(int j = 0; j < dc.length; j++) { dc[j].dump(); } } } private void setEmissionParameters(ModelMaker modelmaker, MSA msa, LinkedList alphabets) { /* set match emission parameters by traversing the msa */ int state = 4; for(int i = 0; i < msa.matchColumns.length; i++) { if(msa.matchColumns[i]) { Vertex v = modelmaker.getVertex(state); for(int j = 0; j < alphabets.size(); j++) { String[] alphabet = ((String[])alphabets.get(j)); double[] counts = new double[alphabet.length]; double[] probs; double totCounts = 0.0; for(int k = 0; k < counts.length; k++) { counts[k] = 0.0; } for(int k = 0; k < msa.nrRows; k++) { int alphaIndex = getAlphabetIndex(msa.theMsa[k][i].getLetter(j+1), alphabet); if(alphaIndex >= 0) { counts[alphaIndex] += 1.0 * msa.sequenceWeights[k]; totCounts += 1.0; } } probs = getPriorizedProbs(counts, alphabet, ((DirichletComponent[])priors.get(j)), totCounts, MATCH_PRIOR_SCALER); v.setInitialEmissionProbs(j+1, probs); } state += 3; } } /* set insert emission parameters by traversing the msa */ state = 2; LinkedList allCounts = new LinkedList(); double[] allTotCounts = new double[alphabets.size()]; for(int i = 0; i < alphabets.size(); i++) { String [] alphabet = ((String[])alphabets.get(i)); double[] counts = new double[alphabet.length]; for(int j = 0; j < counts.length; j++) { counts[j] = 0.0; } allCounts.add(counts); allTotCounts[i] = 0.0; } for(int i = 0; i <= msa.matchColumns.length; i++) { if(i == msa.matchColumns.length) { state = modelmaker.getNrVertices() - 2; Vertex v = modelmaker.getVertex(state); for(int j = 0; j < alphabets.size(); j++) { double[] counts = ((double[])allCounts.get(j)); double[] probs; double totCounts = allTotCounts[j]; probs = getPriorizedProbs(counts, ((String[])alphabets.get(j)), ((DirichletComponent[])priors.get(j)), totCounts, INSERT_PRIOR_SCALER); v.setInitialEmissionProbs(j+1, probs); for(int k = 0; k < counts.length; k++) { counts[k] = 0.0; } allTotCounts[j] = 0.0; } } else if(msa.matchColumns[i]) { Vertex v = modelmaker.getVertex(state); for(int j = 0; j < alphabets.size(); j++) { double[] counts = ((double[])allCounts.get(j)); double[] probs; double totCounts = allTotCounts[j]; probs = getPriorizedProbs(counts, ((String[])alphabets.get(j)), ((DirichletComponent[])priors.get(j)), totCounts, INSERT_PRIOR_SCALER); v.setInitialEmissionProbs(j+1, probs); for(int k = 0; k < counts.length; k++) { counts[k] = 0.0; } allTotCounts[j] = 0.0; } state += 3; } else { for(int j = 0; j < alphabets.size(); j++) { double[] counts = ((double[])allCounts.get(j)); double totCounts = allTotCounts[j]; for(int k = 0; k < msa.nrRows; k++) { int alphaIndex = getAlphabetIndex(msa.theMsa[k][i].getLetter(j+1), ((String[])alphabets.get(j))); if(alphaIndex >= 0) { counts[alphaIndex] += 1.0 * msa.sequenceWeights[k]; totCounts += 1.0; } } } } } } private void setTransitionParametersGlobal(ModelMaker modelmaker, MSA msa) { /* count transitions per sequence by walking it through the hmm */ double[][] transitionCounts = new double[modelmaker.getNrVertices()][modelmaker.getNrVertices()]; for(int row = 0; row < modelmaker.getNrVertices(); row++) { for(int col = 0; col < modelmaker.getNrVertices(); col++) { if(transitionExists(modelmaker,row,col)) { transitionCounts[row][col] = 10.0; //P.MESSAGE("exists: " + row + " to " + col); } else { transitionCounts[row][col] = 0.0; } } } for(int row = 0; row < msa.nrRows; row++) { int col = 0; int match = 4; int insert = 5; int delete = 3; char curState = 'S'; while(true) { if(msa.matchColumns[col]) { if(isAcid(msa.theMsa[row][col].letter_1)) { transitionCounts[1][match] += msa.sequenceWeights[row] * 1.0; curState = 'M'; break; } else { transitionCounts[1][delete] += msa.sequenceWeights[row] * 1.0; curState = 'D'; } } else { if(isAcid(msa.theMsa[row][col].letter_1)) { transitionCounts[1][2] += msa.sequenceWeights[row] * 1.0; curState = 'I'; break; } else { } } col++; } col++; if(curState == 'I') { while(true) { if(msa.matchColumns[col]) { if(isAcid(msa.theMsa[row][col].letter_1)) { transitionCounts[2][1] += msa.sequenceWeights[row] * 1.0; transitionCounts[1][match] += msa.sequenceWeights[row] * 1.0; curState = 'M'; break; } else { transitionCounts[2][1] += msa.sequenceWeights[row] * 1.0; transitionCounts[1][delete] += msa.sequenceWeights[row] * 1.0; curState = 'D'; break; } } else { if(isAcid(msa.theMsa[row][col].letter_1)) { transitionCounts[2][2] += msa.sequenceWeights[row] * 1.0; } else { } } col++; } col++; } while(col < msa.nrColumns) { if(curState == 'M') { if(msa.matchColumns[col]) { if(isAcid(msa.theMsa[row][col].letter_1)) { transitionCounts[match][match+3] += msa.sequenceWeights[row] * 1.0; curState = 'M'; } else { transitionCounts[match][delete + 3] += msa.sequenceWeights[row] * 1.0; curState = 'D'; } match += 3; insert += 3; delete += 3; if(match == modelmaker.getNrVertices() - 5) { insert += 1; break; } } else { if(isAcid(msa.theMsa[row][col].letter_1)) { transitionCounts[match][insert] += msa.sequenceWeights[row] * 1.0; curState = 'I'; } else { } } } else if(curState == 'I') { if(msa.matchColumns[col]) { if(isAcid(msa.theMsa[row][col].letter_1)) { transitionCounts[insert][match+3] += msa.sequenceWeights[row] * 1.0; curState = 'M'; } else { curState = 'D'; } match += 3; insert += 3; delete += 3; if(match == modelmaker.getNrVertices() - 5) { insert += 1; break; } } else { if(isAcid(msa.theMsa[row][col].letter_1)) { transitionCounts[insert][insert] += msa.sequenceWeights[row] * 1.0; curState = 'I'; } else { } } } else if(curState == 'D') { if(msa.matchColumns[col]) { if(isAcid(msa.theMsa[row][col].letter_1)) { transitionCounts[delete][match+3] += msa.sequenceWeights[row] * 1.0; curState = 'M'; } else { transitionCounts[delete][delete+3] += msa.sequenceWeights[row] * 1.0; curState = 'D'; } match += 3; insert += 3; delete += 3; if(match == modelmaker.getNrVertices() - 5) { insert += 1; break; } } else { if(isAcid(msa.theMsa[row][col].letter_1)) { curState = 'I'; } else { } } } col++; } col++; transitionCounts[delete+2][delete+4] += msa.sequenceWeights[row] * 1.0; boolean first = true; while(col < msa.nrColumns) { if(msa.matchColumns[col]) { P.MESSAGE("Error, matchcolumn outside scope of hmm"); System.exit(0); } else { if(first) { transitionCounts[delete+2][insert] += msa.sequenceWeights[row] * 1.0; transitionCounts[insert][delete +2] += msa.sequenceWeights[row] * 1.0; first = false; } transitionCounts[insert][insert] += msa.sequenceWeights[row] * 1.0; } col++; } } updateTransitionProbs(modelmaker, transitionCounts, modelmaker.getNrVertices(), modelmaker.getNrVertices()); } private void updateTransitionProbs(ModelMaker modelmaker, double [][] transitionCounts, int nrRows, int nrCols) { /* set start with match, insert or delete probs */ Vertex v = modelmaker.getVertex(1); v.setTransitionProbability(2, 0.99); double rowSum = transitionCounts[1][3] + transitionCounts[1][4]; v.setTransitionProbability(3, transitionCounts[1][3] * 0.01 / rowSum); v.setTransitionProbability(4, transitionCounts[1][4] * 0.01 / rowSum); /* set N-terminal loop length prob for first insert loop */ v = modelmaker.getVertex(2); v.setTransitionProbability(2, 0.99); v.setTransitionProbability(1, 0.01); /* set model inhouse probs using statistics from the alignment */ for(int row = 3; row < nrRows - 4; row++) { rowSum = 0.0; v = modelmaker.getVertex(row); for(int col = 0; col < nrCols; col++) { rowSum += transitionCounts[row][col]; } for(int col = 0; col < nrCols; col++) { if(transitionCounts[row][col] > 0.0) { v.setTransitionProbability(col, transitionCounts[row][col]/rowSum); } } } /* set C-terminal loop length prob for first insert loop */ v = modelmaker.getVertex(nrRows - 3); v.setTransitionProbability(nrRows - 3, 0.99); v.setTransitionProbability(nrRows - 4, 0.01); /* set end with insert or not probs */ v = modelmaker.getVertex(nrRows - 4); rowSum = transitionCounts[1][3] + transitionCounts[1][4]; v.setTransitionProbability(nrRows - 3, 0.99); v.setTransitionProbability(nrRows - 2, 0.01); } private double[] getPriorizedProbs(double[] counts, String[] alphabet, DirichletComponent[] prior, double totCounts, double priorScaler) { double[] probs = new double[alphabet.length]; double[] logbetaAnValues = new double[prior.length]; double scalingFactor = -100000000000.0; double X_sum = 0.0; double[] X_values = new double[alphabet.length]; for(int comp = 0; comp < prior.length; comp++) { double ed_res_1 = 0.0; double countSums = 0.0; for(int a_index = 0; a_index < alphabet.length; a_index++) { ed_res_1 += BetaGamma.logGamma(prior[comp].prior_values[a_index] + counts[a_index]); countSums += counts[a_index]; } ed_res_1 -= BetaGamma.logGamma(prior[comp].alpha_sum + countSums); logbetaAnValues[comp] = ed_res_1; if(ed_res_1 - prior[comp].logbeta_value > scalingFactor) { scalingFactor = ed_res_1 - prior[comp].logbeta_value; } } for(int a_index = 0; a_index < alphabet.length; a_index++) { X_values[a_index] = 0; for(int comp = 0; comp < prior.length; comp++) { double q_value = prior[comp].q_value; double exponent = logbetaAnValues[comp] - prior[comp].logbeta_value - scalingFactor; double priorprob = prior[comp].prior_values[a_index] * priorScaler + counts[a_index]; double totPriorProb = prior[comp].alpha_sum + totCounts; X_values[a_index] += q_value * Math.exp(exponent) * priorprob / totPriorProb; } X_sum += X_values[a_index]; } for(int a_index = 0; a_index < alphabet.length; a_index++) { probs[a_index] = X_values[a_index] / X_sum; } return probs; } private void readPrifiles(String[] prifiles) { priors = new LinkedList(); for(int i = 0; i < prifiles.length; i++) { try { BufferedReader br = new BufferedReader(new FileReader(prifiles[i])); String prirow = br.readLine(); while(true) { if(prirow.equals("") || prirow.startsWith("#")) { } else { break; } prirow = br.readLine(); } int nrComponents = Integer.parseInt(prirow); DirichletComponent[] components = new DirichletComponent[nrComponents]; priors.add(components); for(int j = 0; j < nrComponents; j++) { prirow = br.readLine(); while(true) { if(prirow.equals("") || prirow.startsWith("#")) { } else { break; } prirow = br.readLine(); } double q_value = Double.parseDouble(prirow); prirow = br.readLine(); while(true) { if(prirow.equals("") || prirow.startsWith("#")) { } else { break; } prirow = br.readLine(); } StringTokenizer st = new StringTokenizer(prirow, "\t "); double[] priorValues = new double[st.countTokens()]; double alpha_sum = 0.0; int k = 0; while(st.hasMoreTokens()) { String s = st.nextToken(); double alpha_value = Double.parseDouble(s); alpha_sum += alpha_value; priorValues[k] = alpha_value; k++; } double logbeta = 0.0; for(k = 0; k < priorValues.length; k++) { logbeta += BetaGamma.logGamma(priorValues[k]); } logbeta -= BetaGamma.logGamma(alpha_sum); DirichletComponent dc = new DirichletComponent("", q_value, priorValues, logbeta, alpha_sum); components[j] = dc; } br.close(); } catch(IOException e) { System.out.println("Could not read prifile"); System.exit(0); } } } private int getAlphabetIndex(String letter, String[] alphabet) { for(int i = 0; i < alphabet.length; i++) { if(letter.equals(alphabet[i])) { return i; } } return -1; } private boolean adjustSequenceWeights() { /* set sequence weights according to model likelihood maximization */ return true; } private boolean isAcid(String s) { if(s.equals(" ") || s.equals("_") || s.equals("-") || s.equals(".")) { return false; } return true; } private boolean transitionExists(ModelMaker modelmaker, int row, int col) { if(modelmaker.transitionExists(row,col)) { return true; } else { return false; } } class DirichletComponent { public String name; public double q_value; public double alpha_sum; public double logbeta_value; public double[] prior_values; public DirichletComponent(String n, double q_v, double[]p_v, double lb_v, double a_s) { name = n; q_value = q_v; alpha_sum = a_s; logbeta_value = lb_v; prior_values = p_v; } public void dump() { P.MESSAGE("name = " + name); P.MESSAGE("q value = " + q_value); P.MESSAGE("alpha sum = " + alpha_sum); P.MESSAGE("logbeta value = " + logbeta_value); System.out.print("prior values: "); for(int i = 0; i < prior_values.length; i++) { System.out.print(prior_values[i] + " "); } P.MESSAGE(""); P.MESSAGE(""); } }}
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?