⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 grlvq_model.c

📁 最详尽的神经网络源码
💻 C
📖 第 1 页 / 共 3 页
字号:
	{
		
		mxFree(mtemp);
		
		mxFree(stdtemp);
		
		mxFree(Nk);
		
	}
	
	if(mxIsEmpty(prhs[4]))
		
	{
		
		mxFree(lambda);
				
	}
	
}


/*-------------------------------------------------------------------------------------------------------------- */

void grlvq_model(double *Xtrain , double *ytrain , OPTIONS options , int d , int Ntrain , int Nproto , int m , 
				 double *Wproto_est , double *yproto_est , double *E_GRLVQ , double *lambda , double *lambda_est ,
				 double *labels , double *temp_train , int *index_train)
				 
{
	
	
	int i , j , l , t , indice , ind , kmin , lmin , offkmin , offlmin , indtemp , indtemp1 , jd;
	
	double yi , temp , dkmin , dlmin , double_max = 1.79769313486231*10e307 , nu , lnu , ctek , ctel  , ctelamb ,tmptmp, tmptmp2, tmptmp1,
		tmptmp3 , tmptmp4 , cte_tmp , cte_lambda , dist_proto;
	

	
	
	for(i = 0 ; i < Ntrain ; i++)
	{
								
		index_train[i] = i;
		
	}

	
	for (t = 0 ; t <options.nb_iterations ; t++)
		
		
	{
		
		if(options.shuffle)
			
		{
			
			for(i = 0 ; i < Ntrain ; i++)
			{
				
				temp_train[i]  = rand();
				
				index_train[i] = i;
				
			}
			
			qsindex(temp_train , index_train , 0 , Ntrain - 1 ); 
			
		}
		
		E_GRLVQ[t]       = 0.0;
		
		for(i = 0 ; i < Ntrain ; i++)
			
		{
			
			indice       = index_train[i];
			
			ind          = indice*d;
			
			yi           = ytrain[indice];
			
			dkmin        = double_max;
			
			dlmin        = double_max;
			
			kmin         = 0;
			
			lmin         = 0;
			
			for(j = 0 ; j < Nproto ; j++)
				
			{
				
				dist_proto  = 0.0;
				
				jd          = j*d;
				
				if (options.metric_method)
				{
					
					for(l = 0 ; l < d ; l++)
					{
						
						temp        = (Xtrain[l + ind] - Wproto_est[l + jd]);

						dist_proto += lambda_est[l]*temp*temp;
						
						
					}
					
				}
				
				else
				{
					for(l = 0 ; l < d ; l++)
					{
						
						temp        = (Xtrain[l + ind] - Wproto_est[l + jd]);

						dist_proto += lambda_est[l]*temp*temp*temp*temp;

					
					}
				}
								
				if(yproto_est[j] == yi)
				{
					
					if(dist_proto < dkmin)
						
					{
						
						dkmin = dist_proto;
						
						kmin  = j;
						
					}
					
				}
				else
				{
					
					if(dist_proto < dlmin)
						
					{
						
						dlmin = dist_proto;
						
						lmin  = j;
						
					}
					
				}
				
			}
			
			tmptmp     = 1.0/(dkmin + dlmin);
			
			nu         = (dkmin - dlmin)*tmptmp;
			
			lnu        = 1.0/(1.0 + exp(-options.xi*nu)); //*(t+1);

			
			E_GRLVQ[t]+= lnu;
			
//			cte_tmp    = options.xi*lnu*(1.0 - lnu)*tmptmp*tmptmp;
			
			cte_tmp    = options.xi*lnu*(1.0 - lnu)*tmptmp;
            
			
			ctek       = cte_tmp*options.epsilonk*dlmin;
			
			ctel       = cte_tmp*options.epsilonl*dkmin;
			
			ctelamb    = options.epsilonlambda*cte_tmp;
			

			offkmin    = kmin*d;
			
			offlmin    = lmin*d;
			
			cte_lambda = 0.0;
			
			if (options.metric_method==2)
			{
				for( l = 0 ; l < d ; l++)
					
				{
					
					
					indtemp               = l + offkmin;
					
					tmptmp1               = Xtrain[l + ind] - Wproto_est[indtemp];					
					
					Wproto_est[indtemp] += ctek*8.0*lambda_est[l]*tmptmp1*tmptmp1*tmptmp1; 
					
					
					indtemp1              = l + offlmin;
					
					tmptmp2               = Xtrain[l + ind] - Wproto_est[indtemp1];
					
					Wproto_est[indtemp1] -= ctel*8.0*lambda_est[l]*tmptmp2*tmptmp2*tmptmp2; 
					
					
					if (options.updatelambda)
					{
						
						tmptmp3               = tmptmp1*tmptmp1;
						
						tmptmp4               = tmptmp2*tmptmp2;
						
						lambda_est[l]        -= ctelamb*4.0*(dlmin*tmptmp3*tmptmp3 - dkmin *tmptmp4*tmptmp4);
						
						lambda_est[l]         = MAX(0 , lambda_est[l]);
						
						cte_lambda           += lambda_est[l];
						
					}
					
					
				}
				
			}
			
			else
			{
				
				for( l = 0 ; l < d ; l++)
					
				{
					
					
					indtemp               = l + offkmin;
					
					tmptmp1               = Xtrain[l + ind] - Wproto_est[indtemp];
					
					Wproto_est[indtemp]  += 4.0*ctek*lambda_est[l]*tmptmp1; 
					
					
					indtemp1              = l + offlmin;
					
					tmptmp2               = Xtrain[l + ind] - Wproto_est[indtemp1];
					
					Wproto_est[indtemp1] -= 4.0*ctel*lambda_est[l]*tmptmp2; 
					
					
					if (options.updatelambda)
					{
						
												
						lambda_est[l]        -= ctelamb*2.0*(dlmin*tmptmp1*tmptmp1 - dkmin*tmptmp2*tmptmp2); 
						
						lambda_est[l]         = MAX(0 , lambda_est[l]);
						
						cte_lambda           += lambda_est[l];
						
					}
					
				}
				
			}
			
			if (options.updatelambda)
			{
				
				cte_lambda        = 1.0/cte_lambda;
				
				for (l = 0 ; l < d ; l++)
				{
					
					lambda_est[l] *= cte_lambda;
					
				}
				
			}
					
     }
	 
	}
	
}



/*----------------------------------------------------------------------------------------------------------------------------------------- */


void qs(double  *a , int lo, int hi)
{
//  lo is the lower index, hi is the upper index
//  of the region of array a that is to be sorted
    int i=lo, j=hi;
    double x=a[(lo+hi)/2] , h;

    //  partition
    do
    {    
        while (a[i]<x) i++; 
        while (a[j]>x) j--;
        if (i<=j)
        {
            h        = a[i]; 
			a[i]     = a[j]; 
			a[j]     = h;
            i++; 
			j--;
        }
    }
	while (i<=j);

    //  recursion
    if (lo<j) qs(a , lo , j);
    if (i<hi) qs(a , i , hi);
}

/*----------------------------------------------------------------------------------------------------------------------------------------- */

void qsindex (double  *a, int *index , int lo, int hi)
{
//  lo is the lower index, hi is the upper index
//  of the region of array a that is to be sorted
    int i=lo, j=hi , ind;
    double x=a[(lo+hi)/2] , h;

    //  partition
    do
    {    
        while (a[i]<x) i++; 
        while (a[j]>x) j--;
        if (i<=j)
        {
            h        = a[i]; 
			a[i]     = a[j]; 
			a[j]     = h;
			ind      = index[i];
			index[i] = index[j];
			index[j] = ind;
            i++; 
			j--;
        }
    }
	while (i<=j);

    //  recursion
    if (lo<j) qsindex(a , index , lo , j);
    if (i<hi) qsindex(a , index , i , hi);
}

/* --------------------------------------------------------------------------- */

 
 void randini(void) 
	 
 {
	 
	 /* SHR3 Seed initialization */
	 
	 jsrseed  = (UL) time( NULL );
	 
	 jsr     ^= jsrseed;
	 
	 
	 /* KISS Seed initialization */
	 
#ifdef ranKISS
	  
	 z        = (UL) time( NULL );
	 
	 w        = (UL) time( NULL ); 
	 
	 jcong    = (UL) time( NULL );
	 
	 mix(z , w , jcong);
	 
#endif 
	 
	 
 }
 

/* --------------------------------------------------------------------------- */

void randnini(void) 
{	  
	register const double m1 = 2147483648.0 ;

	register double  invm1;

	register double dn = 3.442619855899 , tn = dn , vn = 9.91256303526217e-3 , q; 
	
	int i;

	
	/* Ziggurat tables for randn */	 
	
	invm1             = 1.0/m1;
	
	q                 = vn/exp(-0.5*dn*dn);  
	
	kn[0]             = (dn/q)*m1;	  
	
	kn[1]             = 0;
		  
	wn[0]             = q*invm1;	  
	
	wn[zigstep - 1 ]  = dn*invm1;
	
	fn[0]             = 1.0;	  
	
	fn[zigstep - 1]   = exp(-0.5*dn*dn);		
	
	for(i = (zigstep - 2) ; i >= 1 ; i--)      
	{   
		dn              = sqrt(-2.*log(vn/dn + exp(-0.5*dn*dn)));          
	
		kn[i+1]         = (dn/tn)*m1;		  
		
		tn              = dn;          
		
		fn[i]           = exp(-0.5*dn*dn);          
		
		wn[i]           = dn*invm1;      
	}

}


/* --------------------------------------------------------------------------- */


double nfix(void) 
{	
	const double r = 3.442620; 	/* The starting of the right tail */	
	
	static double x, y;
	
	for(;;)

	{
		
		x = hz*wn[iz];
					
		if(iz == 0)
		
		{	/* iz==0, handle the base strip */
			
			do
			{	
				x = -log(rand())*0.2904764;  /* .2904764 is 1/r */  
								
				y = -log(rand());			
			} 
			
			while( (y + y) < (x*x));
			
			return (hz > 0) ? (r + x) : (-r - x);	
		}
		
		if( (fn[iz] + rand()*(fn[iz-1] - fn[iz])) < ( exp(-0.5*x*x) ) ) 
		
		{
		
			return x;

		}

		
		hz = randint;		

		iz = (hz & (zigstep - 1));		
		
		if(abs(hz) < kn[iz]) 

		{
			return (hz*wn[iz]);	

		}


	}

}


/* --------------------------------------------------------------------------- */


double randn(void) 
	
{ 

		hz = randint;
		
		iz = (hz & (zigstep - 1));

		return (abs(hz) < kn[iz]) ? (hz*wn[iz]) : ( nfix() );
	
}



⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -