📄 demo_red_output.m
字号:
function demo_red_output(patterns, labels, alpha, b, rvect, rweight, rthresh, discrepancy, rv_num, full_sv)
global C;
global kernel;
global deg;
demo_red_display(patterns, labels, alpha, b, rvect, rweight, rthresh, discrepancy, rv_num, full_sv);
sv_num = length(find(abs(alpha)>1e-5));
global rvn;
rvn=rv_num;
global slider_rvnum;
max_rv = length(discrepancy);
slider_rvnum = uicontrol('style','slider','units','normal','pos',[0.03 0.82 0.15 0.05], 'value', rv_num, 'min', 1, 'max', length(discrepancy), 'sliderstep', [1/max_rv 2/max_rv],...
'callback', 'slider_rvnum_callback(patterns, labels, alpha, b, rvect, rweight, rthresh, discrepancy)');
text_rvnum = uicontrol('style','text','units','normal','pos',[0.03 0.89 0.15 0.03],'string','Reduced Vector Number');
global svm_radio;
svm_radio = uicontrol('style','radio','units','normal','pos', [0.03 0.72 0.15 0.05], 'max', 1, 'min', 0, 'value', full_sv, 'string', 'Full SVM',...
'callback', 'svm_radio_callback(patterns, labels, alpha, b, rvect, rweight, rthresh, discrepancy)');
global rvm_radio;
rvm_radio = uicontrol('style','radio','units','normal','pos', [0.03 0.62 0.15 0.05], 'max', 1, 'min', 0, 'value', ~full_sv, 'string', 'Reduced Vectors',...
'callback', 'rvm_radio_callback(patterns, labels, alpha, b, rvect, rweight, rthresh, discrepancy)');
% threshold
global thresh_slider;
thresh_slider = uicontrol('style','slider','units','normal','pos',[0.03 0.52 0.15 0.05], 'value', 0, 'min', -0.5, 'max', 0.5,...
'callback', 'thresh_slider_callback(patterns, labels, alpha, b, rvect, rweight, rthresh, discrepancy)');
text_rvnum = uicontrol('style','text','units','normal','pos',[0.03 0.57 0.15 0.03],'string','Threshold');
cont = 0;
cont_button = uicontrol('style','pushbutton','units','normal','pos',[0.03 0.11 0.15 0.15],'string','Click to Retrain','callback','set(gcbo,''UserData'', 1);','UserData','0');
waitfor(cont_button, 'UserData', 1);
%-------------------------------------------------------------------------------------------------
function slider_rvnum_callback(patterns, labels, alpha, b, rvect, rweight, rthresh, discrepancy)
global rvn;
rv_nold = rvn;
rvn = round(get(gcbo, 'value'));
global thresh_slider;
delta_thresh = get(thresh_slider, 'value');
if rv_nold ~= rvn
demo_red_display(patterns, labels, alpha, b, rvect, rweight, rthresh+delta_thresh, discrepancy, rvn, 0);
end
%-------------------------------------------------------------------------------------------------
function svm_radio_callback(patterns, labels, alpha, b, rvect, rweight, rthresh, discrepancy)
global rvm_radio;
set(rvm_radio, 'value', 0);
global thresh_slider;
delta_thresh = get(thresh_slider, 'value');
demo_red_display(patterns, labels, alpha, b, rvect, rweight, rthresh+delta_thresh, discrepancy, 0, 1);
%-------------------------------------------------------------------------------------------------
function rvm_radio_callback(patterns, labels, alpha, b, rvect, rweight, rthresh, discrepancy)
global svm_radio;
global slider_rvnum;
set(svm_radio, 'value', 0);
rvn = round(get(slider_rvnum, 'value'));
global thresh_slider;
delta_thresh = get(thresh_slider, 'value');
demo_red_display(patterns, labels, alpha, b, rvect, rweight, rthresh+delta_thresh, discrepancy, rvn, 0);
%-------------------------------------------------------------------------------------------------
function thresh_slider_callback(patterns, labels, alpha, b, rvect, rweight, rthresh, discrepancy)
delta_thresh = get(gcbo, 'value');
global svm_radio;
bool_svm = get(svm_radio, 'value');
global slider_rvnum;
rvn = round(get(slider_rvnum, 'value'));
demo_red_display(patterns, labels, alpha, b, rvect, rweight, rthresh+delta_thresh, discrepancy, rvn, bool_svm);
%-------------------------------------------------------------------------------------------------
function demo_red_display(patterns, labels, alpha, b, rvect, rweight, rthresh, discrepancy, rv_num, full_sv)
global C;
global kernel;
global deg;
% to get rid of the old results
cla
% create test patterns on a grid
test_num=40 ;
mins=[-1,-1];
maxs=[1,1];
x_range = mins(1):((maxs(1)-mins(1))/(test_num - 1)):maxs(1) ;
y_range = mins(2):((maxs(2)-mins(2))/(test_num - 1)):maxs(2) ;
[xs, ys] = meshgrid(x_range, y_range); % two matrices
grid_patterns = [xs(:)' ; ys(:)'] ;
info_text = uicontrol('style', 'text', 'units', 'pixels', 'pos', [180, 0, 600, 40]);
if full_sv==1
% compute the output on the training patterns
RD=rbf_dot(patterns, patterns, deg) ;
plabels = sign((RD*alpha - b)') ;
% compute the output on grid_patterns
RD=rbf_dot(grid_patterns, patterns, deg) ;
pgrid_output = (RD*alpha - b)' ;
Train_Errors = (plabels ~= labels);
Train_Error = sum(Train_Errors)/length(labels);
set(info_text, 'string', sprintf('%.1f%% Training Error, %d SVs, %d positive SVs, %d negative SVs', Train_Error*100,...
length(find(abs(alpha)>1e-5)), length(find(alpha>1e-5)), length(find(alpha<-1e-5))));
else
% compute the output on the training patterns
RD=rbf_dot(patterns, rvect(:,1:rv_num), deg);
plabels = sign((RD*rweight{rv_num} - rthresh(rv_num))');
% compute the output on grid_patterns
RD=rbf_dot(grid_patterns, rvect(:,1:rv_num), deg) ;
pgrid_output = (RD*rweight{rv_num} - rthresh(rv_num))' ;
Train_Errors = (plabels ~= labels);
Train_Error = sum(Train_Errors)/length(labels);
str_discr='weights: ';
for i=1:rv_num
str_discr = strcat(str_discr, sprintf(' %.2f, ', rweight{rv_num}(i)));
end
set(info_text, 'string', sprintf('%.1f%% Training Error, %d SVs, %d Reduced Vectors used, %.2f discrepancy\n %s\n Threshold: %.4f',...
Train_Error*100, length(find(abs(alpha)>1e-5)), rv_num, discrepancy(rv_num), str_discr, rthresh(rv_num)));
end
%fprintf(1, 'done\n');
% plot the output contours
imag = - reshape(pgrid_output, test_num, test_num);
colormap('jet')
%colormap('cool')
pcolor(x_range, y_range, imag) ;
% draw the boundary
[c,h] = contour(x_range, y_range, imag, [0 0],'k') ;
for i=1:length(h)
set(h(i),'LineWidth',5);
end
% draw the margin
[c,h] = contour(x_range, y_range, imag, [-1 1],'k') ;
for i=1:length(h)
set(h(i),'LineStyle',':','LineWidth',2);
end
shading interp
% plot the training patterns
p = plot(patterns(1,find(labels==-1)), patterns(2,find(labels==-1)), 'ko');%'ko') ;
set(p,'MarkerSize',8,'linewidth',[2]);
p = plot(patterns(1,find(labels==1)), patterns(2,find(labels==1)), 'w.');%'k.') ;
set(p,'MarkerSize',25);
% mark the errors
p = plot(patterns(1,find(Train_Errors==1)), patterns(2,find(Train_Errors==1)), 'kx');
set(p,'MarkerSize',10,'linewidth',[1.5]);
% draw the reduced set vector
if full_sv ~= 1
p = plot(rvect(1,1:rv_num), rvect(2,1:rv_num), 'pentagram');
set(p,'MarkerSize',10, 'MarkerEdgeColor', 'w', 'MarkerFaceColor', 'k');
end
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -