📄 m2o_smo.c
字号:
/*---------------------------------------------------------------------------[Alpha,bias,kercnt] = m2o_smo(data,labels,ker,arg,C,eps,tol)M2O_SMO Multi-class translated to one-class SVM and solved by the modified SMO for one-class problem. Statistical Pattern Recognition Toolbox, Vojtech Franc, Vaclav Hlavac (c) Czech Technical University Prague, http://cmp.felk.cvut.cz Modifications 9-july-2002, VF 25-Nov-2001, V. Franc-------------------------------------------------------------------- */#include "mex.h"#include "matrix.h"#include <math.h>#include <stdlib.h>#include <string.h>#include <limits.h>#include "kernel.h"#define MINUS_INF INT_MIN#define PLUS_INF INT_MAX/* case insensitive string comparision */#ifdef __BORLANDC__ #define STR_COMPARE(A,B,C) strncmpi(A,B,C) /* Borland */#else #define STR_COMPARE(A,B,C) strncmp(A,B,C) /* Linux gcc */#endif#define ABS(A) (A < 0 ) ? (-A) : (A)#define MAX(A,B) (((A) > (B)) ? (A) : (B) )#define MIN(A,B) (((A) < (B)) ? (A) : (B) )#define KDELTA(A,B) (A==B)#define KDELTA4(A1,A2,A3,A4) ((A1==A2) || (A1==A3) || (A1==A4) || (A2==A3) || (A2==A4) || (A3==A4))/* -- Global variables ----------------- */int *labels; double *stop;long tmax;double C;double *alpha;double bias;long num_data;int solution;long t;long N, M; // number of data, number of classesdouble *error_cache;double tol = 0.001;double eps = 0.001;/*-----------------------------------------------*/void get_indices2( long *index, long *class, long i ){ *index = i / (M-1); *class = (i % (M-1))+1; if( *class >= labels[ *index ]) (*class)++; return;}/* ------------------------------------------------------------ Kernel function ------------------------------------------------------------ */double kernel_fce( long a, long b ){ double value; long i1,c1,i2,c2; get_indices2( &i1, &c1, a ); get_indices2( &i2, &c2, b ); if( KDELTA4(labels[i1],labels[i2],c1,c2) ) { value = (+KDELTA(labels[i1],labels[i2]) -KDELTA(labels[i1],c2) -KDELTA(labels[i2],c1) +KDELTA(c1,c2) )*(kernel( i1, i2 )+1); } else { value = 0; } return( value );}/*-----------------------------------*/double learned_func( long k){ double result = 0.; long i; for( i=0; i < num_data; i++ ) { if( alpha[i] > 0) { result = result + alpha[i]*kernel_fce(i,k); } } return( result );}/* ------------------------------------------------------------ Onec------------------------------------------------------------ */void runOnec( void ){ long numChanged = 0; long examineAll = 1; long k; // allocates cache if( (error_cache = mxCalloc(num_data, sizeof(double))) == NULL) { mexErrMsgTxt("Not enough memory for error cache."); } while( numChanged > 0 || examineAll != 0 ) { numChanged = 0; if( examineAll != 0 ) { for(k=0; k < num_data; k++ ) { numChanged = numChanged + examineExample(k); } } else { for(k=0; k < num_data; k++ ) { if( alpha[k] != 0 & alpha[k] != 0 ) { numChanged = numChanged + examineExample(k); } } } if( examineAll == 1) { examineAll = 0; } else { if( numChanged == 0 ) { examineAll = 1; } } } mxFree( error_cache );}/* ----------------------------------------------------------- EXAMINESTEP function ------------------------------------------------------------*/int examineExample( long i1){ double alpha1; double E1; alpha1 = alpha[i1]; if( alpha1 > 0 & alpha1 < C) { E1 = error_cache[i1]; } else { E1 = learned_func(i1); } if(((E1-1) < -tol & alpha1 < C) | ((E1-1) > tol & alpha1 > 0)) { return( takeStep( i1 ) ); } else { return(0); }}/* --------------------------------------------------------- TAKESTEP function -----------------------------------------------------------*/int takeStep( long i1){ double alpha1, a1, E1, k11; long i; alpha1 = alpha[i1]; if( alpha1 > 0 & alpha1 < C) { E1 = error_cache[i1]; } else { E1 = learned_func(i1); } k11 = kernel_fce( i1, i1 ); a1 = alpha1 + (1 - E1)/k11; if( ABS(a1-alpha1)< eps*(a1+alpha1+eps)) return( 0 ); if( a1 > C ) { a1 = C; } else if( a1 < 0 ) { a1 = 0; } for( i=0; i < num_data; i++ ) { if( 0 < alpha[i] & alpha[i] < C ) { error_cache[i] = error_cache[i] + kernel_fce(i1,i)*(a1-alpha1); } } if( !(alpha1 > 0 & alpha1 < C)) { error_cache[i1] = E1 + k11*a1; } alpha[i1] = a1; return( 1 );}/* ============================================================== Main MEX function - interface to Matlab.============================================================== */void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray*prhs[] ){ char ker_id[10]; long i,j ; double *oarg0, *oarg1; long i1, c1; double *tmp; long nsv, ansv; /* -- gets input arguments --------------------------- */ dataA = mxGetPr(prhs[0]); /* pointer at patterns */ dataB = dataA; dim = mxGetM(prhs[0]); /* data dimension */ N = mxGetN(prhs[0]); /* number of data */ tmp = mxGetPr(prhs[1]); if( (labels = mxCalloc(N, sizeof(int))) == NULL) { mexErrMsgTxt("Not enough memory for error cache."); } for( i = 0; i < N; i++ ) { labels[i] = (int)tmp[i]; } /* kernel identifier*/ mxGetString( prhs[2], ker_id, 10 ); if( STR_COMPARE( ker_id, "linear", 6) == 0 ) { ker = 0; } else if( STR_COMPARE( ker_id, "poly", 4) == 0 ) { ker = 1; } else if( STR_COMPARE( ker_id, "rbf", 3) == 0 ) { ker = 2; } else mexErrMsgTxt("Unknown kernel identifier."); /* take kernel argument */ arg1 = mxGetScalar(prhs[3]); /* if kernel is RBF than recompute its argument */ if( ker == 2) arg1 = -2*arg1*arg1; /* regularization constant */ C = mxGetScalar(prhs[4]); /* eps a tol parameter */ if( nrhs > 6 ) eps = mxGetScalar(prhs[5]); if( nrhs > 7 ) tol = mxGetScalar(prhs[6]); /* -- Inicialization ---------------------------- */ ker_cnt = 0; /* counter for number of kernel evaluations */ // gets number of classes = max class label, min c.l. = 1 M = MINUS_INF; for( i = 0; i < N; i++ ) { if( labels[i] > M ) M = labels[i]; } // num of transformed data num_data = (M-1)*N; /* -- calls Kernel S-K ---------------------------- */ /* create vector for Lagrangeians */ if( (alpha = mxCalloc(num_data, sizeof(double))) == NULL) { mexErrMsgTxt("Not enough memory for error cache."); } // run the main algorithm runOnec(); /* -- sets up outputs ------------------------------- */ // Output plhs[0] = mxCreateDoubleMatrix(M,N,mxREAL); oarg0 = mxGetPr(plhs[0]); plhs[1] = mxCreateDoubleMatrix(M,1,mxREAL); oarg1 = mxGetPr(plhs[1]); for(i=0; i < M; i++ ) { for( j=0; j < num_data; j++ ) { get_indices2( &i1, &c1, j ); oarg0[(i1*M)+i] += alpha[j]*(KDELTA(labels[i1],i+1)+KDELTA(i+1,c1)); oarg1[i] += alpha[j]*(KDELTA(labels[i1],i+1)-KDELTA(i+1,c1)); } } // exit result plhs[2] = mxCreateDoubleMatrix(1,1,mxREAL); *((double*)mxGetPr(plhs[2])) = ker_cnt; /* -- free memory ----------------- */ mxFree( alpha ); mxFree( labels );}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -