📄 whk_m.m
字号:
function main()
AllDataNum=300;
TrainDataNum=100;
TestDataNum=100;
% 目标函数时间序列计算
u=rands(1,AllDataNum+10);
y=zeros(1,AllDataNum+10);
for i=2:AllDataNum+1
numerator=16*u(i-1)+8*y(i-1);
denominator=3+4*u(i-1)^2+4*y(i-1)^2;
append=2/10*u(i-1)+2/10*y(i-1);
y(i)=numerator/denominator+append;
end
% 产生所有输入输出样本
AllDataIn=[];
AllDataOut=[];
for i=4:AllDataNum+1
NewIn=[u(i-1);y(i-1);u(i-2);y(i-2);y(i-3)];
AllDataIn=[AllDataIn NewIn];
AllDataOut=[AllDataOut y(i)];
end
TrainDataIn=AllDataIn(:,1:TrainDataNum);
TrainDataOut=AllDataOut(:,1:TrainDataNum);
TestDataIn=AllDataIn(:,TrainDataNum+1:TrainDataNum+TestDataNum);
TestDataOut=AllDataOut(:,TrainDataNum+1:TrainDataNum+TestDataNum);
InDim=5;
OutDim=1;
HiddenUnitNum=20;
MaxEpochs=10000;
lr=0.0005;
E0=4.5;
HiddenSenseLimit=0.005;
HiddenDeleteLen=50;
InDeleteLen=50;
W1=0.1*rands(HiddenUnitNum,InDim);
B1=0.1*rands(HiddenUnitNum,1);
W2=0.1*rands(OutDim,HiddenUnitNum);
B2=0.1*rands(OutDim,1);
W1Ex=[W1 B1];
W2Ex=[W2,B2];
TrainDataInEx=[TrainDataIn' ones(TrainDataNum,1)]';
ErrHistory=[];
Sens1History=[];
Sens2History=[];
FilteredHiddenSens=zeros(1,HiddenUnitNum);
FilteredInSens=zeros(1,InDim);
DeleteHidEpoch=[];
DeleteInEpoch=[];
for epoch=1:MaxEpochs
% 正向传播计算网络输出
HiddenOut=logsig(W1Ex*TrainDataInEx);
HiddenOutEx=[HiddenOut' ones(TrainDataNum,1)]';
NetworkOut=W2Ex*HiddenOutEx;
% 停止学习判断
Error=TrainDataOut-NetworkOut;
SSE=sumsqr(Error)
% 记录每次权值调整后的训练误差
ErrHistory=[ErrHistory SSE];
% 计算各隐节点和输入节点的灵敏度
ErrSign=sign(Error);
SignDelta1=W2'*ErrSign.*HiddenOut.*(1-HiddenOut);
HiddenSens=ErrSign*HiddenOut'.*W2;
InSens=sum(SignDelta1'*W1.*TrainDataIn');
FilteredHiddenSens=FilteredHiddenSens*0.8+HiddenSens*0.2;
FilteredInSens=FilteredInSens*0.8+InSens*0.2;
Sens1History=[Sens1History;FilteredHiddenSens/sum(abs(FilteredHiddenSens))];
Sens2History=[Sens2History;FilteredInSens/sum(abs(FilteredInSens))];
[val1,DeletePos1]=FindUnitsBeDeleted(Sens1History,...
HiddenSenseLimit,HiddenDeleteLen);
[val2,DeletePos2]=FindUnitsBeDeleted(Sens2History,...
InSenseLimit,InDeleteLen);
if SSE<E0,
if ((DeletePos1>0&DeletePos2<0)...
|(DeletePos1>0&DeletePos2>0&val1<val2))
DeletePos=DeletePos1;
FilteredHiddenSens=[FilteredHiddenSens(:,1:DeletePos-1)...
FilteredHiddenSens(:,DeletePos+1:HiddenUnitNum)];
Sens1History=[Sens1History(:,1:DeletePos-1)...
Sens1History(:,DeletePos+1:HiddenUnitNum)];
W1Ex=[W1Ex(1:Delete-1,:);...
W1Ex(DeletePos+1:HiddenUnitNum,:)];
W2Ex=[W2Ex(:,1:Delete-1);...
W2Ex(:,DeletePos+1:HiddenUnit+1)];
HiddenUnitNum=HiddenUnitNum-1;
W1=W1Ex(:,1:InDim);
W2=W2Ex(:,1:HiddenUnitNum);
DeleteHidEpoch=[DeleteHidEpoch epoch];
continue;
end
if ((DeletePos2>0&DeletePos1<0)...
|(DeletePos2>0&DeletePos1>0&val2<val1))
DeletePos=DeletePos2;
FilteredInSens=[FilteredInSens(:,1:DeletePos-1)...
FilteredInSens(:,DeletePos+1:InDim)];
Sens2History=[Sens2History(:,1:DeletePos-1)...
Sens2History(:,DeletePos+1:InDim)];
W1Ex=[W1Ex(:,1:Delete-1) W1Ex(:,DeletePos+1:InDim+1)];
TrainDataInEx=[TrainDataInEx(1:DeletePos-1,:);...
TrainDataInEx(DeletePos+1:InDim+1,:)];
TrainDataIn=[TrainDataIn(1:DeletePos-1,:);...
TrainDataIn(DeletePos+1:InDim,:)];
TestDataIn=[TestDataIn(1:DeletePos-1,:);...
TestDataIn(DeletePos+1:InDim,:)];
InDim=InDim-1;
W1=W1Ex(:,1:InDim);
DeleteInEpoch=[DeleteInEpoch epoch];
continue;
end
end
% 计算反向传播误差
Delta2=Error;
Delta1=W2'*Delta2.*HiddenOut.*(1-HiddenOut);
% 计算权值调节量
dW2Ex=Delta2*HiddenOutEx';
dW1Ex=Delta1*TrainDataInEx';
% 权值调节
W1Ex=W1Ex+lr*dW1Ex;
W2Ex=W2Ex+lr*dW2Ex;
% 分离隐层到输出层的初始权值,以便以后使用
W2=W2Ex(:,1:HiddenUnitNum);
end
% 绘制学习误差曲线
figure
grid
hold on
[xx,Num]=size(ErrHistory);
plot(1:Num,ErrHistory,'k-');
% 显示计算结果
W1=W1Ex(:,1:InDim);
B1=W1Ex(:,InDim+1);
W2;
B2=W2Ex(:,1+HiddenUnitNum);
HiddenUnitNum
InDim
DeleteHidEpoch
DeleteInEpoch
% 绘制学习误差曲线
figure
grid
hold on
[PtNum,InNum]=size(Sens2History);
for i=1:InNum
plot(1:PtNum,Sens2History(:,i)','b-');
end
TestNNOut=BPNet(TestDataIn,W1,B1,W2,B2);
TestError=sumsqr(TestDataOut-TestNNOut)
% 寻找需要被删除的隐节点
function [Sense,DeletePos]=FindUnitsBeDeleted(SensHistory,...
SenseDeLinit,DeleteWin)
[DataLen,UnitNum]=size(SensHistory);
if (DataLen<DeleteWin),
DeletePos=-1;
else
DeletePos=pos;
end
function Out=BPNet(In,W1,B1,W2,B2)
[XXX,InNum]=size(In);
HiddenOut=logsig(W1*In+repmat(B1,1,InNum));
Out=W2*HiddenOut+repmat(B2,1,InNum)
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -