📄 mvgmmrnd.c
字号:
/* mvgmmrnd.c
Draw samples from a mixture of multivariate Gaussian PDF
usage: [Z , index] = mvgmmrnd(N , mu , sigma , p , [n1] , ... , [nl])
mu : Mean vector (d x 1 x M x [n1] , ... , [nl])
sigma : Covariance (d x d x M x [n1] , ... , [nl])
p : Weights vector (1 x 1 x M x [n1] , ... , [nl])
Author : S閎astien PARIS (sebastien.paris@lsis.org)
Example
mu = cat(3 , [-5 ; -5] , [0 ; 0] ,[ 5 ; 5]); %(d x 1 x M)
sigma = cat(3 , [2 0; 0 1] , [2 -.2; -.2 2] , [1 .9; .9 1] ); %(d x d x M)
p = cat(3 , [0.3] , [0.2] , [0.5]); %(1 x 1 x M)
N = 500;
[Z , index] = mvgmmrnd(N , mu , sigma , p);
[x , y] = ndellipse(mu , sigma);
plot(Z(1 , :) , Z(2 , :) , 'k+', x , y , 'g' , 'markersize' , 2 , 'linewidth' , 2);
hold on
plot(reshape(mu(1 , : , :) , 1 , 3) , reshape(mu(2 , : , :) , 1 , 3) , 'r+' , 'markersize' , 6);
hold off
To compile
mex -DranSHR3 mvgmmrnd.c or mex -DranKISS mvgmmrnd.c
Myself, I use Intel CPP compiler as :
mex -DranKISS -f mexopts_intelamd.bat mvgmmrnd.c
or
mex -DranSHR3 -f mexopts_intelamd.bat mvgmmrnd.c
Ver 1.2
Changelog V 1.2 (06/05/05) General Call syntax.
V 1.1 (03/26/05) Bug fix : Now Z = mvgmmrnd(N , mu , sigma , []); works if mu is (d x 1) & sigma is (d x d)
It permits to draw samples from a simple Multivariate Gaussian pdf instead of a Mixture of Multivariate Gaussian pdf
V 1.0 (03/04/05) Initial realese
*/
#include <math.h>
#include <time.h>
#include "mex.h"
/*---------------- Basic generators definition ------------------- */
#define mix(a , b , c) \
{ \
a -= b; a -= c; a ^= (c>>13); \
b -= c; b -= a; b ^= (a<<8); \
c -= a; c -= b; c ^= (b>>13); \
a -= b; a -= c; a ^= (c>>12); \
b -= c; b -= a; b ^= (a<<16); \
c -= a; c -= b; c ^= (b>>5); \
a -= b; a -= c; a ^= (c>>3); \
b -= c; b -= a; b ^= (a<<10); \
c -= a; c -= b; c ^= (b>>15); \
}
#define zigstep 128 // Number of Ziggurat'Steps
#define znew (z = 36969*(z&65535) + (z>>16) )
#define wnew (w = 18000*(w&65535) + (w>>16) )
#define MWC ((znew<<16) + wnew )
#define SHR3 ( jsr ^= (jsr<<17), jsr ^= (jsr>>13), jsr ^= (jsr<<5) )
#define CONG (jcong = 69069*jcong + 1234567)
#define KISS ((MWC^CONG) + SHR3)
#ifdef ranKISS
#define randint KISS
#define rand() (randint*2.328306e-10)
#endif
#ifdef ranSHR3
#define randint SHR3
#define rand() (0.5 + (signed)randint*2.328306e-10)
#endif
/*--------------------------------------------------------------- */
typedef unsigned long UL;
/*--------------------------------------------------------------- */
static UL jsrseed = 31340134 , jsr;
#ifdef ranKISS
static UL z=362436069, w=521288629, jcong=380116160;
#endif
static UL jz , iz , kn[zigstep];
static long hz;
static float wn[zigstep] , fn[zigstep];
/*--------------------------------------------------------------- */
void randini(void);
void randnini(void);
float nfix(void);
double randn(void);
void matvect(double * , double * , double *, int , int , int);
void chol(double * , double * , int , int);
void mvgmmrnd(double * , double * , double * , int , int , int , int , int ,
double * , double * ,
double * , double * , double *);
/*--------------------------------------------------------------- */
void mexFunction( int nlhs, mxArray *plhs[],
int nrhs, const mxArray *prhs[] )
{
double *mu , *sigma , *p;
double *Z , *index;
double *choles ,*v , *b;
const int *dimsmu , *dimssigma , *dimsp;
int *dimsZ;
int numdimsmu , numdimssigma , numdimsp;
int numdimsZ;
int i , d , N , K=1 , V =1 , M = 1;
/* Check input */
if(nrhs < 4)
{
mexErrMsgTxt("At least 4 inputs argument are required for mvgmmrnd");
}
/* Input 1 */
N = (int) mxGetScalar(prhs[0]);
/* Input 2 */
mu = mxGetPr(prhs[1]);
numdimsmu = mxGetNumberOfDimensions(prhs[1]);
dimsmu = mxGetDimensions(prhs[1]);
if ( (dimsmu[1] != 1))
{
mexErrMsgTxt("mu must be (d x 1 x M x n1 x ... x nl)");
}
d = dimsmu[0];
if (numdimsmu > 2)
{
M = dimsmu[2];
}
for(i = 3 ; i < numdimsmu ; i++)
{
K *= dimsmu[i];
}
/* Input 3 */
sigma = mxGetPr(prhs[2]);
numdimssigma = mxGetNumberOfDimensions(prhs[2]);
dimssigma = mxGetDimensions(prhs[2]);
if ( (dimssigma[0] !=d) && (dimssigma[1] != d) ) //&& (dimssigma[2] != M)
{
mexErrMsgTxt("sigma must be (d x d x M x n1 x ... x nl)");
}
/* Input 4 */
numdimsp = mxGetNumberOfDimensions(prhs[3]);
dimsp = mxGetDimensions(prhs[3]);
if ( (dimsp[0] >1) || (dimsp[1] >1) ) // && (dimsp[2] != M)
{
mexErrMsgTxt("p must be (1 x 1 x M)");
}
if ( (dimsp[0] == 0) && (dimsp[1] == 0 ) ) // Empty matrix
{
p = (double *)mxMalloc(sizeof(double));
p[0] = 1.0;
M = 1;
}
if ( (dimsp[0] == 1) && (dimsp[1] == 1 ) )
{
p = mxGetPr(prhs[3]);
}
/* Output 1 */
numdimsZ = 2 + (numdimsmu - 3) + (nrhs - 4);
dimsZ = (int *)mxMalloc(numdimsZ*sizeof(int));
dimsZ[0] = d;
dimsZ[1] = N;
for(i = 3 ; i < numdimsmu ; i++)
{
dimsZ[i - 1] = dimsmu[i];
}
for (i = 4 ; i < nrhs ; i++)
{
dimsZ[(numdimsmu - 3) + i - 2 ] = (int) mxGetScalar(prhs[i]) ;
V *= dimsZ[(numdimsmu - 3) + i - 2 ];
}
/* Output 1 */
plhs[0] = mxCreateNumericArray(numdimsZ , dimsZ, mxDOUBLE_CLASS, mxREAL);
Z = mxGetPr(plhs[0]);
dimsZ[0] = 1;
plhs[1] = mxCreateNumericArray(numdimsZ , dimsZ, mxDOUBLE_CLASS, mxREAL);
index = mxGetPr(plhs[1]);
/* vecteur temporaire */
choles = (double *)mxMalloc((d*d*M*K)*sizeof(double));
v = (double *)mxMalloc((d)*sizeof(double));
b = (double *)mxMalloc((d)*sizeof(double));
/* Rand ~U[0,1] Seed initialization */
randini();
/* Initialize Ziggurat Table with zigstep steps for Normal(0,1) */
randnini();
/* Main call */
mvgmmrnd(mu , sigma , p , d , M , N , K , V , Z , index , choles , v , b);
/* Free ressources */
mxFree(choles);
mxFree(v);
mxFree(b);
mxFree(dimsZ);
if ( (dimsp[0] == 0) && (dimsp[1] == 0 ) ) // Empty matrix
{
mxFree(p) ;
}
}
/* ----------------------------------------------------------------------- */
void mvgmmrnd(double *mu , double *sigma , double *p , int d , int M , int N , int K , int V ,
double *Z , double *index ,
double *choles , double *v , double *b)
{
int h , l , i , j , jd , d2 = d*d, val , KN = K*N , hKN , hdKN , lM , lN , ldN , ii , dN = d*N , dM = d*M , ldM , lddM;
double temp , cP;
// Compute choles=chol(sigma)'; //
chol(sigma , choles , d , M*K);
for (h = 0 ; h < V ; h++)
{
hKN = h*KN;
hdKN = d*hKN;
for (l = 0 ; l < K ; l++)
{
lM = l*M;
lN = l*N + hKN;
ldN = l*dN + hdKN;
ldM = l*dM;
lddM = d*ldM;
for (j = 0 ; j < N ; j++)
{
temp = rand();
val = 1;
cP = p[0 + lM];
while( (temp > cP) && (val < M))
{
cP +=p[val + lM];
val++;
}
index[j + lN] = val;
for (i = 0 ; i < d ; i++)
{
v[i] = randn();
}
matvect(choles , v , b , d , d , (val - 1)*d2 + lddM);
ii = (val - 1)*d + ldM;
jd = j*d + ldN;
for (i = 0 ; i < d ; i++)
{
Z[i + jd] = b[i] + mu[i + ii];
}
}
}
}
}
/*----------------------------------------------------------*/
void matvect(double *A , double *v , double *w, int d , int n , int off)
/*
w = Av, A(d x n), v(n x 1)
*/
{
int t , i ;
register double temp;
for (t = 0 ; t < d ; t++)
{
temp = 0.0;
for(i = 0 ; i < n ; i++)
{
temp += A[t + i*d + off]*v[i];
}
w[t] = temp;
}
}
/*----------------------------------------------------------*/
void chol(double *Q , double *D , int d , int M)
{
int i , j , r , d2=d*d;
int id , d1 = d - 1 , i1 , i1d , l , knnn , jd , jv , v , iv;
double sum , p , inv_p;
for (r = 0 ; r < M ; r++)
{
v = r*d2;
for (i = 0 ; i < d2 ; i++)
{
D[i + v] = Q[i + v];
}
p = sqrt(D[0 + v]);
inv_p = 1.0/p;
D[0 + v] = p;
for(i = 1 ; i < d; i++)
{
D[d*i + v] *= inv_p;
}
for(i = 1 ; i < d; i++)
{
id = i*d;
i1d = id - d;
i1 = i - 1;
iv = i + v;
sum = D[iv + id]; //sum = B[i][i]
for(l = 0; l < i; ++l)
{
knnn = id + l;
sum -= D[knnn + v]*D[knnn + v];
}
p = sqrt(sum);
inv_p = 1.0/p;
for(j = d1; j > i ; --j)
{
jd = j*d;
sum = D[jd + iv];
for(l = 0; l < i ; ++l)
{
sum -= D[jd + l + v]*D[id + l + v];
}
D[jd + iv] = sum*inv_p;
}
D[iv + id] = p;
for(l = d1 ; l>i1 ; l--)
{
D[l + i1d + v] = 0.0;
}
}
// D = D';
for (j = 0 ; j < d ; j++)
{
jd = j*d + v;
jv = j + v;
for(i = j + 1 ; i < d ; i++)
{
D[i + jd] = D[jv + i*d];
D[jv + i*d] = 0.0;
}
}
}
}
/* ----------------------------------------------------------------------- */
void randini(void)
{
/* SHR3 Seed initialization */
jsrseed = (UL) time( NULL );
jsr ^= jsrseed;
/* KISS Seed initialization */
#ifdef ranKISS
z = (UL) time( NULL );
w = (UL) time( NULL );
jcong = (UL) time( NULL );
mix(z , w , jcong);
#endif
}
/* --------------------------------------------------------------------------- */
void randnini(void)
{
register const double m1 = 2147483648.0, m2 = 4294967296.0 ;
register double invm1;
register double dn = 3.442619855899 , tn = dn , vn = 9.91256303526217e-3 , q;
int i;
/* Ziggurat tables for randn */
invm1 = 1.0/m1;
q = vn/exp(-0.5*dn*dn);
kn[0] = (dn/q)*m1;
kn[1] = 0;
wn[0] = q*invm1;
wn[zigstep - 1 ] = dn*invm1;
fn[0] = 1.0;
fn[zigstep - 1] = exp(-0.5*dn*dn);
for(i = (zigstep - 2) ; i >= 1 ; i--)
{
dn = sqrt(-2.*log(vn/dn + exp(-0.5*dn*dn)));
kn[i+1] = (dn/tn)*m1;
tn = dn;
fn[i] = exp(-0.5*dn*dn);
wn[i] = dn*invm1;
}
}
/* --------------------------------------------------------------------------- */
float nfix(void)
{
const float r = 3.442620f; /* The starting of the right tail */
static float x, y;
for(;;)
{
x = hz*wn[iz];
if(iz == 0)
{ /* iz==0, handle the base strip */
do
{
x = -log(rand())*0.2904764; /* .2904764 is 1/r */
y = -log(rand());
}
while( (y + y) < (x*x));
return (hz > 0) ? (r + x) : (-r - x);
}
if( (fn[iz] + rand()*(fn[iz-1] - fn[iz])) < ( exp(-0.5*x*x) ) )
{
return x;
}
hz = randint;
iz = (hz & (zigstep - 1));
if(abs(hz) < kn[iz])
{
return (hz*wn[iz]);
}
}
}
/* --------------------------------------------------------------------------- */
double randn(void)
{
hz = randint;
iz = (hz & (zigstep - 1));
return (abs(hz) < kn[iz]) ? (hz*wn[iz]) : ( nfix() );
};
/* --------------------------------------------------------------------------- */
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -