⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 m2o_sor.c

📁 支持向量机工具箱
💻 C
字号:
/*---------------------------------------------------------------------------[Alpha,bias,kercnt] = m2o_sor(data,labels,ker,arg,C,eps)M2O_SOR Multi-class translated to one-class SVM and solved by  the SOR algorithm. Statistical Pattern Recognition Toolbox, Vojtech Franc, Vaclav Hlavac (c) Czech Technical University Prague, http://cmp.felk.cvut.cz Modifications
 9-july-2002, VF 25-Nov-2001, V. Franc-------------------------------------------------------------------- */#include "mex.h"#include "matrix.h"#include <math.h>#include <stdlib.h>#include <string.h>#include <limits.h>#include "kernel.h"#define MINUS_INF INT_MIN#define PLUS_INF  INT_MAX/* case insensitive string comparision */#ifdef __BORLANDC__   #define STR_COMPARE(A,B,C)      strncmpi(A,B,C)  /* Borland */#else  #define STR_COMPARE(A,B,C)      strncmp(A,B,C) /* Linux gcc */#endif#define ABS(A)     (A < 0 ) ? (-A) : (A)#define MAX(A,B)   (((A) > (B)) ? (A) : (B) )#define MIN(A,B)   (((A) < (B)) ? (A) : (B) )#define KDELTA(A,B) (A==B)#define KDELTA4(A1,A2,A3,A4) ((A1==A2) || (A1==A3) || (A1==A4) || (A2==A3) || (A2==A4) || (A3==A4))/* -- Global variables ----------------- */int *labels; double *stop;long tmax;double C;double *alpha;double bias;long num_data;int solution;long t;long N, M;    // number of data, number of classesdouble *error_cache;double eps = 0.001;
double tol = 0.001;/*-----------------------------------------------*/void get_indices2( long *index, long *class, long i ){   *index = i / (M-1);    *class = (i % (M-1))+1;   if( *class >= labels[ *index ]) (*class)++;   return;}/* ------------------------------------------------------------  Kernel function ------------------------------------------------------------ */double kernel_fce( long a, long b ){  double value;  long i1,c1,i2,c2;  // 1 class  //  value = kernel( a, b );  // 2 classes with bias  // value = kernel( a, b ) + 1;  // if( labels[a] != labels[b] ) value *= -1;  // multiclass case  get_indices2( &i1, &c1, a );  get_indices2( &i2, &c2, b );  if( KDELTA4(labels[i1],labels[i2],c1,c2) ) {    value = (+KDELTA(labels[i1],labels[i2])              -KDELTA(labels[i1],c2)             -KDELTA(labels[i2],c1)             +KDELTA(c1,c2)            )*(kernel( i1, i2 )+1);  }  else  {    value = 0;  }   return( value );}/*-----------------------------------*/double learned_func( long k){  double result = 0.;  long i;  for( i=0; i < num_data; i++ ) {    if( alpha[i] > 0) {      result = result + alpha[i]*kernel_fce(i,k);    }  }        return( result );}/* ------------------------------------------------------------ SOR
------------------------------------------------------------ */void runSor( void )
{
   double omega = 1;
   double ndelta;
   double delta;
   double oldalpha;
   long i, j, t;

   //   mexPrintf("eps %f\n", eps);
   
   t=0;
   do {      
      t++;

      ndelta = 0;
      for( i = 0; i < num_data; i++ ) {
            
         delta=0;
         for(j=0; j < num_data; j++ ) {
            if( alpha[j] != 0) delta += alpha[j]*kernel_fce(i, j);
         }

         delta = omega*(delta - 1)/kernel_fce(i,i);

         oldalpha = alpha[i];            
         alpha[i] -= delta;
         if( alpha[i] < 0 ) alpha[i] = 0; else if( alpha[i] > C ) alpha[i] = C;
            
         ndelta += (alpha[i]-oldalpha)*(alpha[i]-oldalpha);
      }
      
   } while(eps < ndelta && t < 50000);
   //   mexPrintf("t %d, ndelta %f\n", t, ndelta );
}
/* ------------------------------------------------------------  Onec------------------------------------------------------------ */void runOnec( void ){  long numChanged = 0;  long examineAll = 1;  long k;  // allocates cache  if( (error_cache = mxCalloc(num_data, sizeof(double))) == NULL) {      mexErrMsgTxt("Not enough memory for error cache.");  }  while( numChanged > 0 || examineAll != 0 )  {    numChanged = 0;    if( examineAll != 0 ) {      for(k=0; k < num_data; k++ ) {         numChanged = numChanged + examineExample(k);      }    }    else {      for(k=0; k < num_data; k++ ) {         if( alpha[k] != 0 & alpha[k] != 0 ) {            numChanged = numChanged + examineExample(k);         }      }    }    if( examineAll == 1) {      examineAll = 0;    }    else {       if( numChanged == 0 ) {         examineAll = 1;       }    }  }  mxFree( error_cache );}/* ----------------------------------------------------------- EXAMINESTEP function ------------------------------------------------------------*/int examineExample( long i1){  double alpha1;  double E1;  alpha1 = alpha[i1];         if( alpha1 > 0 & alpha1 < C) {    E1 = error_cache[i1];     }  else   {    E1 = learned_func(i1);  }      if(((E1-1) < -tol & alpha1 < C) | ((E1-1) > tol & alpha1 > 0))   {    return( takeStep( i1 ) );  }  else {    return(0);  }}/* ---------------------------------------------------------  TAKESTEP function  -----------------------------------------------------------*/int takeStep( long i1){  double alpha1, a1, E1, k11;  long i;  alpha1 = alpha[i1];   if( alpha1 > 0 & alpha1 < C) {    E1 = error_cache[i1];  }  else   {    E1 = learned_func(i1);  }   k11 = kernel_fce( i1, i1 );  a1 = alpha1 + (1 - E1)/k11;  if( ABS(a1-alpha1)< eps*(a1+alpha1+eps)) return( 0 );  if( a1 > C ) {     a1 = C;  }  else if( a1 < 0 ) {    a1 = 0;  }  for( i=0; i < num_data; i++ ) {    if( 0 < alpha[i] & alpha[i] < C ) {          error_cache[i] = error_cache[i] +          kernel_fce(i1,i)*(a1-alpha1);    }  }  if( !(alpha1 > 0 & alpha1 < C)) {     error_cache[i1] = E1 + k11*a1;  }  alpha[i1] = a1;  return( 1 );}/* ------------------------------------------------------------  Kernel S-K algorithm------------------------------------------------------------ */void runKernelSK( void ) {  double *wx;   // dot products <w,x_i>  double w2;    // dot products <w,w>  long inx, i;  double min_wx;  double q;  // allocates cache  if( (wx = mxCalloc(num_data, sizeof(double))) == NULL) {      mexErrMsgTxt("Not enough memory for error cache.");  }  // inicialization  alpha[0] = 1;  w2 = kernel_fce( 0, 0);  for( i = 0; i < num_data; i++ ) {    wx[i] =kernel_fce( 0, i);  }    solution = 0;  t = 0;  // main optimization cycle  while( solution == 0 && tmax > t )  {     t++;     // finds index of x with minimum <w,x>     for( min_wx = PLUS_INF, inx = -1, i = 0; i < num_data; i++ ) {        if( min_wx > wx[i] ) {           min_wx = wx[i];           inx = i;        }     }        if( sqrt(w2) - min_wx/sqrt( w2 ) > stop[0] ) {           q = MIN(1, (w2 - min_wx)/(w2 - 2*min_wx + kernel_fce(inx,inx) ));              w2 = w2*(1-q)*(1-q) + 2*(1-q)*q*wx[inx] + q*q * kernel_fce(inx,inx);             for( i=0; i < num_data; i++ ) {          if( alpha[i] ) alpha[i] *= 1-q;          wx[i] = wx[i]*(1-q) + q*kernel_fce( i, inx );        }        alpha[inx] += q;       }     else     {        solution = 1;     }  }  mxFree( wx );}/* ============================================================== Main MEX function - interface to Matlab.============================================================== */void mexFunction( int nlhs, mxArray *plhs[],		  int nrhs, const mxArray*prhs[] ){  char ker_id[10];  long i,j ;  double *oarg0, *oarg1;  long i1, c1;  double *tmp;  long nsv, ansv;  /* -- gets input arguments --------------------------- */  dataA = mxGetPr(prhs[0]);  /* pointer at patterns */  dataB = dataA;  dim = mxGetM(prhs[0]);            /* data dimension */  N = mxGetN(prhs[0]);              /* number of data */  tmp = mxGetPr(prhs[1]);  if( (labels = mxCalloc(N, sizeof(int))) == NULL) {      mexErrMsgTxt("Not enough memory for error cache.");  }  for( i = 0; i < N; i++ ) {    labels[i] = (int)tmp[i];  }  /* kernel identifier*/  mxGetString( prhs[2], ker_id, 10 );    if( STR_COMPARE( ker_id, "linear", 6) == 0 ) {     ker = 0;  } else if( STR_COMPARE( ker_id, "poly", 4) == 0 ) {     ker = 1;  } else if( STR_COMPARE( ker_id, "rbf", 3) == 0 ) {     ker = 2;  } else     mexErrMsgTxt("Unknown kernel identifier.");  /* take kernel argument */  arg1 = mxGetScalar(prhs[3]);      /* if kernel is RBF than recompute its argument */  if( ker == 2) arg1 = -2*arg1*arg1;  /* regularization constant */  C = mxGetScalar(prhs[4]);  /* eps a tol parameter */  if( nrhs >= 6 ) eps = mxGetScalar(prhs[5]);  if( nrhs >= 7 ) tol = mxGetScalar(prhs[6]);  /* -- Inicialization ---------------------------- */  ker_cnt = 0;   /* counter for number of kernel evaluations */  // gets number of classes = max class label, min c.l. = 1   M = MINUS_INF;   for( i = 0; i < N; i++ ) {      if( labels[i] > M ) M = labels[i];   }  // num of transformed data  num_data = (M-1)*N;  /* -- calls Kernel S-K ---------------------------- */   /* create vector for Lagrangeians */  if( (alpha = mxCalloc(num_data, sizeof(double))) == NULL) {      mexErrMsgTxt("Not enough memory for error cache.");   }  // run the main algorithm //  runOnec();
  runSor();    /* -- sets up outputs ------------------------------- */  // Output   plhs[0] = mxCreateDoubleMatrix(M,N,mxREAL);  oarg0 = mxGetPr(plhs[0]);  plhs[1] = mxCreateDoubleMatrix(M,1,mxREAL);  oarg1 = mxGetPr(plhs[1]);  for(i=0; i < M; i++ ) {    for( j=0; j < num_data; j++ ) {       get_indices2( &i1, &c1, j );       oarg0[(i1*M)+i] += alpha[j]*(KDELTA(labels[i1],i+1)+KDELTA(i+1,c1));       oarg1[i] += alpha[j]*(KDELTA(labels[i1],i+1)-KDELTA(i+1,c1));    }  }  //  for(i = 0; i < num_data; i++ ) {  //   oarg1[i] = alpha[i];  //  }  // bias    //  for(bias =0, i = 0; i < num_data; i++ ) {  //   if(labels[i]==1) bias += alpha[i]; else bias -= alpha[i];  //  }  //  mexPrintf("bias=%f\n",bias );  // exit result  plhs[2] = mxCreateDoubleMatrix(1,1,mxREAL);  *((double*)mxGetPr(plhs[2])) = ker_cnt;  /* -- free memory ----------------- */  mxFree( alpha );  mxFree( labels );}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -