📄 hmm.java
字号:
final double EPS = 1E-14; Random rnd = new Random(); int count = Integer.parseInt(args[0]); for (int k=200; k>=-200; k--) for (int i=0; i<count; i++) { double logp = Math.abs(rnd.nextDouble()) * Math.pow(10, k); double logq = Math.abs(rnd.nextDouble()); double logpplusq = HMMAlgo.logplus(logp, logq); double p = Math.exp(logp), q = Math.exp(logq), pplusq = Math.exp(logpplusq); if (Math.abs(p+q-pplusq) > EPS * pplusq) System.out.println(p + "+" + q + "-" + pplusq); } }}// The Viterbi algorithm: find the most probable state path producing// the observed outputs xclass Viterbi extends HMMAlgo { double[][] v; // the matrix used to find the decoding // v[i][k] = v_k(i) = // log(max(P(pi in state k has sym i | path pi))) Traceback2[][] B; // the traceback matrix Traceback2 B0; // the start of the traceback public Viterbi(HMM hmm, String x) { super(hmm, x); final int L = x.length(); v = new double[L+1][hmm.nstate]; B = new Traceback2[L+1][hmm.nstate]; v[0][0] = 0; // = log(1) for (int k=1; k<hmm.nstate; k++) v[0][k] = Double.NEGATIVE_INFINITY; // = log(0) for (int i=1; i<=L; i++) v[i][0] = Double.NEGATIVE_INFINITY; // = log(0) for (int i=1; i<=L; i++) for (int ell=0; ell<hmm.nstate; ell++) { int kmax = 0; double maxprod = v[i-1][kmax] + hmm.loga[kmax][ell]; for (int k=1; k<hmm.nstate; k++) { double prod = v[i-1][k] + hmm.loga[k][ell]; if (prod > maxprod) { kmax = k; maxprod = prod; } } v[i][ell] = hmm.loge[ell][x.charAt(i-1)] + maxprod; B[i][ell] = new Traceback2(i-1, kmax); } int kmax = 0; double max = v[L][kmax]; for (int k=1; k<hmm.nstate; k++) { if (v[L][k] > max) { kmax = k; max = v[L][k]; } } B0 = new Traceback2(L, kmax); } public String getPath() { StringBuffer res = new StringBuffer(); Traceback2 tb = B0; int i = tb.i, j = tb.j; while ((tb = B[tb.i][tb.j]) != null) { res.append(hmm.state[j]); i = tb.i; j = tb.j; } return res.reverse().toString(); } public void print(Output out) { for (int j=0; j<hmm.nstate; j++) { for (int i=0; i<v.length; i++) out.print(HMM.fmtlog(v[i][j])); out.println(); } }}// The `Forward algorithm': find the probability of an observed sequence xclass Forward extends HMMAlgo { double[][] f; // the matrix used to find the decoding // f[i][k] = f_k(i) = log(P(x1..xi, pi_i=k)) private int L; public Forward(HMM hmm, String x) { super(hmm, x); L = x.length(); f = new double[L+1][hmm.nstate]; f[0][0] = 0; // = log(1) for (int k=1; k<hmm.nstate; k++) f[0][k] = Double.NEGATIVE_INFINITY; // = log(0) for (int i=1; i<=L; i++) f[i][0] = Double.NEGATIVE_INFINITY; // = log(0) for (int i=1; i<=L; i++) for (int ell=1; ell<hmm.nstate; ell++) { double sum = Double.NEGATIVE_INFINITY; // = log(0) for (int k=0; k<hmm.nstate; k++) sum = logplus(sum, f[i-1][k] + hmm.loga[k][ell]); f[i][ell] = hmm.loge[ell][x.charAt(i-1)] + sum; } } double logprob() { double sum = Double.NEGATIVE_INFINITY; // = log(0) for (int k=0; k<hmm.nstate; k++) sum = logplus(sum, f[L][k]); return sum; } public void print(Output out) { for (int j=0; j<hmm.nstate; j++) { for (int i=0; i<f.length; i++) out.print(HMM.fmtlog(f[i][j])); out.println(); } }}// The `Backward algorithm': find the probability of an observed sequence xclass Backward extends HMMAlgo { double[][] b; // the matrix used to find the decoding // b[i][k] = b_k(i) = log(P(x(i+1)..xL, pi_i=k)) public Backward(HMM hmm, String x) { super(hmm, x); int L = x.length(); b = new double[L+1][hmm.nstate]; for (int k=1; k<hmm.nstate; k++) b[L][k] = 0; // = log(1) // should be hmm.loga[k][0] for (int i=L-1; i>=1; i--) for (int k=0; k<hmm.nstate; k++) { double sum = Double.NEGATIVE_INFINITY; // = log(0) for (int ell=1; ell<hmm.nstate; ell++) sum = logplus(sum, hmm.loga[k][ell] + hmm.loge[ell][x.charAt(i)] + b[i+1][ell]); b[i][k] = sum; } } double logprob() { double sum = Double.NEGATIVE_INFINITY; // = log(0) for (int ell=0; ell<hmm.nstate; ell++) sum = logplus(sum, hmm.loga[0][ell] + hmm.loge[ell][x.charAt(0)] + b[1][ell]); return sum; } public void print(Output out) { for (int j=0; j<hmm.nstate; j++) { for (int i=0; i<b.length; i++) out.print(HMM.fmtlog(b[i][j])); out.println(); } }}// Compute posterior probabilities using Forward and Backwardclass PosteriorProb { Forward fwd; // result of the forward algorithm Backward bwd; // result of the backward algorithm private double logprob; PosteriorProb(Forward fwd, Backward bwd) { this.fwd = fwd; this.bwd = bwd; logprob = fwd.logprob(); // should equal bwd.logprob() } double posterior(int i, int k) // i=index into the seq; k=the HMM state { return Math.exp(fwd.f[i][k] + bwd.b[i][k] - logprob); }}// Traceback objectsabstract class Traceback { int i, j; // absolute coordinates}// Traceback2 objectsclass Traceback2 extends Traceback { public Traceback2(int i, int j) { this.i = i; this.j = j; }}// Auxiliary classes for outputabstract class Output { public abstract void print(String s); public abstract void println(String s); public abstract void println();}class SystemOut extends Output { public void print(String s) { System.out.print(s); } public void println(String s) { System.out.println(s); } public void println() { System.out.println(); }}public class Match3 { public static void main(String[] args) { dice(); // CpG(); } static void dice() { String[] state = { "F", "L" }; double[][] aprob = { { 0.95, 0.05 }, { 0.10, 0.90 } }; String esym = "123456"; double[][] eprob = { { 1.0/6, 1.0/6, 1.0/6, 1.0/6, 1.0/6, 1.0/6 }, { 0.10, 0.10, 0.10, 0.10, 0.10, 0.50 } }; HMM hmm = new HMM(state, aprob, esym, eprob); String x = "315116246446644245311321631164152133625144543631656626566666" + "651166453132651245636664631636663162326455236266666625151631" + "222555441666566563564324364131513465146353411126414626253356" + "366163666466232534413661661163252562462255265252266435353336" + "233121625364414432335163243633665562466662632666612355245242";// Viterbi vit = new Viterbi(hmm, x);// // vit.print(new SystemOut());// System.out.println(vit.getPath());// Forward fwd = new Forward(hmm, x);// // fwd.print(new SystemOut());// System.out.println(fwd.logprob());// Backward bwd = new Backward(hmm, x);// // bwd.print(new SystemOut());// System.out.println(bwd.logprob());// PosteriorProb postp = new PosteriorProb(fwd, bwd);// for (int i=0; i<x.length(); i++)// System.out.println(postp.posterior(i, 1)); String[] xs = { x }; HMM estimate = HMM.baumwelch(xs, state, esym, 0.00001); estimate.print(new SystemOut()); } static void CpG() { String[] state = { "A+", "C+", "G+", "T+", "A-", "C-", "G-", "T-" }; double p2m = 0.05; // P(switch from plus to minus) double m2p = 0.01; // P(switch from minus to plus) double[][] aprob = { { 0.180-p2m, 0.274-p2m, 0.426-p2m, 0.120-p2m, p2m, p2m, p2m, p2m }, { 0.171-p2m, 0.368-p2m, 0.274-p2m, 0.188-p2m, p2m, p2m, p2m, p2m }, { 0.161-p2m, 0.339-p2m, 0.375-p2m, 0.125-p2m, p2m, p2m, p2m, p2m }, { 0.079-p2m, 0.335-p2m, 0.384-p2m, 0.182-p2m, p2m, p2m, p2m, p2m }, { m2p, m2p, m2p, m2p, 0.300-m2p, 0.205-m2p, 0.285-m2p, 0.210-m2p }, { m2p, m2p, m2p, m2p, 0.322-m2p, 0.298-m2p, 0.078-m2p, 0.302-m2p }, { m2p, m2p, m2p, m2p, 0.248-m2p, 0.246-m2p, 0.298-m2p, 0.208-m2p }, { m2p, m2p, m2p, m2p, 0.177-m2p, 0.239-m2p, 0.292-m2p, 0.292-m2p } }; String esym = "ACGT"; double[][] eprob = { { 1, 0, 0, 0 }, { 0, 1, 0, 0 }, { 0, 0, 1, 0 }, { 0, 0, 0, 1 }, { 1, 0, 0, 0 }, { 0, 1, 0, 0 }, { 0, 0, 1, 0 }, { 0, 0, 0, 1 } }; HMM hmm = new HMM(state, aprob, esym, eprob); String x = "CGCG"; Viterbi vit = new Viterbi(hmm, x); vit.print(new SystemOut()); System.out.println(vit.getPath()); Forward fwd = new Forward(hmm, x); System.out.println(fwd.logprob()); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -