⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 start_classify.m

📁 Duda的《模式分类》第二版的配套的Matlab源代码
💻 M
字号:
function [D, test_err, train_err, train_patterns, train_targets, reduced_patterns, reduced_targets] = start_classify(patterns, targets, error_method, redraws, percent, Preprocessing_algorithm, PreprocessingParameters, Classification_algorithm, AlgorithmParameters, region, hm, SepratePreprocessing, plot_on)

% Main function for evaluating a single classifier
% Inputs:
%   patterns                    - The examples of the data
%   targets                     - The labels for the data
%   error_method                - Error estimation method (Cross-validation, Holdout or Resubstitution)
%   redraws                     - Number of redraws needed
%   percent                     - Percentage of training vectors
%   Preprocessing_algorithm     - A preprocessing algorithm
%   PreprocessingParameters     - ...and it's parameters
%   Classification_algorithm    - A classification algorithm
%   AlgorithmParameters         - ...and it's parameters
%	region	                    - Decision region vector: [-x x -y y number_of_points]
%   hm                          - Handle to the message box on the GUI (Can be [])
%   SepratePreprocessing        - Perform separate preprocessing for each class
%   plot_on                     - Plot during preprocessing
%
% Outputs:
%   D                           - The last decision region
%   test_err                    - The test errors
%   train_err                   - The train errors
%   train_patterns              - The train patterns
%   train_targets               - ...and targets
%   reduced_patterns            - patterns after preprocessing
%   reduced_targets             - ...and targets

%Some variable definitions
N               = region(5);
Nclasses		= length(unique(targets)); %Number of classes in targets
test_err 		= zeros(Nclasses+1,redraws);   
train_err 		= zeros(Nclasses+1,redraws);   
x               = linspace(region(1), region(2), N);
y               = linspace(region(3), region(4), N);
mx              = ones(N,1) * x;
my              = y' * ones(1,N);
flatxy          = [mx(:), my(:)]';

reduced_patterns= [];
reduced_targets = [];

if ~isempty(hm),
    hParent         = get(hm,'Parent'); %Get calling window tag    
end

for i = 1: redraws,  
    if ~isempty(hm),
        set(hm, 'String', ['Processing iteration ' num2str(i) ' of ' num2str(redraws) ' iterations...']);
    end
   
   %Make a draw according to the error method chosen
   L = length(targets);
   switch error_method
   case cellstr('Resubstitution')
      test_indices = 1:L;
      train_indices = 1:L;
  	case cellstr('Holdout')
	   [test_indices, train_indices] = make_a_draw(floor(percent/100*L), L);           
   case cellstr('Cross-Validation')
      chunk = floor(L/redraws);
      test_indices = 1 + (i-1)*chunk : i * chunk;
      train_indices  = [1:(i-1)*chunk, i * chunk + 1:L];
   end
   train_patterns = patterns(:, train_indices);    
   train_targets  = targets (:, train_indices);    
   test_patterns  = patterns(:, test_indices);     
   test_targets   = targets (:, test_indices);     
   
   %Preprocess and then find decision region
   switch Preprocessing_algorithm
   case 'None'
       disp('Generating decision region')
       D = reshape(feval(Classification_algorithm, train_patterns, train_targets, flatxy, AlgorithmParameters), N, N); 
       disp('Calculating the error')
       [train_err(:,i), test_err(:,i)] = calculate_error (D, train_patterns, train_targets, test_patterns, test_targets, region, Nclasses);
       
   case {'PCA', 'Whitening_transform'}
       disp('Performing preprocessing')
       [reduced_patterns, reduced_targets, uw, m] = feval(Preprocessing_algorithm, train_patterns, train_targets, PreprocessingParameters); 
       reduced_patterns = uw*(train_patterns - m*ones(1,size(train_patterns,2)));
       disp('Generating decision region')
       [region, x, y] = calculate_region(uw*(patterns-m*ones(1,size(patterns,2))), region);    
       D = reshape(feval(Classification_algorithm, reduced_patterns, reduced_targets, flatxy, AlgorithmParameters), N, N); 
       disp('Calculating the error')
       [train_err(:,i), test_err(:,i)] = calculate_error (D, reduced_patterns, reduced_targets, uw*(test_patterns-m*ones(1,size(test_patterns,2))), test_targets, region, Nclasses);      
       
   case 'Scaling_transform'
       disp('Performing preprocessing')
       [reduced_patterns, reduced_targets, w, m] = feval(Preprocessing_algorithm, train_patterns, train_targets, PreprocessingParameters); 
       disp('Generating decision region')
       [region, x, y] = calculate_region((patterns-m*ones(1,size(patterns,2)))./(w * ones(1,size(patterns,2))), region);    
       D = reshape(feval(Classification_algorithm, reduced_patterns, reduced_targets, flatxy, AlgorithmParameters), N, N); 
       disp('Calculating the error')
       [train_err(:,i), test_err(:,i)] = calculate_error (D, reduced_patterns, reduced_targets, (test_patterns-m*ones(1,size(test_patterns,2)))./(w * ones(1,size(test_patterns,2))), test_targets, region, Nclasses);      
       
   case 'FishersLinearDiscriminant'
       disp('Performing preprocessing')
       [reduced_patterns, reduced_targets, w] = feval(Preprocessing_algorithm, train_patterns, train_targets, []); 
       [region, x, y] = calculate_region(reduced_patterns, region);    
       disp('Generating decision region')
       D = reshape(feval(Classification_algorithm, reduced_patterns, reduced_targets, flatxy, AlgorithmParameters), N, N); 
       disp('Calculating the error')
       [train_err(:,i), test_err(:,i)] = calculate_error (D, reduced_patterns, reduced_targets, [w'*test_patterns; zeros(1,length(test_targets))], test_targets, region, Nclasses);      

       %If possible, replot the data
       if ~isempty(hParent),
           hold off
           plot_scatter([w'*patterns; zeros(1,length(targets))], targets, hParent)
           hold on
       end
       
   otherwise
      disp('Performing preprocessing')
      if SepratePreprocessing,
         disp('Perform seperate preprocessing for each class.')
         in0 = find(train_targets == 0);
         in1 = find(train_targets == 1);
       	 [reduced_patterns0, reduced_targets0] = feval(Preprocessing_algorithm, train_patterns(:,in0), train_targets(in0), PreprocessingParameters, plot_on); 
         [reduced_patterns1, reduced_targets1] = feval(Preprocessing_algorithm, train_patterns(:,in1), train_targets(in1), PreprocessingParameters, plot_on); 
         reduced_patterns = [reduced_patterns0, reduced_patterns1];
         reduced_targets  = [reduced_targets0,  reduced_targets1];
	   else
         [reduced_patterns, reduced_targets] = feval(Preprocessing_algorithm, train_patterns, train_targets, PreprocessingParameters, plot_on); 
      end
      pause(1);
      plot_process([]);
      indices = find(sum(isfinite(reduced_patterns)) > 0);
      reduced_patterns = reduced_patterns(:,indices);
      reduced_targets  = reduced_targets(:,indices);
      if ((i == redraws) & (~isempty(hParent)))
         %Plot only during the last iteration
         plot_scatter(reduced_patterns, reduced_targets, hParent, 1)
	      axis(region(1:4))
      end
      %Show Voronoi diagram
      if ~isempty(findobj('Tag','Voronoi diagram')),
         %Voronoi diagram figure exists
         figure(findobj('Tag','Voronoi diagram'))
         clf;
      else
         figure;
         set(gcf,'Tag','Voronoi diagram');
      end
      hold on
      contour(x,y,voronoi_regions(reduced_patterns, region),length(reduced_targets))
      plot_scatter(reduced_patterns, reduced_targets)
      hold off
      axis(region(1:4));
      grid on;
      title('Voronoi regions')
      if ~isempty(hParent),
          figure(hParent)
      end
      if ((sum(reduced_targets) <= 1) & (sum(~reduced_targets) <= 1) & (~strcmp(Classification_algorithm,'None')))
            error('Too few reduced points (This program needs at least two points of each class). Please restart.')
      else
        if strcmp(Classification_algorithm,'None'),
            %No classification was asked for
            D = zeros(region(5));
            set(gcf,'pointer','arrow');     
        end
      end
	  disp('Generating decision region')
      D = reshape(feval(Classification_algorithm, reduced_patterns, reduced_targets, flatxy, AlgorithmParameters), N, N); 
      disp('Calculating the error')
      [train_err(:,i), test_err(:,i)] = calculate_error (D, train_patterns, train_targets, test_patterns, test_targets, region, Nclasses);
      
   end      
  
end      
  

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -