📄 rbf.m
字号:
%径向基网络 采用均值聚类方法,隐含层输出函数采用高斯函数
a=load('a.txt');
t=load('t.txt')
p=a';
t1=t';
tic
samnum=528;
testsamnum=528;
indim=10;
outdim=4;
clusternum=2;%13个聚类中心
overdap=1.5;%重叠系数
centers=p(:,1:clusternum);%初始中心
numberinclusters=zeros(clusternum,1);%各类中的样本数,初始化为零
indexinclusters=zeros(clusternum,samnum);%各类所含样本的索引号
while 1
numberinclusters=zeros(clusternum,1);%各类中的样本数,初始化为零
indexinclusters=zeros(clusternum,samnum);%各类所含样本的索引号
%按最小距离原则对样本进行分类
for i=1:samnum
alldistance=dist(centers',p(:,i));
[mindist,pos]=min(alldistance);
numberinclusters(pos)=numberinclusters(pos)+1;%最小距离数的个数加1
indexinclusters(pos,numberinclusters(pos))=i;%索引号
end
oldcenters=centers; %保存旧的聚类中心
for i=1:clusternum %重新计算各类的聚类中心
index=indexinclusters(i,1:numberinclusters(i));
centers(:,i)=mean(p(:,index)')';
end
equalnum=sum(sum(centers==oldcenters));%判断聚类中心是否一致
if equalnum==indim*clusternum;
break,
end
end
%计算各隐节点的扩展常数
alldistances=dist(centers',centers);%计算隐节点数据中心的距离
maximum=max(max(alldistances));
for i=1:clusternum
alldistances(i,i)=maximum+1;
end
spreads=overdap*min(alldistances)';
distance=dist(centers',p);
spreadsmat=repmat(spreads,1,samnum);
hiddenunitout=radbas(distance./spreadsmat);
hiddenunitoutex=[hiddenunitout' ones(samnum,1)]';
w2ex=t1*pinv(hiddenunitoutex);
w2=w2ex(:,1:clusternum);
b2=w2ex(:,clusternum+1);
testdistance=dist(centers',p);
testspreadsmat=repmat(spreads,1,testsamnum);
testhiddenunitout=radbas(testdistance./testspreadsmat);
s=w2*testhiddenunitout+repmat(b2,1,testsamnum)
w2
b2
s %最后的输出
for i=1:4
for j=1:528
if s(i,j)>=0.52
s(i,j)=0.99;
else s(i,j)=0.01;
end
end
end
count(1,1:11)=0;
for j=1:528
if s(:,j)==t1(:,j)
if s(:,j)==[0.01;0.01;0.01;0.99]
count(1)=count(1)+1;
elseif s(:,j)==[0.01;0.01;0.99;0.01]
count(2)=count(2)+1;
elseif s(:,j)==[0.01;0.01;0.99;0.99]
count(3)=count(3)+1;
elseif s(:,j)==[0.01;0.99;0.01;0.01]
count(4)=count(4)+1;
elseif s(:,j)==[0.01;0.99;0.01;0.99]
count(5)=count(5)+1;
elseif s(:,j)==[0.01;0.99;0.99;0.01]
count(6)=count(6)+1;
elseif s(:,j)==[0.01;0.99;0.99;0.99]
count(7)=count(7)+1;
elseif s(:,j)==[0.99;0.01;0.01;0.01]
count(8)=count(8)+1;
elseif s(:,j)==[0.99;0.01;0.01;0.99]
count(9)=count(9)+1;
elseif s(:,j)==[0.99;0.01;0.99;0.01]
count(10)=count(10)+1;
elseif s(:,j)==[0.99;0.01;0.99;0.99]
count(11)=count(11)+1;
else end
end
end
COUNT(1,1:11)=0;
for j=1:528
if s(:,j)==[0.01;0.01;0.01;0.99]
COUNT(1)=COUNT(1)+1;
elseif s(:,j)==[0.01;0.01;0.99;0.01]
COUNT(2)=COUNT(2)+1;
elseif s(:,j)==[0.01;0.01;0.99;0.99]
COUNT(3)=COUNT(3)+1;
elseif s(:,j)==[0.01;0.99;0.01;0.01]
COUNT(4)=COUNT(4)+1;
elseif s(:,j)==[0.01;0.99;0.01;0.99]
COUNT(5)=COUNT(5)+1;
elseif s(:,j)==[0.01;0.99;0.99;0.01]
COUNT(6)=COUNT(6)+1;
elseif s(:,j)==[0.01;0.99;0.99;0.99]
COUNT(7)=COUNT(7)+1;
elseif s(:,j)==[0.99;0.01;0.01;0.01]
COUNT(8)=COUNT(8)+1;
elseif s(:,j)==[0.99;0.01;0.01;0.99]
COUNT(9)=COUNT(9)+1;
elseif s(:,j)==[0.99;0.01;0.99;0.01]
COUNT(10)=COUNT(10)+1;
elseif s(:,j)==[0.99;0.01;0.99;0.99]
COUNT(11)=COUNT(11)+1;
else m=0;
end
end
CIUNT(1,1:11)=0;
for j=1:528
if t1(:,j)==[0.01;0.01;0.01;0.99]
CIUNT(1)=CIUNT(1)+1;
elseif t1(:,j)==[0.01;0.01;0.99;0.01]
CIUNT(2)=CIUNT(2)+1;
elseif t1(:,j)==[0.01;0.01;0.99;0.99]
CIUNT(3)=CIUNT(3)+1;
elseif t1(:,j)==[0.01;0.99;0.01;0.01]
CIUNT(4)=CIUNT(4)+1;
elseif t1(:,j)==[0.01;0.99;0.01;0.99]
CIUNT(5)=CIUNT(5)+1;
elseif t1(:,j)==[0.01;0.99;0.99;0.01]
CIUNT(6)=CIUNT(6)+1;
elseif t1(:,j)==[0.01;0.99;0.99;0.99]
CIUNT(7)=CIUNT(7)+1;
elseif t1(:,j)==[0.99;0.01;0.01;0.01]
CIUNT(8)=CIUNT(8)+1;
elseif t1(:,j)==[0.99;0.01;0.01;0.99]
CIUNT(9)=CIUNT(9)+1;
elseif t1(:,j)==[0.99;0.01;0.99;0.01]
CIUNT(10)=CIUNT(10)+1;
elseif t1(:,j)==[0.99;0.01;0.99;0.99]
CIUNT(11)=CIUNT(11)+1;
else end
end
for i=1:11
pc(i)=count(i)/COUNT(i);
pe(i)=count(i)/CIUNT(i);
end
time=toc
pc
pe
shijian=cputime
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -