⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 nsvm.m

📁 支持向量机一个很好的
💻 M
字号:
function [w,gamma,trainCorr, testCorr, cpu_time, nu]=nsvm(A,d,k,nu,output,wp1,arm)% version 1.5% last revision: 07/07/03%=====================================================================% Usage: [w, gamma, train, test, time, nu]=nsvm(A,d,k,nu,output,wp1,arm);%% Input parameters: %    A: Data points%    d: 1's or -1's%    k: way to divide the data set into test and training set%       if k = 0: simply run the algorithm without any correctness%         calculation, this is the default%       if k = 1: run the algorithm and calculate correctness on%         the whole data set%       if k = any value less than the # of rows in the data set:%         divide up the data set into test and training%         using k-fold method%       if k = # of rows in the data set: use the 'leave one out' method%%       nu:             weighted parameter%                       -1 - easy estimation%                       0  - hard estimation%                       any other value - used as nu by the algorithm%                       default - 0%       pw1: percentage of weight for class 1 , pw must be 0<=pw1<=1  %       the precentage for class -1 is 1-pw1 (1 is more important,%       0 is less important					      % %%    output: 0 - no output, 1 - produce output, default is 0%    arm:   1 - use armijo, 0 - otherwise, default is 0%%  Output parameters:%     w:   the normal vector of the classifier%     gamma: the threshold%     trainCorr:      training set correctness%     testCorr:       test set correctness%     cpu_time:       time elapsed%     nu:             the value used for nu%===========================================================if nargin<7arm=0;endif nargin<6wp1=.5;endif nargin<5output=0;endif ((nargin<4)|(nu==0))     nu = EstNuLong(A,d);  % default is hard estimationelseif nu==-1  % easy estimationnu = EstNuShort(A,d);endif nargin<3k=0;endr=randperm(size(d,1));d=d(r,:);A=A(r,:);    % random permutationtrainCorr=0;testCorr=0;if k==0tic;  [w, gamma,iter] = core(A,d,nu,arm,wp1);  cpu_time=toc;  if output==1  fprintf(1,'\nNumber of Iterations: %d',iter);  fprintf(1,'\nElapsed time: %10.2f\n\n',cpu_time);  end  returnend%if k==1 only training set correctness is calculatedif k==1tic;[w, gamma,iter] = core(A,d,nu,arm,wp1);trainCorr = correctness(A,d,w,gamma);cpu_time = toc;  if output == 1fprintf(1,'\nTraining set correctness: %3.2f%%',trainCorr);fprintf(1,'\nNumber of Iterations: %d',iter);fprintf(1,'\nElapsed time: %10.2f\n\n',cpu_time);  end  returnend[sm sn]=size(A);accuIter = 0;lastToc=0;    % used for calculating timeindx = [0:k];indx = floor(sm*indx/k);    %last row numbers for all 'segments'% split trainining set from test settic;for i = 1:kCtest = []; dtest = [];Ctrain = []; dtrain = [];Ctest = A((indx(i)+1:indx(i+1)),:);dtest = d(indx(i)+1:indx(i+1));Ctrain = A(1:indx(i),:);Ctrain = [Ctrain;A(indx(i+1)+1:sm,:)];dtrain = [d(1:indx(i));d(indx(i+1)+1:sm,:)];   [w, gamma,iter] = core(Ctrain,dtrain,nu,arm,wp1);tmpTrainCorr = correctness(Ctrain,dtrain,w,gamma);tmpTestCorr = correctness(Ctest,dtest,w,gamma); if output==1fprintf(1,'________________________________________________\n');fprintf(1,'Fold %d\n',i);fprintf(1,'Training set correctness: %3.2f%%\n',tmpTrainCorr);fprintf(1,'Testing set correctness: %3.2f%%\n',tmpTestCorr);fprintf(1,'Number of iterations: %d\n',iter);fprintf(1,'Elapsed time: %10.2f\n',toc-lastToc);lastToc=toc;endtrainCorr = trainCorr + tmpTrainCorr;testCorr = testCorr + tmpTestCorr;accuIter = accuIter + iter; % accumulative iterationsend % end of for (looping through test sets)     trainCorr = trainCorr/k;     testCorr = testCorr/k;     cpu_time=toc/k;if output == 1     fprintf(1,'==============================================');fprintf(1,'\nTraining set correctness: %3.2f%%',trainCorr);fprintf(1,'\nTesting set correctness: %3.2f%%',testCorr);fprintf(1,'\nAverage number of iterations: %d',accuIter/k);fprintf(1,'\nAverage cpu_time: %10.2f\n',cpu_time);endreturn;  % nsvm function return%%%%%%%%%%% core calculation function %%%%%%%%%%%%%%%%%%%%%function [w, gamma, iter] = core(A,d,nu,arm,wp1);[m,n]=size(A);if m>=n    [w,gamma,iter]=nsvm_with_smw(A,d,nu,arm,wp1);else    [w,gamma,iter]=nsvm_without_smw(A,d,nu,arm,wp1);endreturn%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%        NSVM for m>=n                                           %    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%function [w,gamma,iter]=nsvm_with_smw(A,d,nu,arm,wp1);% with armijo and with SMWmaxIter=100;[m,n]=size(A);iter=0;u=zeros(m,1);e=ones(m,1);H=[spdiags(d,0,m,m)*A -d]; % When balance is on this only approximates the normalpha=1.1*((1/nu)+(norm(H',2)^2));if wp1==0.5   v=u/nu+H*(H'*u)-e;else   vt=ones(m,1);   vt(find(d==1))=(1-wp1)*ones(length(find(d==1)),1);   vt(find(d==-1))=wp1*ones(length(find(d==-1)),1);   v=vt.*u/nu+H*(H'*u)-e;end  hu=-max((v-alpha*u),0)+v;while norm(hu)>10^(-3) & (iter < maxIter)      iter=iter+1;    E=sign(max(v-alpha*u,0));    if wp1==.5       temp=(1./((alpha-(1/nu))*E+(1/nu)));        else       temp=(1./((alpha*E)-E.*vt*(1/nu)+vt.*(1/nu)));   end    G=spdiags(temp.*(1-E),0,m,m);    FHU=temp.*hu;    LD=H'*FHU;    GH=G*H;    SM=speye(n+1)+H'*GH;    R=chol(SM);    sol=R'\LD;    sol=R\sol;    delta=FHU-GH*sol;        if arm==1 %armijo step without the balance yet        lambda=1;        v1=v+e;        v2=v;        su=0.5*(u'*v1)-sum(u)+(1/(2*alpha))*(norm(max(-alpha*u+v2,0))^2-norm(v2)^2);                unew=u-lambda*delta;        v1=u/nu+H*(H'*unew);        v2=v1-e;        sunew=0.5*(unew'*v1)-sum(unew)+(1/(2*alpha))*(norm(max(-alpha*unew+v2,0))^2-norm(v2)^2);        while su-sunew < -(0.25)*lambda*hu                    lambda=0.1*lambda;                       unew=u-lambda*delta;            v1=u/nu+H*(H'*unew);            v2=v1-e;            sunew=0.5*(unew'*v1)-sum(unew)+(1/(2*alpha))*(norm(max(-alpha*unew+v2,0))^2-norm(v2)^2);        end            else        unew=u-delta;    end            u=unew;        if wp1==.5       v=u/nu+H*(H'*u)-e;   else       v=vt.*u/nu+H*(H'*u)-e;   end    hu=-max((v-alpha*u),0)+v;              endw=A'*(d.*u);gamma=-sum(d.*u);    return%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%        NSVM for m<n                                           %    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%function [w,gamma,iter]=nsvm_without_smw(A,d,nu,arm,wp1);maxIter=100;[m,n]=size(A);iter=0;u=zeros(m,1);e=ones(m,1);HH=diag(d)*(A*A'+1)*diag(d);if wp1==.5   Q=speye(m)/nu+HH;else   vt=ones(m,1);   vt(find(d==1))=(1-w1)*ones(length(find(d==1)));   vt(find(d==-1))=w1*ones(length(find(d==-1)));   Q=spdiags(vt,0,m,m)/nu+HH;endalpha=1.1*((1/nu)+norm(HH));v=Q*u-e;hu=-max((v-alpha*u),0)+v;while norm(hu)>10^(-3) & (iter < maxIter)    iter=iter+1;    dhu=sign(max(((Q-alpha*eye(m))*u-e),0));% the 1/2 thing    dhu=sparse((diag(1-dhu))*Q+diag(alpha*dhu));    delta=dhu\hu;        if arm==1 %armijo step        lambda=1;        v=Q*u;        ve=v-e;        su=0.5*(u'*v)-sum(u)+(1/(2*alpha))*(norm(max(-alpha*u+ve,0))^2-norm(ve)^2)                   unew=u-lambda*delta;        v=Q*unew;        ve=v-e;        sunew=0.5*(unew'*v)-sum(unew)+(1/(2*alpha))*(norm(max(-alpha*unew+ve,0))^2-norm(ve)^2);        at=0;        while (su-sunew < -(0.25)*lambda*hu) & (at<5)            at=at+1;            disp('armijo');            lambda=0.5*lambda;                       unew=u-lambda*delta;            v=Q*unew;            ve=v-e;            sunew=0.5*(unew'*v)-sum(unew)+(1/(2*alpha))*(norm(max(-alpha*unew+ve,0))^2-norm(ve)^2);        end        if at==5            unew=u-delta;        end        else        unew=u-delta;    end              u=unew;    v=Q*unew-e;    hu=-max((v-alpha*u),0)+v;               endw=A'*(d.*u);gamma=-sum(d.*u);return    %%%%%%%%%%%%%%%% correctness calculation %%%%%%%%%%%%%%%%function corr = correctness(AA,dd,w,gamma)p=sign(AA*w-gamma);corr=length(find(p==dd))/size(AA,1)*100;return%%%%%%%%%%%%%%EstNuLong%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% hard way to estimate nu if not specified by the userfunction value = EstNuLong(C,d)[m,n]=size(C);e=ones(m,1);H=[C -e];if m<201H2=H;d2=d;elser=rand(m,1);[s1,s2]=sort(r);H2=H(s2(1:200),:);d2=d(s2(1:200));endlamda=1;[vu,u]=eig(H2*H2');u=diag(u);p=length(u);yt=d2'*vu;lamdaO=lamda+1;cnt=0;while (abs(lamdaO-lamda)>10e-4)& (cnt<100)     cnt=cnt+1;     nu1=0;pr=0;ee=0;waw=0;     lamdaO=lamda;     for i=1:p     nu1= nu1 + lamda/(u(i)+lamda);pr= pr + u(i)/(u(i)+lamda)^2;ee= ee + u(i)*yt(i)^2/(u(i)+lamda)^3;waw= waw + lamda^2*yt(i)^2/(u(i)+lamda)^2;   end  lamda=nu1*ee/(pr*waw);endvalue =lamda;if cnt==100    value=1;end    return%%%%%%%%%%%%%%%%%EstNuShort%%%%%%%%%%%%%%%%%%%%%%%% easy way to estimate nu if not specified by the userfunction value = EstNuShort(C,d)value = 1/(sum(sum(C.^2))/size(C,2));return                

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -