gmmtrainmex.cpp
来自「一个关于数据聚类和模式识别的程序,在生物化学,化学中因该都可以用到.希望对大家有」· C++ 代码 · 共 120 行
CPP
120 行
// This is equivalent to gmmTrain.m
// How to compile:
// MATLAB 7.1: mex gmmTrainMex.cpp d:/users/jang/c/lib/dcpr.cpp d:/users/jang/c/lib/utility.cpp -output gmmTrainMex.dll
// Others: mex gmmTrainMex.cpp d:/users/jang/c/lib/dcpr.cpp d:/users/jang/c/lib/utility.cpp
#include <string.h>
#include <math.h>
#include "mex.h"
#include "d:\users\jang\c\lib\dcpr\dcpr.h"
/* Input Arguments */
#define DATA prhs[0]
#define GAUSSIANNUM prhs[1]
#define PLOTOPT prhs[2]
/* Output Arguments */
#define MEAN plhs[0]
#define VARIANCE plhs[1]
#define WEIGHT plhs[2]
#define LOGPROB plhs[3]
void mexFunction(
int nlhs, mxArray *plhs[],
int nrhs, const mxArray *prhs[])
{
double minVariance=1e-6, minImprove=1e-6;
double *data, *mean, *variance, *weight, *u, prevDistortion, *distortion, *distVec, squaredDist, diff;
int dim, dataNum, gaussianNum, m, n, i, j, plotOpt, *membership, *clusterSize, maxLoopCount=100, actualLoopCount;
int beginI, beginJ;
char message[200];
/* Check for proper number of arguments */
if (nrhs<2) {
strcpy(message, mexFunctionName());
strcat(message, " requires 2 or 3 input arguments.\n");
strcat(message, "Usage: [M, V, W, logProb] = ");
strcat(message, mexFunctionName());
strcat(message, "(data, gaussianNum, plotOpt)");
mexErrMsgTxt(message);
}
/* Dimensions of the input matrix */
dim = mxGetM(DATA);
dataNum = mxGetN(DATA);
gaussianNum = mxGetM(GAUSSIANNUM)*mxGetN(GAUSSIANNUM)==1? (int)mxGetScalar(GAUSSIANNUM):mxGetN(GAUSSIANNUM);
/* Create a matrix for the return argument */
MEAN = mxCreateDoubleMatrix(dim, gaussianNum, mxREAL);
VARIANCE = mxCreateDoubleMatrix(1, gaussianNum, mxREAL);
WEIGHT = mxCreateDoubleMatrix(1, gaussianNum, mxREAL);
LOGPROB = mxCreateDoubleMatrix(maxLoopCount, 1, mxREAL);
/* Assign pointers to the various parameters */
data = mxGetPr(DATA);
mean = mxGetPr(MEAN);
plotOpt = nrhs>=3? (int)mxGetScalar(PLOTOPT):0;
membership=(int *)malloc(dataNum*sizeof(int));
distVec=(double *)malloc(gaussianNum*sizeof(double));
clusterSize=(int *)malloc(gaussianNum*sizeof(int));
distortion=(double *)malloc(maxLoopCount*sizeof(double));
// Set initial parameters
// Set initial mean
if (mxGetM(GAUSSIANNUM)*mxGetN(GAUSSIANNUM)==1){
// Use vqKmeans to find the initial M
it (plotOpt) fprintf('Using KMEANS to find the initial mu...\n');
initCenter(data, dim, dataNum, mean, gaussianNum); // Create new centers for vqKmeans
actualLoopCount=vqKmeans(data, dim, dataNum, center, centerNum, membership, distVec, clusterSize, distortion, maxLoopCount, plotOpt);
if (plotOpt) fprintf("Done with vqKmeans."\n');
} else {
// Copy centers to "mean"
memcpy(mean, mxGetPr(GAUSSIANNUM), dim*gaussianNum*sizeof(double)); // Use the passed centers
// Set initial variance
for (i=0; i<gaussianNum; i++)
variance[i]=-INF
for (i=0; i<gaussianNum; i++){
beginI=dim*i;
for (j=0; j<gaussianNum; j++){
if (j==i) continue;
beginJ=dim*j
squaredDist=0;
for (k=0; k<dim; k++){
diff=mean[beginI+k]-mean[beginJ+k];
squaredDist+=diff*diff;
}
if (squaredDist<variance[i])
variance[i]=squaredDist;
}
}
// Set initial weight
for (i=0; i<gaussianNum; i++)
weight[i]=1.0/gaussianNum;
prevDistortion=1000000;
for (i=0; i<maxLoopCount; i++){
distortion[i]=updateCenter(data, dim, dataNum, mean, gaussianNum, membership, distVec, clusterSize);
if (plotOpt==1) printf("i = %d, distortion = %f\n", i, distortion[i]);
// printArray(mean, dim, gaussianNum);
if (fabs(prevDistortion-distortion[i])/prevDistortion<eps){
actualLoopCount=i+1;
break;
}
prevDistortion=distortion[i];
}
// Assign additional output arguments
if (nlhs>=2){
U = mxCreateDoubleMatrix(gaussianNum, dataNum, mxREAL);
u = mxGetPr(U);
for (i=0; i<dataNum; i++)
u[i*gaussianNum+membership[i]]=1;
}
if (nlhs>=3){
DISTORTION = mxCreateDoubleMatrix(actualLoopCount, 1, mxREAL);
for (i=0; i<actualLoopCount; i++)
mxGetPr(DISTORTION)[i]=distortion[i];
}
free(membership); free(distVec); free(clusterSize); free(distortion);
}
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?