⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 ndtimes.c

📁 本文件采用Matlab软件
💻 C
字号:
/*

 Compute matrix product A*B  for ND slice matrices

 Usage
 -----

 C = ndtimes(A,B)


 Inputs
 -------

 A           Matrix (k x n x a1 x a2 x ... al)    L = (l + 2) dimensionnal

 B           Matrix (n x p x b1 x b2 x ... bs)    S = (s + 2) dimensionnal


 CONDITIONS !!

 ai = bi, i = 1,...., min(l , s)



 Ouputs
 -------

 C           Matrix C



 Example
 -------

 d  = 4;
 M  = 3; 
 N  = 10000;
 Q  = [3 1 0 0 ; 0 1 0 0 ; 0 0 3 1 ; 0 0 0 1];
 Fk = randn(d , d , M);
 Xk = randn(d , 1 , M , N); 
 Qk = Q(: , : , ones(1 , M)); 
 Nk = randn(d , 1 , M , N);
 C  = reshape( ndtimes(Fk , Xk) + ndtimes(permute(ndchol(Qk) , [2 1 3]) , Nk) , [d M N]);





 Compile with:
 ------------

 mex   ndtimes.c

 or


 mex  -f mexopts_intel10amd.bat -output ndtimes.dll ndtimes.c

 Author          S閎astien PARIS (sebastien.paris@lsis.org) (5/4/08)
 -------


*/



#include <malloc.h>
#include "mex.h"


/* --------------------------------------- DECLARATION  ------------------------------------- */


void ndtimes(double *, double * , double * , int  , int  , int , int * , int * ,  int * , int);


/* -------------------------------------------------------------------------------------------- */


void mexFunction( int nlhs, mxArray *plhs[] , int nrhs, const mxArray *prhs[] )

{
	
	
	const int *dimsA , *dimsB;
		
	int *dimsC;
	
	int *ind_tpA = NULL , *ind_tpB = NULL , *indA = NULL , *indB = NULL , *indC = NULL;
	
	double   *A, *B , *C;
	
	int  i, j , k , n , p , o, h , v  , numDimsA = 0 , numDimsB = 0 , numDimsC = 0 , sizC = 1 , tpA = 1,  tpB = 1 , rap_dim = 1;
	
	
	/* Check nargin */
	
	if(nrhs != 2)
		
	{
		
		mexErrMsgTxt("Two ND matrix are requiered");
		
	}
	
	
	
    /* ----------- Input ------------ */
	
	
	
	A        = mxGetPr(prhs[0]);
    
	numDimsA = mxGetNumberOfDimensions(prhs[0]);
    
	dimsA    = mxGetDimensions(prhs[0]);
	
	
	B        = mxGetPr(prhs[1]);
    
	numDimsB = mxGetNumberOfDimensions(prhs[1]);
    
	dimsB    = mxGetDimensions(prhs[1]);
	
	
	k        = dimsA[0];
	
	n        = dimsA[1];
	
	o        = dimsB[0];
	
	p        = dimsB[1];
	
	
	if (n != o)
	{
		mexErrMsgTxt("Inner dimensions are not matching !! A(k x n x ...) and B(n x p x ...)");
	}
	
	
	
	if (numDimsA > numDimsB)
		
	{
		
		for (i = 2 ; i<numDimsB ; i++)
			
		{
			tpA *= dimsA[i];
			
			
			tpB *= dimsB[i];
			
		}
		
		if (tpA != tpB)
		{
			mexErrMsgTxt("Dimensions > 2 are not matching");
		}
		
		for (i = numDimsB ; i <numDimsA ; i++)
			
		{
			
			tpA     *= dimsA[i];
			
			rap_dim *= dimsA[i];
			
		}
	}
	
	if (numDimsA <= numDimsB)
		
	{
		for (i=2 ; i<numDimsA ; i++)
			
		{
			tpA *= dimsA[i];
			
			tpB *= dimsB[i];
			
		}
		
		if (tpA != tpB)
		{
			
			mexErrMsgTxt("Dimensions > 2 are not matching");
		}
		
		for (i = numDimsA ; i <numDimsB ; i++)
			
		{
			
			tpB     *= dimsB[i];
			
			rap_dim *= dimsB[i];
			
		}
	}
	
	
	ind_tpA         = (int *)mxMalloc(tpA*sizeof(int));
	
	
	for (i=0 ; i<tpA ; i++)
		ind_tpA[i] = i;
	
	ind_tpB          = (int *)mxMalloc(tpB*sizeof(int));
	
	
	for (i=0 ; i<tpB ; i++)
		ind_tpB[i] = i;
	
	if (numDimsA > numDimsB)
		
	{
		
		indA    = (int *)mxMalloc(tpA*sizeof(int));
		
		indB    = (int *)mxMalloc(tpA*sizeof(int));
		
		indC    = (int *)mxMalloc(tpA*sizeof(int));
		
		for (i=0 ; i<rap_dim ; i++)
		{
			
			h = i*tpB;
			
			for (j=0 ; j<tpB ; j++)
			{						
				v       = j + h;
				
				indA[v] = v;
				
				indB[v] = ind_tpB[j];
				
				indC[v] = v;
			}
			
		}
		
		numDimsC     = numDimsA;
		
		dimsC        = (int *)mxMalloc(numDimsC*sizeof(int));
		
		dimsC[0]     = k;
		
		dimsC[1]     = p;
		
		sizC         = tpA;
		
		for (i = 2; i<numDimsC ; i++)
			
			dimsC[i] = dimsA[i];
		
	}
	
	if (numDimsA < numDimsB)
		
	{
		
		indA    = (int *)mxMalloc(tpB*sizeof(int));
		
		indB    = (int *)mxMalloc(tpB*sizeof(int));
		
		indC    = (int *)mxMalloc(tpB*sizeof(int));
		
		for (i=0 ; i<rap_dim ; i++)
		{
			
			h = i*tpA;
			
			for (j=0 ; j<tpA ; j++)
			{
				v       = j + h;
				
				indA[v] = ind_tpA[j];
				
				indB[v] = v;
				
				indC[v] = v;
			}
			
		}
		
		numDimsC     = numDimsB;
		
		dimsC        = (int *)mxMalloc(numDimsC*sizeof(int));
		
		dimsC[0]     = k;
		
		dimsC[1]     = p;
		
		sizC         = tpB;
		
		for (i = 2; i<numDimsC ; i++)
			
			dimsC[i] = dimsB[i];
		
	}
	
	if (numDimsA == numDimsB)
		
	{
		
		indA    = (int *)mxMalloc(tpB*sizeof(int));
		
		indB    = (int *)mxMalloc(tpB*sizeof(int));
		
		indC    = (int *)mxMalloc(tpB*sizeof(int));
		
		for (i=0 ; i<tpB ; i++)
		{
			
			indA[i] = i;
			
			indB[i] = i;
			
			indC[i] = i;
		}
		
		numDimsC     = numDimsB;
		
		dimsC        = (int *)mxMalloc(numDimsC*sizeof(int));
		
		dimsC[0]     = k;
		
		dimsC[1]     = p;
		
		sizC         = tpA;
		
		for (i = 2; i<numDimsC ; i++)
			
			dimsC[i] = dimsA[i];
		
	}
	
	
	plhs[0]      = mxCreateNumericArray(numDimsC, dimsC, mxDOUBLE_CLASS, mxREAL);
	
	C            = mxGetPr(plhs[0]);
	
	
	/* ----------------- Array multiplication ------------ */
	
	ndtimes(A, B , C , k , n , p , indA , indB , indC , sizC);
	
    /* ----------------------- Free space ---------------- */
	
	
	mxFree(indA);
	
	mxFree(indB);
	
	mxFree(indC);
	
	mxFree(ind_tpA);
	
	mxFree(ind_tpB);
	
	mxFree(dimsC);
}

/* ----------------------------------------------------------------------------------------------- */

void ndtimes(double *A, double *B , double *C , int k , int n , int p , int *indA , int *indB , int *indC , int sizC)

{
	
	int v , rA , rB , rC , t , l , i , tl , kp , kn , np , rAkn, lnrB , trCkp, rBnp , rCkp;
	
	
	kp = k*p;
	
	kn = k*n;
	
	np = n*p;
	
	for(v = 0 ; v<sizC ; v++)
		
	{
		
		rA   = indA[v];
		
		rB   = indB[v];
		
		rC   = indC[v];
		
		rBnp = rB*np; 
		
		rAkn = rA*kn;
		
		rCkp = rC*kp;
		
		for (t = 0 ; t<k ; t++)
		{
			
			trCkp = t + rCkp;
			
			for(l = 0 ; l<p ; l++)
				
			{
				tl     = l*k   + trCkp;
				
				lnrB   = l*n   + rBnp;
				
				C[tl]  = 0.0;
				
				for(i = 0 ; i<n ; i++)
					
					C[tl] += A[t + i*k + rAkn]*B[i + lnrB];
			}          
			
		}
	}
	
}







⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -