gmmtrainmex.cpp

来自「一个关于数据聚类和模式识别的程序,在生物化学,化学中因该都可以用到.希望对大家有」· C++ 代码 · 共 120 行

CPP
120
字号
// This is equivalent to gmmTrain.m
// How to compile:
// MATLAB 7.1: mex gmmTrainMex.cpp d:/users/jang/c/lib/dcpr.cpp d:/users/jang/c/lib/utility.cpp -output gmmTrainMex.dll
// Others: mex gmmTrainMex.cpp d:/users/jang/c/lib/dcpr.cpp d:/users/jang/c/lib/utility.cpp

#include <string.h>
#include <math.h>
#include "mex.h"
#include "d:\users\jang\c\lib\dcpr\dcpr.h"

/* Input Arguments */
#define	DATA		prhs[0]
#define	GAUSSIANNUM	prhs[1]
#define	PLOTOPT		prhs[2]
/* Output Arguments */
#define	MEAN		plhs[0]
#define VARIANCE	plhs[1]
#define WEIGHT		plhs[2]
#define LOGPROB		plhs[3]
			
void mexFunction(
	int nlhs,	mxArray *plhs[],
	int nrhs, const mxArray *prhs[])
{
	double minVariance=1e-6, minImprove=1e-6;
	double *data, *mean, *variance, *weight, *u, prevDistortion, *distortion, *distVec, squaredDist, diff;
	int dim, dataNum, gaussianNum, m, n, i, j, plotOpt, *membership, *clusterSize, maxLoopCount=100, actualLoopCount;
	int beginI, beginJ;
	char message[200];

	/* Check for proper number of arguments */
	if (nrhs<2) {
		strcpy(message, mexFunctionName());
		strcat(message, " requires 2 or 3 input arguments.\n");
		strcat(message, "Usage: [M, V, W, logProb] = ");
		strcat(message, mexFunctionName());
		strcat(message, "(data, gaussianNum, plotOpt)");
		mexErrMsgTxt(message);
	}

	/* Dimensions of the input matrix */
	dim = mxGetM(DATA);
	dataNum = mxGetN(DATA);
	gaussianNum = mxGetM(GAUSSIANNUM)*mxGetN(GAUSSIANNUM)==1? (int)mxGetScalar(GAUSSIANNUM):mxGetN(GAUSSIANNUM);
	
	/* Create a matrix for the return argument */
	MEAN		= mxCreateDoubleMatrix(dim, gaussianNum, mxREAL);
	VARIANCE	= mxCreateDoubleMatrix(1, gaussianNum, mxREAL);
	WEIGHT		= mxCreateDoubleMatrix(1, gaussianNum, mxREAL);
	LOGPROB		= mxCreateDoubleMatrix(maxLoopCount, 1, mxREAL);
	
	/* Assign pointers to the various parameters */
	data = mxGetPr(DATA);
	mean = mxGetPr(MEAN);
	plotOpt = nrhs>=3? (int)mxGetScalar(PLOTOPT):0;
	
	membership=(int *)malloc(dataNum*sizeof(int));
	distVec=(double *)malloc(gaussianNum*sizeof(double));
	clusterSize=(int *)malloc(gaussianNum*sizeof(int));
	distortion=(double *)malloc(maxLoopCount*sizeof(double));
	
	// Set initial parameters
	// Set initial mean
	if (mxGetM(GAUSSIANNUM)*mxGetN(GAUSSIANNUM)==1){
		// Use vqKmeans to find the initial M
		it (plotOpt) fprintf('Using KMEANS to find the initial mu...\n');
		initCenter(data, dim, dataNum, mean, gaussianNum);	// Create new centers for vqKmeans
		actualLoopCount=vqKmeans(data, dim, dataNum, center, centerNum, membership, distVec, clusterSize, distortion, maxLoopCount, plotOpt);
		if (plotOpt) fprintf("Done with vqKmeans."\n');
	} else { 
		// Copy centers to "mean"
		memcpy(mean, mxGetPr(GAUSSIANNUM), dim*gaussianNum*sizeof(double));	// Use the passed centers
	// Set initial variance
	for (i=0; i<gaussianNum; i++)
		variance[i]=-INF
	for (i=0; i<gaussianNum; i++){
		beginI=dim*i;
		for (j=0; j<gaussianNum; j++){
			if (j==i) continue;
			beginJ=dim*j
			squaredDist=0;
			for (k=0; k<dim; k++){
				diff=mean[beginI+k]-mean[beginJ+k];
				squaredDist+=diff*diff;
			}
			if (squaredDist<variance[i])
				variance[i]=squaredDist;
		}
	}
	// Set initial weight
	for (i=0; i<gaussianNum; i++)
		weight[i]=1.0/gaussianNum;

	prevDistortion=1000000;
	for (i=0; i<maxLoopCount; i++){
		distortion[i]=updateCenter(data, dim, dataNum, mean, gaussianNum, membership, distVec, clusterSize);
		if (plotOpt==1) printf("i = %d, distortion = %f\n", i, distortion[i]);
//		printArray(mean, dim, gaussianNum);
		if (fabs(prevDistortion-distortion[i])/prevDistortion<eps){
			actualLoopCount=i+1;
			break;
		}
		prevDistortion=distortion[i];
	}
	
	// Assign additional output arguments
	if (nlhs>=2){
		U = mxCreateDoubleMatrix(gaussianNum, dataNum, mxREAL);
		u = mxGetPr(U);
		for (i=0; i<dataNum; i++)
			u[i*gaussianNum+membership[i]]=1;
	}
	if (nlhs>=3){
		DISTORTION = mxCreateDoubleMatrix(actualLoopCount, 1, mxREAL);
		for (i=0; i<actualLoopCount; i++)
			mxGetPr(DISTORTION)[i]=distortion[i];
	}
	
	free(membership); free(distVec); free(clusterSize); free(distortion);
}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?