📄 kernelskf.c
字号:
/*---------------------------------------------------------------------------[Alpha,bias,sol,t,kercnt,margin,trnerr,margin2]= kernelskf(data,labels,stop,ker,arg,tmax,C) KERNELSKF kernel Schlesinger-Kozinec's algorithm. It solves the Support vector Machines problem with quadratic cost function for classification violations. Inputs: data [dim x N] training patterns labels [1 x N] labels of training patterns stop [1 x 2] if stop(1) == 1 then stopping condition m*-m < stop(2) is used else stopping condition (m*-m)/m < stop(2) is used. Where m* is the optimial margin and m is the margin of found hyperplane (in the given feature space). ker [string] kernel, see 'help kernel'. arg [...] argument of given kernel, see 'help kernel'. tmax [int] maximal number of iterations. C [real] trade-off between margin and training error. Outputs: Alpha [1xN] Lagrangians defining found decision rule. bias [real] bias (threshold) of found decision rule. sol [int] 1 solution is found 0 algorithm stoped (t == tmax) before converged. -1 hyperplane with margin greater then epsilon does not exist. t [int] number of iterations. kercnt [int] number of kernel evaluations. margin [real] margin between classes. trnerr [real] training error. margin [real] margin between classes in the non-linear space. See also SVM. Statistical Pattern Recognition Toolbox, Vojtech Franc, Vaclav Hlavac (c) Czech Technical University Prague, http://cmp.felk.cvut.cz Written Vojtech Franc (diploma thesis) 02.11.1999, 13.4.2000 Modifications 19-Nov-2001, V.Franc 13-Nov-2001, V.Franc 12-Nov-2001, V.Franc 5-Nov-2001, V.Franc 4-Nov-2001, V.Franc, created-------------------------------------------------------------------- */#include "mex.h"#include "matrix.h"#include <math.h>#include <stdlib.h>#include <string.h>#include <limits.h>#include "kernel.h"#define MINUS_INF INT_MIN#define PLUS_INF INT_MAX/* case insensitive string comparision */#ifdef __BORLANDC__ #define STR_COMPARE(A,B,C) strncmpi(A,B,C) /* Borland */#else #define STR_COMPARE(A,B,C) strncasecmp(A,B,C) /* Linux gcc */#endif#define MAX(A,B) (((A) > (B)) ? (A) : (B) )#define MIN(A,B) (((A) < (B)) ? (A) : (B) )/* ============================================================== Main MEX function - interface to Matlab.============================================================== */void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray*prhs[] ){ char skernel[10]; long t; /* iteration number */ long i, j; /* loop variables */ int sol; /* solution: 1=found, 0=not found, -1=does not exist, ... */ double a, b, c; /* auxiliary variables */ double emin, gmin; /* --//-- */ long inx1, inx2; /* --//-- */ double proj1, proj2; /* --//-- */ double *w1x; /* <w1,x>; cache */ double *w2x; /* <w2,x>; cache */ double k; /* --//-- */ double margin2; /* squared margin in the feature space */ double normw; /* ||w1-w2|| */ double *labels; /* pointer at labels */ long N; /* number of training patterns */ double *stop; /* stop[0] stopping criterion and stop[1] its argument */ long tmax; /* maximal number of iterations */ double C; /* trade-off constant */ double kadd; /* diagonal additional term */ double *alpha; /* Lagrangians */ double *bias; /* threshold of the learned indicator function */ double margin; /* margin in the original space */ double trn_err; /* training error */ double dfun; /* value of decision function */ /* ---- CHECK INPUT ARGUMENTS ----------------------- */ if(nrhs < 7) mexErrMsgTxt("Not enough input arguments."); if(nlhs < 3) mexErrMsgTxt("Not enough output arguments."); /* data matrix [dim x N ] */ if( !mxIsNumeric(prhs[0]) || !mxIsDouble(prhs[0]) || mxIsEmpty(prhs[0]) || mxIsComplex(prhs[0]) ) mexErrMsgTxt("Input X must be a real matrix."); /* labels [1 x N ] */ if( !mxIsNumeric(prhs[1]) || !mxIsDouble(prhs[1]) || mxIsEmpty(prhs[1]) || mxIsComplex(prhs[1]) ) mexErrMsgTxt("Input I must be a real vector."); /* stopping condition */ if( !mxIsNumeric(prhs[2]) || !mxIsDouble(prhs[2]) || mxIsEmpty(prhs[2]) || mxIsComplex(prhs[2]) || mxGetN(prhs[2]) != 2 || mxGetM(prhs[2]) != 1 ) mexErrMsgTxt("Input stop must be a vector [1 x 2]."); /* a string as kernel identifier ('linear',poly','rbf' ) */ if( mxIsChar(prhs[3]) != 1 || mxGetM(prhs[3]) != 1 ) mexErrMsgTxt("Input ker must be a string"); else { /* check which kernel */ mxGetString( prhs[3], skernel, 10 ); if( STR_COMPARE( skernel, "linear", 6) == 0 ) { ker = 0; } else if( STR_COMPARE( skernel, "poly", 4) == 0 ) { ker = 1; } else if( STR_COMPARE( skernel, "rbf", 3) == 0 ) { ker = 2; } else mexErrMsgTxt("Unknown kernel identifier."); } /* real input argument for polynomial and rbf kernel */ if( ker == 1 || ker == 2) { if( !mxIsNumeric(prhs[4]) || !mxIsDouble(prhs[4]) || mxIsEmpty(prhs[4]) || mxIsComplex(prhs[4]) || mxGetN(prhs[4]) != 1 || mxGetM(prhs[4]) != 1 ) mexErrMsgTxt("Input arg must be a real scalar."); else { arg1 = mxGetScalar(prhs[4]); /* take kernel argument */ /* if kernel is RBF than recompute its argument */ if( ker == 2) arg1 = -2*arg1*arg1; } } /* tmax */ if( !mxIsNumeric(prhs[5]) || !mxIsDouble(prhs[5]) || mxIsEmpty(prhs[5]) || mxIsComplex(prhs[5]) || (mxGetN(prhs[5]) != 1 && mxGetM(prhs[5]) != 1 )) mexErrMsgTxt("Input tmax must be an integer."); /* one or two real trade-off*/ if( !mxIsNumeric(prhs[6]) || !mxIsDouble(prhs[6]) || mxIsEmpty(prhs[6]) || mxIsComplex(prhs[6]) || (mxGetN(prhs[6]) != 1 && mxGetM(prhs[6]) != 1 )) mexErrMsgTxt("Input C must be a real scalar."); /* ---- GET INPUT ARGUMENTS ------------------------------- */ dataA = mxGetPr(prhs[0]); /* pointer at patterns */ dataB = mxGetPr(prhs[0]); /* pointer at patterns */ labels = mxGetPr(prhs[1]); /* pointer at labels */ dim = mxGetM(prhs[0]); /* data dimension */ N = mxGetN(prhs[0]); /* number of data */ stop = mxGetPr(prhs[2]); if( mxIsInf( mxGetScalar(prhs[5])) ) { tmax = INT_MAX; } else { tmax = (long)mxGetScalar(prhs[5]); } C = mxGetScalar(prhs[6]); // computes additional term to kernel value on the diagonal if( C != 0 ) kadd = 1/(2*C); else kadd = 0; /* create vector for Lagrangeians */ plhs[0] = mxCreateDoubleMatrix(1,N,mxREAL); alpha = mxGetPr(plhs[0]); /*-- INICIALIZATION ------------------------------*/ ker_cnt = 0; /* counter for number of kernel evaluetions */ // inicialization of cached values if( (w1x = mxCalloc(N, sizeof(double))) == NULL) { mexErrMsgTxt("Not enough memory for error cache."); } if( (w2x = mxCalloc(N, sizeof(double))) == NULL) { mexErrMsgTxt("Not enough memory for error cache."); } // takes two vectors as an initial solution for( inx1 = -1, inx2 = -1, i=0; i < N && (inx1==-1 || inx2==-1); i++ ) { if( labels[i] == 1 && inx1 == -1) { inx1 = i; alpha[i] = 1; } else if( labels[i] == 2 && inx2 == -1) { alpha[i] = 1; inx2 = i; } } // inits cache values for( i=0; i < N; i++ ) { w1x[i] = kernel(inx1,i); w2x[i] = kernel(inx2,i); } w1x[inx1] += kadd; w2x[inx2] += kadd; a = kernel( inx1, inx1) + kadd; b = kernel( inx2, inx2) + kadd; c = kernel( inx1, inx2); sol=0; t = 0; /* -- MAIN OPTIMIZATION CYCLE ------------------------ */ while( sol == 0 && tmax > t ) { t++; sol = 1; // -- compute auxciliary variables -- emin = PLUS_INF; gmin = PLUS_INF; for( i = 0; i < N; i++ ) { if( labels[i]== 1 ) { if( w1x[i] - w2x[i] < emin) { emin = w1x[i] - w2x[i]; inx1 = i; } } else { if( w2x[i] - w1x[i] < gmin ) { gmin = w2x[i] - w1x[i]; inx2 = i; } } } // normw = sqrt(<w1-w2,w1-w2>) normw = sqrt( a-2*c+b ); // projection <x,(w1 - w2)> for x from X1 proj1 = (emin + b - c ); // projection <x,(w2 - w1)> for x from X2 proj2 = (gmin + a - c); /* --- stoping condition for the 1st class ------ */ // (proj1 < proj2) ~ the worst point will be used for update if( (proj1 < proj2 ) && // ((stop[0]==2 && (1-(proj1)/(normw*normw)) >= stop[1] ) || ((stop[0]==2 && (1-(proj1)/(normw*normw)) >= stop[1]/2 ) || (stop[0]==1 && (normw-proj1/normw) >= stop[1] ) ) ) { // -- Adaptation phase of vector alpha1 ---------------------------- k = (a - emin - c)/(a+kernel(inx1,inx1)+kadd-2*(w1x[inx1]-w2x[inx1]) ); k = MIN( 1, k ); sol = 0; // -- UPDATE OF CACHED VALUES ----------------------------------- a = a*(1-k)*(1-k) + 2*(1-k)*k*w1x[inx1] + k*k * (kernel(inx1,inx1)+kadd ); c = c*(1-k) + k*w2x[inx1]; for( i=0; i < N; i++ ) { if( labels[i] == 1 && alpha[i] ) alpha[i] *= 1-k; w1x[i] = w1x[i]*(1-k) + k*kernel( i, inx1 ); } w1x[inx1] += k*kadd; alpha[inx1] += k; } else { // --- stopping condition for the 2nd class ------ // if( (stop[0]==2 && 2*(1-proj1/(normw*normw)) >= stop[1] ) || // (stop[0]==1 && (normw-proj1/normw) >= stop[1]) // if( (stop[0]==2 && (1-(proj2)/(normw*normw)) >= stop[1] ) || if( (stop[0]==2 && (1-(proj2)/(normw*normw)) >= stop[1]/2 ) || (stop[0]==1 && (normw-proj2/normw) >= stop[1]) ) { // -- Adaptation phase ---------------------------------- k = (b - gmin -c)/(b+kernel(inx2,inx2)+kadd-2*(w2x[inx2]-w1x[inx2])); k = MIN( 1, k ); sol = 0; // -- UPDATE OF CACHES --------------------------------------- b = b*(1-k)*(1-k) + 2*(1-k)*k*w2x[inx2] + k*k*(kernel(inx2,inx2)+kadd ); c = c*(1-k) + k*w1x[inx2]; for(i = 0; i < N; i++ ) { w2x[i] = (1-k)*w2x[i] + k*kernel(i, inx2); if( labels[i] == 2 && alpha[i]) alpha[i] *= 1-k; } alpha[inx2] += k; w2x[inx2] += k*kadd; } } if( sqrt( a -2*c +b ) <= 0 ) { // algorithm has converged to the zero vector --> classes overlap sol = -1; } } // while(...) if( sol == 1 && (proj1 < 0 || proj2 < 0) ) { sol = 0; } /* --- COMPUTATION OF OUTPUT VALUES ----------------------- */ // sqared margin in transformed space margin2 = a - 2*c + b; // threshold after normalization plhs[1] = mxCreateDoubleMatrix(1,1,mxREAL); bias = mxGetPr(plhs[1]); *bias = (b-a)/margin2;// mexPrintf("f1=%f, f2=%f\n", (2*emin+b-a)/(normw*normw), // (2*gmin+a-b)/(normw*normw)); // mexPrintf("0.5*(min <w,x1> - max<w,x2>)/|w|=%f\n", 0.5*(emin+gmin)/sqrt(margin2)); // solution (normal vect. in the transformed space) after normalization for( i=0; i < N; i++ ) { alpha[i] *= 2/margin2; } // training errors if( nlhs >= 7 ) { trn_err = 0; for(i = 0; i < N; i++ ) { dfun = 0; for( j=0; j < N; j++ ) { if( alpha[j] != 0 ) { if( labels[j] == 1) dfun += alpha[j]*kernel(i,j); else dfun -= alpha[j]*kernel(i,j); } } if( (3-labels[i]*2)*(dfun + *bias) < 0) trn_err++; } plhs[6] = mxCreateDoubleMatrix(1,1,mxREAL); (*mxGetPr(plhs[6])) = trn_err/N; } // compute margin if( nlhs >= 6 ) { margin = 0; margin = 0; for(i = 0; i < N; i++ ) { for( j=0; j < N; j++ ) { if( alpha[i] != 0 && alpha[j] != 0 ) { if( labels[i] == labels[j] ) margin += alpha[i]*alpha[j]*kernel(i,j); else margin -= alpha[i]*alpha[j]*kernel(i,j); } } } margin = 1/sqrt(margin); plhs[5] = mxCreateDoubleMatrix(1,1,mxREAL); (*mxGetPr(plhs[5])) = margin; } // number of kernel evaluations if( nlhs >= 5 ) { plhs[4] = mxCreateDoubleMatrix(1,1,mxREAL); (*mxGetPr(plhs[4])) = (double)ker_cnt; } // solution 1 (found), 0 (not found), -1 (does not exist) if( nlhs >= 3 ) { plhs[2] = mxCreateDoubleMatrix(1,1,mxREAL); (*mxGetPr(plhs[2])) = (double)sol; } // number of iterations if( nlhs >= 4 ) { plhs[3] = mxCreateDoubleMatrix(1,1,mxREAL); (*mxGetPr(plhs[3])) = (double)t; } if( nlhs >= 8 ) { plhs[7] = mxCreateDoubleMatrix(1,1,mxREAL); (*mxGetPr(plhs[7])) = sqrt(margin2); } /* ----- FREE MEMORY ----------------------- */ mxFree( w1x ); mxFree( w2x ); }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -