📄 bpx_train.m
字号:
%此为BP网络训练程序
function retstr = trainnet(NetPara,TrainPara,InputFun,OutputFun,DataDir)
NNTWARN OFF
retdouble = -1;
%保留原目录
olddir=pwd;
%进入数据所在目录
cd(DataDir);
deltalin(1);
deltalog(1);
deltatan(1);
% 网络参数
InputDim=NetPara(1); %输入层节点数
OutputDim=NetPara(2); %输出层节点数
MidDim=NetPara(3); %中间层节点数
data_num=NetPara(4); %训练数据组数
% 网络训练参数
if (TrainPara == -1)
df = 25; %显示间隔次数 25
me = 1000; %最大循环次数 1000
eg = 0.001; %目标误差 0.02
lr = 0.001; %学习速率 0.001
lr_inc = 0.001; %学习速率增加比率 0.001
lr_idec = 10; %学习速率减少比率 10
mom_const = 0.1; %动量常数 0.1
err_ratio = 1e10; %最大误差比率 1e10
else
df=TrainPara(1); %显示间隔次数 25
me=TrainPara(2); %最大循环次数 1000
eg=TrainPara(3); %目标误差 0.02
lr=TrainPara(4); %学习速率 0.001
lr_inc=TrainPara(5); %学习速率增加比率 0.001
lr_idec=TrainPara(6); %学习速率减少比率 10
mom_const=TrainPara(7); %动量常数 0.1
err_ratio=TrainPara(8); %最大误差比率 1e10
end
% 输入层到中间层的传递函数
if (length(InputFun)==0)
InputFun = 'tansig';
end
% 中间层到输出层的传递函数
if (length(OutputFun)==0)
OutputFun = 'purelin'
end
tp=[df me eg lr lr_inc lr_idec mom_const err_ratio];
%tp=[df me eg 0.001 0.001 10 0.1 1e10];
frin_para=fopen('input_para.txt','r'); %输入数据文件
frout_para=fopen('output_para.txt','r'); %输出数据文件
[p,count]=fscanf(frin_para,'%f',[InputDim,data_num]); %取输入数据
[t,count]=fscanf(frout_para,'%f',[OutputDim,data_num]); %取输出数据
fclose(frin_para);
fclose(frout_para);
[r,q]=size(p); [s2,q]=size(t);
[w1,b1]=rands(MidDim,r);
[w2,b2]=rands(s2,MidDim);
w1,b1,w2,b2,epochs,errors]=trainbpx(w1,b1,InputFun,w2,b2,OutputFun,p,t,tp);
%将参数值写入文件
fww1=fopen('w1.dat','w');
fwb1=fopen('b1.dat','w');
fww2=fopen('w2.dat','w');
fwb2=fopen('b2.dat','w');
fprintf(fww1,'%9.4f ',w1);
fprintf(fwb1,'%9.4f ',b1);
fprintf(fww2,'%9.4f\n',w2);
fprintf(fwb2,'%9.4f\n',b2);
fclose(fww1);
fclose(fwb1);
fclose(fww2);
fclose(fwb2);
cd(olddir)
retstr=epochs;
close all;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -