⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 reduced_svm.m

📁 关于人脸识别的一个VC程序
💻 M
字号:
function [rvect, rweight, rthresh, discrepancy]=reduced_svm(patterns, labels, svect, alpha, rv_num, epsilon_discrepancy, max_it, max_attempt)

global C;
global kernel;
global deg;

if nargin <= 5
   epsilon_discrepancy = 0.1;
end
if nargin <= 6
   max_it = 75;
end
if nargin <= 7
   max_attempt = 500;
end

dim = size(svect,1);
%rvect=zeros(dim, rv_num);
m = min(min(svect));
M = max(max(svect));
sv_num=size(svect,2);

% kernel matrix used to compute the discrepancy
K = rbf_dot(svect, svect, deg);

n = 1;
attempt = 0;
denom_too_low = 0;
no_progress = 0;
while n <= rv_num
   % random initialisation
   %rvect(:,n) = rand(dim,1)*(M-m) + m;
   rvect(:,n) = rand(dim,1)*2 - 1;
   
   k=zeros(1,sv_num+n-1);
   % compute the reduced vector
   for it=1:max_it
      rvect_old = rvect(:,n);             % last value
      k(1:sv_num)=alpha.*rbf_dot(svect, rvect(:,n), deg);
      if n > 1
         k(sv_num+1:sv_num+n-1) = -rweight{n-1}.*rbf_dot(rvect(:,1:n-1), rvect(:,n), deg);
      end
      if abs(sum(k)) < 1e-20
         % denominator too small -> redo the optimisation of this reduced vector
         break;
      end
      rvect(:,n) = svect*k(1:sv_num)' / sum(k);
      if n > 1
	      rvect(:,n) = rvect(:,n) + rvect(:,1:n-1)*k(sv_num+1:sv_num+n-1)' / sum(k);
      end
      if max(abs(rvect_old-rvect(:,n))) < 1e-3
         break;
      end
   end
   
   if it >= max_it | abs(sum(k)) < 1e-20
      % it did not converge -> restart
      tmp = abs(sum(k));
      attempt = attempt +1;
      if abs(sum(k)) < 1e-20
         denom_too_low=denom_too_low+1; 
      end
      if it >= max_it
         no_progess=no_progress+1; 
      end;
      if mod(attempt, 100) == 0
         fprintf(1, 'Vector: %d, attempt: %d (cause of slow convergence: %d denom too low, %d no progess)\n', n, attempt, denom_too_low, no_progress);
      end
      if attempt >= max_attempt
         rvect=rvect(:,1:n-1);
         break;
      end
   else
	   % compute the new set of weight
   	Kzx = rbf_dot(svect, rvect(:,1:n), deg);
      Kz = rbf_dot(rvect(:,1:n), rvect(:,1:n), deg);
      rweight{n} = pinv(Kz)*(Kzx'*alpha);
      % compute the 'best' threshold for n reduced vectors
      RD=rbf_dot(patterns, rvect(:,1:n), deg);
      y=(RD*rweight{n})';
      rthresh(n)=-compute_thres(y, labels);
      
      % compute the discrepancy
      discrepancy(n) = alpha'*K*alpha + rweight{n}'*rbf_dot(rvect(:,1:n), rvect(:,1:n), deg)*rweight{n} - 2*alpha'*rbf_dot(svect, rvect(:, 1:n), deg)*rweight{n};
      
      % print progess
      if attempt > 1
         fprintf(1, 'Computing Reduced Vector : %d / %d, %d attempts, %.2f discrepancy\n', n, rv_num, attempt, discrepancy(n));
      else
         fprintf(1, 'Computing Reduced Vector : %d / %d, %.2f discrepancy\n', n, rv_num, discrepancy(n));
      end
      if discrepancy(n) < epsilon_discrepancy
         break;
      end
      n=n+1;
      attempt = 0;
      denom_too_low = 0;
      no_progress = 0;
   end
end

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -