⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 srng_model.c

📁 最详尽的神经网络源码
💻 C
📖 第 1 页 / 共 3 页
字号:
			
			co++;
			
		}
		
	}
	
	else
	{
			
		
		yproto                = mxGetPr(prhs[3]);
		
		if(mxGetNumberOfDimensions(prhs[3]) !=2)
		{
			
			mexErrMsgTxt("yproto must be (1 x Nproto)");
			
		}

		Nproto                = mxGetN(prhs[3]);

		plhs[1]               = mxCreateDoubleMatrix(1 , Nproto, mxREAL);
		
		yproto_est            = mxGetPr(plhs[1]);


		for( i = 0 ; i < Nproto ; i++)
			
		{
			
			yproto_est[i] = yproto[i];
			
		}

	}
	

	Np           = mxMalloc(m*sizeof(int));
	
	
	for (l = 0 ; l < m ; l++)
		
	{
		
		ind   = labels[l];
		
		Np[l] = 0;
		
		for (i = 0 ; i < Nproto ; i++)
			
		{
			
			if(yproto_est[i] == ind)
				
			{
				
				Np[l]++;
				
			}
		}
		
		if(Np[l] > Nproto_max)
			
		{
			
			Nproto_max = Np[l];
			
		}
		
	}
	




/* Input 3   Wproto */
	
	


	if ((nrhs < 3) || mxIsEmpty(prhs[2]) )
		
	{
				
		mtemp    = mxMalloc(d*m*sizeof(double));
		
		stdtemp  = mxMalloc(d*m*sizeof(double));

    	Nk       = mxMalloc(m*sizeof(int));


		for(i = 0 ; i < d*m ; i++)
		{

			mtemp[i]   = 0.0;

			stdtemp[i] = 0.0;

		}
		


		for (l = 0 ; l < m ; l++)
			
		{
			
			ind   = labels[l];
			
			ld    = l*d;

     		Nk[l] = 0;


			
			for (i = 0 ; i < Ntrain ; i++)
				
			{
				
				if(ytrain[i] == ind)
					
				{
					id  = i*d;
					
					for(j = 0 ; j < d ; j++)
						
					{
						
						mtemp[j + ld] += Xtrain[j + id];
						
					}


					Nk[l]++;

										
				}
			}
				
		}


		for (l = 0 ; l < m ; l++)
			
		{
			
			ld   = l*d;
			
			temp = 1.0/Nk[l];
					
			
			for(j = 0 ; j < d ; j++)
				
			{
				
				mtemp[j + ld] *=temp;
				
				
			}
			
		}
		

		for (l = 0 ; l < m ; l++)
			
		{
			
			ind = labels[l];
			
			ld  = l*d;
			
			for (i = 0 ; i < Ntrain ; i++)
				
			{
				
				if(ytrain[i] == ind)
					
				{
					id  = i*d;
					
					for(j = 0 ; j < d ; j++)
						
					{
						temp             = (Xtrain[j + id] - mtemp[j + ld]);

						stdtemp[j + ld] += (temp*temp);
						
					}
					
					
				}
			}
				
		}



		for (l = 0 ; l < m ; l++)
			
		{
			
			ld   = l*d;
			
			temp = 1.0/(Nk[l] - 1);
					
			
			for(j = 0 ; j < d ; j++)
				
			{
				
				stdtemp[j + ld] = temp*sqrt(stdtemp[j + ld]);
				
				
			}
			
		}

		
		plhs[0]               = mxCreateDoubleMatrix(d , Nproto, mxREAL);
		
		Wproto_est            = mxGetPr(plhs[0]);


		for(l = 0 ; l < Nproto ; l++)
			
		{
			ld = l*d;		
			
			for(i = 0 ; i < m ; i++)
			{
				
				if(labels[i] == yproto_est[l] )
					
				{
					
					indice = i*m;
					
				}
				
				
			}

			
			for(i = 0 ; i < d ; i++)
				
			{
				
				Wproto_est[i + ld]  = mtemp[i + indice] + stdtemp[i + indice]*randn();
				
			}
			
		}
	}
	
	else
	{
		
		
		Wproto                = mxGetPr(prhs[2]);
		
		if(mxGetNumberOfDimensions(prhs[2]) !=2 || mxGetM(prhs[2]) != d)
		{
			
			mexErrMsgTxt("Wproto must be (d x Nproto)");
			
		}

		Nproto                = mxGetN(prhs[2]);

		plhs[0]               = mxCreateDoubleMatrix(d , Nproto, mxREAL);
		
		Wproto_est            = mxGetPr(plhs[0]);
	
		for( i = 0 ; i < d*Nproto ; i++)
			
		{
			
			Wproto_est[i] = Wproto[i];
			
		}
		
	}
	



	
	/* Input 5   lambda */

	plhs[2]                = mxCreateDoubleMatrix(d , 1 , mxREAL);
	
	lambda_est             = mxGetPr(plhs[2]);	


	if ((nrhs < 5) || mxIsEmpty(prhs[4]) )
	{
       
		lambda = (double *)mxMalloc(d*sizeof(double));

		
		for (i = 0 ; i < d ; i++)
		{

			lambda[i]          = 1.0/d;

            lambda_est[i]      = lambda[i] ;

		}

       
	}

	else

	{

    lambda                 = mxGetPr(prhs[4]);
		
	for (i = 0 ; i < d ; i++)
		{

			lambda_est[i]       = lambda[i] ;

		}	
	}


   /* Input 6   option */
	

	if ( (nrhs > 5) && !mxIsEmpty(prhs[5]) )
		
	{
		
				
		mxtemp                                   = mxGetField(prhs[5] , 0, "epsilonk");
		
		if(mxtemp != NULL)
		{
			
			tmp                                  = mxGetPr(mxtemp);
			
			options.epsilonk                     = tmp[0];
			
		}

		mxtemp                                   = mxGetField(prhs[5] , 0, "epsilonl");
		
		if(mxtemp != NULL)
		{
			
			tmp                                  = mxGetPr(mxtemp);
			
			options.epsilonl                     = tmp[0];
			
			if (options.epsilonl>options.epsilonk)
			{
				
				
				mexErrMsgTxt("Epsilon_l < Epsilon_k");
				
			}

			
		}
		
		mxtemp                                   = mxGetField(prhs[5] , 0, "epsilonlambda");
		
		if(mxtemp != NULL)
		{
			
			tmp                                  = mxGetPr(mxtemp);
			
			options.epsilonlambda                = tmp[0];
			
			if (options.epsilonlambda>options.epsilonl/10)
			{
				
				
				mexErrMsgTxt("Epsilon_lambda < Epsilon_l/10");
				
			}

			
		}


		mxtemp                                   = mxGetField(prhs[5] , 0, "sigmastart");
		
		if(mxtemp != NULL)
		{
			
			tmp                                  = mxGetPr(mxtemp);
			
			options.sigmastart                   = tmp[0];
			
		}

		mxtemp                                   = mxGetField(prhs[5] , 0, "sigmaend");
		
		if(mxtemp != NULL)
		{
			
			tmp                                  = mxGetPr(mxtemp);
			
			options.sigmaend                     = tmp[0];
			
		}

		mxtemp                                   = mxGetField(prhs[5] , 0, "sigmastretch");
		
		if(mxtemp != NULL)
		{
			
			tmp                                  = mxGetPr(mxtemp);
			
			options.sigmastretch                 = tmp[0];
			
		}

		
		mxtemp                                   = mxGetField(prhs[5] , 0, "threshold");
		
		if(mxtemp != NULL)
		{
			
			tmp                                  = mxGetPr(mxtemp);
			
			options.threshold                    = tmp[0];
			
		}
		
		mxtemp                                   = mxGetField(prhs[5] , 0, "xi");
		
		if(mxtemp != NULL)
		{
			
			tmp                                  = mxGetPr(mxtemp);
			
			options.xi                           = tmp[0];
			
		}
		
		
		
		mxtemp                                   = mxGetField(prhs[5] , 0, "nb_iterations");
		
		if(mxtemp != NULL)
		{
			
			tmp                                  = mxGetPr(mxtemp);
			
			options.nb_iterations                = (int) tmp[0];
			
		}
		
		mxtemp                                   = mxGetField(prhs[5] , 0, "metric_method");
		
		if(mxtemp != NULL)
		{
			
			tmp                                  = mxGetPr(mxtemp);
			
			options.metric_method                = (int) tmp[0];
			
		}
	
		
		mxtemp                                   = mxGetField(prhs[5] , 0, "shuffle");
		
		if(mxtemp != NULL)
		{
			
			tmp                                  = mxGetPr(mxtemp);
			
			options.shuffle                      = (int) tmp[0];
			
		}

		mxtemp                                   = mxGetField(prhs[5] , 0, "updatelambda");
		
		if(mxtemp != NULL)
		{
			
			tmp                                  = mxGetPr(mxtemp);
			
			options.updatelambda                 = (int) tmp[0];
			
		}

		
	}
	
	
	
	
	/*---------- Outputs --------*/ 


			  	
    plhs[3]                = mxCreateDoubleMatrix(1 , options.nb_iterations , mxREAL);
						
	E_SRNG                 = mxGetPr(plhs[3]);


   /*---------- Tempory matrices --------*/ 

	
    temp_train            = mxMalloc(Ntrain*sizeof(double));

    index_train           = mxMalloc(Ntrain*sizeof(int));

	distj                 = mxMalloc(Nproto_max*sizeof(double));

    hproto                = mxMalloc(Nproto_max*sizeof(double));

	sproto                = mxMalloc(Nproto_max*sizeof(double));

    rank_distj            = mxMalloc(Nproto_max*sizeof(int));

    index_distj           = mxMalloc(Nproto_max*sizeof(int));







	/* Main Call */


	srng_model(Xtrain , ytrain , options , d , Ntrain , Nproto , m , 
		       Wproto_est , yproto_est , E_SRNG , lambda , lambda_est , 
			   labels , temp_train , index_train  , distj , rank_distj , index_distj, hproto , sproto , Nproto_max);

		
   /*---------- Free memory --------*/ 

	
	mxFree(labels);
	
	mxFree(ytrainsorted);

	mxFree(temp_train);

	mxFree(index_train);
	

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -