⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 irisrbf2.m

📁 RBF神经网络应用于IRIS数据集的例子
💻 M
字号:
close all
clear
clc
load allirisdata.mat;
N1=40;    %每一类中用来训练的样本数
N2=50-N1;    %每一类中用来测试的样本数  
Train=[IrisData(1:N1,:);IrisData(51:(50+N1),:);IrisData(101:(100+N1),:)]';
Test=[IrisData((N1+1):50,:);IrisData((51+N1):100,:);IrisData((101+N1):150,:)]';
TrainOut=[Label(1:N1);Label(51:(50+N1));Label(101:(100+N1))]';
TestOut=[Label((N1+1):50);Label((51+N1):100);Label((101+N1):150)]';
TrainNum=size(Train,2);
TestNum=size(Test,2);
InDim=size(IrisData,2);
UnitNum=3;

Epoches=0;


Centers=8*rand(InDim,UnitNum);
SP=5*rand(1,UnitNum);
W=5*rand(1,UnitNum)-10;
lrCent=0.01;
lrSP=0.01;
lrW=0.01;
ErrHistory=[];

for epoch=1:inf
    AllDist=dist(Centers',Train);
    SPMat=repmat(SP',1,TrainNum);
    UnitOut=radbas(AllDist./SPMat);
    NetOut=W*UnitOut;
    Error=TrainOut-NetOut;
    
    SSE=sumsqr(Error);
    
    ErrHistory=[ErrHistory SSE];
    
    for i=1:UnitNum
        Cent=zeros(InDim,1);
        SPsum=0;
        Wsum=0;
        for j=1:TrainNum
            Cent=Cent+Error(j)*radbas(dist(Centers(:,i)',Train(:,j))/SP(i))*(Train(:,j)-Centers(:,i));
            SPsum=SPsum+Error(j)*radbas(dist(Centers(:,i)',Train(:,j))/SP(i))*dist(Centers(:,i)',Train(:,j));
            Wsum=Wsum+Error(j)*radbas(dist(Centers(:,i)',Train(:,j))/SP(i));
        end        
        CentGrad=Cent*W(i)/(SP(i)^2);
        SPGrad=SPsum*W(i)/(SP(i)^3);
        WGrad=Wsum;
        Centers(:,i)=Centers(:,i)+lrCent*CentGrad;
        SP(i)=SP(i)+lrSP*SPGrad;
        W(i)=W(i)+lrW*WGrad;
    end
    Epoches=Epoches+1;
end


TestDistance=dist(Centers',TestSamIn);
TestSpreadsMat=repmat(SP',1,TestSamNum);
TestHiddenUnitOut=radbas(TestDistance./TestSpreadsMat);
TestNNOut=W*TestHiddenUnitOut;


%回判
AllDist=dist(Centers',Train);
SPMat=repmat(SP',1,TrainNum);
UnitOut=radbas(AllDist./SPMat);
TrainNNOut=W*UnitOut;
ErrInd1=find(round(TrainNNOut)~=TrainOut);
errorrate1=length(ErrInd1)/TrainNum;
disp('训练样本数')
disp(TrainNum)
disp('对训练样本的回判错误率为');
disp(errorrate1);
TestDistance=dist(Centers',Test);
TestSpreadsMat=repmat(Spreads,1,TestNum);
TestHiddenUnitOut=radbas(TestDistance./TestSpreadsMat);
TestNNOut=W*TestHiddenUnitOut;
ErrInd2=find(round(TestNNOut)~=TestOut);
errorrate2=length(ErrInd2)/TestNum;
disp('测试样本数')
disp(TestNum)
disp('对测试样本判断的错误率为');
disp(errorrate2);

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -