📄 start_multi_classification.m
字号:
% Main function for the GUI multi-algorithm screen
Npoints = 100;
hFigure = gcf;
hm = findobj('Tag', 'Messages');
set(hm,'String','');
error_method_val = get(findobj('Tag', 'popErrorEstimation'),'Value');
error_method_str = get(findobj('Tag', 'popErrorEstimation'),'String');
error_method = char(error_method_str(error_method_val));
if (isempty(whos('targets')))
set(hm,'String','No targets on workspace. Please load targets.')
break
end
if (isempty(whos('features')))
set(hm,'String','No features on workspace. Please load features.')
break
end
h = findobj('Tag', 'txtRedraws');
redraws = str2num(get(h, 'String'));
if isempty(redraws),
set(hm,'String','Please select how many redraws are needed.')
break
else
if (redraws < 1),
set(hm,'String','Number of redraws must be larger than 0.')
break
end
end
h = findobj('Tag', 'txtPrecentage');
percent = str2num(get(h, 'String'));
if strcmp(error_method, 'Holdout'),
if isempty(percent),
set(hm,'String','Please select the percentage of training vectors.')
break
else
if (floor(percent/100*length(targets)) < 1),
set(hm,'String','Number training vectors must be larger than 0.')
break
end
end
end
%Find which algorithms will be used
hAlgorithms = findobj('Tag','lstChosenAlgorithms');
algorithms = get(hAlgorithms,'String');
if ((isempty(deblank(algorithms(1,:)))) & (size(algorithms,1) == 1))
set(hm,'String','Please select at least one algorithm.')
break
end
Nalgorithms = size(algorithms,1);
All_algorithms = read_algorithms('Algorithms.txt');
for i = 1:Nalgorithms,
if strmatch('SVM',deblank(algorithms(i,:)),'exact'),
if isempty(findobj('Tag','SVM_params_window')),
h = svm_params_window;
h1 = findobj(h, 'Tag','LearningParameter');
set(h1,'String','0');
h1 = findobj(h, 'Tag','KernelParameter');
set(h1,'String','0.1');
h1 = findobj(h, 'Tag','GeneralParameter');
set(h1,'String','');
end
Chosen_algorithms(:).Name = 'SVM';
Chosen_algorithms(:).Parameter = str2num(get(findobj(findobj('Tag','Main'),'Tag','txtAlgorithmParameters'),'String'));
else
index = strmatch(deblank(algorithms(i,:)),char(All_algorithms(:).Name),'exact');
if ~isempty(index),
Chosen_algorithms(i).Name = deblank(algorithms(i,:));
if isempty(strmatch('N',All_algorithms(index).Field)),
Chosen_algorithms(i).Parameter = char(inputdlg(['Enter ' All_algorithms(index).Caption], All_algorithms(index).Name, 1, cellstr(All_algorithms(index).Default)));
else
Chosen_algorithms(i).Parameter = '';
end
end
end
end
%Now that the data is OK, start working
set(gcf,'pointer','watch');
%Some variable definitions
Nclasses = find_classes(targets); %Number of classes in targets
test_err = zeros(Nalgorithms,redraws);
train_err = zeros(Nalgorithms,redraws);
%Find the region for the grid
[region,x,y] = calculate_region(features, [zeros(1,4) Npoints]);
for k = 1: Nalgorithms,
for i = 1: redraws,
set(hm, 'String', [Chosen_algorithms(k).Name ' algorithm: Processing iteration ' num2str(i) ' of ' num2str(redraws) ' iterations...']);
%Make a draw according to the error method chosen
L = length(targets);
switch error_method
case cellstr('Resubstitution')
test_indices = 1:L;
train_indices = 1:L;
case cellstr('Holdout')
[test_indices, train_indices] = make_a_draw(floor(percent/100*L), L);
case cellstr('Cross-Validation')
chunk = floor(L/redraws);
test_indices = 1 + (i-1)*chunk : i * chunk;
train_indices = [1:(i-1)*chunk, 1+i * chunk:L];
end
train_features = features(:, train_indices);
train_targets = targets (:, train_indices);
test_features = features(:, test_indices);
test_targets = targets (:, test_indices);
param = str2double(Chosen_algorithms(k).Parameter);
if ~isfinite(param),
param = Chosen_algorithms(k).Parameter;
end
D = feval(Chosen_algorithms(k).Name, train_features, train_targets, param, region);
[classify, err] = classification_error(D, train_features, train_targets, region);
train_err(k,i) = err;
[classify, err] = classification_error(D, test_features, test_targets, region);
test_err(k,i) = err;
end
end
hDisp = findobj('Tag','popErrorDisplay');
sDisp = get(hDisp,'String');
switch char(sDisp(get(hDisp,'Value'))),
case 'Test error'
if (redraws > 1),
err = mean(test_err');
else
err = test_err;
end
case 'Train error'
if (redraws > 1),
err = mean(train_err');
else
err = train_err';
end
otherwise
if (redraws > 1),
err = mean(test_err')*length(test_targets)+mean(train_err')*length(train_targets);
err = err / (length(test_targets)+length(train_targets));
else
err = test_err*length(test_targets)+train_err*length(train_targets);
err = err / (length(test_targets)+length(train_targets));
end
end
hBayes = findobj('Tag','chkBayes');
if ((get(hBayes, 'Value')) & (~isempty(whos('p0')))),
Dbayes = decision_region(m0, s0.^2, w0, m1, s1.^2, w1, p0, region);
[classify, Bayes_err] = classification_error(Dbayes, features, targets, region);
err(length(err)+1) = Bayes_err;
Nalgorithms = Nalgorithms + 1;
Chosen_algorithms(Nalgorithms).Name='Bayes err ';
end
%Plot the results
figure
bar(err)
title('Average classification errors')
for k=1:Nalgorithms,
str = deblank(Chosen_algorithms(k).Name);
str(findstr(str,'_')) = ' ';
h=text(k,err(k)+.02,str);
set(h,'HorizontalAlignment','Center')
set(h,'FontSize',12)
%set(h,'Color',[1 1 1])
end
ax = axis;ax(3)=0;ax(4)=max(1,max(err));
axis(ax)
s = 'Finished!';
set(hm, 'String', s);
set(hFigure,'pointer','arrow');
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -