⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 whk_n.m

📁 魏海坤编著的《神经网络结构设计的理论与方法》 国防工业出版社出版
💻 M
字号:
function main()
InDim=5;
OutDim=1;
Hidden1UnitNum=5;
Hidden2UnitNum=5;
AllSamNum=32;
TrainDataNum=24;
TestDataNum=8;

% 得到训练样本
AllSamIn=[];
for num=0:AllSamNum-1
    str=dec2bin(num);
    [xxx,len]=size(str);
    vect=[];
    for i=1:len
        ch=str(i);
        vect=[vect str2num(ch)];
    end
    if (len<InDim),
        vect=[zeros(1,InDim-len) vect];
    end
    AllSamIn=[AllSamIn vect'];
end
AllSamOut=(AllSamIn(1,:)|AllSamIn(2,:))&...
                       (AllSamIn(3,:)|AllSamIn(4,:)|AllSamIn(5,:));
PermPos=randperm(AllSamNum);
TrainDataIn=AllSamIn(:,PermPos(:,1:TrainDataNum));
TrainDataOut=AllSamOut(:,PermPos(:,1:TrainDataNum));

TestDataIn=AllSamIn(:,PermPos(:,TrainDataNum+1:TrainDataNum+TestDataNum));
TestDataOut=AllSamOut...
             (:,PermPos(:,TrainDataNum+1:TrainDataNum+TrainDataNum));
W1=0.5*rands(Hidden1UnitNum,InDim);
B1=0.5*rands(Hidden1UnitNum,1);
W2=0.5*rands(Hidden2UnitNum,Hidden1UnitNum);
B2=0.5*rands(Hidden2UnitNum,1);
W3=0.5*rands(OutDim,Hidden2UnitNum);
B3=0.5*rands(OutDim,1);

lr=0.9;
Alpha=0.9;
MaxEpoch=2000;
ErrCombine=0.001;
ErrGoal=0.00005;
UnitsCombineThreshold=0.8;
BiasCombineThreshold=0.001;

W1Ex=[W1 B1];
W2Ex=[W2 B2];
W3Ex=[W3 B3];
TrainDataInEx=[TrainDataIn' ones(TainDataNum,1)]';
ErrHistory=[];
ReSizeFlag=1;
for epoch=1:MaxEpoch
    if (ReSizeFlag==1),
        [Hidden2UnitNum,Hidden1UnitNum]=size(W2Ex);
        Hidden1UnitNum=Hidden1UnitNum-1;
        W2=W2Ex(:,1:Hidden1UnitNum);
        W3=W3Ex(:,1:Hidden2UnitNum);
        dW1Ex=zeros(size(W1Ex));
        dW2Ex=zeros(size(W2Ex));
        dW3Ex=zeros(size(W3Ex));
        ReSizeFlag=0;
    end
    
    % 正向传播计算网络输出
    Hidden1Out=logsig(W1Ex*TrainDataInEx);
    Hidden1OutEx=[Hidden1Out' ones(TainDataNum,1)]';
    Hidden2Out=logsig(W2Ex*Hidden1OutEx);
    Hidden2OutEx=[Hidden2Out' ones(TainDataNum,1)]';
    NetworkOut=logsig(W3Ex*Hidden2OutEx);
    
    % 停止学习判断
    ErrHistory=[ErrHistory SSE];
    if (SSE<ErrCombine),
        % 计算隐节点输出标准差
        Hidden1Var=var(Hidden1Out')';
        Hidden2Var=var(Hidden2Out')';
        
        % 计算隐节点输出相关系数
        Hidden1Corr=corrcoef(Hidden1Out');
        Hidden2Corr=corrcoef(Hidden2Out');
        
        % 检查第一隐层是否有隐节点需要合并
        [Hidden1unit1,Hidden1unit2]=FindUnitToCombine(Hidden1Corr,...
                 Hidden1Var,UnitsCombineThreshold,BiasCombineThreshold);
        if (Hidden1unit1>0),
            if (Hidden1unit2)>0,
                [a,b]=LinearReg(Hidden1Out(Hidden1unit1,:),...
                    Hidden1Out(Hidden1unit2,:));
                epoch
                CombineType=11
                DrawCorrelatedUnitsOut(Hidden1Out...
                    (Hidden1unit1,:),Hidden1Out(Hidden1unit2,:));
                [W1Ex,W2Ex]=CombineTwoUnits(Hidden1unit1,...
                    Hidden1unit2,a,b,W1Ex,W2Ex);
            else
                epoch
                CombineType=12
                DrawBiasedUnitOut(Hidden1Out(Hidden1unit1,:));
                UnitMean=mean(Hidden1Out(Hidden1unit1,:));
                [W1Ex,W2Ex]=CombineUnitToBias...
                    (Hidden1unit1,UnitMean,W1Ex,W2Ex);
            end
            ReSizeFlag=1;
            continue;
        end
        % 检查第二隐层是否有隐节点需要合并
        [Hidden2unit1,Hidden2unit2]=FindUnitToCombine(Hidden2Corr,...
                 Hidden2Var,UnitsCombineThreshold,BiasCombineThreshold);
        if (Hidden2unit1>0),
            if (Hidden2unit2)>0,
                epoch
                CombineType=21
                [a,b]=LinearReg(Hidden2Out(Hidden2unit1,:),...
                    Hidden2Out(Hidden2unit2,:));
                DrawCorrelatedUnitsOut(Hidden2Out...
                    (Hidden2unit1,:),Hidden2Out(Hidden2unit2,:));
                [W2Ex,W3Ex]=CombineTwoUnits(Hidden2unit1,...
                    Hidden2unit2,a,b,W2Ex,W3Ex);
            else
                epoch
                CombineType=22
                DrawBiasedUnitOut(Hidden2Out(Hidden2unit1,:));
                UnitMean=mean(Hidden2Out(Hidden2unit1,:));
                [W2Ex,W3Ex]=CombineUnitToBias...
                    (Hidden2unit1,UnitMean,W2Ex,W3Ex);
            end
            ReSizeFlag=1;
            continue;
        end
    end
    
    if (SSE<ErrGoal),break,end
    
    % 计算反向传播误差
    Delta3=Error.*NetworkOut.*(1-NetworkOut);
    Delta2=W3'*Delta3.*Hidden2Out.*(1-Hidden2Out);
    Delta1=W2'*Delta2.*Hidden1Out.*(1-Hidden1Out);
    
    % 保存前一时刻的权值调整量
    dW1Ex0=lr*dW1Ex;
    dW2Ex0=lr*dW2Ex;
    dW3Ex0=lr*dW3Ex;
    
    % 计算权值调整量
    dW3Ex=Delta3*Hidden2OutEx';
    dW2Ex=Delta2*Hidden1OutEx';
    dW1Ex=Delta1*TrainDataInEx';
    
    % 权值调节
    W1Ex=W1Ex+lr*dW1Ex+Alpha*dW1Ex0;
    W2Ex=W2Ex+lr*dW2Ex+Alpha*dW2Ex0;
    W3Ex=W3Ex+lr*dW3Ex+Alpha*dW3Ex0;
    
    % 分离隐层到输出层的权值,以便以后使用
    W2=W2Ex(:,1:Hidden1UnitNum);
    W3=W3Ex(:,1:Hidden2UnitNum);
end
Hidden1UnitNum
Hidden2UnitNum
W1=W1Ex(:,1:InDim);
B1=W1Ex(:,InDim+1);
W2=W2Ex(:,1:Hidden1UnitNum);
B2=W2Ex(:,1+Hidden1UnitNum);
W3=W3Ex(:,1:Hidden2UnitNum);
B3=W3Ex(:,1+Hidden2UnitNum);
TestNNOut=BPNet(TestDataIn,W1,B1,W2,B2,W3,B3);
BinOut=TestNNOut>0.5;
ErrNum=sum(TestDataOut-BinOut)

% 绘制学习误差曲线
figure
echo off
axis on
grid
hold on
[xx,Num]=size(ErrHistory);
semilogy(1:Num,ErrHistory,'k-');

function Out=BPNet(In,W1,B1,W2,B2,W3,B3)
[xxx,InNum]=size(In);
Hidden1Out=logsig(W1*In+repmat(B1,1,InNum));
Hidden2Out=logsig(W2*Hidden1Out+repmat(B2,1,InNum));
Out=logsig(W3*Hidden2Out+repmat(B3,1,InNum));

% 两个隐节点合并
function [W1Ex,W2Ex]=CombineTwoUnits(unit1,unit2,a,b,W1Ex,W2Ex)
[xxx.BiasCol]=size(W2Ex);
W2Ex(:,unit1)=W2Ex(:,unit1)+a*W2Ex(:,unit2);
W2Ex(:,BiasCol)=W2Ex(:,BiasCol)+b*W2Ex(:,unit2);
W1Ex(unit,:)=[];
W2Ex(:,unit)=[];

% 将隐节点合并到偏移
function [W1Ex,W2Ex]=CombineUnitToBias(unit,UnitMean,W1Ex,W2Ex)
[xxx.BiasCol]=size(W2Ex);
W2Ex(:,BiasCol)=W2Ex(:,BiasCol)+UnitMean*W2Ex(:,unit);
W1Ex(unit,:)=[];
W2Ex(:,unit)=[];

% 寻找需要合并的隐节点
function [unit1,unit2]=FindUnitCombine(HiddenCorr,HiddenVar,...
                         UnitsCombineThreshold,BiasCombineThreshold)
CorrTri=triu(HiddenCorr)-eye(size(HiddenCorr));
while (1)
    [Val,Pos]=max(abs(CorrTri));
    [MaxCorr,unit2]=max(Val);
    if (MaxCorr<UnitsCombineThreshold)
        unit1=0;unit2=0;
        break
    end
    unit1=Pos(unit2);
    if (HiddenVar(unit1)>BiasCombineThreshold&...
             HiddenVar(unit2)>BiasCombineThreshold)
         break;
    else
        CorrTri(unit1,unit2)=0;
    end
end
if (unit1>0) return;end
[MinVar,unit]=min(HiddenVar);
if (MinVar<BiasCombineThreshold)
    unit1=unit;
    unit2=0;
end

% 线性回归
function [a,b]=LinearReg(vect1,vect2)
[xxx,n]=size(vect1);
mean_v1=mean(vect1);
mean_v2=mean(vect2);
a=(vect1*vct2'/n-mean_v1*mean_v2)/(vect1*vect1'/n-mean_v1^2);
b=mean_v2-a*mean_v1;

% 绘制两相关隐节点对所有样本的输出
function DrawCorrelatedUnitsOut(UnitOut1,UnitOut2)
[xxx,PtNum]=size(UnitOut1);
figure
echo off
axis([0 PtNum 0 1])
axis on
grid
hold on
plot(1:PtNum,UnitOut1,'b-')
plot(1:PtNum,UnitOut2,'k-')

% 绘制标准差较小的单个隐节点输出
function DrawBiasedUnitOut(UnitOut)
[xxx,PtNum]=size(UnitOut);
figure('Position',[300 300 400 300])
echo off
axis([0 PtNum 0 1])
axis on
grid
hold on
plot(1:PtNum,UnitOut1,'k-')

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -