⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 irwls1.m

📁 支持向量机的一个训练软件
💻 M
字号:
function [nsv,al3,bi,T,Lp]=irwls1(x,y,ker,C,epsi,par,tol);
%
%
%  This function solves the Support Vector Machine for regression.
%
%     [nsv,al3,bi,T,Lp]=irwls1(x,y,ker,C,epsi,par,tol);
%
% This function solves the SVM using the IRWLS procedure. 
% Parameters:
%
%              x: matrix of input vectors. Each row of x1 represents a d-dimensional vector. 
%                 The number of rows is the number of training samples.
%              y: column vector of the targets values. Each row represent the taerget of each sample in x.
%              ker: the kernel to be employed. It has to be a string out of: 'linear', 'poly_h', 'poly_i' and 'rbf'.
%                Type help kernel for further information.
%              C: is the penalty factor of the SVM soft-margin formulation.
%              epsi: is the epsilon insensitive zone defined arround the regression estimate.
%              par: is the parameter of the used kernel. type help kernel for further information.
%              tol: tolerance for the value of Lp, the primal functional, to stop the algorithm.
%                The default value if 10^-5. (optional)
%
% Outputs:
%
%              nsv: number of support vector.
%              al3: the value of the alphas-alpha*.  
%              bi: the value of the bias.
%              T: running time.
%              Lp: A vector that show the value of the primal \|w\|^2+C\sum_{i=0}^n \xi_i in each iteration
%               of the algorithm.%
% Examples:
%
%       with linear kernel
%               x=rand(20,1);
%               y=2*x+2+.2*randn(size(x));
%               [nsv,alpha,bias]=irwls1(x,y,'linear',10,0.1);
%               svrplot(x,y,'linear',alpha,bias,0.1);
%
%               W=x'*alpha;
%               X_test=rand(200,1);
%               Y_test=2*X_test+2+.2*randn(size(X_test));
%               Output=X_test*W+bias;
%               Error=Y_test-Output;
%
%               
%       with nonlinear (RBF) kernel:
%               x=10*(rand(100,1)-.5);
%               y=sinc(x)+.1*randn(size(x));
%               [nsv,alpha,bias]=irwls1(x,y,'rbf',10,.05,0.5);
%               svrplot(x,y,'rbf',alpha,bias,0.05,0.5);
%
%               X_test=10*(rand(1000,1)-.5);
%               Y_test=sinc(X_test)+.1*randn(size(X_test));
%               Output=kernel('rbf',X_test,x(find(alpha),:),0.5)*alpha(find(alpha))+bias;
%               Error=Y_test-Output;
%
% Author:   Fernando Perez-Cruz (fernandop@ieee.org)
% Version:  1.0
% Date:     18th March 2002.
%
%
T=clock;
if (nargin<5 | nargin>7)
   help irwls1
else
   if(nargin==5 & ker(1)=='p')
       disp('The used kernel needs an input parameter. We have set it to 2.');
       par=2;
   elseif(nargin==5 & ker(1)=='r')
       fprintf(1,'The used kernel needs an input parameter. We have set it to %1.3f\n',sqrt(size(x,2)));
       par=sqrt(size(x,2));
   elseif(nargin==5)
       par=0;
   end
   if(nargin<7)
      tol=10^-5;
   end         
   K=10^6;
   N=size(x,1);
   ns=-1;
   k=2;
   hacer=1;
   bi=0;   
   
   H=kernel(ker,x,x,par);
   
   i1p=1:2:N;
   i1n=2:2:N;
   i1=[i1p i1n];
      
   a=zeros(N,1);
   a(i1)=C;
   
   al3=zeros(N,1);
   i2p=[];
   i2n=[];
      
   Lp(1)=N*C;
   e=y;
   while(hacer)
      al3_a=al3;
      bi_a=bi;
      al3=zeros(N,1);
      al3(i2p)=C;
      al3(i2n)=-C;
      if(length(i1))
         Xi=inv([H(i1,i1)+diag(1./(a(i1))) ones(length(i1),1);ones(1,length(i1)) 0]);
         E=ones(N,1);
         E(i1n)=-1;
         aux=Xi*[y(i1)-E(i1)*epsi+C*H(i1,i2n)*ones(length(i2n),1)-C*H(i1,i2p)*ones(length(i2p),1);C*length(i2n)-C*length(i2p)];
         al3(i1)=aux(1:length(i1));
         bi=aux(length(i1)+1);
      end
      e_a=e;
      
      I=find(al3-al3_a);      
      if(length(I))
         e=e_a-H(:,I)*(al3(I)-al3_a(I))-(bi-bi_a);
      end      
      
      Lp(k)=al3'*H*al3/2+C*sum(abs(e(find(abs(e)>epsi)))-epsi);

      if((abs(Lp(k)-Lp(k-1))/Lp(k-1))<tol & k>(ns+1))
         nsv=length(i1)+length(i2n)+length(i2p);
         hacer=0; 
      else                               
          if((Lp(k)-Lp(k-1))/Lp(k-1)>tol)
              LL=Lp(k-1);
              HH=Lp(k);
              posL=0;
              posH=1;
              while(posL==0)
                  jj=(posH+posL)/2;
                  al3b=al3*jj+al3_a*(1-jj);          
                  bib=bi*jj+bi_a*(1-jj);
                  I=find(al3b-al3_a);
                  if(length(I))
                      eb=e_a-H(:,I)*(al3b(I)-al3_a(I))-(bib-bi_a);              
                  end
                  if(HH>LL)
                      HHa=HH;
                      posHa=posH;             
                      HH=al3b'*H*al3b/2+C*sum(abs(eb(find(abs(eb)>epsi)))-epsi);
                      posH=jj;
                  else
                      LL=al3b'*H*al3b/2+C*sum(abs(eb(find(abs(eb)>epsi)))-epsi);
                      posL=jj;
                  end
              end
              jj=posH+posL;
              al3b=al3*jj+al3_a*(1-jj);          
              bib=bi*jj+bi_a*(1-jj);
              I=find(al3b-al3_a);
              if(length(I))
                  eb=e_a-H(:,I)*(al3b(I)-al3_a(I))-(bib-bi_a);              
              end
              valor=al3b'*H*al3b/2+C*sum(abs(eb(find(abs(eb)>epsi)))-epsi);
                            
              posL=posH;
              LL=HH;
              posH=jj;
              HH=valor;
              if(valor>Lp(k-1))
                  while(abs(valor-Lp(k-1))/Lp(k-1)>10^-3)                      
                      jj=(posH+posL)/2;
                      al3b=al3*jj+al3_a*(1-jj);          
                      bib=bi*jj+bi_a*(1-jj);
                      I=find(al3b-al3_a);
                      if(length(I))
                          eb=e_a-H(:,I)*(al3b(I)-al3_a(I))-(bib-bi_a); 
                      end
                      valor=al3b'*H*al3b/2+C*sum(abs(eb(find(abs(eb)>epsi)))-epsi);
                      if(valor>Lp(k-1))                    
                          HH=valor;
                          posH=jj;
                      else
                          LL=valor;
                          posL=jj;
                      end
                  end
              end                                   
              pos=jj;
              
              al3=al3*pos+al3_a*(1-pos);
              bi=bi*pos+bi_a*(1-pos);
              I=find(al3-al3_a);      
              if(length(I))
                  e=e_a-H(:,I)*(al3(I)-al3_a(I))-(bi-bi_a);
              end  
              Lp(k)=al3'*H*al3/2+C*sum(abs(e(find(abs(e)>epsi)))-epsi);
 
              if(pos<10^-10)
                  hacer=0;
                  nsv=length(i1)+length(i2n)+length(i2p);
              end
          end                                
          
          i1p=find(e>=epsi);
          i1n=find(e<=-epsi);      
         
         a=zeros(N,1);
         a(i1p)=C./(e(i1p)-epsi);
         a(i1n)=-C./(e(i1n)+epsi);
         a(find(a>K))=K;
         
         i2p=find(al3>=(C-C/100) & al3<=(C+C/100) & al3_a>=(C-C/100) & al3_a<=(C+C/100) & e>epsi & 15*(e-epsi)>14*(e_a-epsi));
         if(length(setdiff(i1p,i2p)))
             i1p=setdiff(i1p',i2p)';
         else
             [val,pos]=min(e(i2p));
             i2p=setdiff(i2p,i2p(pos));
             i1p=setdiff(i1p',i2p)';
         end
         
         i2n=find(al3>=(-C-C/100) & al3<=(-C+C/100) & al3_a>=(-C-C/100) & al3_a<=(-C+C/100) & e<-epsi & 15*(e+epsi)<14*(e_a+epsi));
         if(length(setdiff(i1n,i2n)))
             i1n=setdiff(i1n',i2n)';
         else
             [val,pos]=max(e(i2n));
             i2n=setdiff(i2n,i2n(pos));
             i1n=setdiff(i1n',i2n)';
         end         
         
         i1=[i1p;i1n];
         
         k=k+1;
      end   
   end
end
T=etime(clock,T);

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -