multeq.m

来自「细胞生长结构可视化工具箱-MATLAB Toolbox1999.zip」· M 代码 · 共 182 行

M
182
字号
% Version where classes are replicated to all have same size. Samples are then presented once each in a random order

% General purpose Growing Cell Structure Visualisation and Classification
%	multclas(train,test,NoClasses,NoNewNodes,epochspernode,smooth,metric,netname)
%
function multclas(train,test,NoClasses,NoNewNodes,epochspernode,smooth,metric,netname)
seed=sum(100*clock);
%rand('seed',seed);
rand('seed',1)

vstr = version;	% Extract version number
vers = str2num(vstr(1))	% Only interested in integer part

% Set background colour
	whitebg('white');

if isempty(netname);
	% Randomise order of samples - NOT NEEDED WITH NEW IP SELECTION PROCEDURE !
%	synth=train;
%	train=[];
%	for i=1:size(synth,1)	% randomise order
%		index = rand(1,1)*(size(synth,1)-1)+1;
%		train = [train;synth(index,:)];
%		synth(index,:) = [];
%	end

	% Initialise the GCS network
	initial;
	
	% Normalise data 
	for i=1:n 
		if std(train(:,i)) ~=0	 % (~( max(train(:,i))==1 & min(train(:,i))==0 ) & (std(train(:,i)) ~=0) )	%
			disp(['normalising column ',int2str(i)]);
			train(:,i) = train(:,i)./std(train(:,i));
			test(:,i) = test(:,i)./std(test(:,i));
		end
	end

	% Set inital node weights to three random training samples
	Ip1=train(floor((rand(1,1)*(size(train,1)-1)+1)),1:n);
	Ip2=train(floor((rand(1,1)*(size(train,1)-1)+1)),1:n);
	Ip3=train(floor((rand(1,1)*(size(train,1)-1)+1)),1:n);
	% Node weight array
	w = [Ip1;Ip2;Ip3;];
else
	eval(['load ',netname]);
end


% Split data into different classes ready for equal prior training
for cindex =1:NoClasses
	eval(['data' int2str(cindex) ' =[];']);	% init matrix for current class
	eval(['i = find(train(:,n+1)==' int2str(cindex) ');']);
	eval(['data' int2str(cindex) ' = train(i,:);']);
	eval(['Gcs.Prior' int2str(cindex) ' = size(data' int2str(cindex) ',1)./size(train,1);']);
end

equalsze;

dataAll=[];	% Complete data set
for cindex=1:NoClasses
	eval(['dataAll = [dataAll; data' int2str(cindex) '];']);
end
train = dataAll;
lamda = epochspernode*size(train,1);

% Train the network
	iter=0;	%current iteration / training sample no
	while (iter<=lamda*NoNewNodes)
		iter=iter+1;	% update sample number
		if (rem(iter,lamda)==1)
			disp(['Training epoch ',int2str(ceil(iter/lamda)),' ...']);
		end
	% 1: Select input
		index = ceil(size(train,1)*rand(1));
		Ip = train(index,1:n);	% Select next input
		train(index,:) = [];
		if isempty(train)	% If have used all samples, replenish data
			train = dataAll;
		end
	% Use random samplling without removal with equal priors
	%	sampleclass = ceil(NoClasses*rand(1));
	%	eval(['sampleindex = ceil(size(data' int2str(cindex) ',1)*rand(1));']);
	%	eval(['Ip = data' int2str(sampleclass) '(sampleindex,1:n);']);

	% 2: Find winning node
		d=[];
		for i=1:size(w,1)
		 	d(i) = norm(Ip-w(i,:),metric)^2;	% Find squared error
		end
		[minval,s] = min(d);

	% 3: Update error of winning node
		E(s) = E(s) + minval; % Sum squared error
	
	% 4: Adapt ref vectors
		w(s,:) = w(s,:) + eb*(Ip-w(s,:));	% Update winning node
	
		for i=1:size(w,1)			% find neighbours and update them
			if C(s,i) == 1
				w(i,:) = w(i,:) + en*(Ip-w(i,:));
			end
		end
	% Supervised learning	
		if rem(iter,lamda) == 0
		%  : Update the node counts	
			count;
		% Now form the bayes classifier and test
			baycopt;	% bayes classifier using optimal threshold	%bayclass;
			% Store performance results and plot history
			accv = [accv,acc];
			sensv = [sensv, sens];
			specv = [specv, spec];	
			if NoClasses ~= 2
				for cindex=1:NoClasses
					eval(['accv' int2str(cindex) ' = [accv' int2str(cindex) ',acc' int2str(cindex) '];']);
				end		
			end
		end

	% 6: Save data to file
		if rem(iter,lamda) == 0
			% Base filename
			filebase = ['gcs'];
			filename = [filebase,'_',int2str(size(C,1))];
			eval(['save ' filename ]);
			if NoClasses == 2
				% PLot ROC Curve
				figure(2);
				areav = [areav, roc_parw(filename)];
				eval(['save ' filename ]);	% redo to add areav to saved data
			end
			figure(1);
			clf;
			XNo = 3:size(accv,2)+2;
			plot(XNo,accv,'k');
			hold on;
			linetype = ['r';'b';'g';'y';'m';'c'];
			if NoClasses ~= 2
				for cindex=1:NoClasses
					eval(['plot(XNo,accv' int2str(cindex) ', linetype(cindex) );']);
				end
			else
				plot(XNo,accv,'k*');
				plot(XNo,sensv,'b');
				plot(XNo,sensv,'bO');
				plot(XNo,specv,'r');
				plot(XNo,specv,'r+');
				plot(XNo,areav,'g-');
			
				text(3.1,0.3,'black - Accuracy');
				text(3.1,0.2,'red - Specificity');
				text(3.1,0.1,'blue - Sensitivity');
				text(3.1,0.4,'green - Area under ROC');
			end
			title('Accuracy perfomance');	
			xlabel('Number of nodes');
			axis([3,size(accv,2)+3,0,1]);
			drawnow;
		end

	% 7: Insert new node if Necessary No of iterations have passed
		if rem(iter,lamda) == 0
			AddNode;
		end

	% 8: Decrease error of all units
		%E = E*(1-beta);
		%Freq = (1-beta)*Freq;

end

msg = [ 'Run complete. The network was saved after each epoch in the files gcs_x,  ';
	'where x is the number of nodes in the network.                            ';
	'                                                                          ';
	'Use the m file factor to perform factor analysis on the current network.  ';
	'Use replot to redisplay all plots.                                        ';
	'Use plotparw to redisplay the frequency visualisation plot.               ';
	'Use post to redisplay the posterior probability estimate.                 ';
	'USe roc_parw to redisplay the receiver operating characterisitc curve.    '];

disp(msg);

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?