⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 knnclass.c

📁 该工具箱是用于统计模式识别的
💻 C
字号:
/*---------------------------------------------------------------------------[tst_labels] = knnclass(tst_data,trn_data,trn_labels, k)KNNCLASS k-Nearest Neighbours classifier.Input: tst_data [dim x n_tst] data to be classified. trn_data [dim x n_trn] training data. trn_labels [ 1 x n_trn] labels of training data. k [int] number of neighbours.Output: tst_labels [1 x n_tst] estimated labels of testing data. Statistical Pattern Recognition Toolbox, Vojtech Franc, Vaclav Hlavac (c) Czech Technical University Prague, http://cmp.felk.cvut.cz Written Vojtech Franc (diploma thesis) 02.11.1999, 13.4.2000  Modifications 18-sep-2002, V. Franc-------------------------------------------------------------------- */ #include "mex.h"#include "matrix.h"#include <math.h>#include <stdlib.h>#include <string.h>#include <limits.h>#define MAX_INF INT_MAX#define MAX(A,B)   (((A) > (B)) ? (A) : (B) )#define MIN(A,B)   (((A) < (B)) ? (A) : (B) )  /* ============================================================== Main MEX function - interface to Matlab.============================================================== */void mexFunction( int nlhs, mxArray *plhs[],		  int nrhs, const mxArray*prhs[] ){  long n_tst, n_trn, k, i, j, l, inx, dim;  double *tst_data, *trn_data;  double *tst_labels, *trn_labels;  double *dist, adist, max_dist;  double a, b;  long max_inx, best_count, best_label, count, *max_labels;  /* -- gets input arguments --------------------------- */  tst_data = mxGetPr( prhs[0]);  trn_data = mxGetPr(prhs[1]);  dim = mxGetM(prhs[0]);            /* data dimension */  n_tst = mxGetN(prhs[0]);              /* number of data */  n_trn = mxGetN(prhs[1]);              /* number of data */  trn_labels = mxGetPr(prhs[2]);  if( dim != mxGetM( prhs[1] )) {     mexErrMsgTxt("Dimension of training and testing data differs.");  }  if( nrhs != 4 ) {     mexErrMsgTxt("Incorrect number of input arguments.");  }    k = mxGetScalar(prhs[3]);      // output labels  plhs[0] = mxCreateDoubleMatrix(1,n_tst,mxREAL);  tst_labels = mxGetPr(plhs[0] );  /*--------------------------*/  if( (dist = mxCalloc(n_trn, sizeof(double))) == NULL) {      mexErrMsgTxt("Not enough memory for error cache.");   }  if( (max_labels = (long*) mxCalloc(k, sizeof(long))) == NULL) {      mexErrMsgTxt("Not enough memory for error cache.");   }  for( i=0; i < n_tst; i++ ) {    for( j=0; j < n_trn; j++ ) {      adist = 0;      for( l=0; l < dim; l++ ) {        a = *(tst_data+(i*dim)+l);        b = *(trn_data+(j*dim)+l);        adist += a*a - 2*a*b + b*b;       }      dist[j] = sqrt(adist);    }    for( l=0; l < k; l++) {      max_dist=MAX_INF;      for( j=0; j < n_trn; j++ ) {        if( max_dist > dist[j] ) {          max_inx = j;          max_dist = dist[j];        }      }      dist[ max_inx ] = -1;      max_labels[l] = trn_labels[max_inx];    }    best_count=0;    for( l=0; l < k; l++) {      count = 0;      for( j=0; j < k; j++) {        if( max_labels[l] == max_labels[j] ) count++;      }      if( count > best_count ) {        best_count = count;        best_label = max_labels[l];       }    }        tst_labels[i] = best_label;  }    mxFree( dist );  // free mem  mxFree( max_labels );  // free mem}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -