📄 h2m_glvq_model.c
字号:
for (l = 0 ; l < m ; l++)
{
ind = labels[l];
ld = l*d;
for (i = 0 ; i < Ntrain ; i++)
{
if(ytrain[i] == ind)
{
id = i*d;
for(j = 0 ; j < d ; j++)
{
temp = (Xtrain[j + id] - mtemp[j + ld]);
stdtemp[j + ld] += (temp*temp);
}
}
}
}
for (l = 0 ; l < m ; l++)
{
ld = l*d;
temp = 1.0/(Nk[l] - 1);
for(j = 0 ; j < d ; j++)
{
stdtemp[j + ld] = temp*sqrt(stdtemp[j + ld]);
}
}
plhs[0] = mxCreateDoubleMatrix(d , Nproto, mxREAL);
Wproto_est = mxGetPr(plhs[0]);
for(l = 0 ; l < Nproto ; l++)
{
ld = l*d;
for(i = 0 ; i < m ; i++)
{
if(labels[i] == yproto_est[l] )
{
indice = i*m;
}
}
for(i = 0 ; i < d ; i++)
{
Wproto_est[i + ld] = mtemp[i + indice] + stdtemp[i + indice]*randn();
}
}
}
else
{
Wproto = mxGetPr(prhs[2]);
if(mxGetNumberOfDimensions(prhs[2]) !=2 || mxGetM(prhs[2]) != d)
{
mexErrMsgTxt("Wproto must be (d x Nproto)");
}
Nproto = mxGetN(prhs[2]);
plhs[0] = mxCreateDoubleMatrix(d , Nproto, mxREAL);
Wproto_est = mxGetPr(plhs[0]);
for( i = 0 ; i < d*Nproto ; i++)
{
Wproto_est[i] = Wproto[i];
}
}
// Outputs
plhs[2] = mxCreateDoubleMatrix(1 , options.nb_iterations , mxREAL);
E_H2MGLVQ = mxGetPr(plhs[2]);
temp_train = mxMalloc(Ntrain*sizeof(double));
index_train = mxMalloc(Ntrain*sizeof(int));
dist = mxMalloc(Nproto*sizeof(double));
powdist = mxMalloc(Nproto*sizeof(double));
/* Main Call */
h2m_glvq_model(Xtrain , ytrain , options , d , Ntrain , Nproto , m ,
Wproto_est , yproto_est , E_H2MGLVQ,
labels , temp_train , index_train , dist , powdist ,
Nk);
mxFree(labels);
mxFree(ytrainsorted);
mxFree(temp_train);
mxFree(index_train);
mxFree(dist);
mxFree(powdist);
mxFree(Nk);
if ((nrhs < 3) || mxIsEmpty(prhs[2]) )
{
mxFree(mtemp);
mxFree(stdtemp);
}
}
/*-------------------------------------------------------------------------------------------------------------- */
void h2m_glvq_model(double *Xtrain , double *ytrain , OPTIONS options , int d , int Ntrain , int Nproto , int m ,
double *Wproto_est , double *yproto_est , double *E_H2MGLVQ,
double *labels , double *temp_train , int *index_train , double *dist , double *powdist ,
int *Nk)
{
int i , j , l , t , ind_label , indice , ind , kmin , lmin , indtemp , jd , Nkk , Nll;
double yi , temp , sum , dkmin , dlmin , double_max = 1.79769313486231*10e307 , nu , lnu , cte , tmp ;
double stept = (options.tmax - options.tmin)/(options.nb_iterations - 1);
double epsilonl , epsilonk = options.epsilonk , cte_epsilonl = (options.epsilonk/options.epsilonl);
double sumdkt , sumdlt , tcurrent , dkH , dlH , powdkmin , powdlmin , ak , al , alpha , tmpkmin , tmplmin;
double ctetmp1 , ctetmp2 , ctetmp3 , ctetmp4 , ctetmp5 , ctetmp6;
for (l = 0 ; l < m ; l++)
{
Nk[l] = 0;
ind = (int)labels[l];
for (i = 0 ; i < Nproto ; i++)
{
if(yproto_est[i] == ind)
{
Nk[l]++;
}
}
}
for(i = 0 ; i < Ntrain ; i++)
{
index_train[i] = i;
}
tcurrent = options.tmax;
for (t = 0 ; t <options.nb_iterations ; t++)
{
epsilonl = options.epsilonl*pow(cte_epsilonl ,(t/(options.nb_iterations - 1)));
if(options.shuffle)
{
for(i = 0 ; i < Ntrain ; i++)
{
temp_train[i] = rand();
index_train[i] = i;
}
qsindex(temp_train , index_train , 0 , Ntrain - 1 );
}
E_H2MGLVQ[t] = 0.0;
for(i = 0 ; i < Ntrain ; i++)
{
indice = index_train[i];
ind = indice*d;
yi = ytrain[indice];
for (j = 0 ; j < m ; j++)
{
if(labels[j] == yi)
{
ind_label = j;
}
}
dkmin = double_max;
dlmin = double_max;
kmin = 0;
lmin = 0;
sumdkt = 0.0;
sumdlt = 0.0;
for(j = 0 ; j < Nproto ; j++)
{
sum = 0.0;
jd = j*d;
for(l = 0 ; l < d ; l++)
{
temp = (Xtrain[l + ind] - Wproto_est[l + jd]);
sum += (temp*temp);
}
dist[j] = sum;
powdist[j] = pow(1.0/sum , tcurrent);
if(yproto_est[j] == yi)
{
if(sum < dkmin)
{
dkmin = sum;
kmin = j;
}
sumdkt += powdist[j];
}
else
{
if(sum < dlmin)
{
dlmin = sum;
lmin = j;
}
sumdlt += powdist[j];
}
}
Nkk = Nk[ind_label];
Nll = Nproto - Nkk;
tmpkmin = sumdkt/powdist[kmin];
dkH = (Nkk*dkmin)/(tmpkmin);
powdkmin = tmpkmin/dkmin;
tmplmin = sumdlt/powdist[lmin];
dlH = (Nll*dlmin)/(tmplmin);
powdlmin = tmplmin/dlmin;
tmp = 1.0/(dkH + dlH);
nu = (dkH - dlH)*tmp;
lnu = 1.0/(1.0 + exp(-options.xi*nu));
E_H2MGLVQ[t] += lnu;
ak = Nkk/(powdkmin*powdkmin);
al = Nll/(powdlmin*powdlmin);
cte = 4.0*options.xi*lnu*(1.0 - lnu)*(tmp*tmp);
ctetmp1 = epsilonk*dlH*cte*ak;
ctetmp2 = epsilonl*dkH*cte*al;
ctetmp3 = -ctetmp1*((tcurrent - 1.0)*(powdkmin/dkmin) - (tcurrent/(dkmin*dkmin)));
ctetmp4 = ctetmp2*((tcurrent - 1.0)*(powdlmin/dlmin) - (tcurrent/(dlmin*dlmin)));
ctetmp5 = ctetmp1*(powdkmin/sumdkt)*tcurrent;
ctetmp6 = -ctetmp2*(powdlmin/sumdlt)*tcurrent;
for(j = 0 ; j < Nproto ; j++)
{
jd = j*d;
if(yproto_est[j] == yi)
{
if(kmin == j)
{
alpha = ctetmp3;
}
else
{
alpha = ctetmp5*(powdist[j]/dist[j]);
}
}
else
{
if(lmin == j)
{
alpha = ctetmp4;
}
else
{
alpha = ctetmp6*(powdist[j]/dist[j]);
}
}
for( l = 0 ; l < d ; l++)
{
indtemp = l + jd;
Wproto_est[indtemp] += alpha*(Xtrain[l + ind] - Wproto_est[indtemp]);
}
}
}
tcurrent -= stept;
}
}
/*----------------------------------------------------------------------------------------------------------------------------------------- */
void qs(double *a , int lo, int hi)
{
// lo is the lower index, hi is the upper index
// of the region of array a that is to be sorted
int i=lo, j=hi;
double x=a[(lo+hi)/2] , h;
// partition
do
{
while (a[i]<x) i++;
while (a[j]>x) j--;
if (i<=j)
{
h = a[i];
a[i] = a[j];
a[j] = h;
i++;
j--;
}
}
while (i<=j);
// recursion
if (lo<j) qs(a , lo , j);
if (i<hi) qs(a , i , hi);
}
/*----------------------------------------------------------------------------------------------------------------------------------------- */
void qsindex (double *a, int *index , int lo, int hi)
{
// lo is the lower index, hi is the upper index
// of the region of array a that is to be sorted
int i=lo, j=hi , ind;
double x=a[(lo+hi)/2] , h;
// partition
do
{
while (a[i]<x) i++;
while (a[j]>x) j--;
if (i<=j)
{
h = a[i];
a[i] = a[j];
a[j] = h;
ind = index[i];
index[i] = index[j];
index[j] = ind;
i++;
j--;
}
}
while (i<=j);
// recursion
if (lo<j) qsindex(a , index , lo , j);
if (i<hi) qsindex(a , index , i , hi);
}
/* --------------------------------------------------------------------------- */
void randini(void)
{
/* SHR3 Seed initialization */
jsrseed = (UL) time( NULL );
jsr ^= jsrseed;
/* KISS Seed initialization */
#ifdef ranKISS
z = (UL) time( NULL );
w = (UL) time( NULL );
jcong = (UL) time( NULL );
mix(z , w , jcong);
#endif
}
/* --------------------------------------------------------------------------- */
void randnini(void)
{
register const double m1 = 2147483648.0, m2 = 4294967296.0 ;
register double invm1;
register double dn = 3.442619855899 , tn = dn , vn = 9.91256303526217e-3 , q;
int i;
/* Ziggurat tables for randn */
invm1 = 1.0/m1;
q = vn/exp(-0.5*dn*dn);
kn[0] = (dn/q)*m1;
kn[1] = 0;
wn[0] = q*invm1;
wn[zigstep - 1 ] = dn*invm1;
fn[0] = 1.0;
fn[zigstep - 1] = exp(-0.5*dn*dn);
for(i = (zigstep - 2) ; i >= 1 ; i--)
{
dn = sqrt(-2.*log(vn/dn + exp(-0.5*dn*dn)));
kn[i+1] = (dn/tn)*m1;
tn = dn;
fn[i] = exp(-0.5*dn*dn);
wn[i] = dn*invm1;
}
}
/* --------------------------------------------------------------------------- */
double nfix(void)
{
const double r = 3.442620; /* The starting of the right tail */
static double x, y;
for(;;)
{
x = hz*wn[iz];
if(iz == 0)
{ /* iz==0, handle the base strip */
do
{
x = -log(rand())*0.2904764; /* .2904764 is 1/r */
y = -log(rand());
}
while( (y + y) < (x*x));
return (hz > 0) ? (r + x) : (-r - x);
}
if( (fn[iz] + rand()*(fn[iz-1] - fn[iz])) < ( exp(-0.5*x*x) ) )
{
return x;
}
hz = randint;
iz = (hz & (zigstep - 1));
if(abs(hz) < kn[iz])
{
return (hz*wn[iz]);
}
}
}
/* --------------------------------------------------------------------------- */
double randn(void)
{
hz = randint;
iz = (hz & (zigstep - 1));
return (abs(hz) < kn[iz]) ? (hz*wn[iz]) : ( nfix() );
}
/* --------------------------------------------------------------------------- */
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -