⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 h2m_glvq_model.c

📁 最详尽的神经网络源码
💻 C
📖 第 1 页 / 共 2 页
字号:
		for (l = 0 ; l < m ; l++)
			
		{
			
			ind = labels[l];
			
			ld  = l*d;
			
			for (i = 0 ; i < Ntrain ; i++)
				
			{
				
				if(ytrain[i] == ind)
					
				{
					id  = i*d;
					
					for(j = 0 ; j < d ; j++)
						
					{
						temp             = (Xtrain[j + id] - mtemp[j + ld]);

						stdtemp[j + ld] += (temp*temp);
						
					}
					
					
				}
			}
				
		}



		for (l = 0 ; l < m ; l++)
			
		{
			
			ld   = l*d;
			
			temp = 1.0/(Nk[l] - 1);
					
			
			for(j = 0 ; j < d ; j++)
				
			{
				
				stdtemp[j + ld] = temp*sqrt(stdtemp[j + ld]);
							
			}
			
		}

		
		plhs[0]               = mxCreateDoubleMatrix(d , Nproto, mxREAL);
		
		Wproto_est            = mxGetPr(plhs[0]);


		for(l = 0 ; l < Nproto ; l++)
			
		{
			ld = l*d;		
			
			for(i = 0 ; i < m ; i++)
			{
				
				if(labels[i] == yproto_est[l] )
					
				{
					
					indice = i*m;
					
				}
								
			}

			
			for(i = 0 ; i < d ; i++)
				
			{
				
				Wproto_est[i + ld]  = mtemp[i + indice] + stdtemp[i + indice]*randn();
				
			}
			
		}
		
	}
	
	else
	{
		
		
		Wproto                = mxGetPr(prhs[2]);
		
		if(mxGetNumberOfDimensions(prhs[2]) !=2 || mxGetM(prhs[2]) != d)
		{
			
			mexErrMsgTxt("Wproto must be (d x Nproto)");
			
		}

		Nproto                = mxGetN(prhs[2]);

		plhs[0]               = mxCreateDoubleMatrix(d , Nproto, mxREAL);
		
		Wproto_est            = mxGetPr(plhs[0]);
	
		for( i = 0 ; i < d*Nproto ; i++)
			
		{
			
			Wproto_est[i] = Wproto[i];
			
		}

	}
	


// Outputs 


			  	
	plhs[2]               = mxCreateDoubleMatrix(1 , options.nb_iterations , mxREAL);
						
	E_H2MGLVQ             = mxGetPr(plhs[2]);
	

	
    temp_train            = mxMalloc(Ntrain*sizeof(double));

    index_train           = mxMalloc(Ntrain*sizeof(int));

    dist                  = mxMalloc(Nproto*sizeof(double));

    powdist               = mxMalloc(Nproto*sizeof(double));

	/* Main Call */


	h2m_glvq_model(Xtrain , ytrain , options , d , Ntrain , Nproto , m , 
		           Wproto_est , yproto_est , E_H2MGLVQ,
			       labels , temp_train , index_train , dist , powdist ,  
	     		   Nk);

		

	mxFree(labels);
	
	mxFree(ytrainsorted);

	mxFree(temp_train);

	mxFree(index_train);

	mxFree(dist);

	mxFree(powdist);

	mxFree(Nk);
			
	
	if ((nrhs < 3) || mxIsEmpty(prhs[2]) )
		
	{
		
		mxFree(mtemp);
		
		mxFree(stdtemp);
				
	}
	
	
}


/*-------------------------------------------------------------------------------------------------------------- */

void h2m_glvq_model(double *Xtrain , double *ytrain , OPTIONS options , int d , int Ntrain , int Nproto , int m , 
					double *Wproto_est , double *yproto_est , double *E_H2MGLVQ,
					double *labels , double *temp_train , int *index_train , double *dist ,  double *powdist , 
					int *Nk)
					
{
	
	
	int i , j , l , t , ind_label , indice , ind , kmin , lmin , indtemp , jd , Nkk , Nll;
	
	double yi , temp , sum , dkmin , dlmin , double_max = 1.79769313486231*10e307 , nu , lnu , cte , tmp ;
	
	double stept = (options.tmax - options.tmin)/(options.nb_iterations - 1);
	
	double epsilonl , epsilonk = options.epsilonk , cte_epsilonl = (options.epsilonk/options.epsilonl);
	
	double sumdkt , sumdlt , tcurrent , dkH , dlH , powdkmin , powdlmin , ak , al , alpha , tmpkmin , tmplmin;
	
	double ctetmp1 , ctetmp2 , ctetmp3 , ctetmp4 , ctetmp5 , ctetmp6;
	
	
	
	for (l = 0 ; l < m ; l++)
		
	{
		
		Nk[l] = 0;
		
		ind   = (int)labels[l];
		
		
		for (i = 0 ; i < Nproto ; i++)
			
		{
			
			if(yproto_est[i] == ind)
				
			{
				
				Nk[l]++;
				
			}
		}
		
	}
	
	
	for(i = 0 ; i < Ntrain ; i++)
	{
								
		index_train[i] = i;
		
	}
	
	
	tcurrent = options.tmax;
	
	
	for (t = 0 ; t <options.nb_iterations ; t++)
		
		
	{
		
		epsilonl = options.epsilonl*pow(cte_epsilonl ,(t/(options.nb_iterations - 1)));
		
		if(options.shuffle)
			
		{
			
			for(i = 0 ; i < Ntrain ; i++)
			{
				
				temp_train[i]  = rand();
				
				index_train[i] = i;
				
			}
			
			qsindex(temp_train , index_train , 0 , Ntrain - 1 ); 
			
		}
		
		
		E_H2MGLVQ[t]     = 0.0;
		
		for(i = 0 ; i < Ntrain ; i++)
			
		{
			
			indice       = index_train[i];
			
			ind          = indice*d;
			
			yi           = ytrain[indice];
			
			
			for (j = 0 ; j <  m ; j++)
			{
				if(labels[j] == yi)
				{
					
					ind_label = j;
					
				}
				
			}
			
			dkmin        = double_max;
			
			dlmin        = double_max;
			
			kmin         = 0;
			
			lmin         = 0;
			
			sumdkt       = 0.0;
			
            sumdlt       = 0.0;
			
			for(j = 0 ; j < Nproto ; j++)
				
			{
				
				sum = 0.0;
				
				jd  = j*d;
				
				for(l = 0 ; l < d ; l++)
				{
					
					temp = (Xtrain[l + ind] - Wproto_est[l + jd]);
					
					sum += (temp*temp);
					
					
				}
				
				dist[j]    = sum;
				
                powdist[j] = pow(1.0/sum , tcurrent); 
				
				if(yproto_est[j] == yi)
				{
					
					if(sum < dkmin)
						
					{
						
						dkmin = sum;
						
						kmin  = j;
						
					}
					
					sumdkt    += powdist[j];
					
				}
				else
				{
					
					if(sum < dlmin)
						
					{
						
						dlmin = sum;
						
						lmin  = j;
						
					}
					
					sumdlt    += powdist[j];
					
				}
				
			}
			
			Nkk           = Nk[ind_label];
			
			Nll           = Nproto - Nkk;
			
			
            tmpkmin       = sumdkt/powdist[kmin];
			
			dkH           = (Nkk*dkmin)/(tmpkmin);
			
			powdkmin      = tmpkmin/dkmin;
			
			
			
            tmplmin       = sumdlt/powdist[lmin];
			
			dlH           = (Nll*dlmin)/(tmplmin);
			
			powdlmin      = tmplmin/dlmin;
			
			
			tmp           = 1.0/(dkH + dlH);
			
			nu            = (dkH  - dlH)*tmp;
			
			lnu           = 1.0/(1.0 + exp(-options.xi*nu));
			
			E_H2MGLVQ[t] += lnu;
			
			
            ak            = Nkk/(powdkmin*powdkmin);
			
            al            = Nll/(powdlmin*powdlmin);
			
			cte           = 4.0*options.xi*lnu*(1.0 - lnu)*(tmp*tmp);
			
			ctetmp1       = epsilonk*dlH*cte*ak;
			
			ctetmp2       = epsilonl*dkH*cte*al;
			
			ctetmp3       = -ctetmp1*((tcurrent - 1.0)*(powdkmin/dkmin) - (tcurrent/(dkmin*dkmin)));
			
			ctetmp4       = ctetmp2*((tcurrent - 1.0)*(powdlmin/dlmin) - (tcurrent/(dlmin*dlmin)));
			
			ctetmp5       = ctetmp1*(powdkmin/sumdkt)*tcurrent;
			
			ctetmp6       = -ctetmp2*(powdlmin/sumdlt)*tcurrent;
			
			
			
			for(j = 0 ; j < Nproto ; j++)
				
			{
				
				jd      = j*d;
				
				if(yproto_est[j] == yi)
				{
					
					if(kmin == j)
					{
						
						alpha = ctetmp3;
						
					}
					
					else
					{
						
						alpha = ctetmp5*(powdist[j]/dist[j]);
						
					}
					
				}
				else
					
				{
					
					if(lmin == j)
					{
						
						alpha = ctetmp4;
						
					}
					else
					{
						
						alpha = ctetmp6*(powdist[j]/dist[j]);
						
					}
					
				}
				
				for( l = 0 ; l < d ; l++)
					
				{
					
					indtemp              = l + jd;
					
					Wproto_est[indtemp] += alpha*(Xtrain[l + ind] - Wproto_est[indtemp]); 
					
				}
				
			}
			
			
		}
		
		tcurrent -= stept;  	
	}
	
}


/*----------------------------------------------------------------------------------------------------------------------------------------- */


void qs(double  *a , int lo, int hi)
{
//  lo is the lower index, hi is the upper index
//  of the region of array a that is to be sorted
    int i=lo, j=hi;
    double x=a[(lo+hi)/2] , h;

    //  partition
    do
    {    
        while (a[i]<x) i++; 
        while (a[j]>x) j--;
        if (i<=j)
        {
            h        = a[i]; 
			a[i]     = a[j]; 
			a[j]     = h;
            i++; 
			j--;
        }
    }
	while (i<=j);

    //  recursion
    if (lo<j) qs(a , lo , j);
    if (i<hi) qs(a , i , hi);
}

/*----------------------------------------------------------------------------------------------------------------------------------------- */

void qsindex (double  *a, int *index , int lo, int hi)
{
//  lo is the lower index, hi is the upper index
//  of the region of array a that is to be sorted
    int i=lo, j=hi , ind;
    double x=a[(lo+hi)/2] , h;

    //  partition
    do
    {    
        while (a[i]<x) i++; 
        while (a[j]>x) j--;
        if (i<=j)
        {
            h        = a[i]; 
			a[i]     = a[j]; 
			a[j]     = h;
			ind      = index[i];
			index[i] = index[j];
			index[j] = ind;
            i++; 
			j--;
        }
    }
	while (i<=j);

    //  recursion
    if (lo<j) qsindex(a , index , lo , j);
    if (i<hi) qsindex(a , index , i , hi);
}


/* --------------------------------------------------------------------------- */

 
 void randini(void) 
	 
 {
	 
	 /* SHR3 Seed initialization */
	 
	 jsrseed  = (UL) time( NULL );
	 
	 jsr     ^= jsrseed;
	 
	 
	 /* KISS Seed initialization */
	 
#ifdef ranKISS
	  
	 z        = (UL) time( NULL );
	 
	 w        = (UL) time( NULL ); 
	 
	 jcong    = (UL) time( NULL );
	 
	 mix(z , w , jcong);
	 
#endif 
	 
	 
 }
 

/* --------------------------------------------------------------------------- */

void randnini(void) 
{	  
	register const double m1 = 2147483648.0, m2 = 4294967296.0 ;

	register double  invm1;

	register double dn = 3.442619855899 , tn = dn , vn = 9.91256303526217e-3 , q; 
	
	int i;

	
	/* Ziggurat tables for randn */	 
	
	invm1             = 1.0/m1;
	
	q                 = vn/exp(-0.5*dn*dn);  
	
	kn[0]             = (dn/q)*m1;	  
	
	kn[1]             = 0;
		  
	wn[0]             = q*invm1;	  
	
	wn[zigstep - 1 ]  = dn*invm1;
	
	fn[0]             = 1.0;	  
	
	fn[zigstep - 1]   = exp(-0.5*dn*dn);		
	
	for(i = (zigstep - 2) ; i >= 1 ; i--)      
	{   
		dn              = sqrt(-2.*log(vn/dn + exp(-0.5*dn*dn)));          
	
		kn[i+1]         = (dn/tn)*m1;		  
		
		tn              = dn;          
		
		fn[i]           = exp(-0.5*dn*dn);          
		
		wn[i]           = dn*invm1;      
	}

}


/* --------------------------------------------------------------------------- */


double nfix(void) 
{	
	const double r = 3.442620; 	/* The starting of the right tail */	
	
	static double x, y;
	
	for(;;)

	{
		
		x = hz*wn[iz];
					
		if(iz == 0)
		
		{	/* iz==0, handle the base strip */
			
			do
			{	
				x = -log(rand())*0.2904764;  /* .2904764 is 1/r */  
								
				y = -log(rand());			
			} 
			
			while( (y + y) < (x*x));
			
			return (hz > 0) ? (r + x) : (-r - x);	
		}
		
		if( (fn[iz] + rand()*(fn[iz-1] - fn[iz])) < ( exp(-0.5*x*x) ) ) 
		
		{
		
			return x;

		}

		
		hz = randint;		

		iz = (hz & (zigstep - 1));		
		
		if(abs(hz) < kn[iz]) 

		{
			return (hz*wn[iz]);	

		}


	}

}


/* --------------------------------------------------------------------------- */


double randn(void) 
	
{ 

		hz = randint;
		
		iz = (hz & (zigstep - 1));

		return (abs(hz) < kn[iz]) ? (hz*wn[iz]) : ( nfix() );
	
}



/* --------------------------------------------------------------------------- */

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -