⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 irisrbf3.m

📁 RBF神经网络应用于IRIS数据集的例子
💻 M
字号:
close all
clear
clc
t=cputime;
load allirisdata.mat;
N1=40;    %每一类中用来训练的样本数
N2=50-N1;    %每一类中用来测试的样本数  
Train=[IrisData(1:N1,:);IrisData(51:(50+N1),:);IrisData(101:(100+N1),:)]';
Test=[IrisData((N1+1):50,:);IrisData((51+N1):100,:);IrisData((101+N1):150,:)]';
TrainOut=[Label(1:N1);Label(51:(50+N1));Label(101:(100+N1))]';
TestOut=[Label((N1+1):50);Label((51+N1):100);Label((101+N1):150)]';
TrainNum=size(Train,2);
TestNum=size(Test,2);
InDim=size(IrisData,2);
UnitNum=3;
E0=0.1;

Epoches=0;


Centers=8*rand(InDim,UnitNum);
SP=5*rand(1,UnitNum);
W=5*rand(1,UnitNum)-10;
lrCent=1;
lrSP=1;
lrW=1;
ErrHistory=[];

for epoch=1:inf
    AllDist=dist(Centers',Train);
    SPMat=repmat(SP',1,TrainNum);
    UnitOut=radbas(AllDist./SPMat);
    NetOut=W*UnitOut;
    Error=TrainOut-NetOut;
    
    SSE=sumsqr(Error);
    
    ErrHistory=[ErrHistory SSE];
    
    if SSE<E0
        break
    end
    
    for i=1:UnitNum
        Cent=(Train-repmat(Centers(:,i),1,TrainNum))*(Error.*radbas(dist(Centers(:,i)',Train)/SP(i)))';
        SPsum=Error.*radbas(dist(Centers(:,i)',Train)/SP(i))*(dist(Centers(:,i)',Train))';
        Wsum=sum(Error.*radbas(dist(Centers(:,i)',Train)/SP(i)));
  
        CentGrad=Cent*W(i)/(SP(i)^2);
        SPGrad=SPsum*W(i)/(SP(i)^3);
        WGrad=Wsum;
        Centers(:,i)=Centers(:,i)+lrCent*CentGrad;
        SP(i)=SP(i)+lrSP*SPGrad;
        W(i)=W(i)+lrW*WGrad;
    end
    Epoches=Epoches+1;
end


TestDistance=dist(Centers',TestSamIn);
TestSpreadsMat=repmat(SP',1,TestSamNum);
TestHiddenUnitOut=radbas(TestDistance./TestSpreadsMat);
TestNNOut=W*TestHiddenUnitOut;


%回判
AllDist=dist(Centers',Train);
SPMat=repmat(SP',1,TrainNum);
UnitOut=radbas(AllDist./SPMat);
TrainNNOut=W*UnitOut;
ErrInd1=find(round(TrainNNOut)~=TrainOut);
errorrate1=length(ErrInd1)/TrainNum;
disp('训练样本数')
disp(TrainNum)
disp('对训练样本的回判错误率为');
disp(errorrate1);
TestDistance=dist(Centers',Test);
TestSpreadsMat=repmat(Spreads,1,TestNum);
TestHiddenUnitOut=radbas(TestDistance./TestSpreadsMat);
TestNNOut=W*TestHiddenUnitOut;
ErrInd2=find(round(TestNNOut)~=TestOut);
errorrate2=length(ErrInd2)/TestNum;
disp('测试样本数')
disp(TestNum)
disp('对测试样本判断的错误率为');
disp(errorrate2);
disp('Total time is');
disp(cputime-t)

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -