⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 start_multi_classification.m

📁 最新的模式识别分类工具箱,希望对朋友们有用!
💻 M
字号:
% Main function for the GUI multi-algorithm screen

Npoints = 100;

hFigure = gcf;
hm = findobj('Tag', 'Messages'); 
set(hm,'String',''); 

error_method_val = get(findobj('Tag', 'popErrorEstimation'),'Value');
error_method_str = get(findobj('Tag', 'popErrorEstimation'),'String');
error_method 	  = char(error_method_str(error_method_val));

if (isempty(whos('targets')))    
   set(hm,'String','No targets on workspace. Please load targets.')   
   break
end

if (isempty(whos('features')))
   set(hm,'String','No features on workspace. Please load features.')
   break
end 
  
h = findobj('Tag', 'txtRedraws');   
redraws = str2num(get(h, 'String'));   
if isempty(redraws), 
   set(hm,'String','Please select how many redraws are needed.')      
   break
else     
   if (redraws < 1), 
      set(hm,'String','Number of redraws must be larger than 0.')     
      break    
   end   
end      
  
h = findobj('Tag', 'txtPrecentage'); 
percent = str2num(get(h, 'String'));   
if strcmp(error_method, 'Holdout'),
	if isempty(percent), 
	   set(hm,'String','Please select the percentage of training vectors.')     
	   break
	else     
	   	if (floor(percent/100*length(targets)) < 1),     
      		set(hm,'String','Number training vectors must be larger than 0.')     
	      	break    
      	end   
   	end
end      

  
  
%Find which algorithms will be used
hAlgorithms	= findobj('Tag','lstChosenAlgorithms');
algorithms  = get(hAlgorithms,'String'); 

if ((isempty(deblank(algorithms(1,:)))) & (size(algorithms,1) == 1))
   set(hm,'String','Please select at least one algorithm.')     
   break   
end
Nalgorithms = size(algorithms,1);

All_algorithms = read_algorithms('Algorithms.txt');

for i = 1:Nalgorithms,
    if strmatch('SVM',deblank(algorithms(i,:)),'exact'),
        if isempty(findobj('Tag','SVM_params_window')),
            h     = svm_params_window;
            h1    = findobj(h, 'Tag','LearningParameter');
            set(h1,'String','0');
            h1    = findobj(h, 'Tag','KernelParameter');
            set(h1,'String','0.1');
            h1    = findobj(h, 'Tag','GeneralParameter');
            set(h1,'String','');
        end 
        Chosen_algorithms(:).Name = 'SVM';
        Chosen_algorithms(:).Parameter = str2num(get(findobj(findobj('Tag','Main'),'Tag','txtAlgorithmParameters'),'String'));
    else
        index = strmatch(deblank(algorithms(i,:)),char(All_algorithms(:).Name),'exact');
        if ~isempty(index),
            Chosen_algorithms(i).Name = deblank(algorithms(i,:));
            if isempty(strmatch('N',All_algorithms(index).Field)),
                Chosen_algorithms(i).Parameter = char(inputdlg(['Enter ' All_algorithms(index).Caption], All_algorithms(index).Name, 1, cellstr(All_algorithms(index).Default)));
            else
                Chosen_algorithms(i).Parameter = '';
            end
        end
    end
end

%Now that the data is OK, start working
set(gcf,'pointer','watch');

%Some variable definitions
Nclasses		= find_classes(targets); %Number of classes in targets
test_err 		= zeros(Nalgorithms,redraws);   
train_err 		= zeros(Nalgorithms,redraws);   

%Find the region for the grid
[region,x,y]  = calculate_region(features, [zeros(1,4) Npoints]);    

for k = 1: Nalgorithms,
   for i = 1: redraws,  
		set(hm, 'String', [Chosen_algorithms(k).Name ' algorithm: Processing iteration ' num2str(i) ' of ' num2str(redraws) ' iterations...']);
   
		%Make a draw according to the error method chosen
		L = length(targets);
		switch error_method
      	case cellstr('Resubstitution')
         	test_indices = 1:L;
	         train_indices = 1:L;
   	   case cellstr('Holdout')
      	   [test_indices, train_indices] = make_a_draw(floor(percent/100*L), L);           
	      case cellstr('Cross-Validation')
   	      chunk = floor(L/redraws);
      	   test_indices = 1 + (i-1)*chunk : i * chunk;
         	train_indices  = [1:(i-1)*chunk, 1+i * chunk:L];
	      end
   
   	   train_features = features(:, train_indices);    
			train_targets  = targets (:, train_indices);    
   	   test_features  = features(:, test_indices);     
			test_targets   = targets (:, test_indices);     
         
       param = str2double(Chosen_algorithms(k).Parameter);
       if ~isfinite(param),
           param = Chosen_algorithms(k).Parameter;
       end
       D = feval(Chosen_algorithms(k).Name, train_features, train_targets, param, region); 
  
	   [classify, err]  = classification_error(D, train_features, train_targets, region);
       train_err(k,i)   = err;     
       [classify, err]  = classification_error(D, test_features, test_targets, region);
       test_err(k,i)    = err;     
      
   end      
end

hDisp   = findobj('Tag','popErrorDisplay');
sDisp   = get(hDisp,'String');
switch char(sDisp(get(hDisp,'Value'))),
case 'Test error'
    if (redraws > 1),
        err    = mean(test_err');
    else
        err    = test_err;
    end        
case 'Train error'
    if (redraws > 1),
        err    = mean(train_err');
    else
        err    = train_err';
    end
otherwise
    if (redraws > 1),
       err = mean(test_err')*length(test_targets)+mean(train_err')*length(train_targets);
       err = err / (length(test_targets)+length(train_targets));
    else
       err = test_err*length(test_targets)+train_err*length(train_targets);
       err = err / (length(test_targets)+length(train_targets));
    end   
end

hBayes = findobj('Tag','chkBayes');
if ((get(hBayes, 'Value')) & (~isempty(whos('p0')))),
   Dbayes = decision_region(m0, s0.^2, w0, m1, s1.^2, w1, p0, region);
   [classify, Bayes_err]  = classification_error(Dbayes, features, targets, region);
   err(length(err)+1) = Bayes_err;
   Nalgorithms = Nalgorithms + 1;
   Chosen_algorithms(Nalgorithms).Name='Bayes err       ';
end   

%Plot the results
figure
bar(err)
title('Average classification errors')
for k=1:Nalgorithms,
   str = deblank(Chosen_algorithms(k).Name);
   str(findstr(str,'_')) = ' ';
   h=text(k,err(k)+.02,str);
   set(h,'HorizontalAlignment','Center')
   set(h,'FontSize',12)
   %set(h,'Color',[1 1 1])
end
ax = axis;ax(3)=0;ax(4)=max(1,max(err));
axis(ax)

s = 'Finished!';
set(hm, 'String', s);   
set(hFigure,'pointer','arrow');

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -