mk_ndxb.c
来自「麻省理工学院的人工智能工具箱,很珍贵,希望对大家有用!」· C语言 代码 · 共 129 行
C
129 行
/* C mex version for mk_multiply_table_ndx.m in potential/tables directory */
/* 3 input, 1 output */
/* bigdom, smalldom, ns */
/* indices of extend small table to multiply big table */
#include "mex.h"
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]){
int i, j, siz_b, siz_s, count, NB, NS, ND, ndim, yrp, temp;
double *pb, *ps, *ns;
int *mask, *sx, *sy, *cpsy, *subs, *s, *cpsy2;
int *pr;
int dims[2];
pb = mxGetPr(prhs[0]);
siz_b = mxGetNumberOfElements(prhs[0]);
ps = mxGetPr(prhs[1]);
siz_s = mxGetNumberOfElements(prhs[1]);
ns = mxGetPr(prhs[2]);
if(siz_s == 0){
dims[0] = 1;
dims[1] = 1;
plhs[0] = mxCreateNumericArray(2, dims, mxINT32_CLASS, mxREAL);
pr = mxGetData(plhs[0]);
*pr = 0;
return;
}
mask = malloc(siz_s * sizeof(int));
count = 0;
for(i=0; i<siz_s; i++){
for(j=0; j<siz_b; j++){
if(ps[i] == pb[j]){
mask[count] = j;
count++;
break;
}
}
}
ndim = siz_b;
NB = 1;
sx = (int *)malloc(sizeof(int)*ndim);
sy = (int *)malloc(sizeof(int)*ndim);
for(i=0; i<ndim; i++){
temp = (int)pb[i] - 1;
sx[i] = (int)ns[temp];
NB *= (int)ns[temp];
sy[i] = 1;
}
NS = 1;
for(i=0; i<count; i++){
temp = (int)ps[i] - 1;
sy[mask[i]] = (int)ns[temp];
NS *= (int)ns[temp];
}
ND = NB / NS;
dims[0] = NS;
dims[1] = ND;
plhs[0] = mxCreateNumericArray(2, dims, mxINT32_CLASS, mxREAL);
pr = mxGetData(plhs[0]);
if(NS == 1){
for(i=0; i<NB; i++){
pr[i] = 0;
}
free(mask);
free(sx);
free(sy);
return;
}
if(NS == NB){
for(i=0; i<NB; i++){
pr[i] = i;
}
free(mask);
free(sx);
free(sy);
return;
}
s = (int *)malloc(sizeof(int)*ndim);
*(cpsy = (int *)malloc(sizeof(int)*ndim)) = 1;
subs = (int *)malloc(sizeof(int)*ndim);
cpsy2 = (int *)malloc(sizeof(int)*ndim);
for(i = 0; i < ndim; i++){
subs[i] = 0;
s[i] = sx[i] - 1;
}
for(i = 0; i < ndim-1; i++){
cpsy[i+1] = cpsy[i]*sy[i]--;
cpsy2[i] = cpsy[i]*sy[i];
}
cpsy2[ndim-1] = cpsy[ndim-1]*(--sy[ndim-1]);
yrp = 0;
for(j=0; j<NB; j++){
pr[j] = yrp;
for(i = 0; i < ndim; i++){
if (subs[i] == s[i]){
subs[i] = 0;
if (sy[i])
yrp -= cpsy2[i];
}
else{
subs[i]++;
if (sy[i])
yrp += cpsy[i];
break;
}
}
}
free(sx);
free(sy);
free(s);
free(cpsy);
free(subs);
free(cpsy2);
free(mask);
}
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?