📄 lvq1_vc.m
字号:
% Learns classifier and classifies test set% using Learning Vector Quantization algorithm nr 1% Usage% [trainError, testError, estTrainLabels, estTestLabels] = ...% LVQ1_VC(trainFeatures, trainLabels,Nmu ,testFeatures, testLabels)% where%% Inputs:% trainFeatures - the training set vectors, one vector per column% trainLabels - the labels of the above% Nmu - Number of centroids% testFeatures - test set, one column per vector% testLabels - labels for test set%% Outputs% trainError - the error rate on the training set (one entry per% class + total error)% testError - the error rate on the test set (one entry per class% + total error)% estTrainLabels - the labels produced by the algorithm for the% training samples% estTestLabels - the labels produced by the algorithm for the% test samplesfunction [trainError, testError, hatTrainLabels, hatTestLabels] = ...LVQ1_VC(trainFeatures, trainLabels, Nmu, testFeatures, testLabels)hm = findobj('Tag', 'Messages'); fprintf('%d-LVQ1: Training\n',Nmu);if (isempty(hm)==0) s = sprintf('%d-LVQ1: Training', Nmu); set(hm,'String',s); refresh;endalpha = 0.9;[Dim, Nsam] = size(trainFeatures);[Nclasses, classes] = find_classes([trainLabels(:);testLabels(:)]); % Number of classes in labelsdist = zeros(Nmu,Nsam);label = zeros(1,Nsam);if ( Nclasses > Nmu) Nmu=Nclasses;endpriors = zeros(Nclasses,1);% used belowfor cl = 1:Nclasses, priors(cl) = sum(trainLabels == classes(cl));endpriors = priors/length(trainLabels);if(Nsam < Nmu) fprintf('Careful, the number of LVQ centroids %d is < than the number of training points %d \n', ... Nmu, Nsam); trainError = zeros(Nclasses+1,1); testError = ones (Nclasses+1,1);end% assigns centroids to each class, in proportion% to the prior of the class, estimated from the % training label frequenciesfor cl = 1:Nclasses, nCenters(cl) = round(Nmu * priors(cl));endwhile (sum(nCenters) < Nmu) [i,j] = min(nCenters); nCenters(j) = nCenters(j)+1;endwhile (sum(nCenters) > Nmu) [i,j] = max(nCenters); nCenters(j) = nCenters(j)-1;endwhile (sum(nCenters==0) > 0) [i,j] = min(nCenters); [k,l] = max(nCenters); nCenters(j)= nCenters(j)+1; nCenters(l)= nCenters(l)-1;end%----------------% Initialize the mu's to randomly selected points% of the training set, in proportion to the number % of samples in each classstartMuIndex = 1;for cl = 1:Nclasses, % sampling w/o replacement from the vectors of the class indexes = trainLabels == classes(cl); [onezeros,indexes] = sort(indexes); numones = sum(onezeros); numzeros = length(onezeros) - numones; auxindex = randperm(numones)+numzeros; auxindex = auxindex(1:nCenters(cl)); endMuIndex = startMuIndex + nCenters(cl) -1; mu(:,startMuIndex:endMuIndex) = trainFeatures(:,indexes(auxindex)); mu_label(startMuIndex:endMuIndex) = classes(cl); startMuIndex = endMuIndex+1;end% The central loopold_mu = zeros(size(mu));while (sum(sum(abs(mu - old_mu))) > 0.1), old_mu = mu; %Classify all the features to one of the mu's for i = 1:Nmu, dist(i,:) = sum((trainFeatures - mu(:,i)*ones(1,Nsam)).^2); end % Label the points [m,label] = min(dist); % Label the mu's for i = 1:Nmu, labeledAsMu = find(label==i); origLabels = trainLabels(:,labeledAsMu); bestClass = 1; if (length(origLabels) > 0), maxNN = 0; for cl=1:Nclasses, nrNN = sum(origLabels == classes(cl)); if(nrNN > maxNN) maxNN = nrNN; bestClass = cl; end end mu_label(i) = classes(bestClass); end end % Recompute the mu's % Move the centroids away from the samples of other classes % and closer to the points of their class for i = 1:Nmu, indices = find(label == i); if ~isempty(indices), % I think it should be Dim Q = ones(Dim,1) * (2 * (trainLabels(indices) == mu_label(i)) - 1); mu(:,i) = mu(:,i) + mean(((trainFeatures(:,indices)-mu(:,i)*ones(1,length(indices))).*Q)')'*alpha; end end alpha = 0.95 * alpha;end% Now computes the training errorfprintf('%d-LVQ1: Computing Training Set Error Rate\n',Nmu);if (isempty(hm)==0) s = sprintf('%d-LVQ1: Computing Training Set Error Rates', Nmu); set(hm,'String',s); refresh;end[Dim, Nsam] = size(trainFeatures);OneD = ones(1,Nmu);hatTrainLabels = zeros(size(trainLabels));for sam = 1:Nsam, % computes squared distances, sorts them diff = trainFeatures(:,sam) * OneD; diff = diff - mu; diff = sum(diff.* diff); [diff, indices] = sort(diff); hatTrainLabels(sam) = mu_label(indices(1));endtrainError = computeError ( classes,trainLabels, hatTrainLabels);% ----------------------------------------------------------------% Now computes the testing errorfprintf('%d-LVQ1: Computing Test Set Error Rate\n',Nmu);if (isempty(hm)==0) s = sprintf('%d-LVQ1: Computing Test Set Error Rates', Nmu); set(hm,'String',s); refresh;end[Dim, Nsam] = size(testFeatures);OneD = ones(1,Nmu);hatTestLabels = zeros(size(testLabels));for sam = 1:Nsam, % computes squared distances, sorts them diff = testFeatures(:,sam) * OneD; diff = diff - mu; diff = sum(diff.* diff); [diff, indices] = sort(diff); hatTestLabels(sam) = mu_label(indices(1));endtestError = computeError( classes, testLabels, hatTestLabels);%
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -