gmm.c

来自「LastWave」· C语言 代码 · 共 480 行

C
480
字号
#include "lastwave.h"#include "signals.h"#include "images.h"static SIGNAL theMoments = NULL;static IMAGE theDMoments = NULL;static SIGNAL theParams = NULL;static LWFLOAT (*theMomFunc)(int n, SIGNAL params);static LWFLOAT (*theDMomFunc)(int nMom, int nParam, SIGNAL params);static LWFLOAT (*theMDMomFunc)(int nMom, int nParam, SIGNAL params);static LWFLOAT (*theMatrixFunc)(int i, int j, SIGNAL params);static void (*thePrintParamFunc)(SIGNAL params);static IMAGE theGMMMatrix = NULL;static int theNumberOfSamples;void InitGMM(int numberOfSamples, int nMom, int nParam, LWFLOAT (*momFunc)(int n, SIGNAL params), LWFLOAT (*dmomFunc)(int nMom, int nParam, SIGNAL params), LWFLOAT (*mdmomFunc)(int nMom, int nParam, SIGNAL params), LWFLOAT (*matrixFunc)(int i, int j, SIGNAL params), void (*printParamFunc)(SIGNAL params)){  if (theMoments == NULL) theMoments = NewSignal();  SizeSignal(theMoments,nMom, YSIG);  if (theParams == NULL) theParams = NewSignal();  SizeSignal(theParams,nParam, YSIG);  if (theDMoments == NULL) theDMoments = NewImage();  SizeImage(theDMoments,nMom,nParam);  theMomFunc = momFunc;  theDMomFunc = dmomFunc;  theMDMomFunc = mdmomFunc;  theMatrixFunc = matrixFunc;  thePrintParamFunc = printParamFunc;      if (theGMMMatrix != NULL) {    DeleteImage(theGMMMatrix);    theGMMMatrix = NULL;  }    theNumberOfSamples = numberOfSamples;}/*  * Set the GMM matrix using parameters params. *  *  params!=NULL : then the matrix is set using the params *  matrix!=NULL : then the matrix 'matrix' *  If both are NULL then the identity is used * */static  IMAGE MultMatrix(IMAGE m1,IMAGE m2, IMAGE m3);void SetGMMMatrix(SIGNAL params,IMAGE matrix){  extern void InverseImage(IMAGE i,IMAGE j);  int n,m;  int r,c;  IMAGE im;    /* Case the GMM matrix is just the identity --> no computation is made and theGMMMatrix==NULL */  if (params==NULL && matrix == NULL) {    if (theGMMMatrix != NULL) DeleteImage(theGMMMatrix);    theGMMMatrix = NULL;    return;  }  /* Otherwise we must allocate it if not already done */  if (theGMMMatrix == NULL) theGMMMatrix = NewImage();  /* If matrix is not NULL then we just copy it */  if (matrix != NULL) {    if (matrix->nrow != theMoments->size || matrix->ncol != theMoments->size) Errorf("SetGMMMatrix() : trying to set a bad sized matrix");    CopyImage(matrix,theGMMMatrix);    return;  }    /* Otherwise we must compute it from the parameters *///  Printf("Computing GMM Matrix of size %dx%d\n",theMoments->size,theMoments->size);  SizeImage(theGMMMatrix,theMoments->size,theMoments->size);  ZeroImage(theGMMMatrix);  for (n=0;n<theMoments->size;n++) {    for (m=0;m<=n;m++) {      theGMMMatrix->pixels[n*theMoments->size+m] = theGMMMatrix->pixels[m*theMoments->size+n] = (*theMatrixFunc)(n,m,params);    }  }  //  Printf("   Inversion...");    /* And inverse it *//*  im = TNewImage();  CopyImage(theGMMMatrix,im); */  InverseImage(theGMMMatrix,theGMMMatrix);/*  im = MultMatrix(theGMMMatrix,im,NULL);      for (r = 0; r < im->nrow; r++) {    for (c = 0; c < im->ncol; c++) {      Printf("%.1g ", im->pixels[im->ncol*r+c]);    }    Printf("\n");  }*///  Printf("  Done!\n");}/* The function to minimize */static LWFLOAT MinGMM(SIGNAL param){  int i,j,n1,n2;  LWFLOAT f,r;  n1 = 0;   n2 = theMoments->size;        /* Case the GMM matrix is the identity */  if (theGMMMatrix == NULL) {    for (f=0,i=n1;i<n2;i++) {      r = (*theMomFunc)(i,param);      f += r*r;    }    return(f);  }      /* Case the GMM matrix is not the identity */  for (i=0;i<theMoments->size;i++) theMoments->Y[i] = (*theMomFunc)(i,param);  f = 0;  for (i=0;i<theMoments->size;i++) {    for (j=0;j<=i;j++) {      if (i == j) f+= theMoments->Y[i]*theMoments->Y[i]*theGMMMatrix->pixels[i*(theMoments->size+1)];      else f+= 2*theMoments->Y[i]*theMoments->Y[j]*theGMMMatrix->pixels[i*theMoments->size+j];    }  }//Printf("--> %g %g %g : %g\n",param->Y[0],param->Y[1],param->Y[2],f);    return(f);}/* The gradient of the function to minimize */static void DMinGMM(SIGNAL param, SIGNAL grad){  int i,j,k,n1,n2;    n1 = 0;   n2 = theMoments->size;        /* Case the GMM matrix is the identity */  if (theGMMMatrix == NULL) {    for (k=0;k<theParams->size;k++) {      grad->Y[k] = 0;      for (i=n1;i<n2;i++) grad->Y[k] += 2*(*theDMomFunc)(i,k,param)*(*theMomFunc)(i,param);    }    return;  }  /* Case the GMM matrix is not the identity */  for (i=0;i<theMoments->size;i++) {    theMoments->Y[i] = (*theMomFunc)(i,param);    for (k=0;k<theParams->size;k++)       theDMoments->pixels[k*theMoments->size+i] = (*theDMomFunc)(i,k,param);  }  for (k=0;k<theParams->size;k++) {    grad->Y[k] = 0;    for (i=0;i<theMoments->size;i++) {      for (j=0;j<=i;j++) {        if (i == j) grad->Y[k]+= 2*theMoments->Y[i]*theDMoments->pixels[k*theMoments->size+i]*theGMMMatrix->pixels[i*(theMoments->size+1)];        else grad->Y[k]+= 2*(theDMoments->pixels[k*theMoments->size+i]*theMoments->Y[j]+theDMoments->pixels[k*theMoments->size+j]*theMoments->Y[i])*theGMMMatrix->pixels[i*theMoments->size+j];      }    }  }  //Printf("d--> %g %g %g : %g %g %g\n",param->Y[0],param->Y[1],param->Y[2],grad->Y[0],grad->Y[1],grad->Y[2]);}/* Compute a matrix multiplication */static IMAGE MultMatrix(IMAGE m1,IMAGE m2, IMAGE m3){  int i,j,k,n;    if (m1->ncol != m2->nrow) Errorf("MatrixMult() : Bad size");    if (m3 == NULL) m3 = TNewImage();  SizeImage(m3,m2->ncol,m1->nrow);  if (m3 == m2) Errorf("MultMatrix() : Weird m2 == m3");  if (m3 == m1) Errorf("MultMatrix() : Weird m1 == m3");    for (i=0,n=0;i<m3->nrow;i++) {    for (j=0;j<m3->ncol;j++,n++) {      m3->pixels[n] = 0;      for (k=0;k<m1->ncol;k++) {        m3->pixels[n] += m1->pixels[i*m1->ncol+k]*m2->pixels[k*m2->ncol+j];      }    }  }     return(m3);}/* Compute covariance of the error */static IMAGE ErrorCovariance(SIGNAL param){  extern void InverseImage(IMAGE i,IMAGE j);  IMAGE G,WG,G_WG,temp,O,r,l;  int i,j,n;    /* Compute G */  G = TNewImage();  SizeImage(G,param->size,theMoments->size);  for (n=0,i=0;i<theMoments->size;i++) {    for (j=0;j<param->size;j++,n++) {      G->pixels[n]= (*theMDMomFunc)(i,j,param);    }  }  /* Compute WG */  if (theGMMMatrix == NULL) WG = G;  else WG = MultMatrix(theGMMMatrix,G,NULL);  /* Compute (G*WG)^-1 */  temp = TNewImage();  TranspImage(G,temp);   G_WG= MultMatrix(temp,WG,NULL);  InverseImage(G_WG,G_WG);    /* Compute WG(G*WG)^-1 and its transposition */  r = MultMatrix(WG,G_WG,NULL);  l = TNewImage();  TranspImage(r,l);  /* Compute Omega */  O = TNewImage();  SizeImage(O,theMoments->size,theMoments->size);  ZeroImage(O);  for (i=0;i<theMoments->size;i++) {    for (j=0;j<=i;j++) {      O->pixels[i*theMoments->size+j] = O->pixels[j*theMoments->size+i] = (*theMatrixFunc)(i,j,param);    }  }  /* Compute the result */  temp = MultMatrix(O,r,NULL);  temp = MultMatrix(l,temp,NULL);    for (n=0;n<temp->nrow*temp->ncol;n++) temp->pixels[n]/=theNumberOfSamples;      return(temp);}/* Compute theoretical covariance of the error if W = Omega^-1 */static IMAGE ThErrorCovariance(SIGNAL param){  extern void InverseImage(IMAGE i,IMAGE j);  IMAGE G,O,OG,GOG,temp;  int i,j,n,r,c;    /* Compute G */  G = TNewImage();  SizeImage(G,param->size,theMoments->size);  for (n=0,i=0;i<theMoments->size;i++) {    for (j=0;j<param->size;j++,n++) {      G->pixels[n]= (*theMDMomFunc)(i,j,param);    }  }  /* Compute Omega */  O = TNewImage();  SizeImage(O,theMoments->size,theMoments->size);  ZeroImage(O);  for (i=0;i<theMoments->size;i++) {    for (j=0;j<=i;j++) {      O->pixels[i*theMoments->size+j] = O->pixels[j*theMoments->size+i] = (*theMatrixFunc)(i,j,param);    }  }  /* Inverse it */  InverseImage(O,O);    /* Compute O^-1G */  OG = MultMatrix(O,G,NULL);    /* Compute G*O^-1G */  temp = TNewImage();  TranspImage(G,temp);  GOG = MultMatrix(temp,OG,NULL);  /* And inverse it */  temp = TNewImage();  CopyImage(GOG,temp);  /* Then we must inverse it */  InverseImage(GOG,GOG);  temp = MultMatrix(GOG,temp,NULL);      for (r = 0; r < temp->nrow; r++) {    for (c = 0; c < temp->ncol; c++) {//      Printf("%.1g ", temp->pixels[temp->ncol*r+c]);    }//    Printf("\n");  }  /* Divide by n */    for (n=0;n<GOG->nrow*GOG->ncol;n++) GOG->pixels[n]/=theNumberOfSamples;  return(GOG);}/* Main procedure for GMM estimation */#define STOP .001SIGNAL EstimateGMM(LISTV lv,int nGMMIter,IMAGE matrix){  extern void ConjGrad(SIGNAL startPoint, LWFLOAT tolerance, LWFLOAT (*f)(SIGNAL point),void (*df)(SIGNAL point,SIGNAL grad), int *nIter, LWFLOAT *minVal, char flagPrint);  int nIter,i,n,oldNIter;  LWFLOAT minVal,oldMinVal,f;  LWFLOAT err;  SIGNAL oldParam = NULL;  SIGNAL paramStart;  VALUE val;  char flagContinue;    SetGMMMatrix(NULL,matrix);    flagContinue = YES;  n = 1;  paramStart = TNewSignal();      while (flagContinue) { //    Printf("###%d GMM iteration\n",n);    if (n == 1) { /* If n ==1 then multistart */      for (i=0;i<lv->length;i++) {        if (GetListvNth(lv,i,&val,&f) != signaliType) Errorf("Bad type in listv");        CopySig((SIGNAL) val, paramStart);        if (paramStart->size != theParams->size) Errorf("EstimateGMM() : Weird bad number of parameters");        ConjGrad(paramStart,0.0001,MinGMM,DMinGMM,&nIter,&minVal,NO);        if (oldParam == NULL) oldParam = NewSignal();        else if (oldMinVal < minVal) continue;        CopySig(paramStart,oldParam);        oldMinVal = minVal;        oldNIter = nIter;      }      CopySig(oldParam,paramStart);      minVal = oldMinVal;      nIter = oldNIter;    }        else { /* regular gradient */            ConjGrad(paramStart,0.0000001,MinGMM,DMinGMM,&nIter,&minVal,NO);    }    //    Printf("     %d iterations ---> ",nIter);//    (*thePrintParamFunc)(paramStart);//    Printf("(%f)\n",minVal);    if (n > 1) { /* Should we go on ? */      for (i=0;i<paramStart->size;i++) {        err = fabs((paramStart->Y[i]-oldParam->Y[i])/paramStart->Y[i]);        if (err > STOP) break;      }      if (i == paramStart->size) flagContinue = NO;    }        CopySig(paramStart,oldParam);    if (n==nGMMIter) flagContinue = NO;    n++;        if (flagContinue) SetGMMMatrix(paramStart,NULL);  }    if (oldParam) DeleteSignal(oldParam);    return(paramStart);}/* The main gmm command */void C_GMM(char **argv){  char *action;  SIGNAL param;  int p,i,n,j,nIter;  SIGNAL val,res,startParams,grad;  LISTV lv;  IMAGE matrix;    argv = ParseArgv(argv,tWORD,&action,-1);    if (!strcmp(action,"cut") || !strcmp(action,"dcut")) {    val = NULL;    if (theMoments == NULL) Errorf("You must init the gmm method first");    param = TNewSignal();    if (!strcmp(action,"dcut")) {      argv = ParseArgv(argv,tINT,&j,-1);      if (j< 0 || j >= theParams->size) Errorf("Bad parameter number '%d'",j);      grad = TNewSignal();      SizeSignal(grad,theParams->size,YSIG);    }    SizeSignal(param,theParams->size,YSIG);    i = 0;    while(*argv) {      if (i >= theParams->size) Errorf("Too many parameters");      if (!ParseFloat_(*argv,0,&(param->Y[i]))) {        p = i;        if (!ParseSignalI_(*argv,NULL,&val)) Errorf("Expecting a signal");      }      i++;      argv++;    }    res = TNewSignal();    if (val != NULL) SizeSignal(res,val->size,XYSIG);    else SizeSignal(res,1,YSIG);    for(n=0;n<res->size;n++) {      if (val != NULL) {        param->Y[p] = val->Y[n];        res->X[n] = val->Y[n];      }      if (!strcmp(action,"cut")) res->Y[n] = MinGMM(param);      else {        DMinGMM(param,grad);        res->Y[n] = grad->Y[j];      }    }    if (val != NULL) SetResultValue(res);    else SetResultFloat(res->Y[0]);    return;   }  if (!strcmp(action,"matrix")) {    argv = ParseArgv(argv,tSIGNALI,&startParams,0);    if (startParams->size != theParams->size) Errorf("Bad number of parameters");    SetGMMMatrix(startParams,NO);    SetResultValue(theGMMMatrix);    return;   }    if (!strcmp(action,"N")) {    NoMoreArgs(argv);    SetResultFloat(theNumberOfSamples);    return;   }  if (!strcmp(action,"moment")) {    argv = ParseArgv(argv,tSIGNALI,&startParams,0);    if (startParams->size != theParams->size) Errorf("Bad number of parameters");    val = TNewSignal();    SizeSignal(val,theMoments->size,YSIG);    for (i=0;i<theMoments->size;i++) val->Y[i] = (*theMomFunc)(i,startParams);    SetResultValue(val);    return;   }  if (!strcmp(action,"start")) {    if (theMoments == NULL) Errorf("You must init the gmm method first");    argv = ParseArgv(argv,tLISTV,&lv,tINT_,1,&nIter,tIMAGEI_,NULL,&matrix,0);    if (nIter<0) Errorf("Bad number of iteration '%d'",nIter);    startParams = EstimateGMM(lv,nIter,matrix);    SetResultValue(startParams);        return;  }  if (!strcmp(action,"error")) {    if (theMoments == NULL) Errorf("You must init the gmm method first");    argv = ParseArgv(argv,tSIGNALI,&startParams,0);    SetResultValue(ErrorCovariance(startParams));    return;  }  if (!strcmp(action,"therror")) {    if (theMoments == NULL) Errorf("You must init the gmm method first");    argv = ParseArgv(argv,tSIGNALI,&startParams,0);    SetResultValue(ThErrorCovariance(startParams));    return;  }      Errorf("Unknown action '%s'",action);}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?