📄 svmclass2.c
字号:
/* -------------------------------------------------------------------- [pred_labels, dec_fun] = svmclass2( tst_data, trn_data, trn_labels, alpha, bias, ker, arg ) Evaluate SVM decision function. To make executable file run 'mex svmclass2.c kernel.c'. inputs: tst_data [D x M ] matrix of M testing D-dimensional patterns. trn_data [D x N ] matrix of N training D-dimensional patterns. trn_labels [1 x N] pattern labels (1 for 1st class, 2 for 2nd class ). alpha [1 x N] Lagrange multipliers of the training patterns. bias [real] bias of the decision function. ker [string] kernel identifier: 'linear', 'poly', 'rbf', 'sigmoid'. arg [reals] kernel argument; for 'linear' is arg=[]; for 'poly' arg[0] is degree of polynom; for 'rbf' arg[0] is sigma; for 'sigmoid' it is tanh(arg[0]*<x,y>+arg[1]). outputs: pred_labels [1xM] predicted labels of testing patterns (1 if fpred>=0 else 2). dec_fun [1 x M] values of decision function for testing patterns.
Statistical Pattern Recognition Toolbox, Vojtech Franc, Vaclav Hlavac (c) Czech Technical University Prague, http://cmp.felk.cvut.cz.
Modifications: 20-may-2002, V.Franc, multipliers canbe also negative 14-November-2001, V.Franc, sigmoid kernel 30-September-2001, V. Franc, comments 22-September-2001, V. Franc, kernel.c used. 19-September-2001, V. Franc, tunned. 18-September-2001, V. Franc, created -------------------------------------------------------------------- */#include "mex.h"#include "matrix.h"#include <math.h>#include <stdlib.h>#include <string.h>#include "kernel.h"/* case insensitive string comparision */#ifdef __BORLANDC__ #define STR_COMPARE(A,B,C) strncmpi(A,B,C) /* Borland */#else #define STR_COMPARE(A,B,C) strncasecmp(A,B,C) /* Linux gcc */#endif/* ------------------------------------------------------- Main MEX function called from Matlab.--------------------------------------------------------*/void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray*prhs[] ){ long i, j; /* loop variables */ double *labels; /* pointer at input calss labels */ double *predLabels; /* predicted labels */ double *fpred; /* values of decision function */ char skernel[10]; /* string identifier of used kernel */ long M = 0; /* number of testing patterns */ long N = 0; /* number of training patterns */ double *target; /* pointer at labels */ double *alpha; /* Lagrange multipliers */ double bias; /* bias */ double afun; /* ---- check number of input arguments ------------- */ if(nrhs < 7) mexErrMsgTxt("Not enough input arguments."); if(nlhs < 1) mexErrMsgTxt("Not enough output arguments."); /* ---- check input arguments ----------------------- */ /* data matrix [dim x M ] */ if( !mxIsNumeric(prhs[0]) || !mxIsDouble(prhs[0]) || mxIsEmpty(prhs[0]) || mxIsComplex(prhs[0]) ) mexErrMsgTxt("Input Xtst must be a real matrix."); /* data matrix [dim x N ] */ if( !mxIsNumeric(prhs[1]) || !mxIsDouble(prhs[1]) || mxIsEmpty(prhs[1]) || mxIsComplex(prhs[1]) ) mexErrMsgTxt("Input Xtrn must be a real matrix."); /* vector of labels (1,2) */ if( !mxIsNumeric(prhs[2]) || !mxIsDouble(prhs[2]) || mxIsEmpty(prhs[2]) || mxIsComplex(prhs[2]) || (mxGetN(prhs[2]) != 1 && mxGetM(prhs[2]) != 1)) mexErrMsgTxt("Input Itrn must be a real vector."); /* Lagrange multipliers */ if( !mxIsNumeric(prhs[3]) || !mxIsDouble(prhs[3]) || mxIsEmpty(prhs[3]) || mxIsComplex(prhs[3]) || (mxGetN(prhs[3]) != 1 && mxGetM(prhs[3]) != 1)) mexErrMsgTxt("Input alpha must be a real vector."); /* bias */ if( !mxIsNumeric(prhs[4]) || !mxIsDouble(prhs[4]) || mxIsEmpty(prhs[4]) || mxIsComplex(prhs[4]) || mxGetN(prhs[4]) != 1 || mxGetM(prhs[4]) != 1 ) mexErrMsgTxt("Input bias must be a real scalar."); /* a string as kernel identifier ('linear',poly','rbf' ) */ if( mxIsChar(prhs[5]) != 1 || mxGetM(prhs[5]) != 1 ) mexErrMsgTxt("Input ker must be a string"); else { /* check which kernel */ mxGetString( prhs[5], skernel, 10 ); if( STR_COMPARE( skernel, "linear", 6) == 0 ) { ker = 0; } else if( STR_COMPARE( skernel, "poly", 4) == 0 ) { ker = 1; } else if( STR_COMPARE( skernel, "rbf", 3) == 0 ) { ker = 2; } else if( STR_COMPARE( skernel, "sigmoid", 7) == 0) { ker = 3; } else mexErrMsgTxt("Unknown kernel identifier."); } /* real input argument for polynomial and rbf kernel */ if( ker > 0) { if( !mxIsNumeric(prhs[6]) || !mxIsDouble(prhs[6]) || mxIsEmpty(prhs[6]) || mxIsComplex(prhs[6]))// mxGetN(prhs[6]) != 1 || mxGetM(prhs[6]) != 1 ) mexErrMsgTxt("Input arg must be a real valued vector."); else { arg1 = mxGetScalar(prhs[6]); /* take kernel argument */ /* if kernel is RBF than compute its argument */ if( ker == 2) { arg1 = -2*arg1*arg1; } /* for sigmoid we need two arguments */ if( ker == 3) { arg2 = mxGetPr(prhs[6])[1];}; } } /* ---- init variables ------------------------------- */ dataA = mxGetPr(prhs[0]); /* pointer at testing patterns */ dataB = mxGetPr(prhs[1]); /* pointer at traning patterns */ dim = mxGetM(prhs[0]); /* data dimension */ M = mxGetN(prhs[0]); /* number of testing data */ N = mxGetN(prhs[1]); /* number of training data */ alpha = mxGetPr(prhs[3]); /* pointer at Lagrangeians */ bias = mxGetScalar(prhs[4]); /* take bias */ if( dim != mxGetM( prhs[1] ) ) { mexErrMsgTxt("Dimension of training and testing patterns differs."); } labels = mxGetPr(prhs[2]); /* labels (1,2) */ /* allocate memory for targets (labels) (1,-1) */ if( (target = mxCalloc(N, sizeof(double) )) == NULL) { mexErrMsgTxt("Not enough memory."); } /* transform labels from (1,2) to (1,-1) */ for( i = 0; i < N; i++ ) { target[i] = - labels[i]*2 + 3; } /* create vector for output labels */ plhs[0] = mxCreateDoubleMatrix(1,M,mxREAL); predLabels = mxGetPr(plhs[0]); /* create vector for vlues of decision function */ if( nlhs >= 2 ) { plhs[1] = mxCreateDoubleMatrix(1,M,mxREAL); fpred = mxGetPr(plhs[1]); } /* Evaluates decision fuction. */ for( i = 0; i < M; i++ ) { afun = 0; /* compute a value of decision function of i-th testing pattern */ for( j = 0; j < N; j++ ) { if( alpha[j] != 0 ) { afun += alpha[j]*target[j]*kernel(i,j); } } afun += bias; /* store computed value */ if( afun >= 0 ) predLabels[i] = 1; else predLabels[i] = 2; if( nlhs >= 2 ) fpred[i] = afun; } /* ----- free memory -------------------------------------- */ mxFree( target );}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -