📄 m2o_sor.c
字号:
/*---------------------------------------------------------------------------[Alpha,bias,kercnt] = m2o_sor(data,labels,ker,arg,C,eps)M2O_SOR Multi-class translated to one-class SVM and solved by the SOR algorithm. Statistical Pattern Recognition Toolbox, Vojtech Franc, Vaclav Hlavac (c) Czech Technical University Prague, http://cmp.felk.cvut.cz Modifications
9-july-2002, VF 25-Nov-2001, V. Franc-------------------------------------------------------------------- */#include "mex.h"#include "matrix.h"#include <math.h>#include <stdlib.h>#include <string.h>#include <limits.h>#include "kernel.h"#define MINUS_INF INT_MIN#define PLUS_INF INT_MAX/* case insensitive string comparision */#ifdef __BORLANDC__ #define STR_COMPARE(A,B,C) strncmpi(A,B,C) /* Borland */#else #define STR_COMPARE(A,B,C) strncmp(A,B,C) /* Linux gcc */#endif#define ABS(A) (A < 0 ) ? (-A) : (A)#define MAX(A,B) (((A) > (B)) ? (A) : (B) )#define MIN(A,B) (((A) < (B)) ? (A) : (B) )#define KDELTA(A,B) (A==B)#define KDELTA4(A1,A2,A3,A4) ((A1==A2) || (A1==A3) || (A1==A4) || (A2==A3) || (A2==A4) || (A3==A4))/* -- Global variables ----------------- */int *labels; double *stop;long tmax;double C;double *alpha;double bias;long num_data;int solution;long t;long N, M; // number of data, number of classesdouble *error_cache;double eps = 0.001;
double tol = 0.001;/*-----------------------------------------------*/void get_indices2( long *index, long *class, long i ){ *index = i / (M-1); *class = (i % (M-1))+1; if( *class >= labels[ *index ]) (*class)++; return;}/* ------------------------------------------------------------ Kernel function ------------------------------------------------------------ */double kernel_fce( long a, long b ){ double value; long i1,c1,i2,c2; // 1 class // value = kernel( a, b ); // 2 classes with bias // value = kernel( a, b ) + 1; // if( labels[a] != labels[b] ) value *= -1; // multiclass case get_indices2( &i1, &c1, a ); get_indices2( &i2, &c2, b ); if( KDELTA4(labels[i1],labels[i2],c1,c2) ) { value = (+KDELTA(labels[i1],labels[i2]) -KDELTA(labels[i1],c2) -KDELTA(labels[i2],c1) +KDELTA(c1,c2) )*(kernel( i1, i2 )+1); } else { value = 0; } return( value );}/*-----------------------------------*/double learned_func( long k){ double result = 0.; long i; for( i=0; i < num_data; i++ ) { if( alpha[i] > 0) { result = result + alpha[i]*kernel_fce(i,k); } } return( result );}/* ------------------------------------------------------------ SOR
------------------------------------------------------------ */void runSor( void )
{
double omega = 1;
double ndelta;
double delta;
double oldalpha;
long i, j, t;
// mexPrintf("eps %f\n", eps);
t=0;
do {
t++;
ndelta = 0;
for( i = 0; i < num_data; i++ ) {
delta=0;
for(j=0; j < num_data; j++ ) {
if( alpha[j] != 0) delta += alpha[j]*kernel_fce(i, j);
}
delta = omega*(delta - 1)/kernel_fce(i,i);
oldalpha = alpha[i];
alpha[i] -= delta;
if( alpha[i] < 0 ) alpha[i] = 0; else if( alpha[i] > C ) alpha[i] = C;
ndelta += (alpha[i]-oldalpha)*(alpha[i]-oldalpha);
}
} while(eps < ndelta && t < 50000);
// mexPrintf("t %d, ndelta %f\n", t, ndelta );
}
/* ------------------------------------------------------------ Onec------------------------------------------------------------ */void runOnec( void ){ long numChanged = 0; long examineAll = 1; long k; // allocates cache if( (error_cache = mxCalloc(num_data, sizeof(double))) == NULL) { mexErrMsgTxt("Not enough memory for error cache."); } while( numChanged > 0 || examineAll != 0 ) { numChanged = 0; if( examineAll != 0 ) { for(k=0; k < num_data; k++ ) { numChanged = numChanged + examineExample(k); } } else { for(k=0; k < num_data; k++ ) { if( alpha[k] != 0 & alpha[k] != 0 ) { numChanged = numChanged + examineExample(k); } } } if( examineAll == 1) { examineAll = 0; } else { if( numChanged == 0 ) { examineAll = 1; } } } mxFree( error_cache );}/* ----------------------------------------------------------- EXAMINESTEP function ------------------------------------------------------------*/int examineExample( long i1){ double alpha1; double E1; alpha1 = alpha[i1]; if( alpha1 > 0 & alpha1 < C) { E1 = error_cache[i1]; } else { E1 = learned_func(i1); } if(((E1-1) < -tol & alpha1 < C) | ((E1-1) > tol & alpha1 > 0)) { return( takeStep( i1 ) ); } else { return(0); }}/* --------------------------------------------------------- TAKESTEP function -----------------------------------------------------------*/int takeStep( long i1){ double alpha1, a1, E1, k11; long i; alpha1 = alpha[i1]; if( alpha1 > 0 & alpha1 < C) { E1 = error_cache[i1]; } else { E1 = learned_func(i1); } k11 = kernel_fce( i1, i1 ); a1 = alpha1 + (1 - E1)/k11; if( ABS(a1-alpha1)< eps*(a1+alpha1+eps)) return( 0 ); if( a1 > C ) { a1 = C; } else if( a1 < 0 ) { a1 = 0; } for( i=0; i < num_data; i++ ) { if( 0 < alpha[i] & alpha[i] < C ) { error_cache[i] = error_cache[i] + kernel_fce(i1,i)*(a1-alpha1); } } if( !(alpha1 > 0 & alpha1 < C)) { error_cache[i1] = E1 + k11*a1; } alpha[i1] = a1; return( 1 );}/* ------------------------------------------------------------ Kernel S-K algorithm------------------------------------------------------------ */void runKernelSK( void ) { double *wx; // dot products <w,x_i> double w2; // dot products <w,w> long inx, i; double min_wx; double q; // allocates cache if( (wx = mxCalloc(num_data, sizeof(double))) == NULL) { mexErrMsgTxt("Not enough memory for error cache."); } // inicialization alpha[0] = 1; w2 = kernel_fce( 0, 0); for( i = 0; i < num_data; i++ ) { wx[i] =kernel_fce( 0, i); } solution = 0; t = 0; // main optimization cycle while( solution == 0 && tmax > t ) { t++; // finds index of x with minimum <w,x> for( min_wx = PLUS_INF, inx = -1, i = 0; i < num_data; i++ ) { if( min_wx > wx[i] ) { min_wx = wx[i]; inx = i; } } if( sqrt(w2) - min_wx/sqrt( w2 ) > stop[0] ) { q = MIN(1, (w2 - min_wx)/(w2 - 2*min_wx + kernel_fce(inx,inx) )); w2 = w2*(1-q)*(1-q) + 2*(1-q)*q*wx[inx] + q*q * kernel_fce(inx,inx); for( i=0; i < num_data; i++ ) { if( alpha[i] ) alpha[i] *= 1-q; wx[i] = wx[i]*(1-q) + q*kernel_fce( i, inx ); } alpha[inx] += q; } else { solution = 1; } } mxFree( wx );}/* ============================================================== Main MEX function - interface to Matlab.============================================================== */void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray*prhs[] ){ char ker_id[10]; long i,j ; double *oarg0, *oarg1; long i1, c1; double *tmp; long nsv, ansv; /* -- gets input arguments --------------------------- */ dataA = mxGetPr(prhs[0]); /* pointer at patterns */ dataB = dataA; dim = mxGetM(prhs[0]); /* data dimension */ N = mxGetN(prhs[0]); /* number of data */ tmp = mxGetPr(prhs[1]); if( (labels = mxCalloc(N, sizeof(int))) == NULL) { mexErrMsgTxt("Not enough memory for error cache."); } for( i = 0; i < N; i++ ) { labels[i] = (int)tmp[i]; } /* kernel identifier*/ mxGetString( prhs[2], ker_id, 10 ); if( STR_COMPARE( ker_id, "linear", 6) == 0 ) { ker = 0; } else if( STR_COMPARE( ker_id, "poly", 4) == 0 ) { ker = 1; } else if( STR_COMPARE( ker_id, "rbf", 3) == 0 ) { ker = 2; } else mexErrMsgTxt("Unknown kernel identifier."); /* take kernel argument */ arg1 = mxGetScalar(prhs[3]); /* if kernel is RBF than recompute its argument */ if( ker == 2) arg1 = -2*arg1*arg1; /* regularization constant */ C = mxGetScalar(prhs[4]); /* eps a tol parameter */ if( nrhs >= 6 ) eps = mxGetScalar(prhs[5]); if( nrhs >= 7 ) tol = mxGetScalar(prhs[6]); /* -- Inicialization ---------------------------- */ ker_cnt = 0; /* counter for number of kernel evaluations */ // gets number of classes = max class label, min c.l. = 1 M = MINUS_INF; for( i = 0; i < N; i++ ) { if( labels[i] > M ) M = labels[i]; } // num of transformed data num_data = (M-1)*N; /* -- calls Kernel S-K ---------------------------- */ /* create vector for Lagrangeians */ if( (alpha = mxCalloc(num_data, sizeof(double))) == NULL) { mexErrMsgTxt("Not enough memory for error cache."); } // run the main algorithm // runOnec();
runSor(); /* -- sets up outputs ------------------------------- */ // Output plhs[0] = mxCreateDoubleMatrix(M,N,mxREAL); oarg0 = mxGetPr(plhs[0]); plhs[1] = mxCreateDoubleMatrix(M,1,mxREAL); oarg1 = mxGetPr(plhs[1]); for(i=0; i < M; i++ ) { for( j=0; j < num_data; j++ ) { get_indices2( &i1, &c1, j ); oarg0[(i1*M)+i] += alpha[j]*(KDELTA(labels[i1],i+1)+KDELTA(i+1,c1)); oarg1[i] += alpha[j]*(KDELTA(labels[i1],i+1)-KDELTA(i+1,c1)); } } // for(i = 0; i < num_data; i++ ) { // oarg1[i] = alpha[i]; // } // bias // for(bias =0, i = 0; i < num_data; i++ ) { // if(labels[i]==1) bias += alpha[i]; else bias -= alpha[i]; // } // mexPrintf("bias=%f\n",bias ); // exit result plhs[2] = mxCreateDoubleMatrix(1,1,mxREAL); *((double*)mxGetPr(plhs[2])) = ker_cnt; /* -- free memory ----------------- */ mxFree( alpha ); mxFree( labels );}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -