📄 smo2.c
字号:
for( k0 = rand(), k = k0; k < N + k0; k++ ) { i2 = k % N;#else for( k = 0; k < N; k++) { i2 = k;#endif if( alpha[i2] > 0 ) { if( takeStep(i1,i2) ) return( 1 ); } }#ifdef RANDOM for( k0 = rand(), k = k0; k < N + k0; k++ ) { i2 = k % N;#else for( k = 0; k < N; k++) { i2 = k;#endif if( takeStep(i1,i2) ) return( 1 ); } } /* if( ... ) */ return( 0 );}/* -------------------------------------------------------------- Main SMO optimization cycle.-------------------------------------------------------------- */void runSMO( void ){ long numChanged = 0; long examineAll = 1; long k; while( numChanged > 0 || examineAll ) { numChanged = 0; if( examineAll ) { for( k = 0; k < N; k++ ) { numChanged += examineExample( k ); } } else { for( k = 0; k < N; k++ ) { if( alpha[k] > 0 ) numChanged += examineExample( k ); } } if( examineAll == 1 ) examineAll = 0; else if( numChanged == 0 ) examineAll = 1; }}/* ============================================================== Main MEX function - interface to Matlab.============================================================== */void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray*prhs[] ){ long i,j ; double *labels12, *initAlpha, *nsv, *trn_err, *margin; double nerr; char skernel[10]; double C; /* ---- check number of input arguments ------------- */ if(nrhs < 5) mexErrMsgTxt("Not enough input arguments."); if(nlhs < 2) mexErrMsgTxt("Not enough output arguments."); /* ---- check input arguments ----------------------- */ /* data matrix [dim x N ] */ if( !mxIsNumeric(prhs[0]) || !mxIsDouble(prhs[0]) || mxIsEmpty(prhs[0]) || mxIsComplex(prhs[0]) ) mexErrMsgTxt("Input X must be a real matrix."); /* vector of labels (1,2) */ if( !mxIsNumeric(prhs[1]) || !mxIsDouble(prhs[1]) || mxIsEmpty(prhs[1]) || mxIsComplex(prhs[1]) || (mxGetN(prhs[1]) != 1 && mxGetM(prhs[1]) != 1)) mexErrMsgTxt("Input I must be a real vector."); /* a string as kernel identifier ('linear',poly','rbf' ) */ if( mxIsChar(prhs[2]) != 1 || mxGetM(prhs[2]) != 1 ) mexErrMsgTxt("Input ker must be a string"); else { /* check which kernel */ mxGetString( prhs[2], skernel, 10 ); if( STR_COMPARE( skernel, "linear", 6) == 0 ) { ker = 0; } else if( STR_COMPARE( skernel, "poly", 4) == 0 ) { ker = 1; } else if( STR_COMPARE( skernel, "rbf", 3) == 0 ) { ker = 2; } else mexErrMsgTxt("Unknown kernel identifier."); } /* real input argument for polynomial and rbf kernel */ if( ker == 1 || ker == 2) { if( !mxIsNumeric(prhs[3]) || !mxIsDouble(prhs[3]) || mxIsEmpty(prhs[3]) || mxIsComplex(prhs[3]) || mxGetN(prhs[3]) != 1 || mxGetM(prhs[3]) != 1 ) mexErrMsgTxt("Input arg must be a real scalar."); else { arg1 = mxGetScalar(prhs[3]); /* take kernel argument */ /* if kernel is RBF than recompute its argument */ if( ker == 2) arg1 = -2*arg1*arg1; } } /* one or two real trade-off constant(s) */ if( !mxIsNumeric(prhs[4]) || !mxIsDouble(prhs[4]) || mxIsEmpty(prhs[4]) || mxIsComplex(prhs[4]) || (mxGetN(prhs[4]) != 1 && mxGetM(prhs[4]) != 1 )) mexErrMsgTxt("Input C must be one or two real scalar(s)."); else { C = mxGetScalar(prhs[4]); } /* real parameter eps */ if( nrhs >= 6 ) { if( !mxIsNumeric(prhs[5]) || !mxIsDouble(prhs[5]) || mxIsEmpty(prhs[5]) || mxIsComplex(prhs[5]) || mxGetN(prhs[5]) != 1 || mxGetM(prhs[5]) != 1 ) mexErrMsgTxt("Input eps must be a real scalar."); else eps = mxGetScalar(prhs[5]); /* take eps argument */ } /* real parameter tol */ if(nrhs >= 7) { if( !mxIsNumeric(prhs[6]) || !mxIsDouble(prhs[6]) || mxIsEmpty(prhs[6]) || mxIsComplex(prhs[6]) || mxGetN(prhs[6]) != 1 || mxGetM(prhs[6]) != 1 ) mexErrMsgTxt("Input tol must be a real scalar."); else tolerance = mxGetScalar(prhs[6]); /* take tolerance argument */ } /* real vector of Lagrangeian multipliers */ if(nrhs >= 8) { if( !mxIsNumeric(prhs[7]) || !mxIsDouble(prhs[7]) || mxIsEmpty(prhs[7]) || mxIsComplex(prhs[7]) || (mxGetN(prhs[7]) != 1 && mxGetM(prhs[7]) != 1 )) mexErrMsgTxt("Input Alpha must be a real vector."); } /* real scalar - bias */ if( nrhs >= 9 ) { if( !mxIsNumeric(prhs[8]) || !mxIsDouble(prhs[8]) || mxIsEmpty(prhs[8]) || mxIsComplex(prhs[8]) || mxGetN(prhs[8]) != 1 || mxGetM(prhs[8]) != 1 ) mexErrMsgTxt("Input bias must be a real scalar."); } /* ---- init variables ------------------------------- */ dataA = mxGetPr(prhs[0]); /* pointer at patterns */ dataB = dataA; dim = mxGetM(prhs[0]); /* data dimension */ N = mxGetN(prhs[0]); /* number of data */ labels12 = mxGetPr(prhs[1]); /* labels (1,2) */ ker_cnt = 0; kadd = 1/(2*C); /* allocate memory for targets (labels) (1,-1) */ if( (target = mxCalloc(N, sizeof(double) )) == NULL) { mexErrMsgTxt("Not enough memory."); } /* transform labels12 (1,2) from to targets (1,-1) */ for( i = 0; i < N; i++ ) { target[i] = - labels12[i]*2 + 3; } /* create output variable for bias */ plhs[1] = mxCreateDoubleMatrix(1,1,mxREAL); b = mxGetPr(plhs[1]); /* take init value of bias if given */ if( nrhs >= 9 ) { *b = -mxGetScalar(prhs[8]); } /* allocate memory for error_cache */ if( (error_cache = mxCalloc(N, sizeof(double) )) == NULL) { mexErrMsgTxt("Not enough memory for error cache."); } /* create vector for Lagrangeians */ plhs[0] = mxCreateDoubleMatrix(1,N,mxREAL); alpha = mxGetPr(plhs[0]); /* if Lagrangeians given then use them as initial values */ if( nrhs >= 8 ) { initAlpha = mxGetPr(prhs[7]); for( i = 0; i < N; i++ ) { alpha[i] = initAlpha[i]; } /* Init error cache for non-bound multipliers. */ for( i = 0; i < N; i++ ) { if( alpha[i] != 0 && alpha[i] != C ) { error_cache[i] = learned_func(i) - target[i]; } } } /* ---- run SMO ------------------------------------------- */ runSMO(); /* ---- output statistics --------------------------------- */ if( nlhs >= 3 ) { /* count number of support vectors */ plhs[2] = mxCreateDoubleMatrix(1,1,mxREAL); nsv = mxGetPr(plhs[2]); *nsv = 0; for( i = 0; i < N; i++ ) { if( alpha[i] > ZERO_LIM ) (*nsv)++; } } if( nlhs >= 4) { /* evaluates classification error on traning patterns */ plhs[3] = mxCreateDoubleMatrix(1,1,mxREAL); trn_err = mxGetPr(plhs[3]); nerr = 0; for( i = 0; i < N; i++ ) { if( target[i] == 1 ) { if( learned_func2(i) < 0 ) nerr++; } else if( learned_func2(i) >= 0 ) nerr++; } *trn_err = nerr/N; } if( nlhs >= 5) { /* compute margin */ plhs[4] = mxCreateDoubleMatrix(1,1,mxREAL); margin = mxGetPr(plhs[4]); *margin = 0; for( i = 0; i < N; i++ ) { for( j = 0; j < N; j++ ) { if( alpha[i] > 0 && alpha[j] > 0 ) *margin += alpha[i]*alpha[j]*target[i]*target[j]*kernel(i,j); } } *margin = 1/sqrt(*margin); } if( nlhs >= 6 ) { plhs[5] = mxCreateDoubleMatrix(1,1,mxREAL); (*mxGetPr(plhs[5])) = (double)ker_cnt; } /* decision function of type <w,x>+b is used */ *b = -*b; /* ----- free memory --------------------------------------- */ mxFree( error_cache ); mxFree( target );}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -