⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 perf_classif.c

📁 最详尽的神经网络源码
💻 C
字号:
/*

 perf_classif : Returns the Classification Rate  and the confusion matrix.
 

 Usage
 -------

 [R , mat_conf ] = perf_classif(ytest , ytest_est , [m]);

 Inputs
 -------

 ytest         True labels   (1 x Ntest)

 ytest_est     Estimated labels   (1 x Ntest)


 Ouputs
 -------

 R             Classification rate

 mat_conf      Confusion Matrix (m x m), where m is the number of label, i.e. m = length(unique(ytest));

	
 To compile
 ----------

mex -output perf_classif.dll perf_classif.c 


mex -f mexopts_intelamd.bat -output perf_classif.dll perf_classif.c


Example 1
---------

Ntest                = 100;
ytest                = double(rand(1 , Ntest) > 0.5);
ytest_est            = double(rand(1 , Ntest) > 0.5);


[R , mat_conf]       = perf_classif(ytest , ytest_est);
 


  
Author : S閎astien PARIS : sebastien.paris@lsis.org
-------

*/#include <limits.h>
#include <math.h>
#include "mex.h"


#define MAX_INF INT_MAX

/*--------------------------------------------------------------------------------*/
/*--------------------------------------------------------------------------------*/
/*--------------------------------------------------------------------------------*/


void qs( double * , int , int  ); 


/*--------------------------------------------------------------------------------*/
/*--------------------------------------------------------------------------------*/
/*--------------------------------------------------------------------------------*/



void mexFunction( int nlhs, mxArray *plhs[] , int nrhs, const mxArray *prhs[] )
{	
	
	double *ytest  , *ytest_est;
	
	double *R , *mat_conf ;		
	int  i , j,  m=0 , im , Ntest;

	int indytest, indytest_est, tempytest, tempytest_est;


	double *ytestsorted, *labels;

	double currentlabel , sum , temp;

	
	/*--------------------------------------------------------------------------------*/
	/*--------------------------------------------------------------------------------*/
	/* -------------------------- Parse INPUT  -------------------------------------- */
	/*--------------------------------------------------------------------------------*/	
	/*--------------------------------------------------------------------------------*/			
	/* ----- Input 1 ----- */		
	ytest            = mxGetPr(prhs[0]);

	if(mxGetM(prhs[0]) != 1)
	{

		 mexErrMsgTxt("ytest must be (1 x Ntest)");	

	}

	
	Ntest            = mxGetN(prhs[0]);
	

	/* ----- Input 2 ----- */

	
	ytest_est       = mxGetPr(prhs[1]);


	if(mxGetN(prhs[1]) != Ntest)
	{

		 mexErrMsgTxt("ytest_est must be (1 x Ntest)");	

	}


	
    /* Determine number of class and vector of labels*/

	if (nrhs < 3)
		
	{
		
		
		ytestsorted    = mxMalloc(Ntest*sizeof(double));
		
		
		for ( i = 0 ; i < Ntest; i++ ) 
		{
			
			ytestsorted[i] = ytest[i];
			
		}
		
		
		qs( ytestsorted , 0 , Ntest - 1 );
		
		
		labels       = mxMalloc(sizeof(double)); 
		
		labels[m]    = ytestsorted[0];
		
		currentlabel = labels[0];
		
		for (i = 0 ; i < Ntest ; i++) 
		{ 
			if (currentlabel != ytestsorted[i]) 
			{ 
				labels       = (double *)mxRealloc(labels , (m+2)*sizeof(double)); 
				
				labels[++m]  = ytestsorted[i]; 
				
				currentlabel = ytestsorted[i];
				
			} 
		} 
		
		m++; 
		
	}

	else

	{


		m        = (int) mxGetScalar(prhs[2]);


	}



	
	/*--------------------------------------------------------------------------------*/
	/*---------------------------------------,----------------------------------------*/
	/* -------------------------- Parse OUTPUT  ------------------------------------- */
	/*--------------------------------------------------------------------------------*/
	/*--------------------------------------------------------------------------------*/	
	/* ----- output 1 ----- */
	
	
	plhs[0]        = mxCreateDoubleMatrix(1 , 1 , mxREAL);

	R              = mxGetPr(plhs[0]);



	/* ----- output 3 ----- */
	
	
	plhs[1]        = mxCreateDoubleMatrix(m , m , mxREAL);

	mat_conf       = mxGetPr(plhs[1]);
	
	
	/*---------------------------------------------------------------------------------*/
	/*---------------------------------------------------------------------------------*/
	/* ----------------------- MAIN CALL  -------------------------------------------- */
	/*---------------------------------------------------------------------------------*/
	/*---------------------------------------------------------------------------------*/	
	/*---------------------------------------------------------------------------------*/	

	
	for(i = 0 ; i < Ntest ; i++)
		
	{
		indytest     = (int)ytest[i];
		
		indytest_est = (int)ytest_est[i];
		
		for (j = 0 ; j < m ; j++)
		{
			
			if (indytest == labels[j])
				
			{
				tempytest = j;
				
			}
			if (indytest_est == labels[j])
				
			{
				tempytest_est = j;
				
			}
			
			
		}
		
		mat_conf[tempytest_est + tempytest*m]++; 	
	}
	

	R[0] = 0.0;


	for (i = 0 ; i < m ; i++)

	{

		im    = i*m;

		temp  = mat_conf[i + im];

		R[0] += temp;

/*
		sum   = 0.0;

		for(j = 0 ; j < m ; j++)

		{

			sum += mat_conf[j + im];

		}


		if(sum != 0.0)
		{
			
			sum = 1.0/sum;
			
		}
		
		else
		{
			
			
			sum  = 1.0;
		}

		for(j = 0 ; j < m ; j++)

		{

			mat_conf[j + im] *=sum ;

		}
*/

	}

    R[0]   /= Ntest;


	if (nrhs < 3)
		
	{
		
		mxFree(labels);
		
		mxFree(ytestsorted);
		
	}	
	
	/*-----------------------------------------------*/
	/*-----------------------------------------------*/
	/* ------------ END of Mex File ---------------- */
	/*-----------------------------------------------*/
	/*-----------------------------------------------*/
		
}
/*----------------------------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------------------------*/
/*-------------------------------------------------------------------------------------------------*/


void qs( double *array, int left, int right ) 
{
	
	double pivot;	// pivot element.
	
	int holex;	// hole index.
	
	int i;
	
	holex          = left + ( right - left ) / 2;

	pivot          = array[ holex ];		     // get pivot from middle of array.
	
	array[holex]   = array[ left ];              // move "hole" to beginning of
	
	holex          = left;			             // range we are sorting.
	
	for ( i = left + 1 ; i <= right ; i++ ) 
	{
		if ( array[ i ] <= pivot ) 
		{
			array[ holex ] = array[ i ];

			array[ i ]     = array[ ++holex ];
		}
	}
	
	if ( holex - left > 1 ) 
	{
	
		qs( array, left, holex - 1 );
	
	}
	if ( right - holex > 1 ) 
	{
	
		qs( array, holex + 1, right );
	
	}
	
	array[ holex ] = pivot;
	
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -