⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 nearest_neighbor_vc.m

📁 最新的模式识别分类工具箱,希望对朋友们有用!
💻 M
字号:
% Learns classifier and classifies test set
% using k-NN rule
% Usage
%      [trainError, testError, estTrainLabels, estTestLabels] = ...
%         Nearest_Neighbor_VC(trainFeatures, trainLabels, testFeatures, testLabels)
% where
%
% Inputs:
% 	trainFeatures	- the training set vectors, one vector per column
%	trainLabels    - the labels of the above
%       algo_pars       - Ignored
%       testFeatures   - test set, one column per vector
%       testLabels     - labels  for test set%
% Outputs
%	train_error   - the error rate on the training set (one entry per
%	                   class + total error)
%	test_error    - the error rate on the test set (one entry per class
%	                  + total error)


function [trainError, testError, estTrainLabels, estTestLabels] = ...
Nearest_Neigbor_VC(trainFeatures, trainLabels, Knn, testFeatures, testLabels)


% Date     Name               Change
% 03/18/02 Vittorio Castelli  Now works for any number of features,
%                             any number of classes. 
%                             Added documentation

hm = findobj('Tag', 'Messages'); 
fprintf('%d-Nearest-Neighbor: Computing Training Set Error Rate\n',Knn);
if (isempty(hm)==0)
  s = sprintf('%d-Nearest-Neighbor: Computing Training Set Error Rates', Knn);
  set(hm,'String',s);
  refresh;
end

[Nclasses, classes]  = find_classes([trainLabels(:);testLabels(:)]); % Number of classes in labels

[Dim, Nsam]          = size(trainFeatures);
OneD = ones(1,Nsam);


% computing  the error on the training set

hatTrainLabels = zeros(size(trainLabels));

for sam = 1:Nsam,
  % computes squared distances, sorts them
  diff = trainFeatures(:,sam) * OneD;
  diff = diff - trainFeatures;
  diff = sum(diff.* diff);
  [diff, indices] = sort(diff);
  % takes care of the small sample problem
  if (length(trainLabels) <= Knn)
    k_nearest = trainLabels;
  else
    k_nearest = trainLabels(indices(1:Knn));
  end
  % finds the class with largest number of neighbors
  bestClass = 0;
  maxNN     = 0;
  for cl=1:Nclasses,
    nrNN = sum(k_nearest == classes(cl));
    if(nrNN > maxNN)
      maxNN = nrNN; bestClass = cl;
    end
  end
  hatTrainLabels(sam) = classes(bestClass);
end

% ================================================================
fprintf('%d-Nearest-Neighbor: Computing Test Set Error Rate\n', Knn);
if (isempty(hm)==0)
  s = sprintf('%d-Nearest-Neighbor: Computing Test Set Error Rates',Knn);
  set(hm,'String',s);
  refresh;
end

% computing the test error

[Dim, Nsam]   = size(testFeatures);
hatTestLabels = zeros(size(testLabels));

for sam = 1:Nsam,
  diff = testFeatures(:,sam) * OneD;
  diff = diff - trainFeatures;
  diff = sum(diff.* diff);
  [diff, indices] = sort(diff);

  if (length(trainLabels) <= Knn)
    k_nearest = trainLabels;
  else
    k_nearest = trainLabels(indices(1:Knn));
  end
  bestClass = 0;
  maxNN     = 0;

  for cl=1:Nclasses,
    nrNN = sum(k_nearest == classes(cl));
    if(nrNN > maxNN)
      maxNN = nrNN; bestClass = cl;
    end
  end
  hatTestLabels(sam) = classes(bestClass);
end


trainError = computeError( classes, trainLabels, hatTrainLabels);
testError  = computeError( classes, testLabels, hatTestLabels);

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -