⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 errpca_diag.cpp

📁 Matlab package for PCA for datasets with missing values
💻 CPP
字号:
// ERRPCA_DIAG.CPP//// [errMx,varcost] = ERRPCA_DIAG( X, A, S, numCPU ) computes a sparse// matrix errMx of reconstruction errors (X - A*S) for the given// sparse matrix X. Output parameter varcost is needed to compute the// variance of the observation noise in PCA_DIAG. It is calculated// using Sv and Av which is the posterior variances of S and A// (assuming fully factorial posterior for both).//// numCPU specifies the number of CPUs used for parallel computing// (default 1).//// See also CF_DIAG.M// Equivalent Matlab code://   M = spones(X);//   errMx = (X - A*S).*M;//   if isempty(Muv)//       varcost = ( (A.^2)*Sv ).*M;//   else//       varcost = ( Av*S.^2 + (A.^2)*Sv + Av*Sv ).*M;//   end//   varcost = full(sum(sum(varcost)));// This software is provided "as is", without warranty of any kind.// Alexander Ilin, Tapani Raiko#include <math.h>#include <string.h>#include "mex.h"#include "matrix.h"#ifndef NOTHREADS#define USETHREADS#include <pthread.h>#endif#define METHOD_PT 0#define METHOD_VBD 1#define METHOD_PPCAD 2#ifdef OLDMATLABAPItypedef int mwIndex;typedef int mwSize;#endiftypedef struct{  double     *A, *S, *X, *ErrMx, *Sv, *Av, varcost;  mwIndex    *ir, *jc;  mwSize     ndata;  mwIndex    tfirst,jx;  mwSize     ncomp;  mwSize     n1;  int        method;}TParams;#ifdef USETHREADSvoid *thread_function(void *);void ThreadComputations(TParams*);#endifvoid mexFunction(int nlhs, mxArray *plhs[], int nrhs,                  const mxArray *prhs[]){  const mxArray *mxX, *mxA, *mxS, *mxAv, *mxSv;  mxArray       *mxErrMx;  double        *X, *A, *S, *Av, *Sv, *ErrMx;  double        *pvarcost, varcost;  mwIndex       r,jx,six,aix;  mwSize        ndata;                  // Number of observed values  mwSize        n1;                     // Dimensionalities of the  mwSize        n2;                     //  the data matrix  mwSize        ncomp;                  // Number of components  double        res;  int           numCPU;                 // Number of threads  mwIndex       *ir, *jc, k;  mwSize        nzmax;  int           method;  double        ak, sk, avk, svk;  mxX = prhs[0];  mxA = prhs[1];  mxS = prhs[2];  Sv = NULL; Av = NULL;  if( nrhs > 3 )  {      mxSv = prhs[3];      if( ! mxIsEmpty( mxSv ) )          Sv = (double *)mxGetPr( mxSv );  }  if( nrhs > 4 )  {      mxAv = prhs[4];      if( ! mxIsEmpty( mxAv ) )          Av = (double *)mxGetPr( mxAv );  }  method = METHOD_PT;  if( nlhs > 1 )      if( Av != NULL && Sv != NULL )          method = METHOD_VBD;      else if( Av == NULL && Sv != NULL )          method = METHOD_PPCAD;  X = (double *)mxGetPr( mxX );  A = (double *)mxGetPr( mxA );  S = (double *)mxGetPr( mxS );  n1 = mxGetM( mxX );  n2 = mxGetN( mxX );  ncomp = mxGetN( mxA );  ir = mxGetIr( mxX );  jc = mxGetJc( mxX );  nzmax = mxGetNzmax(mxX);  ndata = jc[n2];  // Copy the structure of matrix X to output matrix ErrMx  mxErrMx = mxCreateSparse( n1, n2, nzmax, mxREAL );  memcpy( mxGetIr( mxErrMx ), ir, nzmax*sizeof(mwIndex) );  memcpy( mxGetJc( mxErrMx ), jc, (n2+1)*sizeof(mwIndex) );  ErrMx = mxGetPr(mxErrMx);  plhs[0] = mxErrMx;  if( nlhs > 1 )  {      plhs[1] = mxCreateDoubleMatrix(1, 1, mxREAL);      pvarcost = mxGetPr(plhs[1]);  }#ifdef USETHREADS  if( nrhs < 6 )      numCPU = 1;  else      numCPU = (int)*(double *)mxGetPr( prhs[5] );#else  numCPU = 1;#endif  if( numCPU == 1 )  {      jx = 0;      switch( method )      {      case METHOD_PT:          for( r=0; r < ndata; r++ )          {              res = 0;              while( r == jc[jx+1] ) jx++; // Move to next column              aix = ir[r];              six = jx*ncomp;                            for( k=0; k<ncomp; k++ )              {                  res += A[ aix ] * S[ six ];                  six++; aix += n1;              }              ErrMx[r] = X[r] - res;          }          varcost = 0;          break;      case METHOD_VBD:          varcost = 0;          for( r=0; r < ndata; r++ )          {              res = 0;              while( r == jc[jx+1] ) jx++; // Move to next column              aix = ir[r];              six = jx*ncomp;                            for( k=0; k<ncomp; k++ )              {                  ak = A[aix]; avk = Av[aix];                  sk = S[six]; svk = Sv[six];                                    res += ak *sk;                  varcost += avk * sk*sk + ak*ak * svk + avk * svk;                  six++; aix += n1;              }              ErrMx[r] = X[r] - res;          }          break;      case METHOD_PPCAD:          varcost = 0;          for( r=0; r < ndata; r++ )          {              res = 0;              while( r == jc[jx+1] ) jx++; // Move to next column              aix = ir[r];              six = jx*ncomp;                            for( k=0; k<ncomp; k++ )              {                  ak = A[aix];                  sk = S[six]; svk = Sv[six];                                    res += ak *sk;                  varcost += ak*ak * svk;                  six++; aix += n1;              }              ErrMx[r] = X[r] - res;          }          break;      }      if( nlhs > 1 )          (*pvarcost) = varcost;      return;        }#ifdef USETHREADS  /*******************************************************************                    Multi-thread implementation  *******************************************************************/  mwIndex          cfirst;                 // First column for a thread  mwIndex          tmp;  mwIndex          tlast;                  // Last value for a thread  pthread_t        *mythread;  TParams          *tp;  int              i;  mythread = (pthread_t *)malloc( numCPU*sizeof(pthread_t) );  tp = (TParams *)malloc( numCPU*sizeof(TParams) );  for( i=0; i < numCPU; i++ )  {      // Common thread arguments      tp[i].A = A; tp[i].Av = Av;      tp[i].S = S; tp[i].Sv = Sv;      tp[i].ncomp = ncomp;      tp[i].n1 = n1;      tp[i].ErrMx = ErrMx;      tp[i].X = X;      tp[i].ir = ir;      tp[i].jc = jc;      tp[i].method = method;      // Thread specific arguments      cfirst = i * (mwIndex)floor( (double)n2 / numCPU );      tp[i].jx = cfirst;      tp[i].tfirst = jc[cfirst];      if( i == numCPU-1 )          tlast = ndata;      else      {          tmp = (i+1) * (mwIndex)floor( (double)n2 / numCPU );          tlast = jc[tmp];      }      tp[i].ndata = tlast - tp[i].tfirst;     if( i < numCPU-1 )     {         if( pthread_create( &(mythread[i]), NULL, thread_function,                             (void*)(&tp[i]) ) )         {             mexErrMsgTxt("Error creating thread.");         }     }     else     {         ThreadComputations( tp + numCPU-1 );     }  }  for( i=0; i < numCPU-1; i++ )  {      if( pthread_join( mythread[i], NULL ) )      {          printf("Error joining thread\n");      }  }  if( nlhs > 1 )  {      (*pvarcost) = 0;      for( i=0; i < numCPU; i++ )          (*pvarcost) += tp[i].varcost;  }  return;#endif // USETHREADS}#ifdef USETHREADS/***  Thread function*/void *thread_function(void *arg) {     TParams*   tp = (TParams*)arg;  ThreadComputations(tp);  return(NULL);}void ThreadComputations(TParams* tp){     double        *X, *A, *S, *ErrMx, *Av, *Sv;  mwIndex       r,jx,six,aix;  mwSize        ncomp,n1;  mwSize        ndata;                  // Number of observed values  double        res;  mwIndex       *ir, *jc, k;  mwIndex       tfirst;  double        ak, sk, avk, svk, varcost;  A = tp->A; Av = tp->Av;  S = tp->S; Sv = tp->Sv;  X = tp->X;  ir = tp->ir;  jc = tp->jc;  ErrMx = tp->ErrMx;  ndata = tp->ndata;  ncomp = tp->ncomp;  n1 = tp->n1;  tfirst = tp->tfirst;  jx = tp->jx;  switch( tp->method )  {  case METHOD_PT:      for( r=tfirst; r < tfirst+ndata; r++ )      {          res = 0;          while( r == jc[jx+1] ) jx++;          aix = ir[r];          six = jx*ncomp;          for( k=0; k<ncomp; k++ )          {              res += A[ aix ] * S[ six ];              six++; aix += n1;          }          ErrMx[r] = X[r] - res;      }      tp->varcost = 0;      return;  case METHOD_VBD:      varcost = 0;      for( r=tfirst; r < tfirst+ndata; r++ )      {          res = 0;          while( r == jc[jx+1] ) jx++; // Move to next column          aix = ir[r];          six = jx*ncomp;                    for( k=0; k<ncomp; k++ )          {              ak = A[aix]; avk = Av[aix];              sk = S[six]; svk = Sv[six];                            res += ak *sk;              varcost += avk * sk*sk + ak*ak * svk + avk * svk;              six++; aix += n1;          }          ErrMx[r] = X[r] - res;      }      tp->varcost = varcost;      return;        case METHOD_PPCAD:      varcost = 0;      for( r=tfirst; r < tfirst+ndata; r++ )      {          res = 0;          while( r == jc[jx+1] ) jx++; // Move to next column          aix = ir[r];          six = jx*ncomp;                    for( k=0; k<ncomp; k++ )          {              ak = A[aix];              sk = S[six]; svk = Sv[six];                            res += ak *sk;              varcost += ak*ak * svk;              six++; aix += n1;          }          ErrMx[r] = X[r] - res;      }      tp->varcost = varcost;      return;  }}#endif

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -