📄 reduced_svm.m
字号:
function [rvect, rweight, rthresh, discrepancy]=reduced_svm(patterns, labels, svect, alpha, rv_num, epsilon_discrepancy, max_it, max_attempt)
global C;
global kernel;
global deg;
if nargin <= 5
epsilon_discrepancy = 0.1;
end
if nargin <= 6
max_it = 75;
end
if nargin <= 7
max_attempt = 500;
end
dim = size(svect,1);
%rvect=zeros(dim, rv_num);
m = min(min(svect));
M = max(max(svect));
sv_num=size(svect,2);
% kernel matrix used to compute the discrepancy
K = rbf_dot(svect, svect, deg);
n = 1;
attempt = 0;
denom_too_low = 0;
no_progress = 0;
while n <= rv_num
% random initialisation
%rvect(:,n) = rand(dim,1)*(M-m) + m;
rvect(:,n) = rand(dim,1)*2 - 1;
k=zeros(1,sv_num+n-1);
% compute the reduced vector
for it=1:max_it
rvect_old = rvect(:,n); % last value
k(1:sv_num)=alpha.*rbf_dot(svect, rvect(:,n), deg);
if n > 1
k(sv_num+1:sv_num+n-1) = -rweight{n-1}.*rbf_dot(rvect(:,1:n-1), rvect(:,n), deg);
end
if abs(sum(k)) < 1e-20
% denominator too small -> redo the optimisation of this reduced vector
break;
end
rvect(:,n) = svect*k(1:sv_num)' / sum(k);
if n > 1
rvect(:,n) = rvect(:,n) + rvect(:,1:n-1)*k(sv_num+1:sv_num+n-1)' / sum(k);
end
if max(abs(rvect_old-rvect(:,n))) < 1e-3
break;
end
end
if it >= max_it | abs(sum(k)) < 1e-20
% it did not converge -> restart
tmp = abs(sum(k));
attempt = attempt +1;
if abs(sum(k)) < 1e-20
denom_too_low=denom_too_low+1;
end
if it >= max_it
no_progess=no_progress+1;
end;
if mod(attempt, 100) == 0
fprintf(1, 'Vector: %d, attempt: %d (cause of slow convergence: %d denom too low, %d no progess)\n', n, attempt, denom_too_low, no_progress);
end
if attempt >= max_attempt
rvect=rvect(:,1:n-1);
break;
end
else
% compute the new set of weight
Kzx = rbf_dot(svect, rvect(:,1:n), deg);
Kz = rbf_dot(rvect(:,1:n), rvect(:,1:n), deg);
rweight{n} = pinv(Kz)*(Kzx'*alpha);
% compute the 'best' threshold for n reduced vectors
RD=rbf_dot(patterns, rvect(:,1:n), deg);
y=(RD*rweight{n})';
rthresh(n)=-compute_thres(y, labels);
% compute the discrepancy
discrepancy(n) = alpha'*K*alpha + rweight{n}'*rbf_dot(rvect(:,1:n), rvect(:,1:n), deg)*rweight{n} - 2*alpha'*rbf_dot(svect, rvect(:, 1:n), deg)*rweight{n};
% print progess
if attempt > 1
fprintf(1, 'Computing Reduced Vector : %d / %d, %d attempts, %.2f discrepancy\n', n, rv_num, attempt, discrepancy(n));
else
fprintf(1, 'Computing Reduced Vector : %d / %d, %.2f discrepancy\n', n, rv_num, discrepancy(n));
end
if discrepancy(n) < epsilon_discrepancy
break;
end
n=n+1;
attempt = 0;
denom_too_low = 0;
no_progress = 0;
end
end
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -