⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 kernelskf.c~

📁 支持向量机工具箱
💻 C~
字号:
/*---------------------------------------------------------------------------[Alpha,bias,sol,t,kercnt,margin,trnerr]=   kernelskf(data,labels,stop,ker,arg,tmax,C) KERNELSKF kernel Schlesinger-Kozinec's algorithm.  It solves the Support vector Machines problem with quadratic  cost function for classification violations. Inputs:   data [dim x N] training patterns   labels [1 x N] labels of training patterns   stop [1 x 2] if stop(1) == 1 then stopping condition m*-m < stop(2)      is used else stopping condition  (m*-m)/m < stop(2) is used.      Where m* is the optimial margin and m is the margin of found     hyperplane (in the given feature space).   ker [string] kernel, see 'help kernel'.   arg [...] argument of given kernel, see 'help kernel'.   tmax [int] maximal number of iterations.   C [real] trade-off between margin and training error.   Outputs:   Alpha [1xN] Lagrangians defining found decision rule.   bias [real] bias (threshold) of found decision rule.   sol [int] 1 solution is found             0 algorithm stoped (t == tmax) before converged.            -1 hyperplane with margin greater then epsilon                does not exist.   t [int] number of iterations.   kercnt [int] number of kernel evaluations.   margin [real] margin between classes.   trnerr [real] training error. See also SVM. Statistical Pattern Recognition Toolbox, Vojtech Franc, Vaclav Hlavac (c) Czech Technical University Prague, http://cmp.felk.cvut.cz Written Vojtech Franc (diploma thesis) 02.11.1999, 13.4.2000 Modifications  19-Nov-2001, V.Franc  13-Nov-2001, V.Franc  12-Nov-2001, V.Franc  5-Nov-2001, V.Franc  4-Nov-2001, V.Franc, created-------------------------------------------------------------------- */#include "mex.h"#include "matrix.h"#include <math.h>#include <stdlib.h>#include <string.h>#include <limits.h>#include "kernel.h"#define MINUS_INF INT_MIN#define PLUS_INF  INT_MAX/* case insensitive string comparision */#ifdef __BORLANDC__   #define STR_COMPARE(A,B,C)      strncmpi(A,B,C)  /* Borland */#else  #define STR_COMPARE(A,B,C)      strncasecmp(A,B,C) /* Linux gcc */#endif#define MAX(A,B)   (((A) > (B)) ? (A) : (B) )#define MIN(A,B)   (((A) < (B)) ? (A) : (B) )/* ============================================================== Main MEX function - interface to Matlab.============================================================== */void mexFunction( int nlhs, mxArray *plhs[],		  int nrhs, const mxArray*prhs[] ){   char skernel[10];   long t;              /* iteration number */   long i, j;           /* loop variables */   int sol;             /* solution: 1=found, 0=not found, -1=does not exist, ... */   double a, b, c;      /* auxiliary variables */   double emin, gmin;   /* --//--              */   long inx1, inx2;     /* --//--              */   double proj1, proj2; /* --//--              */   double *w1x;         /* <w1,x>; cache       */   double *w2x;         /* <w2,x>; cache       */   double k;            /* --//--              */    double margin2;      /* squared margin in the feature space */   double normw;        /* ||w1-w2|| */   double *labels;      /* pointer at labels */   long N;              /* number of training patterns */   double *stop;        /* stop[0] stopping criterion and stop[1] its argument */   long tmax;           /* maximal number of iterations */   double C;            /* trade-off constant */   double kadd;         /* diagonal additional term */   double *alpha;       /* Lagrangians */   double *bias;        /* threshold of the learned indicator function */   double margin;       /* margin in the original space */   double trn_err;      /* training error */   double dfun;         /* value of decision function */   /* ---- CHECK INPUT ARGUMENTS  ----------------------- */   if(nrhs < 7)      mexErrMsgTxt("Not enough input arguments.");   if(nlhs < 3)      mexErrMsgTxt("Not enough output arguments.");   /* data matrix [dim x N ] */   if( !mxIsNumeric(prhs[0]) || !mxIsDouble(prhs[0]) ||       mxIsEmpty(prhs[0])    || mxIsComplex(prhs[0]) )      mexErrMsgTxt("Input X must be a real matrix.");   /* labels [1 x N ] */   if( !mxIsNumeric(prhs[1]) || !mxIsDouble(prhs[1]) ||       mxIsEmpty(prhs[1])    || mxIsComplex(prhs[1]) )      mexErrMsgTxt("Input I must be a real vector.");   /*  stopping condition */   if( !mxIsNumeric(prhs[2]) || !mxIsDouble(prhs[2]) ||       mxIsEmpty(prhs[2])    || mxIsComplex(prhs[2]) ||       mxGetN(prhs[2]) != 2  || mxGetM(prhs[2]) != 1 )      mexErrMsgTxt("Input stop must be a vector [1 x 2].");   /* a string as kernel identifier ('linear',poly','rbf' ) */   if( mxIsChar(prhs[3]) != 1 || mxGetM(prhs[3]) != 1 )      mexErrMsgTxt("Input ker must be a string");   else {       /* check which kernel  */       mxGetString( prhs[3], skernel, 10 );       if( STR_COMPARE( skernel, "linear", 6) == 0 ) {          ker = 0;       } else if( STR_COMPARE( skernel, "poly", 4) == 0 ) {          ker = 1;       } else if( STR_COMPARE( skernel, "rbf", 3) == 0 ) {          ker = 2;       } else          mexErrMsgTxt("Unknown kernel identifier.");   }    /*  real input argument for polynomial and rbf kernel   */   if( ker == 1 || ker == 2) {      if( !mxIsNumeric(prhs[4]) || !mxIsDouble(prhs[4]) ||         mxIsEmpty(prhs[4])    || mxIsComplex(prhs[4]) ||         mxGetN(prhs[4]) != 1  || mxGetM(prhs[4]) != 1 )         mexErrMsgTxt("Input arg must be a real scalar.");      else {         arg1 = mxGetScalar(prhs[4]);  /* take kernel argument */         /* if kernel is RBF than recompute its argument */         if( ker == 2) arg1 = -2*arg1*arg1;      }   }   /*  tmax  */   if( !mxIsNumeric(prhs[5]) || !mxIsDouble(prhs[5]) ||       mxIsEmpty(prhs[5])    || mxIsComplex(prhs[5]) ||       (mxGetN(prhs[5]) != 1  && mxGetM(prhs[5]) != 1 ))      mexErrMsgTxt("Input tmax must be an integer.");   /*  one or two real trade-off*/   if( !mxIsNumeric(prhs[6]) || !mxIsDouble(prhs[6]) ||       mxIsEmpty(prhs[6])    || mxIsComplex(prhs[6]) ||       (mxGetN(prhs[6]) != 1  && mxGetM(prhs[6]) != 1 ))      mexErrMsgTxt("Input C must be a real scalar.");   /* ---- GET INPUT ARGUMENTS ------------------------------- */   dataA = mxGetPr(prhs[0]);  /* pointer at patterns */   dataB = mxGetPr(prhs[0]);  /* pointer at patterns */   labels = mxGetPr(prhs[1]); /* pointer at labels */   dim = mxGetM(prhs[0]);     /* data dimension */   N = mxGetN(prhs[0]);       /* number of data */   stop = mxGetPr(prhs[2]);   if( mxIsInf( mxGetScalar(prhs[5])) ) {      tmax = INT_MAX;   } else {     tmax = (long)mxGetScalar(prhs[5]);   }   C = mxGetScalar(prhs[6]);   // computes additional term to kernel value on the diagonal   if( C != 0 ) kadd = 1/(2*C); else kadd = 0;    /* create vector for Lagrangeians */   plhs[0] = mxCreateDoubleMatrix(1,N,mxREAL);   alpha = mxGetPr(plhs[0]);   /*-- INICIALIZATION ------------------------------*/   ker_cnt = 0;  /* counter for number of kernel evaluetions */   // inicialization of cached values   if( (w1x = mxCalloc(N, sizeof(double))) == NULL) {      mexErrMsgTxt("Not enough memory for error cache.");   }   if( (w2x = mxCalloc(N, sizeof(double))) == NULL) {      mexErrMsgTxt("Not enough memory for error cache.");   }   // takes two vectors as an initial solution   for( inx1 = -1, inx2 = -1, i=0; i < N && (inx1==-1 || inx2==-1); i++ ) {     if( labels[i] == 1 && inx1 == -1) {        inx1 = i;        alpha[i] = 1;     } else if( labels[i] == 2 && inx2 == -1) {        alpha[i] = 1;         inx2 = i;     }   }   // inits cache values   for( i=0; i < N; i++ ) {      w1x[i] = kernel(inx1,i);       w2x[i] = kernel(inx2,i);    }   w1x[inx1] += kadd;   w2x[inx2] += kadd;   a = kernel( inx1, inx1) + kadd;    b = kernel( inx2, inx2) + kadd;    c = kernel( inx1, inx2);     sol=0;   t = 0;   /* -- MAIN OPTIMIZATION CYCLE ------------------------ */   while( sol == 0 && tmax > t )   {      t++;      sol = 1;         // -- compute auxciliary variables --         emin = PLUS_INF;     gmin = PLUS_INF;     for( i = 0; i < N; i++ )      {        if( labels[i]== 1 ) {           if( w1x[i] - w2x[i] < emin) {              emin = w1x[i] - w2x[i];              inx1 = i;           }         }         else {           if( w2x[i] - w1x[i] < gmin ) {              gmin = w2x[i] - w1x[i];              inx2 = i;           }        }     }       // normw = sqrt(<w1-w2,w1-w2>)     normw = sqrt( a-2*c+b );     // projection <x,(w1 - w2)> for x from X1     proj1 = (emin + b - c );       // projection <x,(w2 - w1)> for x from X2     proj2 = (gmin + a - c);        /* --- stoping condition for the 1st class ------ */    // (proj1 < proj2) ~ the worst point will be used for update     if( (proj1 < proj2 ) &&         //        ((stop[0]==2 && (1-(proj1)/(normw*normw)) >= stop[1] ) ||         ((stop[0]==2 && (1-(proj1)/(normw*normw)) >= stop[1]/2 ) ||          (stop[0]==1 && (normw-proj1/normw) >= stop[1] )         )       )     {           // -- Adaptation phase of vector alpha1 ----------------------------            k = (a - emin - c)/(a+kernel(inx1,inx1)+kadd-2*(w1x[inx1]-w2x[inx1]) );       k = MIN( 1, k );         sol = 0;       // -- UPDATE OF CACHED VALUES -----------------------------------       a = a*(1-k)*(1-k) + 2*(1-k)*k*w1x[inx1] + k*k * (kernel(inx1,inx1)+kadd );            c = c*(1-k) + k*w2x[inx1];       for( i=0; i < N; i++ ) {         if( labels[i] == 1 && alpha[i] ) alpha[i] *= 1-k;         w1x[i] = w1x[i]*(1-k) + k*kernel( i, inx1 );       }       w1x[inx1] += k*kadd;       alpha[inx1] += k;       }          else      {       // --- stopping condition for the 2nd class ------          //       if( (stop[0]==2 && 2*(1-proj1/(normw*normw)) >= stop[1] ) ||       //           (stop[0]==1 && (normw-proj1/normw) >= stop[1])        //       if( (stop[0]==2 && (1-(proj2)/(normw*normw)) >= stop[1] ) ||       if( (stop[0]==2 && (1-(proj2)/(normw*normw)) >= stop[1]/2 ) ||           (stop[0]==1 && (normw-proj2/normw) >= stop[1])          )       {         // -- Adaptation phase ----------------------------------         k = (b - gmin -c)/(b+kernel(inx2,inx2)+kadd-2*(w2x[inx2]-w1x[inx2]));          k = MIN( 1, k );             sol = 0;         // -- UPDATE OF CACHES ---------------------------------------         b = b*(1-k)*(1-k) + 2*(1-k)*k*w2x[inx2] + k*k*(kernel(inx2,inx2)+kadd );         c = c*(1-k) + k*w1x[inx2];                  for(i = 0; i < N; i++ ) {           w2x[i] = (1-k)*w2x[i] + k*kernel(i, inx2);           if( labels[i] == 2 && alpha[i]) alpha[i] *= 1-k;         }         alpha[inx2] += k;         w2x[inx2] += k*kadd;           }           }         if( sqrt( a -2*c +b ) <= 0 ) {        // algorithm has converged to the zero vector --> classes overlap        sol = -1;     }   }  // while(...)   if( sol == 1 && (proj1 < 0 || proj2 < 0) ) {      sol = 0;   }   /* --- COMPUTATION OF OUTPUT VALUES ----------------------- */   // sqared margin in transformed space   margin2 = a - 2*c + b;    // threshold after normalization   plhs[1] = mxCreateDoubleMatrix(1,1,mxREAL);   bias = mxGetPr(plhs[1]);   *bias = (b-a)/margin2;//  mexPrintf("f1=%f, f2=%f\n", (2*emin+b-a)/(normw*normw), //   (2*gmin+a-b)/(normw*normw)); //  mexPrintf("0.5*(min <w,x1> - max<w,x2>)/|w|=%f\n", 0.5*(emin+gmin)/sqrt(margin2));    // solution (normal vect. in the transformed space) after normalization   for( i=0; i < N; i++ ) {     alpha[i] *= 2/margin2;   }   // training errors   if( nlhs >= 7 )    {          trn_err = 0;     for(i = 0; i < N; i++ ) {       dfun = 0;       for( j=0; j < N; j++ ) {         if( alpha[j] != 0 ) {            if( labels[j] == 1)               dfun += alpha[j]*kernel(i,j);             else              dfun -= alpha[j]*kernel(i,j);         }       }       if( (3-labels[i]*2)*(dfun + *bias) < 0) trn_err++;     }     plhs[6] = mxCreateDoubleMatrix(1,1,mxREAL);     (*mxGetPr(plhs[6])) = trn_err/N;   }   // compute margin    if( nlhs >= 6 ) {     margin = 0;     margin = 0;     for(i = 0; i < N; i++ ) {        for( j=0; j < N; j++ ) {          if( alpha[i] != 0 && alpha[j] != 0 ) {            if( labels[i] == labels[j] )               margin += alpha[i]*alpha[j]*kernel(i,j);              else               margin -= alpha[i]*alpha[j]*kernel(i,j);            }       }     }     margin = 1/sqrt(margin);      plhs[5] = mxCreateDoubleMatrix(1,1,mxREAL);     (*mxGetPr(plhs[5])) = margin;   }   // number of kernel evaluations   if( nlhs >= 5 ) {     plhs[4] = mxCreateDoubleMatrix(1,1,mxREAL);     (*mxGetPr(plhs[4])) = (double)ker_cnt;   }   // solution 1 (found), 0 (not found), -1 (does not exist)   if( nlhs >= 3 ) {     plhs[2] = mxCreateDoubleMatrix(1,1,mxREAL);     (*mxGetPr(plhs[2])) = (double)sol;   }   // number of iterations   if( nlhs >= 4 ) {     plhs[3] = mxCreateDoubleMatrix(1,1,mxREAL);     (*mxGetPr(plhs[3])) = (double)t;   }   /* ----- FREE MEMORY ----------------------- */   mxFree( w1x );   mxFree( w2x ); }

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -