⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 nettrain.m

📁 用神经网络进行训练和仿真的接口函数
💻 M
字号:
%此为BP网络训练程序,采用BP神经网络算法中的LM算法实现
%样本数据来源于数据.xls,数据保存在input_para1.txt文件中,output_para1.txt文件对应分类结果,示例
%中为5类
function retstr = NetTrain(ModelNo,NetPara,TrainPara,InputFun,OutputFun,DataDir)
NNTWARN OFF

retstr=-1;
%%%%%%%%%%主要调整以下三个参数%%%%%%%%%%%%%%%%
ModelNo='1';     %神经网络模型编号
NetPara(1)=6;    %输入层节点数
NetPara(2)=1;    %输出层节点数
NetPara(3)=4;    %隐层节点数
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% 样本数据文件
FileName='sample.xls';

% 训练参数
TrainPara(1)=25;
TrainPara(2)=500000;
TrainPara(3)=0.0001;
TrainPara(4)=0.001;
TrainPara(5)=0.001;
TrainPara(6)=10;
TrainPara(7)=0.1;
TrainPara(8)=1e10;
DataDir='.';
InputFun = 'tansig';
OutputFun = 'purelin';
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

%保留原目录
olddir=pwd;
%进入数据所在目录
cd(DataDir);

% 网络参数
InputDim=NetPara(1);       %输入层节点数
OutputDim=NetPara(2);      %输出层节点数
MidDim=NetPara(3);         %中间层节点数

% 网络训练参数
if (TrainPara == -1)
    df = 25;            %显示间隔次数 25
    me = 1000;             %最大循环次数 1000
    eg = 0.001;             %目标误差 0.02

    lr = 0.001;            %学习速率 0.001
    lr_inc = 0.001;        %学习速率增加比率 0.001
    lr_idec = 10;       %学习速率减少比率 10
    mom_const = 0.1;     %动量常数 0.1
    err_ratio = 1e10;     %最大误差比率 1e10
else
    df=TrainPara(1);            %显示间隔次数 25
    me=TrainPara(2);             %最大循环次数 1000
    eg=TrainPara(3);             %目标误差 0.02

    lr=TrainPara(4);            %学习速率 0.001
    lr_inc=TrainPara(5);        %学习速率增加比率 0.001
    lr_idec=TrainPara(6);       %学习速率减少比率 10
    mom_const=TrainPara(7);     %动量常数 0.1
    err_ratio=TrainPara(8);     %最大误差比率 1e10
end
% 输入层到中间层的传递函数
if (length(InputFun)==0)
    InputFun = 'tansig';
end
% 中间层到输出层的传递函数  
if (length(OutputFun)==0)
    OutputFun = 'purelin';
end    
tp=[df me eg lr lr_inc lr_idec mom_const err_ratio];

%取样本数据
[sample_data] = XLSREAD(FileName);


%输入向量归一化
%用线性函数把数据转换为0.001-0.0.9995之间
%p和t分别表示网络的输入和输出
for i=1:7
 c=sample_data(:,i);
 d=(c-min(c))/(max(c)-min(c));
 d=((0.5-0.001)/0.5)*d+0.001;
 if(i<7)
   p(:,i)=d;
 else
   t=d;  
 end
end 
% 反归一化
%for i=1:6
% c=sample_data(:,i);
% d=((p(:,i)-0.001)*0.5)/(0.5-0.001);  
% d=d*(max(c)-min(c))+min(c);
% a(:,i)=d;
%end 

[r,q]=size(p'); [s2,q]=size(t');

[w1,b1]=rands(MidDim,r);
[w2,b2]=rands(s2,MidDim);

NNTWARN OFF
[w1,b1,w2,b2,epochs,errors]=trainlm(w1,b1,InputFun,w2,b2,OutputFun,p',t',tp);

%将参数值写入文件
fww1=fopen(sprintf('w%s%s',ModelNo,'1.dat'),'w');
fwb1=fopen(sprintf('b%s%s',ModelNo,'1.dat'),'w');
fww2=fopen(sprintf('w%s%s',ModelNo,'2.dat'),'w');
fwb2=fopen(sprintf('b%s%s',ModelNo,'2.dat'),'w');

fprintf(fww1,'%9.4f ',w1);
fprintf(fwb1,'%9.4f ',b1);
fprintf(fww2,'%9.4f\n',w2);
fprintf(fwb2,'%9.4f\n',b2);

fclose(fww1);
fclose(fwb1);
fclose(fww2);
fclose(fwb2);

ferr=fopen(sprintf('lm_err%s%s',ModelNo,'.dat'),'w');
fprintf(ferr,'%10.6f\n',errors);
fclose(ferr);

cd(olddir);

retstr=epochs;
close all;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -