⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 rnn.cc

📁 随机神经网络的源代码
💻 CC
📖 第 1 页 / 共 2 页
字号:
#include "rnn.hh"/* Internal Function */void RNN::free_fmat(float **data,int row, int col){  delete data[0];  delete data;}/* Internal Function */void RNN::free_cmat(u_char **data,int row, int col){  delete data[0];  delete data;}/* Internal Function */float ** RNN::fmat(int row, int col){  float **T;  int i,j;  T = new float *[row];  T[0] = new float [row*col];  memset(T[0],0,row*col*sizeof(float));  for (i=1;i<row;i++)     T[i] = &(T[0][i*col]);  return(T);}/* Internal Function */u_char ** RNN::cmat(int row, int col){  u_char **tmpN;  int i,j;  tmpN = new u_char *[row];  tmpN[0] = new u_char [row*col];  memset(tmpN[0],0,row*col*sizeof(u_char));  for (i=0;i<row;i++)     tmpN[i] = &(tmpN[0][i*col]);  return(tmpN);}/*Internal FunctionFunction: new_rnnPurpose: initializes a neural networkArguments:  ip: number of "input" neurons (neurons which accept inputs)  hn: number of "hidden" neurons (neurons accepting no i/o)  op: number of "output" neurons (neurons which produce o/p)  ETA: learning rate  RATE: rate of o/p neuron firing  rnntype: select topology of rnn    0: feed-forward (optimized for speed)    1: fully topologically recurrent     2: undefined topology which must be defined using connect*/void RNN::new_rnn(int ip, int hn, int op, float ETA, float RATE, u_char rnntype){  int i,j;  N = ip+hn+op;  IP = ip;  HN = hn;  OP = op;  eta = ETA;  rate = RATE;  RNNTYPE = rnntype;  weights_changed = 1;  wp = fmat(N,N);  wm = fmat(N,N);  w = fmat(N,N);  for (i=0;i<N;i++)    w[i][i] = 1.0;					      C = cmat(N,(int)((N+7)/8.0));  GOP = new float[N];  num = new float[N];  memset(GOP,0,N*sizeof(float));  typ = new u_char[N];  for (i=0;i<N;i++)    {      if (i < IP) typ[i] = 1;      else if (i < IP + HN) typ[i] = 2;      else typ[i] = 3;    }  q = new float[N];  memset(q,0,N*sizeof(float));  den = new float[N];  memset(den,0,N*sizeof(float));  r = new float[N];  memset(r,0,N*sizeof(float));  lp = new float[N];  memset(lp,0,N*sizeof(float));  lm = new float[N];  memset(lm,0,N*sizeof(float));  if (RNNTYPE == 0)     {      for (i=0;i<N;i++)	for (j=0;j<N;j++)	  if (typ[i] + 1 == typ[j]) connect(i,j);    }  else if (RNNTYPE == 1)    {      for (i=0;i<N;i++)	for (j=0;j<N;j++)	  if (i != j) connect(i,j);    }  else {  }  randomweights();}/* Internal Function */void RNN::ff_inv(){  int i,j,k;  float z;  for (i=0;i<IP;i++)    for (j=IP;j<IP+HN;j++)      w[i][j] = (wp[i][j] - wm[i][j]*q[j])/den[j];  for (i=IP;i<IP+HN;i++)    for (j=IP+HN;j<N;j++)      w[i][j] = (wp[i][j] - wm[i][j]*q[j])/den[j];  for (i=0;i<IP;i++)     for (j=IP+HN;j<N;j++) {      z = 0;      for (k=IP;k<IP+HN;k++)	z+= w[i][k] * w[k][j];      w[i][j] = z;    }}/* Internal Function */void RNN::gen_inv(){  int r,c,i,st,j,k,rn;  float z,z2;  float **x,**y;    x = fmat(N,N);  y = fmat(N,N);    for (c=0;c<N;c++)    for (r=0;r<N;r++) {      if (c == r) {	y[r][c] = 1;	x[r][c] = 1 - (wp[r][c] - wm[r][c]*q[c])/den[c];	for (i=0;i<c;i++)	  x[r][c] -= x[r][i]*y[i][c];      }      if (r > c) {	x[r][c] = - (wp[r][c] - wm[r][c]*q[c])/den[c];	for (i=0;i<c;i++)	  x[r][c] -= x[r][i]*y[i][c];      }      if (r < c) {	y[r][c] = - (wp[r][c] - wm[r][c]*q[c])/den[c];	for (i=0;i<r;i++)	  y[r][c] -= x[r][i]*y[i][c];	y[r][c] = y[r][c] / x[r][r];      }    }    for (rn=0;rn<N;rn++) {    for (j=0;j<N;j++)      {	z = x[j][rn];	z2 = y[j][N-1-rn];	if (j > rn)	  for (k=0;k<=rn;k++)	    if (rn==k) x[j][k] = -z*x[rn][rn];	    else x[j][k] -= z*x[rn][k];	if (j < N-1-rn)	  for (k=N-1-rn;k<N;k++)	    if (N-1-rn == k) y[j][k] = -z2 * y[N-1-rn][N-1-rn];	    else y[j][k] -= z2*y[N-1-rn][k];      }  }    for (i=0;i<N;i++)    for (j=0;j<N;j++) {      z = 0;      for (k=(i > j ? i : j);k<N;k++)	z += x[k][j] * y[i][k];      w[i][j] = z;    }    free_fmat(x,N,N);  free_fmat(y,N,N);}/* Internal Function */void RNN::ff_computews(){  int u,v,k,l;  float *gp, *gm;  float deltwp, deltwm;  float tmp_wp, tmp_wm;  float gpu, gpv, gmu, gmv;  weights_changed = 1;  for (u=0;u<IP;u++) {    for (v=IP;v<IP+HN;v++) {      gpu = -1.0 / den[u];      gmu = -1.0 / den[u];      gpv = 1.0 / den[v];      gmv = -q[v]/den[v];      tmp_wp = tmp_wm = 0;      for (k=N-OP;k<N;k++)	{	  deltwp = gpu * q[u] * w[u][k];	  deltwm = gmu * q[u] * w[u][k];	  deltwp += gpv * q[u] * w[v][k];	  deltwm += gmv * q[u] * w[v][k];	  tmp_wp += (q[k]-GOP[k-(N-OP)])*deltwp;	  tmp_wm += (q[k]-GOP[k-(N-OP)])*deltwm;	}      wp[u][v] -= eta*tmp_wp;      wm[u][v] -= eta*tmp_wm;      if (wp[u][v] < 0) wp[u][v] = 0;      if (wm[u][v] < 0) wm[u][v] = 0;    }  }  for (u=IP;u<IP+HN;u++) {    for (v=IP+HN;v<N;v++) {      gpu = -1.0 / den[u];      gmu = -1.0 / den[u];      gpv = 1.0 / den[v];      gmv = -q[v]/den[v];      tmp_wp = tmp_wm = 0;      for (k=N-OP;k<N;k++)	{	  deltwp = gpu * q[u] * w[u][k];	  deltwm = gmu * q[u] * w[u][k];	  deltwp += gpv * q[u] * w[v][k];	  deltwm += gmv * q[u] * w[v][k];	  tmp_wp += (q[k]-GOP[k-(N-OP)])*deltwp;	  tmp_wm += (q[k]-GOP[k-(N-OP)])*deltwm;	}      wp[u][v] -= eta*tmp_wp;      wm[u][v] -= eta*tmp_wm;      if (wp[u][v] < 0) wp[u][v] = 0;      if (wm[u][v] < 0) wm[u][v] = 0;    }  }}/* Internal Function */void RNN::gen_computews(){  int u,v,k,l;  float *gp, *gm, deltwp, deltwm;  float tmp_wm, tmp_wp;    weights_changed = 1;  gp = new float[N];  gm = new float[N];  for (u=0;u<N;u++) {    for (v=0;v<N;v++)      if (C[u][v/8] & 1 << (v%8))        {	  gp[u] = gp[v] = 0;	  gm[u] = gm[v] = 0;	  gp[u] += -1/den[u];	  gm[u] += -1/den[u];	  gp[v] += 1/den[v];	  gm[v] += -q[v]/den[v];	  tmp_wp = tmp_wm = 0;          for (k=N-OP;k<N;k++)	    {	      deltwp = gp[u] * q[u] * w[u][k];	      deltwm = gm[u] * q[u] * w[u][k];	      deltwp += gp[v] * q[u] * w[v][k];	      deltwm += gm[v] * q[u] * w[v][k];	      tmp_wp += (q[k]-GOP[k-(N-OP)])*deltwp;	      tmp_wm += (q[k]-GOP[k-(N-OP)])*deltwm;	    }	  wp[u][v] -= eta*tmp_wp;	  wm[u][v] -= eta*tmp_wm;          if (wp[u][v] < 0) wp[u][v] = 0;          if (wm[u][v] < 0) wm[u][v] = 0;        }  }  delete gp;  delete gm;}/* Internal Function */void RNN::genop(){  int i, j, k;  int d;  float *qp;    qp = new float[N];  for (i=0;i<N;i++)    q[i] = 0;  do    {      d = 1;      for (i=0;i<N;i++)        {	  if (weights_changed) {	    r[i] = 0;	    for (j=0;j<N;j++)	      r[i] += wp[i][j] + wm[i][j];	    if (r[i] == 0) r[i] = rate;	  }          den[i] = r[i] + lm[i]*r[i];          num[i] = lp[i] * r[i];          for (j=0;j<N;j++)	    {	      num[i] += q[j] * wp[j][i];	      den[i] += q[j] * wm[j][i];	    }        }      for (i=0;i<N;i++) {	qp[i] = num[i]/den[i];	if (qp[i] > 1) qp[i] = 1;	if (fabs(qp[i] - q[i]) > 0.003) d = 0;	q[i] = qp[i];      }    } while (d == 0);  delete qp;  weights_changed = 0;}/* Internal Function */void RNN::ffop(){  int i, j, k;  int d,cnt = 0;  int b1,b2,e1,e2;    for (i=0;i<N;i++)    {      if (typ[i] == 1) {        b1 = IP;        e1 = IP+HN;        b2 = 0;        e2 = 0;      }      else if (typ[i] == 2) {        e1 = N;        b1 = IP+HN;        b2 = 0;        e2 = IP;      }      else if (typ[i] == 3) {        b1 = 0;        e1 = 0;        b2 = IP;        e2 = IP+HN;      }      if (weights_changed) {        r[i] = 0;        for (j=b1;j<e1;j++)          r[i] += wp[i][j] + wm[i][j];        if (r[i] == 0) r[i] = rate;      }      den[i] = r[i] + lm[i] * r[i];      num[i] = lp[i] * r[i];      for (j=b2;j<e2;j++)        {          num[i] += (float)q[j] * (float)wp[j][i];          den[i] += (float)q[j] * (float)wm[j][i];        }      q[i] = num[i]/den[i];      if (q[i] > 1) q[i] = 1;    }  weights_changed = 0;}/* Internal Function */void RNN::genop(int quant){  int i, j, k;  int d;  float *qp;    qp = new float[N];  for (i=0;i<N;i++)    q[i] = 0;  do    {      d = 1;      for (i=0;i<N;i++)        {	  if (weights_changed) {	    r[i] = 0;	    for (j=0;j<N;j++)	      r[i] += wp[i][j] + wm[i][j];	    if (r[i] == 0) r[i] = rate;	  }          den[i] = r[i] + lm[i]*r[i];          num[i] = lp[i] * r[i];          for (j=0;j<N;j++)	    {	      num[i] += q[j] * wp[j][i];	      den[i] += q[j] * wm[j][i];	    }        }      for (i=0;i<N;i++) {	qp[i] = num[i]/den[i];	if (qp[i] > 1) qp[i] = 1;	if (i >= IP) {	  qp[i] = (int) (qp[i] * (float)quant) / (float)quant;	  if (qp[i] > 1) qp[i] = 1;	}	if (fabs(qp[i] - q[i]) > 0.003) d = 0;	q[i] = qp[i];      }    } while (d == 0);  delete qp;  weights_changed = 0;}/* Internal function */void RNN::ffop(int quant){  int i, j;  int b1,b2,e1,e2;    for (i=0;i<N;i++)    {      switch(typ[i]) {      case 1:  {	b1 = IP;	e1 = IP + HN;	b2 = 0;	e2 = 0;	break;      }      case 2: {	e1 = N;	b1 = IP + HN;	b2 = 0;	e2 = IP;	break;      }      case 3: {	b1 = 0;	e1 = 0;	b2 = IP;	e2 = b2 + HN;	break;      }      }      if (weights_changed) {	r[i] = 0;	for (j=b1;j<e1;j++)	  r[i] += wp[i][j] + wm[i][j];	if (r[i] == 0) r[i] = rate;      }      den[i] = r[i] + lm[i] * r[i];      num[i] = lp[i] * r[i];      for (j=b2;j<e2;j++)        {          num[i] += q[j] * wp[j][i];          den[i] += q[j] * wm[j][i];        }      q[i] = num[i]/den[i];      if (q[i] > 1) q[i] = 1;      if (i >= IP) q[i] = ( (int) (q[i] * quant) / (float)quant);    }  weights_changed = 0;}/* create empty rnn */RNN::RNN() {  N = 0;  IP = 0;  HN = 0;  OP = 0;  rate = 0;  eta = 0;  wp = (float **)NULL;  wm = (float **)NULL;  w = (float **)NULL;  C = (u_char **)NULL;  den = (float *)NULL;  GOP = (float *)NULL;  typ = (u_char *)NULL;  q = (float *)NULL;  r = (float *)NULL;  lp = (float *)NULL;  lm = (float *)NULL;}RNN::~RNN(){  if (wp != (float **) NULL) free_fmat(wp,N,N);  if (wm != (float **) NULL) free_fmat(wm,N,N);  if (w != (float **) NULL) free_fmat(w,N,N);  if (C != (u_char **) NULL) free_cmat(C,N,(int)((N+7)/8.0));  delete GOP;  delete typ;  delete q;  delete den;  delete r;  delete lp;  delete lm;}/* Function: RNNPurpose:  create defined rnn Arguments:  ip: number of input type neurons  hn: number of hidden type neurons  op: number of output type neurons  ETA: learning rate for the RNN  RATE: defined "rate" for o/p neurons (usually 1)  rnntype: architecture type of RNN    0: feed-forward    1: fully recurrant    2: free form - define architecture using function connect*/RNN::RNN(int ip, int hn, int op, float ETA, float RATE, u_char rnntype){  new_rnn(ip,hn,op,ETA,RATE,rnntype);

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -