⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 nn_predict.c

📁 最详尽的神经网络源码
💻 C
字号:

/*


  Predict class label for a Test data with a trained model.

  Usage
  ------

  [ytest_est , dist] = NN_predict(Xtest , Wproto_est , yproto_est , [lambda_est] , [options]);

  
  Inputs
  -------

  Xtest                                 Test data (d x Ntest)
  Wproto_est                            Estimated prototypes weigths (d x Nproto)
  yproto_est                            Estimated prototypes labels  (1 x Nproto)
  lambda_est                            Estimated Weigths factor  (d x 1). Default lambda_est = ones(d , 1);
  options.metric                        = 1 for euclidian distance (default), = 2 for d4

  
  Outputs
  -------
  
  yproto_est                            Estimated labels  (1 x Ntest)
  dist                                  Distance between Xtest and Prototypes (Nproto x Ntest)


  To compile
  ----------


  mex  -g  -output NN_predict.dll NN_predict.c

  mex -f mexopts_intelamd.bat -output NN_predict.dll NN_predict.c

  

  Example 1
  ---------
  

  close all
  load ionosphere
  Nproto_pclass                      = 4*ones(1 , length(unique(y)));
  
  options.epsilonk                   = 0.005;
  options.epsilonl                   = 0.001;
  options.epsilonlambda              = 10e-8;
  options.xi                         = 10;
  options.nb_iterations              = 5000;
  options.metric_method              = 1;
  options.shuffle                    = 1;
  options.updatelambda               = 1;

  options.method                     = 7;
  options.holding.rho                = 0.7;
  options.holding.K                  = 1;


  X                                  = normalize(X);
  [Itrain , Itest]                   = sampling(X , y , options);
  [Xtrain , ytrain , Xtest , ytest]  = samplingset(X , y , Itrain , Itest);

  

  [Wproto , yproto , lambda]         = ini_proto(Xtrain , ytrain , Nproto_pclass);
  [Wproto_est , yproto_est , lambda_est,  E_GRLVQ]    = grlvq_model(Xtrain , ytrain , Wproto , yproto , lambda, options);
  
  [ytest_est , disttest]             = NN_predict(Xtest , Wproto_est , yproto_est , lambda_est , options);
  [ytrain_est , disttrain]           = NN_predict(Xtrain , Wproto_est , yproto_est , lambda_est , options);

  Perftrain                          = perf_classif(ytrain , ytrain_est); 
  Perftest                           = perf_classif(ytest , ytest_est);;

  dktrain                            = min(disttrain(yproto==0 , :));
  dltrain                            = min(disttrain(yproto~=0 , :));
  nutrain                            = (dktrain - dltrain)./(dktrain + dltrain);
  [tptrain , fptrain]                = basicroc(ytrain , nutrain);

   
  dktest                             = min(disttest(yproto==0 , :));
  dltest                             = min(disttest(yproto~=0 , :));
  nutest                             = (dktest - dltest)./(dktest + dltest);
  [tptest , fptest]                  = basicroc(ytest , nutest);


  disp('Performances Train/Test')
  disp([Perftrain , Perftest])
  
  figure(1)
  plot(E_GRLVQ);
  title('E_{GRLVQ}(t)' , 'fontsize' , 12)
  
  figure(2)
  stem(lambda_est);
  title('\lambda' , 'fontsize' , 12)

  figure(3)
  plot(fptrain , tptrain , fptest , tptest , 'r' , 'linewidth'  , 2)
  xlabel('false positive rate');
  ylabel('true positive rate');
  title('ROC curve','fontsize' , 12);
  legend(['Train'] , ['Test'])



 Author : S閎astien PARIS : sebastien.paris@lsis.org
 -------  Date : 04/09/2006

 Reference "A new Generalized LVQ Algorithm via Harmonic to Minimumm Distance Measure Transition", A.K. Qin, P.N. Suganthan and J.J. Liang,
 ---------  IEEE International Conference on System, Man and Cybernetics, 2004



*/


#include <math.h>
#include <mex.h>



typedef struct OPTIONS 
{
 
  int    metric_method;
  
} OPTIONS; 




/* Function prototypes */



void glvq_predict(double * , double * , double * , double * , int , int  , int  , OPTIONS ,
				  double * , double *);



/*-------------------------------------------------------------------------------------------------------------- */


void mexFunction( int nlhs, mxArray *plhs[] , int nrhs, const mxArray *prhs[] )

{
	

    double *Xtest , *Wproto , *yproto , *lambda;

	OPTIONS options = { 1};

	double *ytest_est , *dist;

	double *tmp;

	int d , Ntest  , Nproto  ;

	int i , j  , l;

	double  currentlabel , ind , temp;

	mxArray *mxtemp;




    /* Input 1  Xtrain */
	
	Xtest         = mxGetPr(prhs[0]);
    		
	if( mxGetNumberOfDimensions(prhs[0]) !=2 )
	{
		
		mexErrMsgTxt("Xtest must be (d x Ntest)");
		
	}
	
	d         = mxGetM(prhs[0]);
	 
	Ntest     = mxGetN(prhs[0]);


	
	
	/* Input 2  Wproto */
	
	Wproto    = mxGetPr(prhs[1]);
    	
	
	
	if(mxGetNumberOfDimensions(prhs[1]) !=2 || mxGetM(prhs[1]) != d)
	{
		
		mexErrMsgTxt("Wproto must be (d x Nproto)");
		
	}


	Nproto     = mxGetN(prhs[1]);


	/* Input 3   yproto */
	

		
	yproto         = mxGetPr(prhs[2]);
	
	if(mxGetNumberOfDimensions(prhs[2]) !=2 || mxGetN(prhs[2]) != Nproto)
	{
		
		mexErrMsgTxt("yproto must be (1 x Nproto)");
		
	}
	

	/* Input 4   lambda */

	if (nrhs >= 4 && !mxIsEmpty(prhs[3]))
	{


		lambda    = mxGetPr(prhs[3]);


	}

	else

	{


		lambda = (double *)mxMalloc(d*sizeof(double));

		for (i = 0 ; i < d ; i++)
		{

			lambda[i] = 1.0;

		}


	}

	if ( (nrhs >= 5) && !mxIsEmpty(prhs[4]) )
		
	{
		
		
		mxtemp                                   = mxGetField(prhs[4] , 0 , "metric_method");
		
		if(mxtemp != NULL)
		{
			
			tmp                                  = mxGetPr(mxtemp);
			
			options.metric_method                = (int) tmp[0];
			
		}
		
	}




	plhs[0]               = mxCreateDoubleMatrix(1 , Ntest, mxREAL);
	
	ytest_est             = mxGetPr(plhs[0]);
	
	plhs[1]               = mxCreateDoubleMatrix(Nproto , Ntest, mxREAL);
	
	dist                  = mxGetPr(plhs[1]);
	
	

	/* Main Call */


	glvq_predict(Xtest , Wproto , yproto , lambda , d , Ntest , Nproto , options , 
		         ytest_est , dist);

		
 
	if(nrhs < 4 || mxIsEmpty(prhs[3]))
	{

		mxFree(lambda);

	}
	
}


/*-------------------------------------------------------------------------------------------------------------- */

void glvq_predict(double *Xtest , double *Wproto , double *yproto , double *lambda , int d , int Ntest , int Nproto , OPTIONS options , 
				  double *ytest_est , double *dist)
				   
{
	
	
	int i , j , l , ld , id , ind, lNproto;
	
	double  disttmp , temp , dist_min , double_max = 1.79769313486231*10e307;
	
	
	if (options.metric_method)
	{
		for(l = 0 ; l < Ntest ; l++)
			
		{
			
			ld       = l*d;

			lNproto  = l*Nproto;
			
			dist_min = double_max; 
			
			ind      = 0;
			
			for (i = 0 ; i < Nproto ; i++)
			{
				
				
				id      = i*d;
				
				disttmp = 0.0;
				
				for( j = 0 ; j < d ; j++)
					
				{
					
					temp     = (Xtest[j + ld] - Wproto[j + id]);
					
					disttmp += lambda[j]*temp*temp;
					
				}

				dist[i + lNproto] = disttmp;
				
				if(disttmp < dist_min)
					
				{
					
					dist_min = disttmp;
					
					ind      = i;
					
				}
				
			}
			
			
			ytest_est[l] = yproto[ind];
			
		}
		
	}
	else
	{
		
		for(l = 0 ; l < Ntest ; l++)
			
		{
			
			ld       = l*d;

			lNproto  = l*Nproto;
			
			dist_min = double_max; 
			
			ind      = 0;
			
			for (i = 0 ; i < Nproto ; i++)
			{
				
				
				id      = i*d;
				
				disttmp = 0.0;
				
				for( j = 0 ; j < d ; j++)
					
				{
					
					temp     = (Xtest[j + ld] - Wproto[j + id]);
					
					disttmp += lambda[j]*temp*temp*temp*temp;
					
				}

				dist[i + lNproto] = disttmp;

				
				if(disttmp < dist_min)
					
				{
					
					dist_min = disttmp;
					
					ind      = i;
					
				}
				
			}
			
			
			ytest_est[l] = yproto[ind];
			
		}
			
	}
	
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -