⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 srng_model.c

📁 最详尽的神经网络源码
💻 C
📖 第 1 页 / 共 3 页
字号:
	mxFree(Np);

	mxFree(distj);

	mxFree(rank_distj);

	mxFree(index_distj);

	mxFree(hproto);

	mxFree(sproto);


	if ((nrhs < 3) || mxIsEmpty(prhs[2]) )
		
	{
		
		mxFree(mtemp);
		
		mxFree(stdtemp);

    	mxFree(Nk);
		
	}
	
	if(mxIsEmpty(prhs[4]))
		
	{
		
		mxFree(lambda);
				
	}
	
}


/*-------------------------------------------------------------------------------------------------------------- */

void srng_model(double *Xtrain , double *ytrain , OPTIONS options , int d , int Ntrain , int Nproto , int m , 
				double *Wproto_est , double *yproto_est , double *E_SRNG , double *lambda , double *lambda_est ,
				double *labels , double *temp_train , int *index_train  , double *distj , int *rank_distj, int *index_distj , double *hproto , double* sproto , int Nproto_max)
				
{
	
	
	int i , j , l , t , indice , ind , lmin , offkmin , offlmin , indtemp  , jd;
	
	int coj , rank;
	
	double yi , temp , sum_srng , dlmin , double_max = 1.79769313486231*10e307 ;

	double lnu , cte , ctek , ctel  , ctelamb , tmptmp , tmptmp2, tmptmp1 , cte_tmp;
	
	double sigma , qj  , dist_proto , sum_lambda , dist_temp , weight;
	
	


	hproto[0]        = 1.0; //exp(-0/sigma);
	
	sproto[0]        = 1.0;


	for(i = 0 ; i < Ntrain ; i++)
	{
								
		index_train[i] = i;
		
	}

	
	
	for (t = 0 ; t <options.nb_iterations ; t++)
		
		
	{
		
		// Shuffle Train data //
		if(options.shuffle)
			
		{
			
			for(i = 0 ; i < Ntrain ; i++)
			{
				
				temp_train[i]  = rand();
				
				index_train[i] = i;
				
			}
			
			qsindex(temp_train , index_train , 0 , Ntrain - 1 ); 
			
		}
		
		
		// Update neighbourg size ///
		
		
		sigma            = (options.sigmaend - options.sigmastart) *(1.0 - 2.0/(1.0 + exp(t*options.sigmastretch))) + options.sigmastart;

			
		
		for (i = 1 ; i < Nproto_max ; i++)
		{
			
			hproto[i]    = exp(-i/sigma);
			
			sproto[i]    = hproto[i] + sproto[i - 1];     
			
		}
		
		
		
		// Compute distance Wproto, Xtrain
		
		E_SRNG[t]        = 0.0;	
		
		for(i = 0 ; i < Ntrain ; i++)
			
		{
			
			indice       = index_train[i];
			
			ind          = indice*d;
			
			yi           = ytrain[indice];
			
			dlmin        = double_max;
			
			lmin         = 0;
			
			coj          = 0;
			
			
			
			for(j = 0 ; j < Nproto ; j++)
				
			{
				
				dist_proto = 0.0;
				
				jd         = j*d;
				
				if (options.metric_method)
				{
					
					for(l = 0 ; l < d ; l++)
					{
						
						temp        = (Xtrain[l + ind] - Wproto_est[l + jd]);

						dist_proto += lambda_est[l]*temp*temp;
				
						
					}
					
				}
				
				else
				{
					for(l = 0 ; l < d ; l++)
					{
						
						temp        = (Xtrain[l + ind] - Wproto_est[l + jd]);

						dist_proto += lambda_est[l]*temp*temp*temp*temp;					
						
						
					}
				}
				
				if(yproto_est[j] == yi)
				{
					
					distj[coj]       = dist_proto;
					
					index_distj[coj] = j;
					
					coj++;
					
				}
				
				else
					
				{
					
					if(dist_proto < dlmin)
						
					{
						
						dlmin        = dist_proto;
						
						lmin         = j;
						
					}
					
				}
				
			}
			
			
			
			for(j = 0 ; j < coj ; j++)
			{
				
				rank_distj[j]   = j;
				
			}
			
			qsindex(distj , rank_distj , 0 , coj - 1 ); 
			
			sum_lambda     = 0.0;
			
			sum_srng       = 0.0;
			
			temp           = 1.0/sproto[coj - 1];
			
			offlmin        = lmin*d;
			
			
			for(j = 0 ; j < coj ; j++)
				
			{
				
				
				rank       = rank_distj[j];
				
				weight     = hproto[j]*temp;
					
				
				if (weight > options.threshold)
				{
					
					dist_temp  = distj[j];
							
					tmptmp     = 1.0/(dist_temp + dlmin);
					
					qj         = (dist_temp - dlmin)*tmptmp;
					
					lnu        = 1.0/(1.0 + exp(-options.xi*qj));
					
					cte        = lnu*weight; 
					
					sum_srng  += cte;


									
					// cte_tmp    = options.xi*cte*(1.0 - lnu)*tmptmp*tmptmp;
					
					cte_tmp    = options.xi*cte*(1.0 - lnu)*tmptmp;				
					
					ctek       = options.epsilonk*cte_tmp*dlmin;
					
					ctel       = options.epsilonl*cte_tmp;
					
					ctelamb    = options.epsilonlambda*cte_tmp;
					
					offkmin    = index_distj[rank]*d;
										
					
					for( l = 0 ; l < d ; l++)
						
					{
						
						indtemp               = l + offkmin;
						
						tmptmp1               = Xtrain[l + ind] - Wproto_est[indtemp];
						
						Wproto_est[indtemp]  += 4.0*ctek*lambda_est[l]*tmptmp1; 
						
						
						indtemp               = l + offlmin;
						
						tmptmp2               = Xtrain[l + ind] - Wproto_est[indtemp];
						
						Wproto_est[indtemp]  -= 4.0*ctel*dist_temp*lambda_est[l]*tmptmp2;
						
						if (options.updatelambda)
						{
														
							lambda_est[l]        -= ctelamb*2.0*(dlmin*tmptmp1*tmptmp1 - dist_temp*tmptmp2*tmptmp2); 
							
						}
						
					}
					
				}
				
			}
			
			
			
			E_SRNG[t]    += sum_srng;
			
			if (options.updatelambda)
			{
				
				
				for (l = 0 ; l < d ; l++)
				{
					
					lambda_est[l]         = MAX(0 , lambda_est[l]);
					
					sum_lambda           += lambda_est[l];			
					
				}
				
				
				sum_lambda         = 1.0/sum_lambda;
				
				for (l = 0 ; l < d ; l++)
				{
					
					lambda_est[l] *= sum_lambda;
					
				}
			}
       }
	     
	   
	}
	
}



/*----------------------------------------------------------------------------------------------------------------------------------------- */


void qs(double  *a , int lo, int hi)
{
//  lo is the lower index, hi is the upper index
//  of the region of array a that is to be sorted
    int i=lo, j=hi;
    double x=a[(lo+hi)/2] , h;

    //  partition
    do
    {    
        while (a[i]<x) i++; 
        while (a[j]>x) j--;
        if (i<=j)
        {
            h        = a[i]; 
			a[i]     = a[j]; 
			a[j]     = h;
            i++; 
			j--;
        }
    }
	while (i<=j);

    //  recursion
    if (lo<j) qs(a , lo , j);
    if (i<hi) qs(a , i , hi);
}

/*----------------------------------------------------------------------------------------------------------------------------------------- */

void qsindex (double  *a, int *index , int lo, int hi)
{
//  lo is the lower index, hi is the upper index
//  of the region of array a that is to be sorted
    int i=lo, j=hi , ind;
    double x=a[(lo+hi)/2] , h;

    //  partition
    do
    {    
        while (a[i]<x) i++; 
        while (a[j]>x) j--;
        if (i<=j)
        {
            h        = a[i]; 
			a[i]     = a[j]; 
			a[j]     = h;
			ind      = index[i];
			index[i] = index[j];
			index[j] = ind;
            i++; 
			j--;
        }
    }
	while (i<=j);

    //  recursion
    if (lo<j) qsindex(a , index , lo , j);
    if (i<hi) qsindex(a , index , i , hi);
}

/* --------------------------------------------------------------------------- */

 
 void randini(void) 
	 
 {
	 
	 /* SHR3 Seed initialization */
	 
	 jsrseed  = (UL) time( NULL );
	 
	 jsr     ^= jsrseed;
	 
	 
	 /* KISS Seed initialization */
	 
#ifdef ranKISS
	  
	 z        = (UL) time( NULL );
	 
	 w        = (UL) time( NULL ); 
	 
	 jcong    = (UL) time( NULL );
	 
	 mix(z , w , jcong);
	 
#endif 
	 
	 
 }
 

/* --------------------------------------------------------------------------- */

void randnini(void) 
{	  
	register const double m1 = 2147483648.0, m2 = 4294967296.0 ;

	register double  invm1;

	register double dn = 3.442619855899 , tn = dn , vn = 9.91256303526217e-3 , q; 
	
	int i;

	
	/* Ziggurat tables for randn */	 
	
	invm1             = 1.0/m1;
	
	q                 = vn/exp(-0.5*dn*dn);  
	
	kn[0]             = (dn/q)*m1;	  
	
	kn[1]             = 0;
		  
	wn[0]             = q*invm1;	  
	
	wn[zigstep - 1 ]  = dn*invm1;
	
	fn[0]             = 1.0;	  
	
	fn[zigstep - 1]   = exp(-0.5*dn*dn);		
	
	for(i = (zigstep - 2) ; i >= 1 ; i--)      
	{   
		dn              = sqrt(-2.*log(vn/dn + exp(-0.5*dn*dn)));          
	
		kn[i+1]         = (dn/tn)*m1;		  
		
		tn              = dn;          
		
		fn[i]           = exp(-0.5*dn*dn);          
		
		wn[i]           = dn*invm1;      
	}

}


/* --------------------------------------------------------------------------- */


double nfix(void) 
{	
	const double r = 3.442620f; 	/* The starting of the right tail */	
	
	static double x, y;
	
	for(;;)

	{
		
		x = hz*wn[iz];
					
		if(iz == 0)
		
		{	/* iz==0, handle the base strip */
			
			do
			{	
				x = -log(rand())*0.2904764;  /* .2904764 is 1/r */  
								
				y = -log(rand());			
			} 
			
			while( (y + y) < (x*x));
			
			return (hz > 0) ? (r + x) : (-r - x);	
		}
		
		if( (fn[iz] + rand()*(fn[iz-1] - fn[iz])) < ( exp(-0.5*x*x) ) ) 
		
		{
		
			return x;

		}

		
		hz = randint;		

		iz = (hz & (zigstep - 1));		
		
		if(abs(hz) < kn[iz]) 

		{
			return (hz*wn[iz]);	

		}


	}

}


/* --------------------------------------------------------------------------- */


double randn(void) 
	
{ 

		hz = randint;
		
		iz = (hz & (zigstep - 1));

		return (abs(hz) < kn[iz]) ? (hz*wn[iz]) : ( nfix() );
	
}



⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -