brutesearchmex.cpp

来自「K-nearest neighbors 搜索 聚类时经常使用的一种方法 国外网站」· C++ 代码 · 共 274 行

CPP
274
字号
#include "mex.h"
#include "BruteSearchTranspose.cpp"






// In Matlab la funzione deve essere idc=NNSearch(x,y,pk,ptrtree)

/* the gateway function */
void mexFunction( int nlhs, mxArray *plhs[],
int nrhs,   mxArray *prhs[])
{
    
    double       *p;//reference points
    double       *qp;//query points
    double* results;
    char*        String;//Input String
    int String_Leng;//length of the input String
    int N,Nq,dim,i,j;
    double* distances;//distances for k-neighbors
    int* idck;//k-neighbors
    int idc;//nearest neighbor
    double*k;
    int kint;
    double* r;
    double* D;//Output distance matrix
    double mindist;
    
    //Errors Check
    
    if (nlhs>2)
    {
        mexErrMsgTxt("Too many outputs");
    }
    
    if (nrhs<2)
    {
        mexErrMsgTxt("Not enough inputs");
    }
    
    
    N=mxGetN(prhs[0]);//number of reference points
    Nq=mxGetN(prhs[1]);//number of query points
    dim=mxGetM(prhs[0]);//dimension of points
    
//     mexPrintf("Dimension %4.4d\n",dim);
   
    
    //Check is inputs has the same dimension
    if (mxGetM(prhs[1])!=dim)
    {
        mexErrMsgTxt("Points must have the same dimension");
    }
    
    
    //Checking input format
    if( !mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1]) )
    { mexErrMsgTxt("Inputs points must be double matrix ");
    }
    
    
    
    p = mxGetPr(prhs[0]);//puntatore all'array dei punti
    qp = mxGetPr(prhs[1]);//puntatore all'array dei punti
    

    
    //Getting the Input String
    
    if (nrhs>2)
    {
        
        
        if(!mxIsChar(prhs[2]))//is the input a string
            
        {
            mexErrMsgTxt("Input 3 must be of type char.");
        }
        
        
        String_Leng=mxGetN(prhs[2]);//StringLength
        
        if (String_Leng>1)//Check if string is correct
        {
            mexErrMsgTxt("Input 3 must be only one String.");
        }
        
        
        
        String =mxArrayToString (prhs[2]);//get the string
        
        if(String == NULL)
        {
            mexErrMsgTxt("Could not convert input to string.");
        }
//         mexPrintf("The input string is:  %s\n",String);
    }
    else
    {
        String=NULL;
//         mexPrintf("No input string");
    }
    
    
    //Choose the algorithm from the input String
    
    if(String==NULL)
    {  //Nearest Neighbor
//         mexPrintf("Nearest Neighbor\n");
        plhs[0] = mxCreateDoubleMatrix(Nq, 1,mxREAL);//costruisce l'output array
        results = mxGetPr(plhs[0]);//appicicaci il puntatore
        
        if (nlhs==1)
        { //Search with no distance
            
            for (i=0;i<Nq;i++)
            {
                results[i]=BruteSearch(p,&qp[i*dim] ,N,dim,&mindist)+1;
            }
        }
        else //Search with distance
        { //build the output distance matrix
            plhs[1] = mxCreateDoubleMatrix(Nq, 1,mxREAL);
            D = mxGetPr(plhs[1]);//appicicaci il puntatore
            
            for (i=0;i<Nq;i++)
            {
                
               
                results[i]=BruteSearch(p,&qp[i*dim] ,N,dim,&mindist)+1;
                D[i]=sqrt(mindist);
            }
        }
        
    }
    
    
    
    else if(String[0]=='k')
    {       //KNearest Neighbor
        
//         mexPrintf("KNearest Neighbor\n");
        
        
        
        if (nrhs<4)
        {
            mexErrMsgTxt("Number of neighbours not given");
        }
        
        k=mxGetPr(prhs[3]);
        
        kint=k[0];
        
        
        idck=new int[kint];
        distances=new double[kint];
     
        plhs[0] = mxCreateDoubleMatrix(Nq,kint,mxREAL);//costruisce l'output array
        results = mxGetPr(plhs[0]);//appicicaci il puntatore
        
        if (nlhs==1) //Search without Distance matrix output
        {
            for (i=0;i<Nq;i++)
            {
             
                BruteKSearch(p,&qp[i*dim] ,kint,N,dim,distances,idck);//Run the query
                
                for(j=0;j<kint;j++)
                {   
                    results[j*Nq+i]=idck[j]+1;//plus one Matlab notation
                }
                
            }
        }
        else//Search with Distance matrix
        {
            plhs[1] = mxCreateDoubleMatrix(Nq,kint,mxREAL);//costruisce l'output array
            D = mxGetPr(plhs[1]);//appicicaci il puntatore
            for (i=0;i<Nq;i++)
            {
               
                BruteKSearch(p,&qp[i*dim] ,kint,N,dim,distances,idck);//Run the query
                
                for(j=0;j<kint;j++)
                {
                    D[j*Nq+i]=sqrt(distances[j]);
                    results[j*Nq+i]=idck[j]+1;//plus one Matlab notation
                }
               
            }
            
        }
    }
    
    
    else if(String[0]=='r')
    {       //Radius Search
       
        
        if (Nq>1)
        {
            mexErrMsgTxt("Radius Serach possible with only one query point");
        }
        
        if (nrhs<4)
        {
            mexErrMsgTxt("Radius not given");
        }
        
        r=mxGetPr(prhs[3]);
       
        vector<int> idcv;
        
        if (nlhs==1)//Search without distance
        {
            BruteRSearch(p,qp,*r,N,dim,&idcv);
            
            kint=idcv.size();//Size of vector
            
            plhs[0] = mxCreateDoubleMatrix(1,kint,mxREAL);//costruisce l'output array
            results = mxGetPr(plhs[0]);
           
            for(i=0;i<kint;i++)//copy in the otput array
            {
                //              mexPrintf("%4.1d\n",idcv[i]);
                results[i]=idcv[i]+1;//copy the output
            }
            
        }
        else
        {
        vector<double> distvect;//declare the distance vector
       
        BruteRSearchWithDistance(p,qp,*r,N,dim,&idcv,&distvect);
        
        kint=idcv.size();//Size of vector
     
       
        plhs[0] = mxCreateDoubleMatrix(1,kint,mxREAL);//costruisce l'output array
        results = mxGetPr(plhs[0]);
       
        plhs[1] = mxCreateDoubleMatrix(1,kint,mxREAL);//costruisce l'output array
        D= mxGetPr(plhs[1]);
       
        for(i=0;i<kint;i++)//copy in the otput array
        {
          
            results[i]=idcv[i]+1;//copy the output
            D[i]=distvect[i];//The distance is already radsquared
        }
        }
    }
    
    
    
    else
    {
        mexErrMsgTxt("Invalid Input String.");
    }
    
    //deallocatememory
    delete [] distances;
    delete [] idck;
}






⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?