profilehmmparameterestimator.java

来自「马尔科夫模型的c语言实现」· Java 代码 · 共 580 行

JAVA
580
字号
import java.io.*;import java.util.*;class ProfileHMMParameterEstimator{    /* this class implements the setting of the parameters for a profile hmm     * built by modhmm using the profile7 architecture */    final double MATCH_PRIOR_SCALER = 10.0;    final double INSERT_PRIOR_SCALER = 1000.0;    LinkedList priors;    public ProfileHMMParameterEstimator()    {	    }        public void setProfileHMMParameters(ModelMaker modelmaker, MSA msa, String[] prifiles, LinkedList alphabets,					String weightingScheme, String alignmentScheme)    {	readPrifiles(prifiles);	boolean done = false;	while(!done) {	    setEmissionParameters(modelmaker, msa, alphabets);	    if(alignmentScheme.startsWith("G")) {		setTransitionParametersGlobal(modelmaker, msa);	    }	    else {		//setTransitionParametersLocal(modelmaker, msa);	    }	    done = adjustSequenceWeights();	}    }    public void dumpPriors()    {	for(int i = 0; i < priors.size();i++) {	    DirichletComponent[] dc = ((DirichletComponent[])priors.get(i));	    for(int j = 0; j < dc.length; j++) {		dc[j].dump();	    }	}    }    private void setEmissionParameters(ModelMaker modelmaker, MSA msa, LinkedList alphabets)    {	/* set match emission parameters by traversing the msa */	int state = 4;	for(int i = 0; i < msa.matchColumns.length; i++) {	    if(msa.matchColumns[i]) {		Vertex v = modelmaker.getVertex(state);		for(int j = 0; j < alphabets.size(); j++) {		    String[] alphabet = ((String[])alphabets.get(j));		    double[] counts = new double[alphabet.length];		    double[] probs;		    double totCounts = 0.0;		    for(int k = 0; k < counts.length; k++) {			counts[k] = 0.0;		    }		    for(int k = 0; k < msa.nrRows; k++) {			int alphaIndex = getAlphabetIndex(msa.theMsa[k][i].getLetter(j+1), alphabet);			if(alphaIndex >= 0) {			    counts[alphaIndex] += 1.0 * msa.sequenceWeights[k];			    totCounts += 1.0;			}		    }		    probs = getPriorizedProbs(counts, alphabet, ((DirichletComponent[])priors.get(j)), totCounts, 					      MATCH_PRIOR_SCALER);		    v.setInitialEmissionProbs(j+1, probs); 		    		}				state += 3;	    }	}		/* set insert emission parameters by traversing the msa */	state = 2;		LinkedList allCounts = new LinkedList();	double[] allTotCounts = new double[alphabets.size()];		for(int i = 0; i < alphabets.size(); i++) {	    String [] alphabet = ((String[])alphabets.get(i));	    double[] counts = new double[alphabet.length];	    for(int j = 0; j < counts.length; j++) {		counts[j] = 0.0;	    }	    allCounts.add(counts);	    allTotCounts[i] = 0.0;	}		for(int i = 0; i <= msa.matchColumns.length; i++) {	    if(i == msa.matchColumns.length) {		state = modelmaker.getNrVertices() - 2;		Vertex v = modelmaker.getVertex(state);		for(int j = 0; j < alphabets.size(); j++) {		    double[] counts = ((double[])allCounts.get(j));		    double[] probs;		    double totCounts = allTotCounts[j];		    		    probs = getPriorizedProbs(counts, ((String[])alphabets.get(j)), ((DirichletComponent[])priors.get(j)),					      totCounts, INSERT_PRIOR_SCALER);		    v.setInitialEmissionProbs(j+1, probs); 		    for(int k = 0; k < counts.length; k++) {			counts[k] = 0.0;		    }		    allTotCounts[j] = 0.0;		}	    }	    else if(msa.matchColumns[i]) {		Vertex v = modelmaker.getVertex(state);		for(int j = 0; j < alphabets.size(); j++) {		    double[] counts = ((double[])allCounts.get(j));		    double[] probs;		    double totCounts = allTotCounts[j];		    		    probs = getPriorizedProbs(counts, ((String[])alphabets.get(j)), ((DirichletComponent[])priors.get(j)),					      totCounts, INSERT_PRIOR_SCALER);		    v.setInitialEmissionProbs(j+1, probs); 		    for(int k = 0; k < counts.length; k++) {			counts[k] = 0.0;		    }		    allTotCounts[j] = 0.0;		}				state += 3;	    }	    else {		for(int j = 0; j < alphabets.size(); j++) {		    double[] counts = ((double[])allCounts.get(j));		    double totCounts = allTotCounts[j];		    for(int k = 0; k < msa.nrRows; k++) {			int alphaIndex = getAlphabetIndex(msa.theMsa[k][i].getLetter(j+1), ((String[])alphabets.get(j)));			if(alphaIndex >= 0) {			    counts[alphaIndex] += 1.0 * msa.sequenceWeights[k];			    totCounts += 1.0;			}		    }		}	    }	}    }        private void setTransitionParametersGlobal(ModelMaker modelmaker, MSA msa)    {	/* count transitions per sequence by walking it through the hmm */	double[][] transitionCounts = new double[modelmaker.getNrVertices()][modelmaker.getNrVertices()];	for(int row = 0; row < modelmaker.getNrVertices(); row++) {	    for(int col = 0; col < modelmaker.getNrVertices(); col++) {		if(transitionExists(modelmaker,row,col)) {		    transitionCounts[row][col] = 10.0;		    //P.MESSAGE("exists: " + row + " to " + col); 		}		else {		    transitionCounts[row][col] = 0.0;		}	    }	}	for(int row = 0; row < msa.nrRows; row++) {	    int col = 0;	    int match = 4;	    int insert = 5;	    int delete = 3;	    char curState = 'S';	    while(true) {		if(msa.matchColumns[col]) {		    if(isAcid(msa.theMsa[row][col].letter_1)) {			transitionCounts[1][match] += msa.sequenceWeights[row] * 1.0;			curState = 'M';			break;		    }		    else {			transitionCounts[1][delete] += msa.sequenceWeights[row] * 1.0;			curState = 'D';		    }		}		else {		    if(isAcid(msa.theMsa[row][col].letter_1)) {			transitionCounts[1][2] += msa.sequenceWeights[row] * 1.0;			curState = 'I';			break;		    }		    else {					    }		}		col++;	    }	    col++;	    if(curState == 'I') {		while(true) {		    if(msa.matchColumns[col]) {			if(isAcid(msa.theMsa[row][col].letter_1)) {			    transitionCounts[2][1] += msa.sequenceWeights[row] * 1.0;			    transitionCounts[1][match] += msa.sequenceWeights[row] * 1.0;			    curState = 'M';			    break;			}			else {			    transitionCounts[2][1] += msa.sequenceWeights[row] * 1.0;			    transitionCounts[1][delete] += msa.sequenceWeights[row] * 1.0;			    curState = 'D';			    break;			}		    }		    else {			if(isAcid(msa.theMsa[row][col].letter_1)) {			    transitionCounts[2][2] += msa.sequenceWeights[row] * 1.0;			}			else {			    			}		    }		    col++;		}		col++;	    }	    	    while(col < msa.nrColumns) {		if(curState == 'M') {		    if(msa.matchColumns[col]) {			if(isAcid(msa.theMsa[row][col].letter_1)) {			    transitionCounts[match][match+3] += msa.sequenceWeights[row] * 1.0;			    curState = 'M';			}			else {			    transitionCounts[match][delete + 3] += msa.sequenceWeights[row] * 1.0;			    curState = 'D';			}			match += 3;			insert += 3;			delete += 3;			if(match == modelmaker.getNrVertices() - 5) {			    insert += 1;			    break;			}		    }		    else {			if(isAcid(msa.theMsa[row][col].letter_1)) {			    transitionCounts[match][insert] += msa.sequenceWeights[row] * 1.0;			    curState = 'I';			}			else {			    			}		    }		}				else if(curState == 'I') {		    if(msa.matchColumns[col]) {			if(isAcid(msa.theMsa[row][col].letter_1)) {			    transitionCounts[insert][match+3] += msa.sequenceWeights[row] * 1.0;			    curState = 'M';			}			else {			    curState = 'D';			}			match += 3;			insert += 3;			delete += 3;			if(match == modelmaker.getNrVertices() - 5) {			    insert += 1;			    break;			}					    }		    else {			if(isAcid(msa.theMsa[row][col].letter_1)) {			    transitionCounts[insert][insert] += msa.sequenceWeights[row] * 1.0;			    curState = 'I';			}			else {			    			}		    }		}				else if(curState == 'D') {		    if(msa.matchColumns[col]) {			if(isAcid(msa.theMsa[row][col].letter_1)) {			    transitionCounts[delete][match+3] += msa.sequenceWeights[row] * 1.0;			    curState = 'M';			}			else {			    transitionCounts[delete][delete+3] += msa.sequenceWeights[row] * 1.0;			    curState = 'D';			}			match += 3;			insert += 3;			delete += 3;			if(match == modelmaker.getNrVertices() - 5) {			    			    insert += 1;			    break;			}		    }		    else {			if(isAcid(msa.theMsa[row][col].letter_1)) {			    curState = 'I';			}			else {			    			}		    }		}		col++;	    }	    col++;	    transitionCounts[delete+2][delete+4] += msa.sequenceWeights[row] * 1.0;	    boolean first = true;	    while(col < msa.nrColumns) {		if(msa.matchColumns[col]) {		    P.MESSAGE("Error, matchcolumn outside scope of hmm");		    System.exit(0);		}		else {		    if(first) {			transitionCounts[delete+2][insert] += msa.sequenceWeights[row] * 1.0;			transitionCounts[insert][delete +2] += msa.sequenceWeights[row] * 1.0;			first = false;		    }		    transitionCounts[insert][insert] += msa.sequenceWeights[row] * 1.0;		}		col++;	    }	   	}       	updateTransitionProbs(modelmaker, transitionCounts, modelmaker.getNrVertices(), modelmaker.getNrVertices());		    }        private void updateTransitionProbs(ModelMaker modelmaker, double [][] transitionCounts, int nrRows, int nrCols)    {	/* set start with match, insert or delete probs */	Vertex v = modelmaker.getVertex(1);	v.setTransitionProbability(2, 0.99);	double rowSum = transitionCounts[1][3] + transitionCounts[1][4];	v.setTransitionProbability(3, transitionCounts[1][3] * 0.01 / rowSum);	v.setTransitionProbability(4, transitionCounts[1][4] * 0.01 / rowSum);	/* set N-terminal loop length prob for first insert loop */	v = modelmaker.getVertex(2);	v.setTransitionProbability(2, 0.99);	v.setTransitionProbability(1, 0.01);			/* set model inhouse probs using statistics from the alignment */	for(int row = 3; row < nrRows - 4; row++) {	    rowSum = 0.0;	    v = modelmaker.getVertex(row);	    for(int col = 0; col < nrCols; col++) {		rowSum += transitionCounts[row][col];	    }	    for(int col = 0; col < nrCols; col++) {		if(transitionCounts[row][col] > 0.0) {		    v.setTransitionProbability(col, transitionCounts[row][col]/rowSum);		}	    }	}	/* set C-terminal loop length prob for first insert loop */	v = modelmaker.getVertex(nrRows - 3);	v.setTransitionProbability(nrRows - 3, 0.99);	v.setTransitionProbability(nrRows - 4, 0.01);    	/* set end with insert or not probs */	v = modelmaker.getVertex(nrRows - 4);	rowSum = transitionCounts[1][3] + transitionCounts[1][4];	v.setTransitionProbability(nrRows - 3, 0.99);	v.setTransitionProbability(nrRows - 2, 0.01);    }    private double[] getPriorizedProbs(double[] counts, String[] alphabet, DirichletComponent[] prior, double totCounts,				       double priorScaler)    {	double[] probs = new double[alphabet.length];	double[] logbetaAnValues = new double[prior.length];	double scalingFactor = -100000000000.0;	double X_sum = 0.0;	double[] X_values = new double[alphabet.length];		for(int comp = 0; comp < prior.length; comp++) {	    double ed_res_1 = 0.0;	    double countSums = 0.0;	    for(int a_index = 0; a_index < alphabet.length; a_index++) {		ed_res_1 += BetaGamma.logGamma(prior[comp].prior_values[a_index] + counts[a_index]);		countSums += counts[a_index];	    }	    ed_res_1 -= BetaGamma.logGamma(prior[comp].alpha_sum + countSums);	    logbetaAnValues[comp] = ed_res_1;	    if(ed_res_1 - prior[comp].logbeta_value > scalingFactor) {		scalingFactor = ed_res_1 - prior[comp].logbeta_value;	    }	}		for(int a_index = 0; a_index < alphabet.length; a_index++) {	    X_values[a_index] = 0;	    for(int comp = 0; comp < prior.length; comp++) {		double q_value = prior[comp].q_value;		double exponent = logbetaAnValues[comp] - prior[comp].logbeta_value - scalingFactor;		double priorprob = prior[comp].prior_values[a_index] * priorScaler + counts[a_index];		double totPriorProb = prior[comp].alpha_sum + totCounts;		X_values[a_index] += q_value * Math.exp(exponent) * priorprob / totPriorProb;			    }	    X_sum += X_values[a_index];	}	for(int a_index = 0; a_index < alphabet.length; a_index++) {	    probs[a_index] = X_values[a_index] / X_sum;	}	return probs;    }    private void readPrifiles(String[] prifiles)    {	priors = new LinkedList(); 	for(int i = 0; i < prifiles.length; i++) {	    try {		BufferedReader br = new BufferedReader(new FileReader(prifiles[i]));		String prirow = br.readLine();		while(true) {		    if(prirow.equals("") || prirow.startsWith("#")) {					    }		    else {			break;		    }		    prirow = br.readLine();		}				int nrComponents = Integer.parseInt(prirow);		DirichletComponent[] components = new DirichletComponent[nrComponents];		priors.add(components);				for(int j = 0; j < nrComponents; j++) {		    prirow = br.readLine();		    while(true) {			if(prirow.equals("") || prirow.startsWith("#")) {			    			}			else {			    break;			}			prirow = br.readLine();		    }		    double q_value = Double.parseDouble(prirow);		    		    prirow = br.readLine();		    while(true) {			if(prirow.equals("") || prirow.startsWith("#")) {			    			}			else {			    break;			}			prirow = br.readLine();		    }		    StringTokenizer st = new StringTokenizer(prirow, "\t ");		    double[] priorValues = new double[st.countTokens()];		    double alpha_sum = 0.0;		    int k = 0;		    while(st.hasMoreTokens()) {			String s = st.nextToken();			double alpha_value = Double.parseDouble(s);			alpha_sum += alpha_value;			priorValues[k] = alpha_value;			k++;		    }		    		    double logbeta = 0.0;		    for(k = 0; k < priorValues.length; k++) {			logbeta += BetaGamma.logGamma(priorValues[k]);		    }		    logbeta -= BetaGamma.logGamma(alpha_sum);		    DirichletComponent dc = new DirichletComponent("", q_value, priorValues, logbeta, alpha_sum);		    components[j] = dc;		}							br.close();			    }	    catch(IOException e) {		System.out.println("Could not read prifile");		System.exit(0);	    }	}    }    private int getAlphabetIndex(String letter, String[] alphabet)    {		for(int i = 0; i < alphabet.length; i++) {	    if(letter.equals(alphabet[i])) {		return i;	    }	}	return -1;    }    private boolean adjustSequenceWeights()    {	/* set sequence weights according to model likelihood maximization */			return true;    }    private boolean isAcid(String s)    {	if(s.equals(" ") || s.equals("_") || s.equals("-") || s.equals(".")) {	    return false;	}	return true;    }    private boolean transitionExists(ModelMaker modelmaker, int row, int col)    {	if(modelmaker.transitionExists(row,col)) {	    return true;	}	else {	    return false;	}    }    class DirichletComponent    {	public String name;	public double q_value;	public double alpha_sum;	public double logbeta_value;	public double[] prior_values;	public DirichletComponent(String n, double q_v, double[]p_v, double lb_v, double a_s)	{	    name = n;	    q_value = q_v;	    alpha_sum = a_s;	    logbeta_value = lb_v;	    prior_values = p_v;	}	public void dump()	{	    P.MESSAGE("name = " + name);	    P.MESSAGE("q value = " + q_value);	    P.MESSAGE("alpha sum = " + alpha_sum);	    P.MESSAGE("logbeta value = " + logbeta_value);	    System.out.print("prior values: ");	    for(int i = 0; i < prior_values.length; i++) {		System.out.print(prior_values[i] + " ");	    }	    P.MESSAGE("");	    P.MESSAGE("");	}    }}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?