sreestimate.c
来自「General Hidden Markov Model Library 一个通用」· C语言 代码 · 共 907 行 · 第 1/2 页
C
907 行
if (fabs (r->mue_u_denom[i][m]) <= DBL_MIN) { /* < EPS_PREC ? */#if MCI ighmm_mes (MESCONTR, "u[%d][%d]: denominator == 0.0!\n", i, m);#endif ; /* smo->s[i].u[m] unchanged! */ } else { u_im = r->u_num[i][m] / r->mue_u_denom[i][m]; if (u_im <= GHMM_EPS_U) u_im = (double) GHMM_EPS_U; smo->s[i].u[m] = u_im; } /* modification for truncated normal density: 1-dim optimization for mue, calculate u directly note: if denom == 0 --> mue and u not recalculated above */ if (smo->s[i].density[m] == normal_right && fabs (r->mue_u_denom[i][m]) > DBL_MIN) { A = smo->s[i].mue[m]; B = r->sum_gt_otot[i][m] / r->mue_u_denom[i][m]; /* A^2 ~ B -> search max at border of EPS_U */ if (B - A * A < GHMM_EPS_U) { mue_left = -GHMM_EPS_NDT; /* attention: only works if EPS_NDT > EPS_U ! */ mue_right = A; if ((ighmm_gtail_pmue_umin (mue_left, A, B, GHMM_EPS_NDT) > 0.0 && ighmm_gtail_pmue_umin (mue_right, A, B, GHMM_EPS_NDT) > 0.0) || (ighmm_gtail_pmue_umin (mue_left, A, B, GHMM_EPS_NDT) < 0.0 && ighmm_gtail_pmue_umin (mue_right, A, B, GHMM_EPS_NDT) < 0.0)) fprintf (stderr, "umin:fl:%.3f\tfr:%.3f\t; left %.3f\t right %3f\t A %.3f\t B %.3f\n", ighmm_gtail_pmue_umin (mue_left, A, B, GHMM_EPS_NDT), ighmm_gtail_pmue_umin (mue_right, A, B, GHMM_EPS_NDT), mue_left, mue_right, A, B); mue_im = ghmm_zbrent_AB (ighmm_gtail_pmue_umin, mue_left, mue_right, ACC, A, B, GHMM_EPS_NDT); u_im = GHMM_EPS_U; } else { Atil = A + GHMM_EPS_NDT; Btil = B + GHMM_EPS_NDT * A; mue_left = (-C_PHI * sqrt (Btil + GHMM_EPS_NDT * Atil + CC_PHI * m_sqr (Atil) / 4.0) - CC_PHI * Atil / 2.0 - GHMM_EPS_NDT) * 0.99; mue_right = A; if (A < Btil * ighmm_rand_normal_density_pos (-GHMM_EPS_NDT, 0, Btil)) mue_right = m_min (GHMM_EPS_NDT, mue_right); else mue_left = m_max (-GHMM_EPS_NDT, mue_left); mue_im = ghmm_zbrent_AB (ighmm_gtail_pmue_interpol, mue_left, mue_right, ACC, A, B, GHMM_EPS_NDT); u_im = Btil - mue_im * Atil; } /* set modified values of mue and u */ smo->s[i].mue[m] = mue_im; if (u_im < (double) GHMM_EPS_U) u_im = (double) GHMM_EPS_U; smo->s[i].u[m] = u_im; } /* end modifikation truncated density */ } /* for (m ..) */ /* adjusting weights for fixed mixture components if necessary */ if (fix_flag == 1) { for (m = 0; m < smo->s[i].M; m++) { if (smo->s[i].mixture_fix[m] == 0) { smo->s[i].c[m] = (smo->s[i].c[m] * fix_w) / unfix_w; } } }#if MCI if (!c_num_pos) ighmm_mes (MESCONTR, "all numerators c[%d][m] == 0 (denominator=%.4f)!\n", i, r->c_denom[i]);#endif } /* for (i = 0 .. < smo->N) */ res = 0;STOP: /* Label STOP from ARRAY_[CM]ALLOC */ return (res);# undef CUR_PROC} /* sreestimate_setlambda *//*----------------------------------------------------------------------------*/int sreestimate_one_step (ghmm_cmodel * smo, local_store_t * r, int seq_number, int *T, double **O, double *log_p, double *seq_w){# define CUR_PROC "sreestimate_one_step" int res = -1; int k, i, j, m, t, j_id, valid_parameter, valid_logp, osc; double **alpha = NULL; double **beta = NULL; double *scale = NULL; double ***b = NULL; int T_k = 0, T_k_max = 0; double c_t, sum_alpha_a_ji, gamma, gamma_ct, f_im, quot; double log_p_k; double contrib_t; *log_p = 0.0; valid_parameter = valid_logp = 0; /*scan for max T_k: alloc of alpha, beta, scale and b only once */ T_k_max = T[0]; for (k = 1; k < seq_number; k++) if (T[k] > T_k_max) T_k_max = T[k]; if (sreestimate_alloc_matvek (&alpha, &beta, &scale, &b, T_k_max, smo->N, smo->M) == -1) { GHMM_LOG_QUEUED(LCONVERTED); goto STOP; } /* loop over all sequences */ for (k = 0; k < seq_number; k++) { /* Test: ignore sequences with very small weights */ /* if (seq_w[k] < 0.0001) continue; */ /* seq. is used for calculation of log_p */ valid_logp++; T_k = T[k]; /* precompute output densities */ sreestimate_precompute_b (smo, O[k], T_k, b); if (smo->cos > 1) { smo->class_change->k = k; } if ((ghmm_cmodel_forward (smo, O[k], T_k, b, alpha, scale, &log_p_k) == -1) || (ghmm_cmodel_backward (smo, O[k], T_k, b, beta, scale) == -1)) {#if MCI ighmm_mes (MESCONTR, "O(%2d) can't be build from smodel smo!\n", k);#endif /* penalty costs */ *log_p += seq_w[k] * (double) GHMM_PENALTY_LOGP; continue; } else /* printf("\n\nalpha:\n"); ighmm_cmatrix_print(stdout,alpha,T_k,smo->N,"\t", ",", ";"); printf("\n\nbeta:\n"); ighmm_cmatrix_print(stdout,beta,T_k,smo->N,"\t", ",", ";"); printf("\n\n"); */ /* weighted error function */ *log_p += log_p_k * seq_w[k]; /* seq. is used for parameter estimation */ valid_parameter++; /* loop over all states */ for (i = 0; i < smo->N; i++) { /* Pi */ r->pi_num[i] += seq_w[k] * alpha[0][i] * beta[0][i]; r->pi_denom += seq_w[k] * alpha[0][i] * beta[0][i]; /* sum over all i */ /* loop over t (time steps of seq.) */ for (t = 0; t < T_k; t++) { c_t = 1 / scale[t]; if (t > 0) { if (smo->cos == 1) { osc = 0; } else { if (!smo->class_change->get_class) { printf ("ERROR: get_class not initialized\n"); goto STOP; } osc = smo->class_change->get_class (smo, O[k], k, t - 1); /*printf("osc=%d : cos = %d, k = %d, t = %d, state=%d\n",osc,smo->cos,smo->class_change->k,t,i); */ if (osc >= smo->cos){ printf("ERROR: get_class returned index %d but model has only %d classes !\n",osc,smo->cos); goto STOP; } } /* A: starts at t=1 !!! */ for (j = 0; j < smo->s[i].out_states; j++) { j_id = smo->s[i].out_id[j]; contrib_t = (seq_w[k] * alpha[t - 1][i] * smo->s[i].out_a[osc][j] * b[t][j_id][smo->s[i].M] * beta[t][j_id] * c_t); /* c[t] = 1/scale[t] */ r->a_num[i][osc][j] += contrib_t; /* printf("r->a_num[%d][%d][%d] += (alpha[%d][%d] * smo->s[%d].out_a[%d][%d] * b[%d]%d][%d] * beta[%d][%d] * c_t = %f * %f * %f * %f * %f = %f\n", i,osc,j,t-1,i,i,osc,j,t,j_id,smo->M,t,j_id,alpha[t - 1][i], smo->s[i].out_a[osc][j], b[t][j_id][smo->M], beta[t][j_id], c_t,r->a_num[i][osc][j]); */ r->a_denom[i][osc] += contrib_t; /* printf("r->a_denom[%d][%d] += %f\n",i,osc,r->a_denom[i][osc]); */ } /* calculate sum (j=1..N){alp[t-1][j]*a_jc(t-1)i} */ sum_alpha_a_ji = 0.0; for (j = 0; j < smo->s[i].in_states; j++) { j_id = smo->s[i].in_id[j]; sum_alpha_a_ji += alpha[t-1][j_id] * smo->s[i].in_a[osc][j]; } } /* if t>0 */ else { /* calculate sum(j=1..N){alpha[t-1][j]*a_jci}, which is used below for (t=1) = pi[i] (alpha[-1][i] not defined) !!! */ sum_alpha_a_ji = smo->s[i].pi; } /* if t>0 */ /* ========= if state fix, continue;====================== */ if (smo->s[i].fix) continue; /* C-denominator: */ r->c_denom[i] += seq_w[k] * alpha[t][i] * beta[t][i]; /* if sum_alpha_a_ji == 0.0, all following values are 0! */ if (sum_alpha_a_ji == 0.0) continue; /* next t */ /* loop over no of density functions for C-numer., mue and u */ for (m = 0; m < smo->s[i].M; m++) { /* c_im * b_im */ f_im = b[t][i][m]; gamma = seq_w[k] * sum_alpha_a_ji * f_im * beta[t][i]; gamma_ct = gamma * c_t; /* c[t] = 1/scale[t] */ /* numerator C: */ r->c_num[i][m] += gamma_ct; /* numerator Mue: */ r->mue_num[i][m] += (gamma_ct * O[k][t]); /* denom. Mue/U: */ r->mue_u_denom[i][m] += gamma_ct; /* numerator U: */ r->u_num[i][m] += (gamma_ct * m_sqr (O[k][t] - smo->s[i].mue[m])); /* sum gamma_ct * O[k][t] * O[k][t] (truncated normal density): */ r->sum_gt_otot[i][m] += (gamma_ct * m_sqr (O[k][t])); /* sum gamma_ct * log(b_im) */ if (gamma_ct > 0.0) { quot = b[t][i][m] / smo->s[i].c[m]; r->sum_gt_logb[i][m] += (gamma_ct * log (quot)); } } } /* for (t=0, t<T) */ } /* for (i=0, i<smo->N) */ } /* for (k = 0; k < seq_number; k++) */ /* reset class_change->k to default value */ if (smo->cos > 1) { smo->class_change->k = -1; } if (valid_parameter) { /* new parameter lambda: set directly in model */ if (sreestimate_setlambda (r, smo) == -1) { GHMM_LOG_QUEUED(LCONVERTED); return (-1); } /* only test : */ /* printf("scale:\n"); for(t=0;t<T[0];t++){ printf("%f, ",scale[t]); } printf("\n\n"); for(osc =0;osc<smo->cos;osc++) { for(i=0;i<smo->N;i++){ for(j=0;j<smo->N;j++){ printf("osc= %d, i = %d, j= %d : %f / %f = %f\n",osc,i,j,r->a_num[i][osc][j], r->a_denom[i][osc],(r->a_num[i][osc][j] / r->a_denom[i][osc])); } } } ghmm_cmodel_print(stdout,smo); */ if (ghmm_cmodel_check(smo) == -1) { GHMM_LOG_QUEUED(LCONVERTED); printf("Model invalid !\n\n"); goto STOP; } /* else { printf("Model is ok.\n"); } */ } else { /* NO sequence can be build from smodel smo! */ /* diskret: *log_p = +1; */ ighmm_mes (MES_WIN, " NO sequence can be build from smodel smo!\n"); return (-1); } sreestimate_free_matvec (alpha, beta, scale, b, T_k_max, smo->N); return (valid_logp); /* return(valid_parameter); */STOP: /* Label STOP from ARRAY_[CM]ALLOC */ sreestimate_free_matvec (alpha, beta, scale, b, T_k, smo->N); return (res);# undef CUR_PROC} /* sreestimate_one_step *//*============================================================================*//* int ghmm_cmodel_baum_welch(ghmm_cmodel *smo, ghmm_cseq *sqd) {*/int ghmm_cmodel_baum_welch (ghmm_cmodel_baum_welch_context * cs){# define CUR_PROC "ghmm_cmodel_baum_welch" int i, j, n, valid, valid_old, max_iter_bw; double log_p, log_p_old, diff, eps_iter_bw; local_store_t *r = NULL; char *str; /* truncated normal density needs static varialbles C_PHI and CC_PHI */ for (i = 0; i < cs->smo->N; i++){ for (j = 0; j < cs->smo->s[i].M; j++){ if (cs->smo->s[i].density[j] == normal_right) { C_PHI = ighmm_rand_get_xPHIless1 (); CC_PHI = m_sqr (C_PHI); break; } } } /* local store for all iterations */ r = sreestimate_alloc (cs->smo); if (!r) { GHMM_LOG_QUEUED(LCONVERTED); goto STOP; }; sreestimate_init (r, cs->smo); log_p_old = -DBL_MAX; valid_old = cs->sqd->seq_number; n = 1; max_iter_bw = m_min (GHMM_MAX_ITER_BW, cs->max_iter); eps_iter_bw = m_max (GHMM_EPS_ITER_BW, cs->eps); /*printf(" *** ghmm_cmodel_baum_welch %d, %f \n",max_iter_bw,eps_iter_bw );*/ while (n <= max_iter_bw) { valid = sreestimate_one_step (cs->smo, r, cs->sqd->seq_number, cs->sqd->seq_len, cs->sqd->seq, &log_p, cs->sqd->seq_w); /* to follow convergence of bw: uncomment next line */ printf ("\tBW Iter %d\t log(p) %.4f\n", n, log_p); if (valid == -1) { str = ighmm_mprintf (NULL, 0, "sreestimate_one_step false (%d.step)\n", n); GHMM_LOG(LCONVERTED, str); m_free (str); goto STOP; }#if MCI if (n == 1) ighmm_mes (MESINFO, "%8.5f (-log_p input smodel)\n", -log_p); else ighmm_mes (MESINFO, "\n%8.5f (-log_p)\n", -log_p);#endif diff = log_p - log_p_old; if (diff < -GHMM_EPS_PREC) { if (valid > valid_old) { str = ighmm_mprintf (NULL, 0, "log P < log P-old (more sequences (%d) , n = %d)\n", valid - valid_old, n); GHMM_LOG(LCONVERTED, str); m_free (str); } /* no convergence */ else { str = ighmm_mprintf (NULL, 0, "NO convergence: log P(%e) < log P-old(%e)! (n = %d)\n", log_p, log_p_old, n); GHMM_LOG(LCONVERTED, str); m_free (str); break; /* goto STOP; ? */ } } /* stop iteration */ if (diff >= 0.0 && diff < fabs (eps_iter_bw * log_p)) {#if MCI ighmm_mes (MESINFO, "Convergence after %d steps\n", n);#endif break; } else { /* for next iteration */ valid_old = valid; log_p_old = log_p; /* set values to zero */ sreestimate_init (r, cs->smo); n++; } } /* while (n <= MAX_ITER_BW) */#if MCI ighmm_mes (MESINFO, "%8.5f (-log_p optimized smodel)\n", -log_p);#endif /* log_p outside this function */ *cs->logp = log_p; /* test plausibility of new parameters */ /* if (ghmm_cmodel_check(mo) == -1) { GHMM_LOG_QUEUED(LCONVERTED); goto STOP; } */ sreestimate_free (&r, cs->smo->N); return (0);STOP: /* Label STOP from ARRAY_[CM]ALLOC */ sreestimate_free (&r, cs->smo->N); return (-1);# undef CUR_PROC} /* ghmm_cmodel_baum_welch */#undef ACC#undef MCI#undef MESCONTR#undef MESINFO
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?