derivatives.c
来自「C-package of "Long Short-Term Memory" fo」· C语言 代码 · 共 75 行
C
75 行
/* * $Id: derivatives.c 1186 2006-10-04 08:40:07 +0000 (Wed, 04 Oct 2006) mhe $ * */ #include "lstm.h"#include "derivatives.h"void derivatives() { static int u,v,j; static double tmp1, tmp2, tmp3; #ifdef FF for (u=0;u<num_blocks;u++) { //tmp0 = (1.0-Y_in[u])*Y_in[u]; //tmp1 = Y_in[u]*2.0; for (v=0;v<block_size[u];v++) { tmp1 = Y_in[u] * G[u][v]; tmp2 = (1.0 - Y_in[u]) * tmp1; tmp3 = (1.0 - 0.5 * G[u][v]) * tmp1; for (j=0;j<in_mod_b;j++) { /* weights to input gate */ //SI[u][v][inp_idx[element][j]] += tmp2 * inp[element][j]; /* weights to cell input */ //SC[u][v][inp_idx[element][j]] += tmp3 * inp[element][j]; SC[u][v][j] += tmp3 * inp[element][j]; } SC[u][v][in_mod_b] += tmp3; for (j=in_mod_b;j<cell_mod;j++) { /* weights to input gate */ SI[u][v][j] += tmp2 * Yk_old[j]; /* weights to cell input */ //SC[u][v][j] += tmp3 * Yk_old[j]; } //SC[u][v][in_mod_b] = 0; } }#else for (u=0;u<num_blocks;u++) { //tmp0 = (1.0-Y_in[u])*Y_in[u]; //tmp1 = Y_in[u]*2.0; for (v=0;v<block_size[u];v++) { tmp1 = Y_in[u] * G[u][v]; tmp2 = (1.0 - Y_in[u]) * tmp1; tmp3 = (1.0 - 0.5 * G[u][v]) * tmp1; for (j=0;j<in_nn_mod;j++) { /* weights to input gate */ //SI[u][v][inp_idx[element][j]] += tmp2 * inp[element][j]; /* weights to cell input */ SC[u][v][inp_idx[element][j]] += tmp3 * inp[element][j]; } for (j=in_mod_b;j<cell_mod;j++) { /* weights to input gate */ SI[u][v][j] += tmp2 * Yk_old[j]; /* weights to cell input */ //SC[u][v][j] += tmp3 * Yk_old[j]; } SC[u][v][in_mod_b] += tmp3; //SC[u][v][in_mod_b] = 0; } }#endif }
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?