📄 trainnet.m
字号:
% Version where classes are replicated to all have same size. Samples are then presented once each in a random order
% General purpose Growing Cell Structure Visualisation and Classification
%%
function TrainNet(GcsProj)
GcsProj.Gcs.Status = 'Training';
GcsProj.Gcs.StatusColour = [1 0 0];
UpdateGUI;
%seed=sum(100*clock);
%rand('seed',seed);
rand('seed',1)
vers; % Store Matlab version
% Set background colour
whitebg('white');
% If network is untrained then perform necessary setup - normalisation, splitting data into training and test sets.
if strcmp(GcsProj.Gcs.Trained,'No')
% Might want to move this to separate m file and offer as option
% Normalise data
if (GcsProj.Gcs.Normalise == 1) & isempty(GcsProj.Gcs.NormF1) % normalisation selected and has not already been performed
for i=1:GcsProj.Gcs.n
if std(GcsProj.dataAll(:,i)) ~=0 % (~( max(train(:,i))==1 & min(train(:,i))==0 ) & (std(train(:,i)) ~=0) ) %
disp(['normalising column ',int2str(i)]);
eval(['GcsProj.Gcs.NormF' int2str(i) ' = std(GcsProj.dataAll(:,i));']);
eval(['GcsProj.dataAll(:,i) = GcsProj.dataAll(:,i)./GcsProj.Gcs.NormF' int2str(i) ';']);
% eval(['GcsProj.test(:,i) = GcsProj.test(:,i)./GcsProj.Gcs.NormF' int2str(i) ';']); % Both sets must be normalised by original calculated value
end
end
end
% Create list of indeces for samples prior to extracting test samples
GcsProj.Gcs.index = [1:size(GcsProj.dataAll,1)];
% Split data into training and test sets
GcsProj.train = GcsProj.dataAll;
GcsProj.test = [];
testindex =[];
for i=1:GcsProj.NoTest
index = floor(rand(1,1)*(size(GcsProj.train,1)-1)+1); % Randomly select a sample
GcsProj.test = [GcsProj.test; GcsProj.train(index,1:GcsProj.Gcs.n+1)]; % copy sample to test set
GcsProj.train(index,:) = []; % Remove test sample fromm training set
testindex = [testindex; GcsProj.Gcs.index(index)]; % Add sample No to list of indeces
GcsProj.Gcs.index(index) = []; % Remove index from index array
end
GcsProj.Gcs.index = testindex; % Store test sample indeces in GCS structure
if isempty(GcsProj.Gcs.w)
% Set inital node weights to three random training samples
Ip1=GcsProj.train(floor((rand(1,1)*(size(GcsProj.train,1)-1)+1)),1:GcsProj.Gcs.n);
Ip2=GcsProj.train(floor((rand(1,1)*(size(GcsProj.train,1)-1)+1)),1:GcsProj.Gcs.n);
Ip3=GcsProj.train(floor((rand(1,1)*(size(GcsProj.train,1)-1)+1)),1:GcsProj.Gcs.n);
% Node weight array
GcsProj.Gcs.w = [Ip1;Ip2;Ip3;];
end
% Split data into different classes ready for equal prior training
for cindex =1:GcsProj.Gcs.NoClasses
eval(['GcsProj.data' int2str(cindex) ' =[];']); % init matrix for current class
eval(['i = find(GcsProj.train(:,GcsProj.Gcs.n+1)==' int2str(cindex) ');']);
eval(['GcsProj.data' int2str(cindex) ' = GcsProj.train(i,:);']);
eval(['GcsProj.Gcs.Prior' int2str(cindex) ' = size(GcsProj.data' int2str(cindex) ',1)./size(GcsProj.train,1);']);
end
if GcsProj.Gcs.EqualPriors
GcsProj=equalsze(GcsProj);
end
GcsProj.train=[]; % Complete data set
for cindex=1:GcsProj.Gcs.NoClasses
eval(['GcsProj.train = [GcsProj.train; GcsProj.data' int2str(cindex) '];']);
end
% Save Training Data so that it can be reloaded at a latter date
trdata = GcsProj;
trdata.Gcs=[];
filebase = GcsProj.Gcs.Name;
filename = [filebase,'.mat'];
GcsProj.Gcs.Trainfilename =filename; % Store name of file in GCS structure
eval(['save ' filename ' trdata;']);
end
%eval(['save ' filebase '_train train test dataAll']);
GcsProj.lamda = GcsProj.epochspernode*size(GcsProj.train,1);
GcsProj.Gcs.cmap1 = []; % Clear so that plot routines will correctly recalculate colourmaps after further training
% Train the network
train = GcsProj.train; % Copy to working set from which used samples are removed
iter=0; %current iteration / training sample no
hwait = waitbar(0,'Training ...');
while (iter<=GcsProj.lamda*(GcsProj.NoNewNodes+1)) % +1 so that last node is added and trained before exiting
iter=iter+1; % update sample number
figure(hwait);
waitbar(iter/(GcsProj.lamda*(GcsProj.NoNewNodes+1)))
if (rem(iter,GcsProj.lamda)==1)
disp(['Training epoch ',int2str(ceil(iter/GcsProj.lamda)),' ...']);
end
% 1: Select input
index = ceil(size(train,1)*rand(1)); %index = ceil(size(GcsProj.train,1)*rand(1));
Ip = train(index,1:GcsProj.Gcs.n); % Select next input
train(index,:) = [];
if isempty(train) % If have used all samples, replenish data
train = GcsProj.train;
end
% Use random samplling without removal with equal priors
% sampleclass = ceil(NoClasses*rand(1));
% eval(['sampleindex = ceil(size(data' int2str(cindex) ',1)*rand(1));']);
% eval(['Ip = data' int2str(sampleclass) '(sampleindex,1:n);']);
% 2: Find winning node
d=[];
for i=1:size(GcsProj.Gcs.w,1)
d(i) = norm(Ip-GcsProj.Gcs.w(i,:),GcsProj.Gcs.metric)^2; % Find squared error
end
[minval,s] = min(d);
% 3: Update error of winning node
GcsProj.Gcs.E(s) = GcsProj.Gcs.E(s) + minval; % Sum squared error
% 4: Adapt ref vectors
GcsProj.Gcs.w(s,:) = GcsProj.Gcs.w(s,:) + GcsProj.Gcs.eb*(Ip-GcsProj.Gcs.w(s,:)); % Update winning node
for i=1:size(GcsProj.Gcs.w,1) % find neighbours and update them
if GcsProj.Gcs.C(s,i) == 1
GcsProj.Gcs.w(i,:) = GcsProj.Gcs.w(i,:) + GcsProj.Gcs.en*(Ip-GcsProj.Gcs.w(i,:));
end
end
% Supervised learning
if rem(iter,GcsProj.lamda) == 0
% : Update the node counts
GcsProj.Gcs = count(GcsProj.Gcs,GcsProj.train);
% Now form the bayes classifier and test
GcsProj.Gcs = baycopt(GcsProj.Gcs,GcsProj.test); % bayes classifier using optimal threshold %bayclass;
% Store performance results and plot history
% Set Trained state to Yes
GcsProj.Gcs.Trained = 'Yes';
UpdateGUI;
end
% 6: Save data to file
if rem(iter,GcsProj.lamda) == 0
% Base filename
% filebase = ['gcs'];
% filename = [filebase,'_',int2str(size(GcsProj.Gcs.C,1))];
SaveGcs(GcsProj.Gcs);
% eval(['save ' filename ' Gcs;']); % Save only Gcs network - trainng set already saved, no not implemented yet!
% clear Gcs
%if GcsProj.Gcs.NoClasses == 2
% PLot ROC Curve
% figure(2);
%areav = [areav, roc_parw(filename)];
%eval(['save ' filename ]); % redo to add areav to saved data
%end
end
% 7: Insert new node if Necessary No of iterations have passed
if rem(iter,GcsProj.lamda) == 0
if iter ~= GcsProj.lamda*(GcsProj.NoNewNodes+1)
GcsProj.Gcs = AddNode(GcsProj.Gcs);
end
end
% 8: Decrease error of all units
%E = E*(1-beta);
%Freq = (1-beta)*Freq;
end
msg = [ 'Run complete. The network was saved after each epoch in the files gcs_x, ';
'where x is the number of nodes in the network. ';
' ';
'Use the m file factor to perform factor analysis on the current network. ';
'Use replot to redisplay all plots. ';
'Use plotparw to redisplay the frequency visualisation plot. ';
'Use post to redisplay the posterior probability estimate. ';
'USe roc_parw to redisplay the receiver operating characterisitc curve. '];
%disp(msg);
close(hwait)
GcsProj.Gcs.Status = 'Idle';
GcsProj.Gcs.StatusColour = 'default';
UpdateGUI;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -