⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 bsvm2_mex.c

📁 很好的matlab模式识别工具箱
💻 C
📖 第 1 页 / 共 2 页
字号:
  long i,j ;         /* common use loop variables */
  long inx1, inx2;   
  long NA; 
  long tmax;         /* input arg - max number of iteration */ 
  long t;            /* output arg - number of iterations */
  long verb;         /* input argument */
  double thlb;       /* input arg - threshold on lower bound */
  double C;          /* input arg - regularization const */
  double tolrel;     /* input arg */
  double tolabs;     /* input arg */
  double trnerr;     /* output arg */
  double *tmp_ptr;  
  double *tmp_ptr1;
  double *tmp_ptr2; 
  double *vector_c;  /* auxiliary */ 
  double *Alpha;     /* solution vector */ 
  double *History;   /* output arg */
  double *diagK;     /* cache for diagonal of virtual K matrix */

  /*------------------------------------------------------------------- */
  /* Take input arguments                                               */
  /*------------------------------------------------------------------- */

  if( nrhs != 12) mexErrMsgTxt("Incorrect number of input arguments.");

  dataA = mxGetPr(prhs[0]);   /* pointers at data */
  dataB = dataA;
  dim = mxGetM(prhs[0]);      /* data dimension */
  num_data = mxGetN(prhs[0]); /* number of data */
  labels = mxGetPr(prhs[1]);  /* pointer at data labels */

  /* take kernel identifier and its argument */
  ker = kernel_id( prhs[2] ); 
  if( ker == -1 ) mexErrMsgTxt("Improper kernel identifier.");
  arg1 = mxGetPr(prhs[3]);

  C = mxGetScalar(prhs[4]);   /* regularization constant */

  /* take string identifier QP solver to be used */
  if( mxIsChar( prhs[5] ) != 1) mexErrMsgTxt("solver must be string.");
  buf_len = (mxGetM(prhs[5]) * mxGetN(prhs[5])) + 1;
  buf_len = (buf_len > 20) ? 20 : buf_len;
  mxGetString( prhs[5], solver, buf_len );

  /* maximal allowed number of iterations */
  tmax = mxIsInf( mxGetScalar(prhs[6])) ? INT_MAX : (long)mxGetScalar(prhs[6]); 
  tolabs = mxGetScalar(prhs[7]);   /* abs. precision defining stopping cond*/
  tolrel = mxGetScalar(prhs[8]);   /* rel. precision defining stopping cond*/
  /* threshold on lower bound */
  thlb = mxIsInf( mxGetScalar(prhs[9])) ? DBL_MAX : (double)mxGetScalar(prhs[9]); 

  Cache_Size = (long)mxGetScalar(prhs[10]);  /* cache size */
  if( Cache_Size < 1 ) mexErrMsgTxt("Cache must be greater than 1."); 
  if( Cache_Size > num_data ) Cache_Size = num_data; 

  verb = (long)mxGetScalar(prhs[11]);  /* verbosity on/off */

  /*------------------------------------------------------------------- */
  /* Inicialization (caches, etc.)                                      */
  /*------------------------------------------------------------------- */


  /* constant added to diagonal of separable problem */
  if( C!=0 ) kernel_diag = 1/(2*C); else kernel_diag = 0;

  /* num_classes = max( labels ) */
  num_classes = MINUS_INF; 
  for( i = 0; i < num_data; i++ ) { 
     if( labels[i] > num_classes ) num_classes = (long)labels[i]; 
  }

  /* computes number of virtual "single-class" examples */
  num_virt_data = (num_classes-1)*num_data;

  ker_cnt = 0;    /* counter of kernel evaluations */
  access_cnt = 0;  /* counter for access to the kernel matrix */

  /* allocattes and precomputes diagonal of virtual K matrix */
  diagK = mxCalloc(num_virt_data, sizeof(double));
  if( diagK == NULL ) mexErrMsgTxt("Not enough memory.");
  for(i = 0; i < num_virt_data; i++ ) {
    diagK[i] = kernel_fce(i,i);
  }

  /* allocates memory for kernel cache */
  kernel_columns = mxCalloc(Cache_Size, sizeof(double*));
  if( kernel_columns == NULL ) mexErrMsgTxt("Not enough memory.");
  cache_index = mxCalloc(Cache_Size, sizeof(double));
  if( cache_index == NULL ) mexErrMsgTxt("Not enough memory.");

  for(i = 0; i < Cache_Size; i++ ) {
    kernel_columns[i] = mxCalloc(num_data, sizeof(double));
    if(kernel_columns[i] == NULL) mexErrMsgTxt("Not enough memory.");

    cache_index[i] = -2;
  }
  first_kernel_inx = 0;

  /* allocates memory for three virtual kernel matrix columns */
  for(i = 0; i < 3; i++ ) {
    virt_columns[i] = mxCalloc(num_virt_data, sizeof(double));
    if(virt_columns[i] == NULL) mexErrMsgTxt("Not enough memory.");
  }
  first_virt_inx = 0; 

  /* Solution vector */
  Alpha = mxCalloc(num_virt_data, sizeof(double));
  if( Alpha == NULL ) mexErrMsgTxt("Not enough memory.");

  /* Vector c; for this problem set to zero */
  vector_c = mxCalloc(num_virt_data, sizeof(double));
  if( vector_c == NULL ) mexErrMsgTxt("Not enough memory.");
  for(i = 0; i < num_virt_data; i++ ) vector_c[i] = 0;

  /*------------------------------------------------------------------- */
  /* Call QP solver                                                     */
  /*------------------------------------------------------------------- */

  if ( strcmp( solver, "mdm" )==0 ) {  
     exitflag = gmnp_mdm( &get_col, diagK, vector_c, num_virt_data, tmax, 
         tolabs, tolrel, thlb, Alpha, &t, &History, verb );
  } else if ( strcmp( solver, "imdm" )==0 ) {  
     exitflag = gmnp_imdm( &get_col, diagK, vector_c, num_virt_data, tmax, 
         tolabs, tolrel, thlb, Alpha, &t, &History, verb );
  } else if ( strcmp( solver, "iimdm" )==0 ) {  
     exitflag = gmnp_iimdm( &get_col, diagK, vector_c, num_virt_data, tmax, 
         tolabs, tolrel, thlb, Alpha, &t, &History, verb );
  } else if ( strcmp( solver, "keerthi" )==0 ) {  
     exitflag = gmnp_keerthi( &get_col, diagK, vector_c, num_virt_data, tmax, 
         tolabs, tolrel, thlb, Alpha, &t, &History, verb );
  } else if ( strcmp( solver, "kowalczyk" )==0 ) {  
     exitflag = gmnp_kowalczyk( &get_col, diagK, vector_c, num_virt_data, tmax, 
         tolabs, tolrel, thlb, Alpha, &t, &History, verb );
  } else if ( strcmp( solver, "kozinec" )==0 ) {  
     exitflag = gmnp_kozinec( &get_col, diagK, vector_c, num_virt_data, tmax, 
         tolabs, tolrel, thlb, Alpha, &t, &History, verb );
  } else {
     mexErrMsgTxt("Unknown solver identifier.");
  }

  /*------------------------------------------------------------------- */
  /* Generate outputs                                                   */
  /*------------------------------------------------------------------- */

  /* matrix Alpha [num_classes x num_data] */
  plhs[0] = mxCreateDoubleMatrix(num_classes,num_data,mxREAL);
  tmp_ptr1 = mxGetPr(plhs[0]);

  /* bias vector b [num_classes x 1] */
  plhs[1] = mxCreateDoubleMatrix(num_classes,1,mxREAL);
  tmp_ptr2 = mxGetPr(plhs[1]);

  for( i=0; i < num_classes; i++ ) {
    for( j=0; j < num_virt_data; j++ ) {
       get_indices2( &inx1, &inx2, j );

       tmp_ptr1[(inx1*num_classes)+i] += 
            Alpha[j]*(KDELTA(labels[inx1],i+1)+KDELTA(i+1,inx2));
       tmp_ptr2[i] += Alpha[j]*(KDELTA(labels[inx1],i+1)-KDELTA(i+1,inx2));
    }
  }

  /* exit_flag [1x1] */
  plhs[2] = mxCreateDoubleMatrix(1,1,mxREAL);
  *(mxGetPr(plhs[2])) = (double)exitflag;

  /* kercnt [1x1] */
  plhs[3] = mxCreateDoubleMatrix(1,1,mxREAL);
  *(mxGetPr(plhs[3])) = (double)ker_cnt;

  /* access [1x1] */
  plhs[4] = mxCreateDoubleMatrix(1,1,mxREAL);
  *(mxGetPr(plhs[4])) = (double)access_cnt;

  /* trnerr [1x1] */
  err_bit = mxCalloc(num_data, sizeof(int));
  if( err_bit == NULL ) mexErrMsgTxt("Not enough memory.");
  for( i=0; i < num_classes; i++ ) {
    for( j=0; j < num_virt_data; j++ ) {
       get_indices2( &inx1, &inx2, j );
       if( Alpha[j] > 2*C ) err_bit[inx1] = 1; 
    }
  }

  for( trnerr = 0, i = 0; i < num_data; i++ ) trnerr += err_bit[i];

  trnerr = trnerr/num_data;
  plhs[5] = mxCreateDoubleMatrix(1,1,mxREAL);
  *(mxGetPr(plhs[5])) = trnerr;

  /* t [1x1] */
  plhs[6] = mxCreateDoubleMatrix(1,1,mxREAL);
  *(mxGetPr(plhs[6])) = (double)t;

  /* NA [1x1] */
  for( NA = 0, j=0; j < num_virt_data; j++ ) {
     if( Alpha[j] > 0 ) NA++; 
  }

  plhs[7] = mxCreateDoubleMatrix(1,1,mxREAL);
  *(mxGetPr(plhs[7])) = (double)NA;

  /* UB [1x1] */
  plhs[8] = mxCreateDoubleMatrix(1,1,mxREAL);
  *(mxGetPr(plhs[8])) = History[INDEX(1,t,2)];

  /* LB [1x1] */
  plhs[9] = mxCreateDoubleMatrix(1,1,mxREAL);
  *(mxGetPr(plhs[9])) = History[INDEX(0,t,2)];

  /* History [2 x (t+1)] */
  plhs[10] = mxCreateDoubleMatrix(2,t+1,mxREAL);
  tmp_ptr = mxGetPr( plhs[10] );
  for( i = 0; i <= t; i++ ) {
     tmp_ptr[INDEX(0,i,2)] = History[INDEX(0,i,2)];
     tmp_ptr[INDEX(1,i,2)] = History[INDEX(1,i,2)];
  }

  /*------------------------------------------------------------------- */
  /* Free used memory                                                   */
  /*------------------------------------------------------------------- */
  mxFree( vector_c );
  mxFree( Alpha );
  mxFree( History );
  mxFree( diagK );
  for(i = 0; i < Cache_Size; i++ ) mxFree(kernel_columns[i]);
  for(i = 0; i < 3; i++ ) mxFree(virt_columns[i]);
  mxFree( kernel_columns );
  mxFree( cache_index );
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -