gradescent.c
来自「General Hidden Markov Model Library 一个通用」· C语言 代码 · 共 556 行 · 第 1/2 页
C
556 行
return -1; /* loop over all sequences */ for (k = 0; k < sq->seq_number; k++) { seq_len = sq->seq_len[k]; if (-1 == ighmm_reestimate_alloc_matvek (&alpha, &beta, &scale, seq_len, mo->N)) continue; /* calculate forward and backward variables without labels: */ if (-1 == ghmm_dmodel_forward (mo, sq->seq[k], seq_len, alpha, scale, &log_p)) { printf ("forward error!\n"); goto FREE; } if (-1 == ghmm_dmodel_backward (mo, sq->seq[k], seq_len, beta, scale)) { printf ("backward error!\n"); goto FREE; } /* compute n matrices (no labels): */ if (-1 == ghmm_dmodel_label_gradient_expectations (mo, alpha, beta, scale, sq->seq[k], seq_len, m_b, m_a, m_pi)) printf ("Error in sequence %d, length %d (no labels)\n", k, seq_len); /* calculate forward and backward variables with labels: */ if (-1 == ghmm_dmodel_label_forward (mo, sq->seq[k], sq->state_labels[k], seq_len, alpha, scale, &log_p)) { printf ("forward labels error!\n"); goto FREE; } if (-1 == ghmm_dmodel_label_backward (mo, sq->seq[k], sq->state_labels[k], seq_len, beta, scale, &log_p)) { printf ("backward labels error!\n"); goto FREE; } /* compute m matrices (labels): */ if (-1 == ghmm_dmodel_label_gradient_expectations (mo, alpha, beta, scale, sq->seq[k], seq_len, m_b, m_a, m_pi)) printf ("Error in sequence %d, length %d (with labels)\n", k, seq_len); /* reestimate model parameters: */ /* PI */ pi_sum = 0; /* update */ for (i = 0; i < mo->N; i++) { if (mo->s[i].pi > 0.0) { gradient = eta * (m_pi[i] - n_pi[i]); if (mo->s[i].pi + gradient > GHMM_EPS_PREC) mo->s[i].pi += gradient; else mo->s[i].pi = GHMM_EPS_PREC; } /* sum over new PI vector */ pi_sum += mo->s[i].pi; } if (pi_sum < GHMM_EPS_PREC) { /* never get here */ fprintf (stderr, "Training ruined the model. You lose.\n"); k = sq->seq_number; goto FREE; } /* normalise */ for (i = 0; i < mo->N; i++) mo->s[i].pi /= pi_sum; /* A */ for (i = 0; i < mo->N; i++) { a_row_sum = 0; /* update */ for (j = 0; j < mo->s[i].out_states; j++) { j_id = mo->s[i].out_id[j]; gradient = eta * (m_a[i * mo->N + j_id] - n_a[i * mo->N + j_id]) / (seq_len - 1); if (mo->s[i].out_a[j] + gradient > GHMM_EPS_PREC) mo->s[i].out_a[j] += gradient; else mo->s[i].out_a[j] = GHMM_EPS_PREC; /* sum over rows of new A matrix */ a_row_sum += mo->s[i].out_a[j]; } if (a_row_sum < GHMM_EPS_PREC) { /* never get here */ fprintf (stderr, "Training ruined the model. You lose.\n"); k = sq->seq_number; goto FREE; } /* normalise */ for (j = 0; j < mo->s[i].out_states; j++) { mo->s[i].out_a[j] /= a_row_sum; /* mirror out_a to corresponding in_a */ j_id = mo->s[i].out_id[j]; for (g = 0; g < mo->s[j_id].in_states; g++) if (i == mo->s[j_id].in_id[g]) { mo->s[j_id].in_a[g] = mo->s[i].out_a[j]; break; } } } /* B */ for (i = 0; i < mo->N; i++) { /* don't update fix states */ if (mo->s[i].fix) continue; /* update */ size = ghmm_ipow (mo, mo->M, mo->order[i]); for (h = 0; h < size; h++) { b_block_sum = 0; for (g = 0; g < mo->M; g++) { hist = h * mo->M + g; gradient = eta * (m_b[i][hist] - n_b[i][hist]) / seq_len; /* printf("gradient[%d][%d] = %g, m_b = %g, n_b = %g\n" , i, hist, gradient, m_b[i][hist], n_b[i][hist]); */ if (gradient + mo->s[i].b[hist] > GHMM_EPS_PREC) mo->s[i].b[hist] += gradient; else mo->s[i].b[hist] = GHMM_EPS_PREC; /* sum over M-length blocks of new B matrix */ b_block_sum += mo->s[i].b[hist]; } if (b_block_sum < GHMM_EPS_PREC) { /* never get here */ fprintf (stderr, "Training ruined the model. You lose.\n"); k = sq->seq_number; goto FREE; } /* normalise */ for (g = 0; g < mo->M; g++) { hist = h * mo->M + g; mo->s[i].b[hist] /= b_block_sum; } } } /* restore "tied_to" property */ if (mo->model_type & GHMM_kTiedEmissions) ghmm_dmodel_update_tie_groups (mo); FREE: ighmm_reestimate_free_matvek (alpha, beta, scale, seq_len); } gradient_descent_gfree (m_b, m_a, m_pi, mo->N); gradient_descent_gfree (n_b, n_a, n_pi, mo->N); return 0;#undef CUR_PROC}/*----------------------------------------------------------------------------*//** Trains the ghmm_dmodel with a set of annotated sequences till convergence using gradient descent. Model must not have silent states. (checked in Python wrapper) @return trained model/NULL pointer success/error @param mo: pointer to a ghmm_dmodel @param sq: struct of annotated sequences @param eta: intial parameter eta (learning rate) @param no_steps number of training steps */ghmm_dmodel* ghmm_dmodel_label_gradient_descent (ghmm_dmodel* mo, ghmm_dseq * sq, double eta, int no_steps){#define CUR_PROC "ghmm_dmodel_label_gradient_descent" char * str; int runs = 0; double cur_perf, last_perf; ghmm_dmodel *last; last = ghmm_dmodel_copy(mo); last_perf = compute_performance (last, sq); while (eta > GHMM_EPS_PREC && runs < no_steps) { runs++; if (-1 == gradient_descent_onestep(mo, sq, eta)) { ghmm_dmodel_free(&last); return NULL; } cur_perf = compute_performance(mo, sq); if (last_perf < cur_perf) { /* if model is degenerated, lower eta and try again */ if (cur_perf > 0.0) { str = ighmm_mprintf(NULL, 0, "current performance = %g", cur_perf); GHMM_LOG(LINFO, str); m_free(str); ghmm_dmodel_free(&mo); mo = ghmm_dmodel_copy(last); eta *= .5; } else { /* Improvement insignificant, assume convergence */ if (fabs (last_perf - cur_perf) < cur_perf * (-1e-8)) { ghmm_dmodel_free(&last); str = ighmm_mprintf(NULL, 0, "convergence after %d steps.", runs); GHMM_LOG(LINFO, str); m_free(str); return 0; } if (runs < 175 || 0 == runs % 50) { str = ighmm_mprintf(NULL, 0, "Performance: %g\t improvement: %g\t step %d", cur_perf, cur_perf - last_perf, runs); GHMM_LOG(LINFO, str); m_free(str); } /* significant improvement, next iteration */ ghmm_dmodel_free(&last); last = ghmm_dmodel_copy(mo); last_perf = cur_perf; eta *= 1.07; } } /* no improvement */ else { if (runs < 175 || 0 == runs % 50) { str = ighmm_mprintf(NULL, 0, "Performance: %g\t !IMPROVEMENT: %g\t step %d", cur_perf, cur_perf - last_perf, runs); GHMM_LOG(LINFO, str); m_free(str); } /* try another training step */ runs++; eta *= .85; if (-1 == gradient_descent_onestep(mo, sq, eta)) { ghmm_dmodel_free(&last); return NULL; } cur_perf = compute_performance (mo, sq); str = ighmm_mprintf(NULL, 0, "Performance: %g\t ?Improvement: %g\t step %d", cur_perf, cur_perf - last_perf, runs); GHMM_LOG(LINFO, str); m_free(str); /* improvement, save and proceed with next iteration */ if (last_perf < cur_perf && cur_perf < 0.0) { ghmm_dmodel_free (&last); last = ghmm_dmodel_copy(mo); last_perf = cur_perf; } /* still no improvement, revert to saved model */ else { runs--; ghmm_dmodel_free(&mo); mo = ghmm_dmodel_copy(last); eta *= .9; } } } ghmm_dmodel_free(&last); return mo;#undef CUR_PROC}
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?