📄 bpnet.cs
字号:
x[i] = p[isamp, i] / in_rate;
}
for (int i = 0; i < outNum; i++)
{
yd[i] = t[isamp, i] / in_rate;
}
//计算隐层的输入和输出
for (int j = 0; j < hideNum; j++)
{
o1[j] = 0.0;
for (int i = 0; i < inNum; i++)
{
o1[j] += w[i, j] * x[i];
}
x1[j] = 1.0 / (1.0 + Math.Exp(-o1[j] - b1[j])); //隐层输出
}
//计算输出层的输入和输出
for (int k = 0; k < outNum; k++)
{
o2[k] = 0.0;
for (int j = 0; j < hideNum; j++)
{
o2[k] += v[j, k] * x1[j];
}
x2[k] = 1.0 / (1.0 + Math.Exp(-o2[k] - b2[k])); //网络输出
//x2[k] = o2[k] + b2[k]; //线性函数
}
//计算输出层误差和均方差
for (int k = 0; k < outNum; k++)
{
qq[k] = (yd[k] - x2[k]) * x2[k] * (1.0 - x2[k]); //输出层的误差,用于更新V,很重要
for (int j = 0; j < hideNum; j++)
{
dv[j, k] += qq[k] * x1[j];
//dv[j, k] += rate * qq[k] * x1[j];
}
db2[k] += qq[k];
e += (yd[k] - x2[k]) * (yd[k] - x2[k]);
}
//计算隐层误差
for (int j = 0; j < hideNum; j++)
{
pp[j] = 0.0;
for (int k = 0; k < outNum; k++)
{
pp[j] += qq[k] * v[j, k]; //误差的反向传递
}
pp[j] = pp[j] * x1[j] * (1 - x1[j]);
for (int i = 0; i < inNum; i++)
{
dw[i, j] += pp[j] * x[i];
//dw[i, j] += rate * pp[j] * x[i];
}
db1[j] += pp[j];
}
}//end isamp
computerSWV(dv0,dv,sv);
adjustWV(v, sv);
computerSWV(dw0,dw,sw);
adjustWV(w, sw);
computerSWV(db20, db2, sb2);
adjustWV(b2, sb2);
computerSWV(db10, db1, sb1);
adjustWV(b1, sb1);
e = Math.Sqrt(e);
}//end train
//Polak-Ribiere共轭梯度法
public void computerSWV( double[,] dv0,double[,] dv,double[,] sv )
{
int hnum = dv0.GetLength(0);
int onum = dv0.GetLength(1);
double dv_dv0,pata0=0,pata1=0,pata;
for (int k = 0; k < onum; k++)
{
for (int j = 0; j < hnum; j++)
{
dv_dv0 = dv[j, k] - dv0[j, k];
pata0 += dv[hnum - j - 1, onum - k - 1] * dv_dv0;
pata1 += dv0[j, k] * dv0[j, k];
}
}
pata=pata0/pata1;
for (int k = 0; k < onum; k++)
{
for (int j = 0; j < hnum; j++)
{
sv[j, k] = sv[j, k] * pata + dv[j, k];
}
}
}
public void computerSWV(double[] dv0, double[] dv, double[] sv)
{
int hnum = dv0.GetLength(0);
double dv_dv0, pata0 = 0, pata1 = 0, pata;
for (int k = 0; k < hnum; k++)
{
dv_dv0 = dv[k] - dv0[k];
pata0 += dv[hnum - k - 1] * dv_dv0;
pata1 += dv0[k] * dv0[k];
}
pata = pata0 / pata1;
for (int k = 0; k < hnum; k++)
{
sv[k] = sv[k] * pata + dv[k];
}
}
public void adjustWV(double[,] w, double[,] dw)
{
for (int i = 0; i < w.GetLength(0); i++)
{
for (int j = 0; j < w.GetLength(1); j++)
{
w[i, j] += rate*dw[i, j];
}
}
}
public void adjustWV(double[] w, double[] dw)
{
for (int i = 0; i < w.Length; i++)
{
w[i] += rate*dw[i];
}
}
public void resetWV(double[,] dw)
{
for (int i = 0; i < dw.GetLength(0); i++)
{
for (int j = 0; j < dw.GetLength(1); j++)
{
dw[i, j]=0;
}
}
}
public void resetWV(double[] dw)
{
for (int i = 0; i < dw.Length; i++)
{
dw[i]=0;
}
}
//数据仿真函数
public double[] sim(double[] psim)
{
for (int i = 0; i < inNum; i++) //仿真数据归一化
x[i] = psim[i] / in_rate;
for (int j = 0; j < hideNum; j++)
{
o1[j] = 0.0;
for (int i = 0; i < inNum; i++)
o1[j] = o1[j] + w[i, j] * x[i];
x1[j] = 1.0 / (1.0 + Math.Exp(-o1[j] - b1[j]));
}
for (int k = 0; k < outNum; k++)
{
o2[k] = 0.0;
for (int j = 0; j < hideNum; j++)
o2[k] = o2[k] + v[j, k] * x1[j];
x2[k] = 1.0 / (1.0 + Math.Exp(-o2[k] - b2[k]));
x2[k] = in_rate * x2[k]; //输出数据还原
}
return x2;
} //end sim
//保存矩阵w,v
public void saveMatrix(double[,] w, string filename)
{
StreamWriter sw = File.CreateText(filename);
for (int i = 0; i < w.GetLength(0); i++)
{
for (int j = 0; j < w.GetLength(1); j++)
{
sw.Write(w[i, j] + " ");
}
sw.WriteLine();
}
sw.Close();
}
//保存矩阵b1,b2
public void saveMatrix(double[] b, string filename)
{
StreamWriter sw = File.CreateText(filename);
for (int i = 0; i < b.Length; i++)
{
sw.Write(b[i] + " ");
}
sw.Close();
}
//读取矩阵W,V
public void readMatrixW(double[,] w, string filename)
{
StreamReader sr;
try
{
sr = new StreamReader(filename, Encoding.GetEncoding("gb2312"));
String line;
int i = 0;
while ((line = sr.ReadLine()) != null)
{
string[] s1 = line.Trim().Split(' ');
for (int j = 0; j < s1.Length; j++)
{
w[i, j] = Convert.ToDouble(s1[j]);
}
i++;
}
sr.Close();
}
catch (Exception e)
{
// Let the user know what went wrong.
Console.WriteLine("The file could not be read:");
Console.WriteLine(e.Message);
}
}
//读取矩阵b1,b2
public void readMatrixB(double[] b, string filename)
{
StreamReader sr;
try
{
sr = new StreamReader(filename, Encoding.GetEncoding("gb2312"));
String line;
int i = 0;
while ((line = sr.ReadLine()) != null)
{
b[i] = Convert.ToDouble(line);
i++;
}
sr.Close();
}
catch (Exception e)
{
// Let the user know what went wrong.
Console.WriteLine("The file could not be read:");
Console.WriteLine(e.Message);
}
}
}//end bpnet
} //end namespace
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -