⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 svmclassl2ls.m

📁 matlab图像处理工具相
💻 M
字号:
function [xsup,w,b,pos,timeps,alpha]=svmclassL2LS(x,y,c,lambda,kernel,kerneloption,verbose,span,qpsize,chunksize)


% 
% [xsup,w,b,pos,timeps,alpha]=svmclassL2LS(x,y,c,lambda,kernel,kerneloption,verbose,span,qpsize,chunksize)
%
% %
% %   large-scale classification svm 
% %


% dbstop if warning
% dbstop if error

n=size(y,1);
if nargin < 10
    chunksize=100;
end;
if nargin<9
    qpsize=100; 
end;
if nargin < 10
    % even number
    chunksize=qpsize;
    
end;
if isstruct(x)
    if length(x.indice)~=length(y)
        error('Length of x and y should be equal');
    end;
end


maxqpsize=qpsize;
if qpsize> n
    qpsize=n;
end;
if rem(qpsize,2)==1
    qpsize=qpsize-1;
end;
kkttol=1e-5;
difftol=1e-12;





alphaold=zeros(n,1);
alpha=zeros(n,1);
workingset=zeros(n,1);
nws=zeros(n,1);

class1=(y>=0);
class0=(y<0);
iteration=0;
bias=0;
lambdab=0;
%
%keyboard

while 1
    
        if iteration==200
         % keyboard
            %break    
        end;
    
    
    %
    %   calcul des indices des SV et non SV
    %
    
    SV=(abs(alpha)>difftol);
    
    
    %
    %    Calcul de la sortie du SVM
    %
    
    if iteration==0  ;
        changedSV=find(SV);
        changedAlpha=alpha(changedSV);
        s=zeros(n,1);
        
    else
        changedSV=find( abs(alpha-alphaold)> difftol );
        changedAlpha=alpha(changedSV)-alphaold(changedSV);  
    end;
    
    if ~isempty(changedSV)
        
        chunks1=ceil(n/chunksize);
        chunks2=ceil(length(changedSV)/chunksize);
        
        for ch1=1:chunks1
            ind1=(1+(ch1-1)*chunksize) : min( n, ch1*chunksize);
            for ch2=1:chunks2
                ind2=(1+(ch2-1)*chunksize) : min(length(changedSV), ch2*chunksize);
                
                %-----------------------------------------------------------                
                if ~isfield(x,'datafile')
                    x1=x(ind1,:);
                    x2=x(changedSV(ind2),:);
                else
                    x1=fileaccess(x.datafile,x.indice(ind1),x.dimension);
                    x2=fileaccess(x.datafile,x.indice(changedSV(ind2)),x.dimension);
                    
                end;   
                kchunk=svmkernel(x1,kernel,kerneloption,x2);
                %-----------------------------------------------------------  
                %kchunk=svmkernel(x(ind1,:),kernel,kerneloption,x(changedSV(ind2),:)); 
                %-----------------------------------------------------------  
                coeff=changedAlpha(ind2).*y(changedSV(ind2));
                
                s(ind1)=s(ind1)+ kchunk*coeff;
            end;
        end
        
    end;    
    
    %
    %  calcul du biais du SVM que sur l'ensemble du working set et
    %  SVnonbound
    
%         indworkingSV= find(SV& workingset);
%         if ~isempty(indworkingSV)
%           %  bias= mean( y(indworkingSV)-s(indworkingSV) );   
%             bias= mean( y(indworkingSV)-s(indworkingSV) - y(indworkingSV).*alpha(indworkingSV)/c );
%         end;
        
        
        if sum(SV)>0
          %  bias= mean( y(indworkingSV)-s(indworkingSV) );   
            bias= mean( y(SV)-s(SV) - y(SV).*alpha(SV)/c );
        else
            bias=0;
        end;

   % bias=lambdab; % this is the Lagrange multiplier of the equality constraints of monqp
    
    
    %
    %  KKT Conditions
    %
    
    kkt=(s+bias).*y - 1;
      testkkt= abs(abs(kkt)-alpha/c);  %  check the equation of KKT for this condition.
    kktviolation=   (SV   & ( testkkt> kkttol) )|( ~SV & (kkt < -kkttol));
%  kktviolation=   (SV   & ( testkkt> kkttol) )|( ~SV & (kkt +alpha/c< -kkttol));
%     testkkt=-kkt -alpha/c;
%     kktviolation1= (SV   & ( testkkt> kkttol) );
%     kktviolation2=( ~SV & (kkt < - kkttol));
%    kktviolation=   kktviolation1 | kktviolation2;
%     
    if sum(kktviolation)==0
        break;   %  c'est fini tout 
    end;
    
    
    
    %
    %   Calcul du nouveau working set
    %
    
    if iteration==0
        searchdir=rand(n,1);
        set1=class1;
        set2=class0;
    end;
    
    
    
    oldworkingset=workingset;
    workingset=zeros(n,1);
    n1=sum(set1);
    n2=sum(set2);
        
    %         indpos=find(y==1);
    %         indneg=find(y==-1);
    %         
    %         
    %         % ici on fait un tirage al閍toire parmi tout les points!! c vraiment
    %         % tout pourri.
    %         RandIndpos=randperm(length(indpos));
    %         RandIndneg=randperm(length(indneg));
    %         nbpos=min(length(indpos),round(qpsize/2));
    %         nbneg=min(length(indneg),round(qpsize/2));
    %         ind=[indpos(RandIndpos(1:nbpos));indneg(RandIndneg(1:nbneg))];
    %         workingset(ind)=ones(length(ind),1);
    %         
    
    indkktviolation=find(kktviolation);
    nbkktviolation=sum(kktviolation);
    if qpsize==n
        workingset=ones(n,1);
%     elseif nbkktviolation <=qpsize
%         nbOK=qpsize-nbkktviolation;
%         indOK=find(~kktviolation);
%         indiceOK=randperm(length(indOK));
%         
%         ind=[indkktviolation; indOK(indiceOK(1:nbOK))];
%         workingset(ind)=ones(length(ind),1);
    else
        
        indposkktviol= find(y==1 & kktviolation);
        indposkktviol=indposkktviol(randperm(length(indposkktviol)));
        indnegkktviol= find(y==-1 & kktviolation);
        indnegkktviol=indnegkktviol(randperm(length(indnegkktviol)));
        indposOK= find(y==1 &  ~kktviolation);
        indnegOK= find(y==-1 &  ~kktviolation);
        nbposViol=min(length(indposkktviol),round(qpsize/2));
        nbnegViol=min(length(indnegkktviol),round(qpsize/2));
        nbposOK=min(qpsize/2-nbposViol,length(indposOK));;
        nbnegOK=min(qpsize/2-nbnegViol,length(indposOK));
        ind=[indposkktviol(1:nbposViol);indposOK(1:nbposOK) ;indnegkktviol(1:nbnegViol);indnegOK(1:nbnegOK)];
        workingset(ind)=ones(length(ind),1);
        
        %     indkktviolation=find(kktviolation);
        %     nbkktviolation=length(indkktviolation);
        %             randomindice=randperm(nbkktviolation);
        %         workingset(indkktviolation(randomindice(1:min(nbkktviolation,qpsize))))=1;
        
        
    end;
    if all( abs(oldworkingset-workingset) < difftol)
        indpos=find(y==1);
        indneg=find(y==-1);
        %keyboard
        
        % ici on fait un tirage al閍toire parmi tout les points!! c vraiment
        % tout pourri.
        RandIndpos=randperm(length(indpos));
        RandIndneg=randperm(length(indneg));
        nbpos=min(length(indpos),round(qpsize/2));
        nbneg=min(length(indneg),round(qpsize/2));
        ind=[indpos(RandIndpos(1:nbpos));indneg(RandIndneg(1:nbneg))];
        workingset(ind)=ones(length(ind),1);
        
    end;  
    
    
    indworkingset=find(workingset);
    workingsize=length(indworkingset);
    nws=~workingset;
    indnws= find(nws);
    
    
    
    
    %
    %   Resolution du QP probleme sur le nouveau Working set
    %
    
    % le calcul de Qbn*alphan ne fait intervenir que les donn茅es aux alphan non nulles et les donn茅es de la working
    % set
    
    
    nwSV= (nws & SV);
    indnwSV=find(nwSV);
    Qbnalphan=0;
    if length(indnwSV)>0
        
        chunks=ceil(length(indnwSV)/chunksize);
        for ch=1:chunks
            ind=(1+(ch-1)*chunksize ): min( length(indnwSV), ch*chunksize);
            %-----------------------------------------------------------                
            if ~isfield(x,'datafile')
                x1=x(indworkingset,:);
                x2=x(indnwSV(ind),:);
            else
                x1=fileaccess(x.datafile,x.indice(indworkingset),x.dimension);
                x2=fileaccess(x.datafile,x.indice(indnwSV(ind)),x.dimension);
                
            end;   
            pschunk=svmkernel(x1,kernel,kerneloption,x2);
            %-----------------------------------------------------------  
            % pschunk=svmkernel(x(indworkingset,:),kernel,kerneloption,x(indnwSV(ind),:));
            %-----------------------------------------------------------  
            
            
            
            Qbnalphan=Qbnalphan + y(indworkingset).*(pschunk*(alpha(indnwSV(ind)).*y(indnwSV(ind))));
        end;
        e= - (Qbnalphan - ones(workingsize,1));
        
    else
        e=ones(workingsize,1);
    end;
    
    %-----------------------------------------------------------  
    % Calcul de la matrice Hbb
    %-----------------------------------------------------------    
    yb=y(indworkingset);
    if ~isfield(x,'datafile')
        psbb=svmkernel(x(indworkingset,:),kernel,kerneloption);
    else
        x1=fileaccess(x.datafile,x.indice(indworkingset),x.dimension);
        psbb=svmkernel(x1,kernel,kerneloption);
    end;
    Hbb=psbb.*(yb*yb')+1/c*eye(size(psbb));
    
    
    A=yb;
    if length(indnws)>0
        b=-alpha(indnws)'*y(indnws);
    else
        b=0;
    end;
    
    cinfty=+inf;
    [alphab,lambdab,pos]=monqp(Hbb,e,A,b,cinfty,lambda,0);%,psbb);
    %     [alphab,lambdab,pos]=monqpCinfty(Hbb,e,A,b,lambda,0);%,psbb);
    alphaold=alpha;
    aux=zeros(workingsize,1);
    aux(pos)=alphab;
    alpha(indworkingset)=aux;
    iteration=iteration+1;
    if verbose >0
        obj= 0.5*aux'*Hbb*aux- aux'*e;
        fprintf('i: %d number changedAlpha : %d  Nb KKT Violation: %d Objective Val:%f\n',iteration,length(find( abs(alpha-alphaold)> difftol)),sum(kktviolation),obj);
    end;
    if sum(kktviolation) < maxqpsize
        qpsize=maxqpsize;
        chunksize=maxqpsize;
    end;
end;


SV=(abs(alpha)>difftol);


pos=find(SV);

if ~isfield(x,'datafile')
    xsup = x(pos,:);
else
    xsup=x;
    xsup.indice=x.indice(pos);
end;
ysup = y(pos);
w = (alpha(pos).*ysup);

indworkingSV= find(SV& workingset);
% if ~isempty(indworkingSV)
%    % bias= mean( y(indworkingSV)-s(indworkingSV) );   
%     bias= mean( y(indworkingSV)-s(indworkingSV) - y(indworkingSV).*alpha(indworkingSV)/c );
% end;

        if ~isempty(SV)
          %  bias= mean( y(indworkingSV)-s(indworkingSV) );   
            bias= mean( y(SV)-s(SV) - y(SV).*alpha(SV)/c );
        else
            bias=0;
        end;

b = bias;
timeps=[];
alpha=alpha(pos);

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -