⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 emd_mex.c

📁 earth mover s distance
💻 C
字号:

#include "mex.h"
#include "emd.h"

#define DATA_TYPE double

/*indexing a matlab matrix (2d array)
v : pointer to an array
d : number of rows
i,j : row and column index
*/
#define MAT_ELEM(v,d,i,j) (*(v+((d)*(j))+(i)))


DATA_TYPE *ptrCost=NULL;
int rowsCost=0;

float dist(feature_t *iF1, feature_t *jF2)
{
    return (float)MAT_ELEM(ptrCost,rowsCost,*iF1,*jF2);
}


void mexFunction(
int          nlhs,
mxArray      *plhs[],
int          nrhs,
const mxArray *prhs[]
)
{
    int i,n1,n2;
    DATA_TYPE* ptr;
    feature_t *f1,*f2;
    float *w1,*w2;
    float       e;
    flow_t      *flow;
    int  flowSize;
    signature_t s1,s2;
    /* Check for proper arguments */
    if (nrhs != 3)
    {
        mexErrMsgTxt("3 input arguments required.");
    }

    if(nlhs > 2)
    {
        mexErrMsgTxt("Too many outputs.");
    }
    
    if(!mxIsNumeric(prhs[0]) || !mxIsNumeric(prhs[1]) || !mxIsNumeric(prhs[2]))
    mexErrMsgTxt("Input arguments must be numeric matrices\n");
    
    if(mxIsSparse(prhs[0])||mxIsSparse(prhs[1])||mxIsSparse(prhs[2]))
    {
        mexErrMsgTxt("Sparse matrices are not supported.");
    }
    
    if (mxGetNumberOfDimensions(prhs[0])>2 ||
    mxGetNumberOfDimensions(prhs[1])>2 ||
    mxGetNumberOfDimensions(prhs[2])>2 )
    mexErrMsgTxt("Multidemnsionl arrays are not supported!\n");
    
    if( mxGetClassID(prhs[0])!=mxDOUBLE_CLASS ||
    mxGetClassID(prhs[1])!=mxDOUBLE_CLASS ||
    mxGetClassID(prhs[2])!=mxDOUBLE_CLASS )
    mexErrMsgTxt("Currently only double type is supported.\n");
    
    if(mxGetM(prhs[0]))
    
    n1 = mxGetN(prhs[0]);
    n2 = mxGetN(prhs[1]);
    if(!n1 || !n2)
    mexErrMsgTxt("Signatures can not be empty!\n");
    if(1!=mxGetM(prhs[0]) || 1!=mxGetM(prhs[1]))
    mexErrMsgTxt("Weights of signatures should be row vectors!\n");
    
    ptrCost=(DATA_TYPE*)mxGetData(prhs[2]);rowsCost=mxGetM(prhs[2]);
    if(n1!=mxGetM(prhs[2]) ||  n2!=mxGetN(prhs[2]))
    mexErrMsgTxt("Size of cost matrix is not consistent with signatures.\n");
    
    f1=mxCalloc(n1,sizeof(feature_t));
    w1=mxCalloc(n1,sizeof(float));
    ptr=(DATA_TYPE*)mxGetData(prhs[0]);
    for(i=0;i<n1;i++)
    {
        w1[i]=(float)ptr[i];
        f1[i]=i;
    }
    f2=mxCalloc(n2,sizeof(feature_t));
    w2=mxCalloc(n2,sizeof(float));
    ptr=(DATA_TYPE*)mxGetData(prhs[1]);
    for(i=0;i<n2;i++)
    {
        w2[i]=(float)ptr[i];
        f2[i]=i;
    }
    s1.n=n1;s1.Features=f1;s1.Weights=w1;
    s2.n=n2;s2.Features=f2;s2.Weights=w2;
    
    if(nlhs<=1)
    {
		e = emd(&s1, &s2, dist, 0, 0);
        plhs[0]=mxCreateDoubleScalar(e);
    }else if(nlhs==2)
    {
		flow=mxCalloc(n1+n2-1,sizeof(flow_t)); //maxsimum size of flow
		e = emd(&s1, &s2, dist, flow, &flowSize);
		plhs[0]=mxCreateDoubleScalar(e);
        plhs[1]= mxCreateDoubleMatrix(flowSize,3,mxREAL);
		ptr=(DATA_TYPE*)mxGetData(plhs[1]);
		for(i=0;i<flowSize;i++)
		{
			MAT_ELEM(ptr,flowSize,i,0)=flow[i].from;
			MAT_ELEM(ptr,flowSize,i,1)=flow[i].to;
			MAT_ELEM(ptr,flowSize,i,2)=flow[i].amount;
		}
		mxFree(flow);
    }

    mxFree(f1);
    mxFree(f2);
    mxFree(w1);
    mxFree(w2);
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -