📄 weightedknnrule_vc.m
字号:
% Learns classifier and classifies test set% using weighted k-NN rule% Usage% [trainError, testError, estTrainLabels, estTestLabels] = ...% weightedKNNRule_VC(trainFeatures, trainLabels,params ,testFeatures, testLabels)% where%% Inputs:% trainFeatures - the training set vectors, one vector per column% trainLabels - the labels of the above% params - <k>, the number of nearest neighbors% <weightFile>, a file with the weights, a% a vector having k (non-increasing) entries. % Entry 1 is the weight of the nearest neighbor,% while entry k is the weight of the farthest neighbor% IF weightFile IS NOT PROVIDED, DEFAULT weights are% provided using the MacLeod, Luk and Titterington% weights with parameters s = 2k, and alpha = .5% (see lecture notes)% to avoid providing <weightfile> there are 2 alternatives% 1. in the user interface, use '' to indicate no% weight file% 2. provide a file name of a non-existing file% (e.g.,'foononexist')% An example of weight file, "exampleWeightFile"% is provided as a reference% testFeatures - test set, one column per vector% testLabels - labels for test set%% Outputs% trainError - the error rate on the training set (one entry per% class + total error)% testError - the error rate on the test set (one entry per class% + total error)% estTrainLabels - the labels produced by the algorithm for the% training samples% estTestLabels - the labels produced by the algorithm for the% test samplesfunction [trainError, testError, estTrainLabels, estTestLabels] = ... weightedKNNRule_VC(trainFeatures, trainLabels,algParam ,testFeatures, testLabels)%%[Nclasses, classes] = find_classes([trainLabels(:);testLabels(:)]); % Number of classes in labels[Dim, Nsam] = size(trainFeatures);comma_loc = findstr(algParam,',');Knn = str2num(algParam(2:comma_loc(1)-1));weightfile = algParam((comma_loc(1)+2):(length(algParam)-2));%--------------------------------% opening the weights fileif ( weightfile ~= []) fp = fopen(weightfile,'r'); if ( fp ~= -1) weights = fscanf(fp,'%f'); fclose(fp); haveWeights =1; if ( length(weights) > Knn) weights = weights(1:Knn); weights = weights(:); % column vector elseif (length(weights) < Knn) weights = [weights(:);zeros(Knn-length(weights),1)]; %pad with zeros else weights = weights(:); % column vector end else haveWeights = 0; endelse haveWeights =0;endhm = findobj('Tag', 'Messages'); fprintf('Weighted %d-N-N Rule: Classifying Training Set\n',Knn);if (isempty(hm)==0) s = sprintf('Weighted %d-N-N Rule: Classifying Training Set\n',Knn); set(hm,'String',s); refresh;end% --------------------------------% classifies training set[Dim,Nsamples] = size(trainFeatures);Onen = ones(1,Nsamples);for sam=1:Nsamples aux = trainFeatures(:,sam) *Onen; sdiff = trainFeatures - aux; sdiff = sdiff .* sdiff; % computes the squared distances, no need to take square rots distsq = sum(sdiff); [sorted_dist, indices] = sort(distsq); if length(trainLabels) <= Knn k_nearest = trainLabels; else k_nearest = trainLabels(indices(1:Knn)); end % If necessary, computes the weight vector if haveWeights == 0, k1 =2*Knn; % a small scrambling to deal with k > n/2 ! if ( k1 < length(trainLabels)) dmax = sqrt ( 2 * max(sorted_dist)); if ( k1 < Knn) k1 = length(trainLabels); end else dmax = sqrt(sorted_dist(k1)); k1 = Knn; end alpha = 0.5; auxDist = sqrt(sorted_dist(1:Knn)); weights = ((dmax - auxDist) + alpha * (dmax - auxDist(1))) ./ ... ((1+alpha)*(dmax - auxDist(1))); end counts = zeros(1,Nclasses); for i=1:Knn, label = k_nearest(i); label = find(classes == label); counts (label) = counts (label) + weights(i); end [foo,target] = max(counts); estTrainLabels(sam) = classes(target);end% ================================% classifies the test set hm = findobj('Tag', 'Messages'); fprintf('Weighted %d-N-N Rule: Classifying Test Set\n',Knn);if (isempty(hm)==0) s = sprintf('Weighted %d-N-N Rule: Classifying Test Set\n',Knn); set(hm,'String',s); refresh;end[Dim,Nsamples] = size(testFeatures);for sam=1:Nsamples aux = testFeatures(:,sam) *Onen; sdiff = trainFeatures - aux; sdiff = sdiff .* sdiff; % computes the squared distances, no need to take square rots distsq = sum(sdiff); [sorted_dist, indices] = sort(distsq); if length(trainLabels) <= Knn k_nearest = trainLabels; else k_nearest = trainLabels(indices(1:Knn)); end % If necessary, computes the weight vector if haveWeights == 0, k1 =2*Knn; % a small scrambling to deal with k > n/2 ! if ( k1 < length(trainLabels)) dmax = sqrt ( 2 * max(sorted_dist)); if ( k1 < Knn) k1 = length(trainLabels); end else dmax = sqrt(sorted_dist(k1)); k1 = Knn; end alpha = 0.5; auxDist = sqrt(sorted_dist(1:Knn)); weights = ((dmax - auxDist) + alpha * (dmax - auxDist(1))) ./ ... ((1+alpha)*(dmax - auxDist(1))); end counts = zeros(1,Nclasses); for i=1:Knn, label = k_nearest(i); label = find(classes == label); counts (label) = counts (label) + weights(i); end [foo,target] = max(counts); estTestLabels(sam) = classes(target);end%%trainError = computeError( classes, trainLabels, estTrainLabels);testError = computeError( classes, testLabels, estTestLabels);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -