⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 nca.c

📁 功能为neighborhood components analysis
💻 C
字号:
/*
 * simple code for computing the KL-divergence objective function and gradient
 * from "Neighbourhood Components Analysis" Goldberger et al, NIPS04
 *
 * charless fowlkes
 * fowlkes@cs.berkeley.edu
 * 2005-02-23
 *
 */

#include <mex.h>
#include <string.h>
#include <math.h>

void mexFunction (int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])
{
    double *A,*X,*Y,*AXT,*F,*M,*P,*ED2,*pi;
    int ID,OD,N,K,ci,i,j,k,m,n;
    // check number of arguments
    if (nlhs < 2) {
        mexErrMsgTxt("Too few output arguments.");
    }
    if (nlhs >= 3) {
        mexErrMsgTxt("Too many output arguments.");
    }
    if (nrhs < 4) {
        mexErrMsgTxt("Too few input arguments.");
    }
    if (nrhs >= 5) {
        mexErrMsgTxt("Too many input arguments.");
    }

    // get arguments
    
    A = mxGetPr(prhs[0]);
    ID = mxGetN(prhs[0]);
    OD = mxGetM(prhs[0]);
    X = mxGetPr(prhs[1]);
    if (mxGetN(prhs[1]) != ID) { mexErrMsgTxt("data (X) dimension  does not match (A) input dimension"); }
    N = mxGetM(prhs[1]);

    Y = mxGetPr(prhs[2]);
    K = mxGetN(prhs[2]);
    if (mxGetM(prhs[2]) != N) { mexErrMsgTxt("different #of class labels (Y) and point coordinates (X)"); } 

    AXT = mxGetPr(prhs[3]);
    if (mxGetN(prhs[3]) != N) { mexErrMsgTxt("AX has wrong # colums"); } 
    if (mxGetM(prhs[3]) != OD) { mexErrMsgTxt("AX has wrong # rows"); } 

    printf("pts=%d ",N);
    printf("classes=%d ",K);
    printf("indim=%d ",ID);
    printf("outdim=%d \n",OD);

    ////// set up output arguments
    plhs[0] = mxCreateDoubleMatrix(1,1,mxREAL); 
    //plhs[1] = mxCreateDoubleMatrix(1,ID*OD,mxREAL); 
    plhs[1] = mxCreateDoubleMatrix(ID,ID,mxREAL); 
    F = mxGetPr(plhs[0]);
    M = mxGetPr(plhs[1]);

    //compute exp(-D2)
    //ED2 = new double[N*N];
    ED2 = mxCalloc(N*N,sizeof(double));
    for (i = 0; i < N; i++)
    {
      for (j = 0; j < N; j++)
      {
        double d2 = 0;
        for (k = 0; k < OD; k++)
        {
          d2 = d2 + (AXT[i*OD+k] - AXT[j*OD+k])*(AXT[i*OD+k] - AXT[j*OD+k]) ;
        }
        ED2[i*N+j] = exp(-d2);
      }
    }

    //compute softmax function P 
    //P = new double[N*N];
    P = mxCalloc(N*N,sizeof(double));
    for (j = 0; j < N; j++)
    {
      for (i = 0; i < N; i++)
      {
        if (i == j)
        {
          P[j*N+i] = 0;
        }
        else
        {
          double den = 0;
          for (k = 0; k < N; k++)
          {
            if (k != i)
            {
              den = den + ED2[i*N+k];
            }
          }
          P[j*N+i] = ED2[j*N+i] / den;
        }
      }
    }

    //compute classification probability pi and total objective F
    //pi = new double[N];
    pi = mxCalloc(N,sizeof(double));
    F[0] = 0;
    for (i = 0; i < N; i++)
    {
      int ci = -1;
      for (k = 0; k < K; k++)
      {
        if (Y[k*N+i] != 0)
        {
          ci = k; 
        }
      }
      pi[i] = 0; //probability of drawing a neighbor in our same class
      for (j = 0; j < N; j++)
      {
        if (Y[ci*N+j] != 0)
        {
          pi[i] = pi[i] + P[j*N+i];
        }
      }
      F[0] = F[0] + log(pi[i]);
    }

    //now compute the gradient
    //double* M = new double[ID*ID];
    memset(M,0,ID*ID*sizeof(double));
    for (i = 0; i < N; i++)
    {
      //add in first sum
      for (k = 0; k < N; k++)
      {
        for (m = 0; m < ID; m++)
        {
          for (n = 0; n < ID; n++)
          {
            M[m*ID+n] = M[m*ID+n] + P[k*N+i]*(X[m*N+i] - X[m*N+k])*(X[n*N+i] - X[n*N+k]);
          }
        }
      }

      //subtract off second sum (only over class of point i)
      ci = -1;
      for (k = 0; k < K; k++)
      {
        if (Y[k*N+i] != 0)
        {
          ci = k; 
        }
      }
      for (j = 0; j < N; j++)
      {
        if (Y[ci*N+j] != 0)
        {
          for (m = 0; m < ID; m++)
          {
            for (n = 0; n < ID; n++)
            {
              M[m*ID+n] = M[m*ID+n] - (1/pi[i])*P[j*N+i]*(X[m*N+i] - X[m*N+j])*(X[n*N+i] - X[n*N+j]);
            }
          }
        }
      }
    }
      
    free(ED2);
    free(P);
    free(pi);
}



⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -