⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 lvq3_vc.m

📁 最新的模式识别分类工具箱,希望对朋友们有用!
💻 M
字号:
% Learns classifier and classifies test set% using Learning Vector Quantization algorithm nr 3% Usage%      [trainError, testError, estTrainLabels, estTestLabels] = ...%           LVQ3_VC(trainFeatures, trainLabels,Nmu ,testFeatures, testLabels)% where%% Inputs:% 	trainFeatures	- the training set vectors, one vector per column%	trainLabels    - the labels of the above%       Nmu             - Number of centroids%       testFeatures   - test set, one column per vector%       testLabels     - labels  for test set%% Outputs%	trainError     - the error rate on the training set (one entry per%	                   class + total error)%	testError      - the error rate on the test set (one entry per class%	                  + total error)%       estTrainLabels - the labels produced by the algorithm for the%                          training samples%       estTestLabels - the labels produced by the algorithm for the%                          test samplesfunction [trainError, testError, hatTrainLabels, hatTestLabels] = ...                LVQ3_VC(trainFeatures, trainLabels, Nmu, testFeatures, testLabels)hm = findobj('Tag', 'Messages'); fprintf('%d-LVQ3: learning: this might take a while\n',Nmu);if (isempty(hm)==0)  s = sprintf('%d-LVQ3: Learning, this might take a while', Nmu);  set(hm,'String',s);  refresh;endif ((sum(trainLabels) == length(trainLabels)) | (sum(~trainLabels) == length(trainLabels))),    error('LVQ3 works only if there are features from both classes.')end[Nclasses, classes]  = find_classes([trainLabels(:);testLabels(:)]); % Number of classes in labels[Dim,Nsam] =size(trainFeatures);alpha = 10/Nsam;dist  = zeros(Nmu,Nsam);label = zeros(1,Nsam);window = 0.25;epsilon= 0.25;if ( Nclasses > Nmu)   Nmu=Nclasses;endpriors = zeros(Nclasses,1);% used belowfor cl = 1:Nclasses,  priors(cl) = sum(trainLabels == classes(cl));endpriors = priors/length(trainLabels);if(Nsam < Nmu)  fprintf('Careful, the number of LVQ centroids %d is < than the number of training points %d \n', ...          Nmu, Nsam);  trainError = zeros(Nclasses+1,1);  testError  = ones (Nclasses+1,1);end% assigns centroids to each class, in proportion% to the prior of the class, estimated from the % training label frequenciesfor  cl = 1:Nclasses,  nCenters(cl)      = round(Nmu * priors(cl));endwhile (sum(nCenters) < Nmu)  [i,j] = min(nCenters);  nCenters(j) =   nCenters(j)+1;endwhile (sum(nCenters) > Nmu)  [i,j] = max(nCenters);  nCenters(j) =   nCenters(j)-1;endwhile (sum(nCenters==0) > 0)  [i,j] = min(nCenters);  [k,l] = max(nCenters);  nCenters(j)= nCenters(j)+1;  nCenters(l)= nCenters(l)-1;end%----------------% Initialize the mu's to randomly selected points% of the training set, in proportion to the number % of samples in each classstartMuIndex = 1;for  cl = 1:Nclasses,  % sampling w/o replacement from the vectors of the class  indexes       = trainLabels == classes(cl);  [onezeros,indexes] = sort(indexes);  numones       = sum(onezeros);  numzeros      = length(onezeros) - numones;  auxindex      = randperm(numones)+numzeros;  auxindex      = auxindex(1:nCenters(cl));  endMuIndex    = startMuIndex + nCenters(cl) -1;  mu(:,startMuIndex:endMuIndex)      = trainFeatures(:,indexes(auxindex));  mu_label(startMuIndex:endMuIndex)  = classes(cl);  startMuIndex  = endMuIndex+1;end% The central loopmu_target= [zeros(1,floor(Nmu/2)) ones(1,Nmu-floor(Nmu/2))];old_mu	= zeros(size(mu));iterations = 0;while ((sum(sum(abs(mu - old_mu))) > 0.01) & (iterations < 1e4)),   iterations = iterations + 1;   old_mu     = mu;      % finds the distances between each point and each centroid   for i = 1:Nmu,      dist(i,:) = sum((trainFeatures - mu(:,i)*ones(1,Nsam)).^2);   end      % sorts the distances between points and centroids,   % retain the two closest centroids.      [dist,label] = sort(dist);   closest	= dist(1:2,:);       % Compute windows:   % finds points where the ratio of the distance to the closest centroid to    % the next closest centroid is at least (1-window) / (1+window)   % These are points that are close to the two closest centroids   auxclosest1  = (closest(1,:) == 0);   in_window = (min(closest(1,:)./closest(2,:), ...					(1-auxclosest1).*closest(2,:)./...					(closest(1,:)+auxclosest1(1,:))) > (1-window)/(1+window));   % and these are their indices.   indices   = find(in_window);   % Move the mu's   for i = 1:length(indices),      x	 = indices(i);      mu1 = label(1,x);      mu2 = label(2,x);      if ((trainLabels(x) == mu_target(mu1)) & (trainLabels(x) == mu_target(mu2))),         mu(:,mu1) = mu(:,mu1) + epsilon * alpha * (trainFeatures(:,x) - mu(:,mu1));         mu(:,mu2) = mu(:,mu2) + epsilon * alpha * (trainFeatures(:,x) - mu(:,mu2));      elseif (trainLabels(x) == mu_target(mu1)),		mu(:,mu1) = mu(:,mu1) + alpha * (trainFeatures(:,x) - mu(:,mu1));		mu(:,mu2) = mu(:,mu2) - alpha * (trainFeatures(:,x) - mu(:,mu2));      elseif (trainLabels(x) == mu_target(mu2)),		mu(:,mu1) = mu(:,mu1) - alpha * (trainFeatures(:,x) - mu(:,mu1));		mu(:,mu2) = mu(:,mu2) + alpha * (trainFeatures(:,x) - mu(:,mu2));      end   end   alpha = 0.95 * alpha;end% % Now computes the training errorhm = findobj('Tag', 'Messages'); fprintf('%d-LVQ3: Computing Training Set Error Rate\n',Nmu);if (isempty(hm)==0)  s = sprintf('%d-LVQ3: Computing Training Set Error Rates', Nmu);  set(hm,'String',s);  refresh;end[Dim, Nsam]    = size(trainFeatures);OneD           = ones(1,Nmu);hatTrainLabels = zeros(size(trainLabels));for sam = 1:Nsam,  % computes squared distances, sorts them  diff = trainFeatures(:,sam) * OneD;  diff = diff - mu;  diff = sum(diff.* diff);  [diff, indices]     = sort(diff);  hatTrainLabels(sam) = mu_label(indices(1));endtrainError = computeError ( classes,trainLabels, hatTrainLabels);% ----------------------------------------------------------------% Now computes the testing errorhm = findobj('Tag', 'Messages'); fprintf('%d-LVQ3: Computing Test Set Error Rate\n',Nmu);if (isempty(hm)==0)  s = sprintf('%d-LVQ3: Computing Test Set Error Rates', Nmu);  set(hm,'String',s);  refresh;end[Dim, Nsam]    = size(testFeatures);OneD           = ones(1,Nmu);hatTestLabels = zeros(size(testLabels));for sam = 1:Nsam,  % computes squared distances, sorts them  diff = testFeatures(:,sam) * OneD;  diff = diff - mu;  diff = sum(diff.* diff);  [diff, indices]     = sort(diff);  hatTestLabels(sam) = mu_label(indices(1));endtestError = computeError( classes, testLabels, hatTestLabels);%

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -