derivatives.c

来自「C-package of "Long Short-Term Memory" fo」· C语言 代码 · 共 75 行

C
75
字号
/* * $Id: derivatives.c 1186 2006-10-04 08:40:07 +0000 (Wed, 04 Oct 2006) mhe $ *  */ #include "lstm.h"#include "derivatives.h"void derivatives() {       static int u,v,j;    static double tmp1, tmp2, tmp3;        #ifdef FF        for (u=0;u<num_blocks;u++) {        //tmp0 = (1.0-Y_in[u])*Y_in[u];        //tmp1 = Y_in[u]*2.0;        for (v=0;v<block_size[u];v++) {            tmp1 = Y_in[u] * G[u][v];            tmp2 = (1.0 - Y_in[u]) * tmp1;            tmp3 = (1.0 - 0.5 * G[u][v]) * tmp1;            for (j=0;j<in_mod_b;j++) {                                /* weights to input gate */                //SI[u][v][inp_idx[element][j]] += tmp2 * inp[element][j];                /* weights to cell input */                //SC[u][v][inp_idx[element][j]] += tmp3 * inp[element][j];                SC[u][v][j] += tmp3 * inp[element][j];            }            SC[u][v][in_mod_b] += tmp3;            for (j=in_mod_b;j<cell_mod;j++) {                /* weights to input gate */                SI[u][v][j] += tmp2 * Yk_old[j];                /* weights to cell input */                //SC[u][v][j] += tmp3 * Yk_old[j];            }            //SC[u][v][in_mod_b] = 0;        }            }#else   for (u=0;u<num_blocks;u++) {        //tmp0 = (1.0-Y_in[u])*Y_in[u];        //tmp1 = Y_in[u]*2.0;        for (v=0;v<block_size[u];v++) {            tmp1 = Y_in[u] * G[u][v];            tmp2 = (1.0 - Y_in[u]) * tmp1;            tmp3 = (1.0 - 0.5 * G[u][v]) * tmp1;            for (j=0;j<in_nn_mod;j++) {                                /* weights to input gate */                //SI[u][v][inp_idx[element][j]] += tmp2 * inp[element][j];                /* weights to cell input */                SC[u][v][inp_idx[element][j]] += tmp3 * inp[element][j];            }            for (j=in_mod_b;j<cell_mod;j++) {                /* weights to input gate */                SI[u][v][j] += tmp2 * Yk_old[j];                /* weights to cell input */                //SC[u][v][j] += tmp3 * Yk_old[j];            }            SC[u][v][in_mod_b] += tmp3;            //SC[u][v][in_mod_b] = 0;        }            }#endif    }

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?