vqkmeansmex.cpp

来自「一个关于数据聚类和模式识别的程序,在生物化学,化学中因该都可以用到.希望对大家有」· C++ 代码 · 共 77 行

CPP
77
字号
// This is equivalent to vqKmeans.m
// How to compile:
// MATLAB 7.1: mex vqKmeansMex.cpp d:/users/jang/c/lib/dcpr.cpp d:/users/jang/c/lib/utility.cpp -output vqKmeansMex.dll
// Others: mex vqKmeansMex.cpp d:/users/jang/c/lib/dcpr.cpp d:/users/jang/c/lib/utility.cpp

#include <string.h>
#include <math.h>
#include "mex.h"
#include "d:\users\jang\c\lib\dcpr\dcpr.h"

/* Input Arguments */
#define	DATA		prhs[0]
#define	CLUSTERNUM	prhs[1]
#define	PLOTOPT		prhs[2]
/* Output Arguments */
#define	CENTER	plhs[0]
#define U	plhs[1]
#define DISTORTION	plhs[2]

void mexFunction(
	int nlhs,	mxArray *plhs[],
	int nrhs, const mxArray *prhs[])
{
	double *data, *center, *u, prevDistortion, *distortion, *distVec;
	int dim, dataNum, centerNum, m, n, i, j, plotOpt, *membership, *clusterSize, maxLoopCount=100, actualLoopCount;
	char message[200];

	/* Check for proper number of arguments */
	if (nrhs<2) {
		strcpy(message, mexFunctionName());
		strcat(message, " requires 2 or 3 input arguments.\n");
		strcat(message, "Usage: center = ");
		strcat(message, mexFunctionName());
		strcat(message, "(data, centerNum, plotOpt)");
		mexErrMsgTxt(message);
	}

	/* Dimensions of the input matrix */
	dim = mxGetM(DATA);
	dataNum = mxGetN(DATA);
	centerNum = mxGetM(CLUSTERNUM)*mxGetN(CLUSTERNUM)==1? (int)mxGetScalar(CLUSTERNUM):mxGetN(CLUSTERNUM);

	/* Create a matrix for the return argument */
	CENTER = mxCreateDoubleMatrix(dim, centerNum, mxREAL);

	/* Assign pointers to the various parameters */
	data = mxGetPr(DATA);
	center = mxGetPr(CENTER);
	plotOpt = nrhs>=3? (int)mxGetScalar(PLOTOPT):0;
	
	if (mxGetM(CLUSTERNUM)*mxGetN(CLUSTERNUM)==1)
		initVqCenter(data, dim, dataNum, center, centerNum);	// Create new centers
	else
		memcpy(center, mxGetPr(CLUSTERNUM), dim*centerNum*sizeof(double));	// Use the passed centers
	
	membership=(int *)malloc(dataNum*sizeof(int));
	distVec=(double *)malloc(centerNum*sizeof(double));
	clusterSize=(int *)malloc(centerNum*sizeof(int));
	distortion=(double *)malloc(maxLoopCount*sizeof(double));
	
	actualLoopCount=vqKmeans(data, dim, dataNum, center, centerNum, membership, distVec, clusterSize, distortion, maxLoopCount, plotOpt);
	
	// Assign additional output arguments
	if (nlhs>=2){
		U = mxCreateDoubleMatrix(centerNum, dataNum, mxREAL);
		u = mxGetPr(U);
		for (i=0; i<dataNum; i++)
			u[i*centerNum+membership[i]]=1;
	}
	if (nlhs>=3){
		DISTORTION = mxCreateDoubleMatrix(actualLoopCount, 1, mxREAL);
		for (i=0; i<actualLoopCount; i++)
			mxGetPr(DISTORTION)[i]=distortion[i];
	}
	
	free(membership); free(distVec); free(clusterSize); free(distortion);
}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?