📄 lvq3_vc.m
字号:
% Learns classifier and classifies test set% using Learning Vector Quantization algorithm nr 3% Usage% [trainError, testError, estTrainLabels, estTestLabels] = ...% LVQ3_VC(trainFeatures, trainLabels,Nmu ,testFeatures, testLabels)% where%% Inputs:% trainFeatures - the training set vectors, one vector per column% trainLabels - the labels of the above% Nmu - Number of centroids% testFeatures - test set, one column per vector% testLabels - labels for test set%% Outputs% trainError - the error rate on the training set (one entry per% class + total error)% testError - the error rate on the test set (one entry per class% + total error)% estTrainLabels - the labels produced by the algorithm for the% training samples% estTestLabels - the labels produced by the algorithm for the% test samplesfunction [trainError, testError, hatTrainLabels, hatTestLabels] = ... LVQ3_VC(trainFeatures, trainLabels, Nmu, testFeatures, testLabels)hm = findobj('Tag', 'Messages'); fprintf('%d-LVQ3: learning: this might take a while\n',Nmu);if (isempty(hm)==0) s = sprintf('%d-LVQ3: Learning, this might take a while', Nmu); set(hm,'String',s); refresh;endif ((sum(trainLabels) == length(trainLabels)) | (sum(~trainLabels) == length(trainLabels))), error('LVQ3 works only if there are features from both classes.')end[Nclasses, classes] = find_classes([trainLabels(:);testLabels(:)]); % Number of classes in labels[Dim,Nsam] =size(trainFeatures);alpha = 10/Nsam;dist = zeros(Nmu,Nsam);label = zeros(1,Nsam);window = 0.25;epsilon= 0.25;if ( Nclasses > Nmu) Nmu=Nclasses;endpriors = zeros(Nclasses,1);% used belowfor cl = 1:Nclasses, priors(cl) = sum(trainLabels == classes(cl));endpriors = priors/length(trainLabels);if(Nsam < Nmu) fprintf('Careful, the number of LVQ centroids %d is < than the number of training points %d \n', ... Nmu, Nsam); trainError = zeros(Nclasses+1,1); testError = ones (Nclasses+1,1);end% assigns centroids to each class, in proportion% to the prior of the class, estimated from the % training label frequenciesfor cl = 1:Nclasses, nCenters(cl) = round(Nmu * priors(cl));endwhile (sum(nCenters) < Nmu) [i,j] = min(nCenters); nCenters(j) = nCenters(j)+1;endwhile (sum(nCenters) > Nmu) [i,j] = max(nCenters); nCenters(j) = nCenters(j)-1;endwhile (sum(nCenters==0) > 0) [i,j] = min(nCenters); [k,l] = max(nCenters); nCenters(j)= nCenters(j)+1; nCenters(l)= nCenters(l)-1;end%----------------% Initialize the mu's to randomly selected points% of the training set, in proportion to the number % of samples in each classstartMuIndex = 1;for cl = 1:Nclasses, % sampling w/o replacement from the vectors of the class indexes = trainLabels == classes(cl); [onezeros,indexes] = sort(indexes); numones = sum(onezeros); numzeros = length(onezeros) - numones; auxindex = randperm(numones)+numzeros; auxindex = auxindex(1:nCenters(cl)); endMuIndex = startMuIndex + nCenters(cl) -1; mu(:,startMuIndex:endMuIndex) = trainFeatures(:,indexes(auxindex)); mu_label(startMuIndex:endMuIndex) = classes(cl); startMuIndex = endMuIndex+1;end% The central loopmu_target= [zeros(1,floor(Nmu/2)) ones(1,Nmu-floor(Nmu/2))];old_mu = zeros(size(mu));iterations = 0;while ((sum(sum(abs(mu - old_mu))) > 0.01) & (iterations < 1e4)), iterations = iterations + 1; old_mu = mu; % finds the distances between each point and each centroid for i = 1:Nmu, dist(i,:) = sum((trainFeatures - mu(:,i)*ones(1,Nsam)).^2); end % sorts the distances between points and centroids, % retain the two closest centroids. [dist,label] = sort(dist); closest = dist(1:2,:); % Compute windows: % finds points where the ratio of the distance to the closest centroid to % the next closest centroid is at least (1-window) / (1+window) % These are points that are close to the two closest centroids auxclosest1 = (closest(1,:) == 0); in_window = (min(closest(1,:)./closest(2,:), ... (1-auxclosest1).*closest(2,:)./... (closest(1,:)+auxclosest1(1,:))) > (1-window)/(1+window)); % and these are their indices. indices = find(in_window); % Move the mu's for i = 1:length(indices), x = indices(i); mu1 = label(1,x); mu2 = label(2,x); if ((trainLabels(x) == mu_target(mu1)) & (trainLabels(x) == mu_target(mu2))), mu(:,mu1) = mu(:,mu1) + epsilon * alpha * (trainFeatures(:,x) - mu(:,mu1)); mu(:,mu2) = mu(:,mu2) + epsilon * alpha * (trainFeatures(:,x) - mu(:,mu2)); elseif (trainLabels(x) == mu_target(mu1)), mu(:,mu1) = mu(:,mu1) + alpha * (trainFeatures(:,x) - mu(:,mu1)); mu(:,mu2) = mu(:,mu2) - alpha * (trainFeatures(:,x) - mu(:,mu2)); elseif (trainLabels(x) == mu_target(mu2)), mu(:,mu1) = mu(:,mu1) - alpha * (trainFeatures(:,x) - mu(:,mu1)); mu(:,mu2) = mu(:,mu2) + alpha * (trainFeatures(:,x) - mu(:,mu2)); end end alpha = 0.95 * alpha;end% % Now computes the training errorhm = findobj('Tag', 'Messages'); fprintf('%d-LVQ3: Computing Training Set Error Rate\n',Nmu);if (isempty(hm)==0) s = sprintf('%d-LVQ3: Computing Training Set Error Rates', Nmu); set(hm,'String',s); refresh;end[Dim, Nsam] = size(trainFeatures);OneD = ones(1,Nmu);hatTrainLabels = zeros(size(trainLabels));for sam = 1:Nsam, % computes squared distances, sorts them diff = trainFeatures(:,sam) * OneD; diff = diff - mu; diff = sum(diff.* diff); [diff, indices] = sort(diff); hatTrainLabels(sam) = mu_label(indices(1));endtrainError = computeError ( classes,trainLabels, hatTrainLabels);% ----------------------------------------------------------------% Now computes the testing errorhm = findobj('Tag', 'Messages'); fprintf('%d-LVQ3: Computing Test Set Error Rate\n',Nmu);if (isempty(hm)==0) s = sprintf('%d-LVQ3: Computing Test Set Error Rates', Nmu); set(hm,'String',s); refresh;end[Dim, Nsam] = size(testFeatures);OneD = ones(1,Nmu);hatTestLabels = zeros(size(testLabels));for sam = 1:Nsam, % computes squared distances, sorts them diff = testFeatures(:,sam) * OneD; diff = diff - mu; diff = sum(diff.* diff); [diff, indices] = sort(diff); hatTestLabels(sam) = mu_label(indices(1));endtestError = computeError( classes, testLabels, hatTestLabels);%
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -