mk_ndxb.c

来自「麻省理工学院的人工智能工具箱,很珍贵,希望对大家有用!」· C语言 代码 · 共 129 行

C
129
字号
/* C mex version for mk_multiply_table_ndx.m in potential/tables directory */
/* 3 input, 1 output         */
/* bigdom, smalldom, ns      */
/* indices of extend small table to multiply big table */

#include "mex.h"

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]){
	int     i, j, siz_b, siz_s, count, NB, NS, ND, ndim, yrp, temp;
	double  *pb, *ps, *ns;
	int     *mask, *sx, *sy, *cpsy, *subs, *s, *cpsy2;
	int     *pr;
	int     dims[2];


	pb = mxGetPr(prhs[0]);
	siz_b = mxGetNumberOfElements(prhs[0]);
	ps = mxGetPr(prhs[1]);
	siz_s = mxGetNumberOfElements(prhs[1]);
	ns = mxGetPr(prhs[2]);

	if(siz_s == 0){
		dims[0] = 1;
		dims[1] = 1;
		plhs[0] = mxCreateNumericArray(2, dims, mxINT32_CLASS, mxREAL);
		pr = mxGetData(plhs[0]);
		*pr = 0;
		return;
	}

	mask = malloc(siz_s * sizeof(int));
	count = 0;
	for(i=0; i<siz_s; i++){
		for(j=0; j<siz_b; j++){
			if(ps[i] == pb[j]){
				mask[count] = j;
				count++;
				break;
			}
		}
	}
	
	ndim = siz_b;
	NB = 1;
	sx = (int *)malloc(sizeof(int)*ndim);
	sy = (int *)malloc(sizeof(int)*ndim);
	for(i=0; i<ndim; i++){
		temp = (int)pb[i] - 1;
		sx[i] = (int)ns[temp];
		NB *= (int)ns[temp];
		sy[i] = 1;
	}
	NS = 1;
	for(i=0; i<count; i++){
		temp = (int)ps[i] - 1;
		sy[mask[i]] = (int)ns[temp];
		NS *= (int)ns[temp];
	}
	
	ND = NB / NS;
	dims[0] = NS;
	dims[1] = ND;
	plhs[0] = mxCreateNumericArray(2, dims, mxINT32_CLASS, mxREAL);
	pr = mxGetData(plhs[0]);

	if(NS == 1){
		for(i=0; i<NB; i++){
			pr[i] = 0;
		}
		free(mask);
		free(sx);
		free(sy);
		return;
	}
	if(NS == NB){
		for(i=0; i<NB; i++){
			pr[i] = i;
		}
		free(mask);
		free(sx);
		free(sy);
		return;
	}

	s = (int *)malloc(sizeof(int)*ndim);
	*(cpsy = (int *)malloc(sizeof(int)*ndim)) = 1;
	subs =   (int *)malloc(sizeof(int)*ndim);
	cpsy2 =  (int *)malloc(sizeof(int)*ndim);
	for(i = 0; i < ndim; i++){
		subs[i] = 0;
		s[i] = sx[i] - 1;
	}
			
	for(i = 0; i < ndim-1; i++){
		cpsy[i+1] = cpsy[i]*sy[i]--;
		cpsy2[i] = cpsy[i]*sy[i];
	}
	cpsy2[ndim-1] = cpsy[ndim-1]*(--sy[ndim-1]);

	yrp = 0;
	for(j=0; j<NB; j++){
		pr[j] = yrp;
		for(i = 0; i < ndim; i++){
			if (subs[i] == s[i]){
				subs[i] = 0;
				if (sy[i])
					yrp -= cpsy2[i];
			}
			else{
				subs[i]++;
				if (sy[i])
					yrp += cpsy[i];
				break;
			}
		}
	}

	free(sx);
	free(sy);
	free(s);
	free(cpsy);
	free(subs);
	free(cpsy2);
    free(mask);
}



⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?