backward_pass.c

来自「C-package of "Long Short-Term Memory" fo」· C语言 代码 · 共 128 行

C
128
字号
/* * $Id: backward_pass.c 1269 2007-05-09 07:49:08Z mhe $ */ #include <stdio.h>#include "lstm.h"#include "backward_pass.h"void backward_pass() {    int k,i,j,u,v;    register double sum;    static double tmp;        /* output units */    for (k=cell_mod,j=0;k<ges_mod;k++,j++) {        e[k] = error[j] * (1.0 - Yk_new[k]) * Yk_new[k]; // log        /* weight update contribution */                tmp = alpha * e[k];                   i=in_mod_b+1;                for (u=0;u<num_blocks;u++) {            i++;            for (v=0;v<block_size[u];v++) {                i++;                DW[k][i] += tmp*Yk_new[i];            }            i++;        }                /* bias to output unit */        DW[k][in_mod_b] += tmp;            }        /* error to memory cells ec[][] and internal states es[][] */    i=in_mod_b+1;        /* input gates */    for (u=0;u<num_blocks;u++) {        i++;        for (v=0;v<block_size[u];v++) {            i++;            sum = 0;            for (k=cell_mod;k<ges_mod;k++) {                sum+= W[k][i]*e[k];            }            ec[u][v] = sum;            es[u][v] = 0.5 * Y_out[u] * (1.0 + H[u][v]) * (1.0 - H[u][v]) * sum;        }        i++;    }    /* output gates */    for (u=0;u<num_blocks;u++) {        sum=0;        for (v=0;v<block_size[u];v++) {            sum+= H[u][v]*ec[u][v];        }        eo[u]=sum*(1.0-Y_out[u])*Y_out[u];    }    /* Derivatives of the internal state */    derivatives();    /* updates for weights to input and output gates and memory cells */    i = in_mod_b + 1;    for (u=0;u<num_blocks;u++) {                        /* input gate */        /* input */        /*        for (j=0;j<in_mod_b;j++) {            sum = 0;            for (v=0;v<block_size[u];v++) {                sum += es[u][v]*SI[u][v][j];              }            DW[i][j] += sum;        }        */        /* lstm */        for (j=in_mod_b;j<cell_mod;j++) {            sum = 0;            for (v=0;v<block_size[u];v++) {                sum += es[u][v]*SI[u][v][j];              }            DW[i][j] += alpha * sum;        }                /*output gate */        i++;        tmp = alpha * eo[u];                /* input */        /*        for (j=0;j<in_mod;j++) {            DW[i][j] += tmp * Yk_old[j];        }        */        /* lstm */        for (j=in_mod_b;j<cell_mod;j++) {            DW[i][j] += tmp * Yk_old[j];        }                /* memory states */        for (v=0;v<block_size[u];v++) {            i++;            tmp = alpha * es[u][v];            /* input */            for (j=0;j<=in_mod_b;j++) {                DW[i][j] += tmp*SC[u][v][j];            }            /* lstm & bias */            /*            for (j=in_mod_b;j<cell_mod;j++) {                DW[i][j] += tmp*SC[u][v][j];                       }            */        }        i++;    } }

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?