📄 srng_model.c
字号:
mxFree(Np);
mxFree(distj);
mxFree(rank_distj);
mxFree(index_distj);
mxFree(hproto);
mxFree(sproto);
if ((nrhs < 3) || mxIsEmpty(prhs[2]) )
{
mxFree(mtemp);
mxFree(stdtemp);
mxFree(Nk);
}
if(mxIsEmpty(prhs[4]))
{
mxFree(lambda);
}
}
/*-------------------------------------------------------------------------------------------------------------- */
void srng_model(double *Xtrain , double *ytrain , OPTIONS options , int d , int Ntrain , int Nproto , int m ,
double *Wproto_est , double *yproto_est , double *E_SRNG , double *lambda , double *lambda_est ,
double *labels , double *temp_train , int *index_train , double *distj , int *rank_distj, int *index_distj , double *hproto , double* sproto , int Nproto_max)
{
int i , j , l , t , indice , ind , lmin , offkmin , offlmin , indtemp , jd;
int coj , rank;
double yi , temp , sum_srng , dlmin , double_max = 1.79769313486231*10e307 ;
double lnu , cte , ctek , ctel , ctelamb , tmptmp , tmptmp2, tmptmp1 , cte_tmp;
double sigma , qj , dist_proto , sum_lambda , dist_temp , weight;
hproto[0] = 1.0; //exp(-0/sigma);
sproto[0] = 1.0;
for(i = 0 ; i < Ntrain ; i++)
{
index_train[i] = i;
}
for (t = 0 ; t <options.nb_iterations ; t++)
{
// Shuffle Train data //
if(options.shuffle)
{
for(i = 0 ; i < Ntrain ; i++)
{
temp_train[i] = rand();
index_train[i] = i;
}
qsindex(temp_train , index_train , 0 , Ntrain - 1 );
}
// Update neighbourg size ///
sigma = (options.sigmaend - options.sigmastart) *(1.0 - 2.0/(1.0 + exp(t*options.sigmastretch))) + options.sigmastart;
for (i = 1 ; i < Nproto_max ; i++)
{
hproto[i] = exp(-i/sigma);
sproto[i] = hproto[i] + sproto[i - 1];
}
// Compute distance Wproto, Xtrain
E_SRNG[t] = 0.0;
for(i = 0 ; i < Ntrain ; i++)
{
indice = index_train[i];
ind = indice*d;
yi = ytrain[indice];
dlmin = double_max;
lmin = 0;
coj = 0;
for(j = 0 ; j < Nproto ; j++)
{
dist_proto = 0.0;
jd = j*d;
if (options.metric_method)
{
for(l = 0 ; l < d ; l++)
{
temp = (Xtrain[l + ind] - Wproto_est[l + jd]);
dist_proto += lambda_est[l]*temp*temp;
}
}
else
{
for(l = 0 ; l < d ; l++)
{
temp = (Xtrain[l + ind] - Wproto_est[l + jd]);
dist_proto += lambda_est[l]*temp*temp*temp*temp;
}
}
if(yproto_est[j] == yi)
{
distj[coj] = dist_proto;
index_distj[coj] = j;
coj++;
}
else
{
if(dist_proto < dlmin)
{
dlmin = dist_proto;
lmin = j;
}
}
}
for(j = 0 ; j < coj ; j++)
{
rank_distj[j] = j;
}
qsindex(distj , rank_distj , 0 , coj - 1 );
sum_lambda = 0.0;
sum_srng = 0.0;
temp = 1.0/sproto[coj - 1];
offlmin = lmin*d;
for(j = 0 ; j < coj ; j++)
{
rank = rank_distj[j];
weight = hproto[j]*temp;
if (weight > options.threshold)
{
dist_temp = distj[j];
tmptmp = 1.0/(dist_temp + dlmin);
qj = (dist_temp - dlmin)*tmptmp;
lnu = 1.0/(1.0 + exp(-options.xi*qj));
cte = lnu*weight;
sum_srng += cte;
// cte_tmp = options.xi*cte*(1.0 - lnu)*tmptmp*tmptmp;
cte_tmp = options.xi*cte*(1.0 - lnu)*tmptmp;
ctek = options.epsilonk*cte_tmp*dlmin;
ctel = options.epsilonl*cte_tmp;
ctelamb = options.epsilonlambda*cte_tmp;
offkmin = index_distj[rank]*d;
for( l = 0 ; l < d ; l++)
{
indtemp = l + offkmin;
tmptmp1 = Xtrain[l + ind] - Wproto_est[indtemp];
Wproto_est[indtemp] += 4.0*ctek*lambda_est[l]*tmptmp1;
indtemp = l + offlmin;
tmptmp2 = Xtrain[l + ind] - Wproto_est[indtemp];
Wproto_est[indtemp] -= 4.0*ctel*dist_temp*lambda_est[l]*tmptmp2;
if (options.updatelambda)
{
lambda_est[l] -= ctelamb*2.0*(dlmin*tmptmp1*tmptmp1 - dist_temp*tmptmp2*tmptmp2);
}
}
}
}
E_SRNG[t] += sum_srng;
if (options.updatelambda)
{
for (l = 0 ; l < d ; l++)
{
lambda_est[l] = MAX(0 , lambda_est[l]);
sum_lambda += lambda_est[l];
}
sum_lambda = 1.0/sum_lambda;
for (l = 0 ; l < d ; l++)
{
lambda_est[l] *= sum_lambda;
}
}
}
}
}
/*----------------------------------------------------------------------------------------------------------------------------------------- */
void qs(double *a , int lo, int hi)
{
// lo is the lower index, hi is the upper index
// of the region of array a that is to be sorted
int i=lo, j=hi;
double x=a[(lo+hi)/2] , h;
// partition
do
{
while (a[i]<x) i++;
while (a[j]>x) j--;
if (i<=j)
{
h = a[i];
a[i] = a[j];
a[j] = h;
i++;
j--;
}
}
while (i<=j);
// recursion
if (lo<j) qs(a , lo , j);
if (i<hi) qs(a , i , hi);
}
/*----------------------------------------------------------------------------------------------------------------------------------------- */
void qsindex (double *a, int *index , int lo, int hi)
{
// lo is the lower index, hi is the upper index
// of the region of array a that is to be sorted
int i=lo, j=hi , ind;
double x=a[(lo+hi)/2] , h;
// partition
do
{
while (a[i]<x) i++;
while (a[j]>x) j--;
if (i<=j)
{
h = a[i];
a[i] = a[j];
a[j] = h;
ind = index[i];
index[i] = index[j];
index[j] = ind;
i++;
j--;
}
}
while (i<=j);
// recursion
if (lo<j) qsindex(a , index , lo , j);
if (i<hi) qsindex(a , index , i , hi);
}
/* --------------------------------------------------------------------------- */
void randini(void)
{
/* SHR3 Seed initialization */
jsrseed = (UL) time( NULL );
jsr ^= jsrseed;
/* KISS Seed initialization */
#ifdef ranKISS
z = (UL) time( NULL );
w = (UL) time( NULL );
jcong = (UL) time( NULL );
mix(z , w , jcong);
#endif
}
/* --------------------------------------------------------------------------- */
void randnini(void)
{
register const double m1 = 2147483648.0, m2 = 4294967296.0 ;
register double invm1;
register double dn = 3.442619855899 , tn = dn , vn = 9.91256303526217e-3 , q;
int i;
/* Ziggurat tables for randn */
invm1 = 1.0/m1;
q = vn/exp(-0.5*dn*dn);
kn[0] = (dn/q)*m1;
kn[1] = 0;
wn[0] = q*invm1;
wn[zigstep - 1 ] = dn*invm1;
fn[0] = 1.0;
fn[zigstep - 1] = exp(-0.5*dn*dn);
for(i = (zigstep - 2) ; i >= 1 ; i--)
{
dn = sqrt(-2.*log(vn/dn + exp(-0.5*dn*dn)));
kn[i+1] = (dn/tn)*m1;
tn = dn;
fn[i] = exp(-0.5*dn*dn);
wn[i] = dn*invm1;
}
}
/* --------------------------------------------------------------------------- */
double nfix(void)
{
const double r = 3.442620f; /* The starting of the right tail */
static double x, y;
for(;;)
{
x = hz*wn[iz];
if(iz == 0)
{ /* iz==0, handle the base strip */
do
{
x = -log(rand())*0.2904764; /* .2904764 is 1/r */
y = -log(rand());
}
while( (y + y) < (x*x));
return (hz > 0) ? (r + x) : (-r - x);
}
if( (fn[iz] + rand()*(fn[iz-1] - fn[iz])) < ( exp(-0.5*x*x) ) )
{
return x;
}
hz = randint;
iz = (hz & (zigstep - 1));
if(abs(hz) < kn[iz])
{
return (hz*wn[iz]);
}
}
}
/* --------------------------------------------------------------------------- */
double randn(void)
{
hz = randint;
iz = (hz & (zigstep - 1));
return (abs(hz) < kn[iz]) ? (hz*wn[iz]) : ( nfix() );
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -