📄 forward_pass.c
字号:
/* * $Id: forward_pass.c 1273 2007-05-09 16:28:04Z mhe $ */ #include <math.h>#include <stdio.h>#include <stdlib.h>#include <string.h>#include "lstm.h"#include "forward_pass.h"static union { double d; struct { #ifdef LITTLE_ENDIAN int j, i; #else int i, j; #endif } n;} _eco;void forward_pass(int train, int element, int **inp_idx, double **inp) { static int i,j,u,v,k; static double sum; // , sumblock[100]; static char actfilename[255]; /* ### memory cells ### */ i=in_mod_b+1; for (u=0;u<num_blocks;u++) { /* input gates */ sum = 0; /* input */ /* for (j=0;j<in_nn_mod;j++) { sum += W[i][inp_idx[element][j]] * inp[element][j]; } */ /* lstm */ for (j=in_mod_b;j<cell_mod;j++) { sum += W[i][j] * Yk_old[j]; } Y_in[u] = 1/(1+EXP(-sum)); Yk_new[i]= Y_in[u]; /* output gate */ i++; sum = 0; /* input */ /* for (j=0;j<in_nn_mod;j++) { sum += W[i][inp_idx[element][j]] * inp[element][j]; } */ /* lstm */ for (j=in_mod_b;j<cell_mod;j++) { sum += W[i][j] * Yk_old[j]; } /* peep hole */ //sum -= S[u][0]; Y_out[u] = 1/(1+EXP(-sum)); Yk_new[i]= Y_out[u]; /* uth memory cell block */ for (v=0;v<block_size[u];v++) { /* activation of function g of vth memory cell of block u */ i++; sum = W[i][in_mod_b]; // bias #ifdef FF for (j = 0; j < in_mod_b; j++) { sum += W[i][j] * inp[element][j]; } #else for (j=0;j<in_nn_mod;j++) { sum += W[i][inp_idx[element][j]] * inp[element][j]; } #endif /* peep hole */ //sum -= S[u][v]; //sumblock[u] = sum; G[u][v] = 2.0/(1+EXP(-sum)); /* update internal state */ S[u][v] += Y_in[u]*G[u][v]; /* activation function h */ if (S[u][v] > 500.0) { fprintf(stderr, "S drifted\n"); H[u][v] = 1.0; } else { H[u][v] = 2.0/(1+EXP(-S[u][v]))-1.0; } /* activation of vth memory cell of block u */ Yc[u][v] = H[u][v]*Y_out[u]; Yk_new[i] = Yc[u][v]; } i++; } /* ### output units activation ### */ if (targ) /* only if target for this input */ { for (k=cell_mod;k<ges_mod;k++) { /* bias */ sum = W[k][in_mod_b]; /* memory cells input */ i=in_mod_b+1; for (u=0;u<num_blocks;u++) { i++; for (v=0;v<block_size[u];v++) { i++; sum += W[k][i]*Yk_new[i]; } i++; } /* activation */ Yk_new[k] = 1/(1+EXP(-sum)); } } // output of memory activity if (0) { if (element == 0) { //sprintf(actfilename, "%d%s", prand_t[prot_current_idx], ".possibleprotea.memoryactivity.train.dat"); sprintf(actfilename, "%d.pos.test.dat", prot_current_idx); fprintf(stderr, "%s\n", actfilename); ma = fopen(actfilename, "w"); } //for (u=0;u<num_blocks;u++) { //fprintf(ma, "%f %f %f %f ", Y_in[u], Y_out[u], sumblock[u], S[u][0]); //fprintf(ma, "%f ", S[u][0]); //} //fprintf(ma, "\n"); fprintf(ma, "%f\n", S[4][0]); //fprintf(ma, "%f\n", Yk_new[cell_mod]); fflush(ma); /* * * See in execute_act_test() for closing the file * */ } }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -