📄 rce_vc.m
字号:
% Learns classifier and classifies test set% using Learning Vector Quantization algorithm nr 1% Usage% [trainError, testError, estTrainLabels, estTestLabels] = ...% RCE_VC(trainFeatures, trainLabels,lambda_m ,testFeatures, testLabels)% where%% Inputs:% trainFeatures - the training set vectors, one vector per column% trainLabels - the labels of the above% lambda_m - maximum radius % testFeatures - test set, one column per vector% testLabels - labels for test set%% Outputs% trainError - the error rate on the training set (one entry per% class + total error)% testError - the error rate on the test set (one entry per class% + total error)% estTrainLabels - the labels produced by the algorithm for the% training samples% estTestLabels - the labels produced by the algorithm for the% test samplesfunction [trainError, testError, estTrainLabels, estTestLabels] = ... RCE_VC(trainFeatures, trainLabels,lambda_m ,testFeatures, testLabels)[Nclasses, classes] = find_classes([trainLabels(:);testLabels(:)]); % Number of classes in labelsepsilon = 1e-4;[Dim,Nf] = size(trainFeatures);hm = findobj('Tag', 'Messages'); fprintf('RCE: Training (this might take a while)\n');if (isempty(hm)==0) s = sprintf('RCE: Training (this might take a while)'); set(hm,'String',s); refresh;end%Train the classifierlambda = zeros(1,Nf);% computes the default valuefor cl = 1:Nclasses, priors(cl) = sum(trainLabels == classes(cl));endpriors = priors./length(trainLabels);[foo, bestGuess] = max(priors);bestGuess = classes(bestGuess);for sam = 1:Nf, % computes distances between each training sample and all the other ones dist = sqrt(sum((trainFeatures - trainFeatures(:,sam) * ones(1,Nf)).^2)); [m, indices] = sort(dist); % finds the positions of the samples having different labels x_hat = find(trainLabels(indices) ~= trainLabels(sam)); % radius around the point: no point inside the sphere has different label % from the point in question lambda(sam) = min(dist(x_hat(1))-epsilon,lambda_m);end% Now computes the training errorfprintf('RCE: Computing Training Set Error Rate\n');if (isempty(hm)==0) s = sprintf('RCE: Computing Training Set Error Rates'); set(hm,'String',s); refresh;end%%[Dim, Nsam] = size(trainFeatures);OneD = ones(1,Nsam);hatTrainLabels = zeros(size(trainLabels));for sam = 1:Nsam, diff = trainFeatures(:,sam) * OneD; diff = sum((trainFeatures - diff).^2); % distances squared indices = find(diff<lambda); % recall, lambda are sphere around other points if ( isempty(indices) ) hatTrainLabels(sam) = bestGuess; else for cl=1:Nclasses vote(cl) = sum(trainLabels(indices) == classes(cl)); end [foo, cl] = max(vote); hatTrainLabels(sam) = classes(cl); endendtrainError = computeError ( classes,trainLabels, hatTrainLabels);% ------------------------------------------------fprintf('RCE: Computing Test Set Error Rate\n');if (isempty(hm)==0) s = sprintf('RCE: Computing Test Set Error Rates'); set(hm,'String',s); refresh;end[Dim, Nsam] = size(testFeatures);hatTestLabels = zeros(size(testLabels));for sam = 1:Nsam, diff = testFeatures(:,sam) * OneD; diff = sum((trainFeatures - diff).^2); % distances squared indices = find(diff<lambda); % recall, lambda are sphere around other points if ( isempty(indices) ) hatTestLabels(sam) = bestGuess; else for cl=1:Nclasses vote(cl) = sum(trainLabels(indices) == classes(cl)); end [foo, cl] = max(vote); hatTestLabels(sam) = classes(cl); endendtestError = computeError ( classes,testLabels, hatTestLabels);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -