📄 multclas.m
字号:
% General purpose Growing Cell Structure Visualisation and Classification
% multclas(train,test,NoClasses,NoNewNodes,epochspernode,smooth,metric,netname)
%
function multclas(train,test,NoClasses,NoNewNodes,epochspernode,smooth,metric,netname)
rand('seed',sum(100*clock));
vstr = version; % Extract version number
vers = str2num(vstr(1)) % Only interested in integer part
% Set background colour
whitebg('white');
if isempty(netname);
% Randomise order of samples - NOT NEEDED WITH NEW IP SELECTION PROCEDURE !
% synth=train;
% train=[];
% for i=1:size(synth,1) % randomise order
% index = rand(1,1)*(size(synth,1)-1)+1;
% train = [train;synth(index,:)];
% synth(index,:) = [];
% end
% Initialise the GCS network
initial;
% Normalise data
for i=1:n
if std(train(:,i)) ~=0 % (~( max(train(:,i))==1 & min(train(:,i))==0 ) & (std(train(:,i)) ~=0) ) %
disp(['normalising column ',int2str(i)]);
train(:,i) = train(:,i)./std(train(:,i));
test(:,i) = test(:,i)./std(test(:,i));
end
end
% Set inital node weights to three random training samples
Ip1=train(floor((rand(1,1)*(size(train,1)-1)+1)),1:n);
Ip2=train(floor((rand(1,1)*(size(train,1)-1)+1)),1:n);
Ip3=train(floor((rand(1,1)*(size(train,1)-1)+1)),1:n);
% Node weight array
w = [Ip1;Ip2;Ip3;];
else
eval(['load ',netname]);
end
% Split data into different classes ready for equal prior training
for cindex =1:NoClasses
eval(['data' int2str(cindex) ' =[];']); % init matrix for current class
eval(['i = find(train(:,n+1)==' int2str(cindex) ');']);
eval(['data' int2str(cindex) ' = train(i,:);']);
eval(['Prior' int2str(cindex) ' = size(data' int2str(cindex) ',1)./size(train,1);']);
end
% Train the network
iter=0; %current iteration / training sample no
while (iter<=lamda*NoNewNodes)
iter=iter+1; % update sample number
if (rem(iter,lamda)==1)
disp(['Training epoch ',int2str(ceil(iter/lamda)),' ...']);
end
% 1: Select input
% Ip = train(rem(iter-1,size(train,1))+1,1:n); % Select next input
% Use random samplling without removal with equal priors
sampleclass = ceil(NoClasses*rand(1));
eval(['sampleindex = ceil(size(data' int2str(cindex) ',1)*rand(1));']);
eval(['Ip = data' int2str(sampleclass) '(sampleindex,1:n);']);
% 2: Find winning node
d=[];
for i=1:size(w,1)
d(i) = norm(Ip-w(i,:),metric)^2; % Find squared error
end
[minval,s] = min(d);
% 3: Update error of winning node
E(s) = E(s) + minval; % Sum squared error
% 4: Adapt ref vectors
w(s,:) = w(s,:) + eb*(Ip-w(s,:)); % Update winning node
for i=1:size(w,1) % find neighbours and update them
if C(s,i) == 1
w(i,:) = w(i,:) + en*(Ip-w(i,:));
end
end
% Supervised learning
if rem(iter,lamda) == 0
% : Update the node counts
count;
% Now form the bayes classifier and test
baycopt; % bayes classifier using optimal threshold %bayclass;
% Store performance results and plot history
accv = [accv,acc];
sensv = [sensv, sens];
specv = [specv, spec];
for cindex=1:NoClasses
eval(['accv' int2str(cindex) ' = [accv' int2str(cindex) ',acc' int2str(cindex) '];']);
end
end
% 6: Save data to file
if rem(iter,lamda) == 0
% Base filename
filebase = ['gcs'];
filename = [filebase,'_',int2str(size(C,1))];
eval(['save ' filename ]);
% PLot ROC Curve
% areav = [areav, roc_parw(filename)];
% eval(['save ' filename ]); % redo to add areav to saved data
figure(1);
clf;
XNo = 3:size(accv,2)+2;
plot(XNo,accv,'k');
hold on;
linetype = ['r';'b';'g';'y';'m';'c'];
for cindex=1:NoClasses
eval(['plot(XNo,accv' int2str(cindex) ', linetype(cindex) );']);
end
% plot(XNo,accv,'k*');
% plot(XNo,sensv,'b');
% plot(XNo,sensv,'bO');
% plot(XNo,specv,'r');
% plot(XNo,specv,'r+');
% plot(XNo,areav,'g-');
title('Accuracy perfomance');
xlabel('Number of nodes');
axis([3,size(accv,2)+3,0,1]);
% text(3.1,0.3,'black - Accuracy');
% text(3.1,0.2,'red - Specificity');
% text(3.1,0.1,'blue - Sensitivity');
% text(3.1,0.4,'green - Area under ROC');
drawnow;
end
% 7: Insert new node if Necessary No of iterations have passed
if rem(iter,lamda) == 0
AddNode;
end
% 8: Decrease error of all units
%E = E*(1-beta);
%Freq = (1-beta)*Freq;
end
msg = [ 'Run complete. The network was saved after each epoch in the files gcs_x, ';
'where x is the number of nodes in the network. ';
' ';
'Use the m file factor to perform factor analysis on the current network. ';
'Use replot to redisplay all plots. ';
'Use plotparw to redisplay the frequency visualisation plot. ';
'Use post to redisplay the posterior probability estimate. ';
'USe roc_parw to redisplay the receiver operating characterisitc curve. '];
disp(msg);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -