⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 smo2.c

📁 支持向量机工具箱
💻 C
📖 第 1 页 / 共 2 页
字号:
      for( k0 = rand(), k = k0; k < N + k0; k++ ) {         i2 = k % N;#else      for( k = 0; k < N; k++) {         i2 = k;#endif         if( alpha[i2] > 0 ) {            if( takeStep(i1,i2) )               return( 1 );         }      }#ifdef RANDOM      for( k0 = rand(), k = k0; k < N + k0; k++ ) {         i2 = k % N;#else      for( k = 0; k < N; k++) {         i2 = k;#endif         if( takeStep(i1,i2) )            return( 1 );      }   } /* if( ... ) */   return( 0 );}/* -------------------------------------------------------------- Main SMO optimization cycle.-------------------------------------------------------------- */void runSMO( void ){   long numChanged = 0;   long examineAll = 1;   long k;   while( numChanged > 0 || examineAll ) {      numChanged = 0;      if( examineAll ) {         for( k = 0; k < N; k++ ) {            numChanged += examineExample( k );         }      }      else {         for( k = 0; k < N; k++ ) {            if( alpha[k] > 0 )               numChanged += examineExample( k );         }      }      if( examineAll == 1 )         examineAll = 0;      else if( numChanged == 0 )         examineAll = 1;   }}/* ============================================================== Main MEX function - interface to Matlab.============================================================== */void mexFunction( int nlhs, mxArray *plhs[],		  int nrhs, const mxArray*prhs[] ){   long i,j ;   double *labels12, *initAlpha, *nsv, *trn_err, *margin;   double nerr;   char skernel[10];   double C;   /* ---- check number of input arguments  ------------- */   if(nrhs < 5)      mexErrMsgTxt("Not enough input arguments.");   if(nlhs < 2)      mexErrMsgTxt("Not enough output arguments.");   /* ---- check input arguments  ----------------------- */   /* data matrix [dim x N ] */   if( !mxIsNumeric(prhs[0]) || !mxIsDouble(prhs[0]) ||       mxIsEmpty(prhs[0])    || mxIsComplex(prhs[0]) )      mexErrMsgTxt("Input X must be a real matrix.");   /* vector of labels (1,2) */   if( !mxIsNumeric(prhs[1]) || !mxIsDouble(prhs[1]) ||       mxIsEmpty(prhs[1])    || mxIsComplex(prhs[1]) ||       (mxGetN(prhs[1]) != 1 && mxGetM(prhs[1]) != 1))      mexErrMsgTxt("Input I must be a real vector.");   /* a string as kernel identifier ('linear',poly','rbf' ) */   if( mxIsChar(prhs[2]) != 1 || mxGetM(prhs[2]) != 1 )      mexErrMsgTxt("Input ker must be a string");   else {       /* check which kernel  */       mxGetString( prhs[2], skernel, 10 );       if( STR_COMPARE( skernel, "linear", 6) == 0 ) {          ker = 0;       } else if( STR_COMPARE( skernel, "poly", 4) == 0 ) {          ker = 1;       } else if( STR_COMPARE( skernel, "rbf", 3) == 0 ) {          ker = 2;       } else          mexErrMsgTxt("Unknown kernel identifier.");   }   /*  real input argument for polynomial and rbf kernel   */   if( ker == 1 || ker == 2) {      if( !mxIsNumeric(prhs[3]) || !mxIsDouble(prhs[3]) ||         mxIsEmpty(prhs[3])    || mxIsComplex(prhs[3]) ||         mxGetN(prhs[3]) != 1  || mxGetM(prhs[3]) != 1 )         mexErrMsgTxt("Input arg must be a real scalar.");      else {         arg1 = mxGetScalar(prhs[3]);  /* take kernel argument */         /* if kernel is RBF than recompute its argument */         if( ker == 2) arg1 = -2*arg1*arg1;      }   }   /*  one or two real trade-off constant(s)  */   if( !mxIsNumeric(prhs[4]) || !mxIsDouble(prhs[4]) ||       mxIsEmpty(prhs[4])    || mxIsComplex(prhs[4]) ||       (mxGetN(prhs[4]) != 1  && mxGetM(prhs[4]) != 1 ))      mexErrMsgTxt("Input C must be one or two real scalar(s).");   else {      C = mxGetScalar(prhs[4]);   }   /* real parameter eps */   if( nrhs >= 6 ) {      if( !mxIsNumeric(prhs[5]) || !mxIsDouble(prhs[5]) ||         mxIsEmpty(prhs[5])    || mxIsComplex(prhs[5]) ||         mxGetN(prhs[5]) != 1  || mxGetM(prhs[5]) != 1 )         mexErrMsgTxt("Input eps must be a real scalar.");      else         eps = mxGetScalar(prhs[5]);   /* take eps argument */   }   /* real parameter tol */   if(nrhs >= 7) {      if( !mxIsNumeric(prhs[6]) || !mxIsDouble(prhs[6]) ||         mxIsEmpty(prhs[6])    || mxIsComplex(prhs[6]) ||         mxGetN(prhs[6]) != 1  || mxGetM(prhs[6]) != 1 )         mexErrMsgTxt("Input tol must be a real scalar.");      else         tolerance = mxGetScalar(prhs[6]);  /* take tolerance argument */   }   /* real vector of Lagrangeian multipliers */   if(nrhs >= 8) {      if( !mxIsNumeric(prhs[7]) || !mxIsDouble(prhs[7]) ||          mxIsEmpty(prhs[7])    || mxIsComplex(prhs[7]) ||          (mxGetN(prhs[7]) != 1  && mxGetM(prhs[7]) != 1 ))          mexErrMsgTxt("Input Alpha must be a real vector.");   }   /* real scalar - bias */   if( nrhs >= 9 ) {      if( !mxIsNumeric(prhs[8]) || !mxIsDouble(prhs[8]) ||         mxIsEmpty(prhs[8])    || mxIsComplex(prhs[8]) ||         mxGetN(prhs[8]) != 1  || mxGetM(prhs[8]) != 1 )         mexErrMsgTxt("Input bias must be a real scalar.");   }   /* ---- init variables ------------------------------- */   dataA = mxGetPr(prhs[0]);  /* pointer at patterns */   dataB = dataA;   dim = mxGetM(prhs[0]);     /* data dimension */   N = mxGetN(prhs[0]);       /* number of data */   labels12 = mxGetPr(prhs[1]);    /* labels (1,2) */      ker_cnt = 0;   kadd = 1/(2*C);   /* allocate memory for targets (labels) (1,-1) */   if( (target = mxCalloc(N, sizeof(double) )) == NULL) {      mexErrMsgTxt("Not enough memory.");   }   /* transform labels12 (1,2) from to targets (1,-1) */   for( i = 0; i < N; i++ ) {      target[i] = - labels12[i]*2 + 3;   }   /* create output variable for bias */   plhs[1] = mxCreateDoubleMatrix(1,1,mxREAL);   b = mxGetPr(plhs[1]);   /* take init value of bias if given */   if( nrhs >= 9 ) {      *b = -mxGetScalar(prhs[8]);     }   /* allocate memory for error_cache */   if( (error_cache = mxCalloc(N, sizeof(double) )) == NULL) {      mexErrMsgTxt("Not enough memory for error cache.");   }   /* create vector for Lagrangeians */   plhs[0] = mxCreateDoubleMatrix(1,N,mxREAL);   alpha = mxGetPr(plhs[0]);   /* if Lagrangeians given then use them as initial values */   if( nrhs >= 8 ) {      initAlpha = mxGetPr(prhs[7]);      for( i = 0; i < N; i++ ) {         alpha[i] = initAlpha[i];      }      /* Init error cache for non-bound multipliers. */      for( i = 0; i < N; i++ ) {         if( alpha[i] != 0 && alpha[i] != C ) {            error_cache[i] = learned_func(i) - target[i];         }      }   }   /* ---- run SMO ------------------------------------------- */   runSMO();   /* ---- output statistics --------------------------------- */   if( nlhs >= 3 ) {      /* count number of support vectors */      plhs[2] = mxCreateDoubleMatrix(1,1,mxREAL);      nsv = mxGetPr(plhs[2]);      *nsv = 0;      for( i = 0; i < N; i++ ) {         if( alpha[i] > ZERO_LIM ) (*nsv)++;      }   }   if( nlhs >= 4) {     /* evaluates classification error on traning patterns */     plhs[3] = mxCreateDoubleMatrix(1,1,mxREAL);     trn_err = mxGetPr(plhs[3]);     nerr = 0;     for( i = 0; i < N; i++ ) {        if( target[i] == 1 ) {           if( learned_func2(i) < 0 ) nerr++;        }        else           if( learned_func2(i) >= 0 ) nerr++;     }     *trn_err = nerr/N;   }   if( nlhs >= 5) {           /* compute margin */      plhs[4] = mxCreateDoubleMatrix(1,1,mxREAL);      margin = mxGetPr(plhs[4]);      *margin = 0;      for( i = 0; i < N; i++ ) {        for( j = 0; j < N; j++ ) {           if( alpha[i] > 0 && alpha[j] > 0 )              *margin += alpha[i]*alpha[j]*target[i]*target[j]*kernel(i,j);        }      }      *margin = 1/sqrt(*margin);   }   if( nlhs >= 6 ) {     plhs[5] = mxCreateDoubleMatrix(1,1,mxREAL);     (*mxGetPr(plhs[5])) = (double)ker_cnt;   }   /* decision function of type <w,x>+b is used */   *b = -*b;   /* ----- free memory --------------------------------------- */   mxFree( error_cache );   mxFree( target );}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -