📄 grlvq_model.c
字号:
{
mxFree(mtemp);
mxFree(stdtemp);
mxFree(Nk);
}
if(mxIsEmpty(prhs[4]))
{
mxFree(lambda);
}
}
/*-------------------------------------------------------------------------------------------------------------- */
void grlvq_model(double *Xtrain , double *ytrain , OPTIONS options , int d , int Ntrain , int Nproto , int m ,
double *Wproto_est , double *yproto_est , double *E_GRLVQ , double *lambda , double *lambda_est ,
double *labels , double *temp_train , int *index_train)
{
int i , j , l , t , indice , ind , kmin , lmin , offkmin , offlmin , indtemp , indtemp1 , jd;
double yi , temp , dkmin , dlmin , double_max = 1.79769313486231*10e307 , nu , lnu , ctek , ctel , ctelamb ,tmptmp, tmptmp2, tmptmp1,
tmptmp3 , tmptmp4 , cte_tmp , cte_lambda , dist_proto;
for(i = 0 ; i < Ntrain ; i++)
{
index_train[i] = i;
}
for (t = 0 ; t <options.nb_iterations ; t++)
{
if(options.shuffle)
{
for(i = 0 ; i < Ntrain ; i++)
{
temp_train[i] = rand();
index_train[i] = i;
}
qsindex(temp_train , index_train , 0 , Ntrain - 1 );
}
E_GRLVQ[t] = 0.0;
for(i = 0 ; i < Ntrain ; i++)
{
indice = index_train[i];
ind = indice*d;
yi = ytrain[indice];
dkmin = double_max;
dlmin = double_max;
kmin = 0;
lmin = 0;
for(j = 0 ; j < Nproto ; j++)
{
dist_proto = 0.0;
jd = j*d;
if (options.metric_method)
{
for(l = 0 ; l < d ; l++)
{
temp = (Xtrain[l + ind] - Wproto_est[l + jd]);
dist_proto += lambda_est[l]*temp*temp;
}
}
else
{
for(l = 0 ; l < d ; l++)
{
temp = (Xtrain[l + ind] - Wproto_est[l + jd]);
dist_proto += lambda_est[l]*temp*temp*temp*temp;
}
}
if(yproto_est[j] == yi)
{
if(dist_proto < dkmin)
{
dkmin = dist_proto;
kmin = j;
}
}
else
{
if(dist_proto < dlmin)
{
dlmin = dist_proto;
lmin = j;
}
}
}
tmptmp = 1.0/(dkmin + dlmin);
nu = (dkmin - dlmin)*tmptmp;
lnu = 1.0/(1.0 + exp(-options.xi*nu)); //*(t+1);
E_GRLVQ[t]+= lnu;
// cte_tmp = options.xi*lnu*(1.0 - lnu)*tmptmp*tmptmp;
cte_tmp = options.xi*lnu*(1.0 - lnu)*tmptmp;
ctek = cte_tmp*options.epsilonk*dlmin;
ctel = cte_tmp*options.epsilonl*dkmin;
ctelamb = options.epsilonlambda*cte_tmp;
offkmin = kmin*d;
offlmin = lmin*d;
cte_lambda = 0.0;
if (options.metric_method==2)
{
for( l = 0 ; l < d ; l++)
{
indtemp = l + offkmin;
tmptmp1 = Xtrain[l + ind] - Wproto_est[indtemp];
Wproto_est[indtemp] += ctek*8.0*lambda_est[l]*tmptmp1*tmptmp1*tmptmp1;
indtemp1 = l + offlmin;
tmptmp2 = Xtrain[l + ind] - Wproto_est[indtemp1];
Wproto_est[indtemp1] -= ctel*8.0*lambda_est[l]*tmptmp2*tmptmp2*tmptmp2;
if (options.updatelambda)
{
tmptmp3 = tmptmp1*tmptmp1;
tmptmp4 = tmptmp2*tmptmp2;
lambda_est[l] -= ctelamb*4.0*(dlmin*tmptmp3*tmptmp3 - dkmin *tmptmp4*tmptmp4);
lambda_est[l] = MAX(0 , lambda_est[l]);
cte_lambda += lambda_est[l];
}
}
}
else
{
for( l = 0 ; l < d ; l++)
{
indtemp = l + offkmin;
tmptmp1 = Xtrain[l + ind] - Wproto_est[indtemp];
Wproto_est[indtemp] += 4.0*ctek*lambda_est[l]*tmptmp1;
indtemp1 = l + offlmin;
tmptmp2 = Xtrain[l + ind] - Wproto_est[indtemp1];
Wproto_est[indtemp1] -= 4.0*ctel*lambda_est[l]*tmptmp2;
if (options.updatelambda)
{
lambda_est[l] -= ctelamb*2.0*(dlmin*tmptmp1*tmptmp1 - dkmin*tmptmp2*tmptmp2);
lambda_est[l] = MAX(0 , lambda_est[l]);
cte_lambda += lambda_est[l];
}
}
}
if (options.updatelambda)
{
cte_lambda = 1.0/cte_lambda;
for (l = 0 ; l < d ; l++)
{
lambda_est[l] *= cte_lambda;
}
}
}
}
}
/*----------------------------------------------------------------------------------------------------------------------------------------- */
void qs(double *a , int lo, int hi)
{
// lo is the lower index, hi is the upper index
// of the region of array a that is to be sorted
int i=lo, j=hi;
double x=a[(lo+hi)/2] , h;
// partition
do
{
while (a[i]<x) i++;
while (a[j]>x) j--;
if (i<=j)
{
h = a[i];
a[i] = a[j];
a[j] = h;
i++;
j--;
}
}
while (i<=j);
// recursion
if (lo<j) qs(a , lo , j);
if (i<hi) qs(a , i , hi);
}
/*----------------------------------------------------------------------------------------------------------------------------------------- */
void qsindex (double *a, int *index , int lo, int hi)
{
// lo is the lower index, hi is the upper index
// of the region of array a that is to be sorted
int i=lo, j=hi , ind;
double x=a[(lo+hi)/2] , h;
// partition
do
{
while (a[i]<x) i++;
while (a[j]>x) j--;
if (i<=j)
{
h = a[i];
a[i] = a[j];
a[j] = h;
ind = index[i];
index[i] = index[j];
index[j] = ind;
i++;
j--;
}
}
while (i<=j);
// recursion
if (lo<j) qsindex(a , index , lo , j);
if (i<hi) qsindex(a , index , i , hi);
}
/* --------------------------------------------------------------------------- */
void randini(void)
{
/* SHR3 Seed initialization */
jsrseed = (UL) time( NULL );
jsr ^= jsrseed;
/* KISS Seed initialization */
#ifdef ranKISS
z = (UL) time( NULL );
w = (UL) time( NULL );
jcong = (UL) time( NULL );
mix(z , w , jcong);
#endif
}
/* --------------------------------------------------------------------------- */
void randnini(void)
{
register const double m1 = 2147483648.0 ;
register double invm1;
register double dn = 3.442619855899 , tn = dn , vn = 9.91256303526217e-3 , q;
int i;
/* Ziggurat tables for randn */
invm1 = 1.0/m1;
q = vn/exp(-0.5*dn*dn);
kn[0] = (dn/q)*m1;
kn[1] = 0;
wn[0] = q*invm1;
wn[zigstep - 1 ] = dn*invm1;
fn[0] = 1.0;
fn[zigstep - 1] = exp(-0.5*dn*dn);
for(i = (zigstep - 2) ; i >= 1 ; i--)
{
dn = sqrt(-2.*log(vn/dn + exp(-0.5*dn*dn)));
kn[i+1] = (dn/tn)*m1;
tn = dn;
fn[i] = exp(-0.5*dn*dn);
wn[i] = dn*invm1;
}
}
/* --------------------------------------------------------------------------- */
double nfix(void)
{
const double r = 3.442620; /* The starting of the right tail */
static double x, y;
for(;;)
{
x = hz*wn[iz];
if(iz == 0)
{ /* iz==0, handle the base strip */
do
{
x = -log(rand())*0.2904764; /* .2904764 is 1/r */
y = -log(rand());
}
while( (y + y) < (x*x));
return (hz > 0) ? (r + x) : (-r - x);
}
if( (fn[iz] + rand()*(fn[iz-1] - fn[iz])) < ( exp(-0.5*x*x) ) )
{
return x;
}
hz = randint;
iz = (hz & (zigstep - 1));
if(abs(hz) < kn[iz])
{
return (hz*wn[iz]);
}
}
}
/* --------------------------------------------------------------------------- */
double randn(void)
{
hz = randint;
iz = (hz & (zigstep - 1));
return (abs(hz) < kn[iz]) ? (hz*wn[iz]) : ( nfix() );
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -