📄 srng_model.c
字号:
co++;
}
}
else
{
yproto = mxGetPr(prhs[3]);
if(mxGetNumberOfDimensions(prhs[3]) !=2)
{
mexErrMsgTxt("yproto must be (1 x Nproto)");
}
Nproto = mxGetN(prhs[3]);
plhs[1] = mxCreateDoubleMatrix(1 , Nproto, mxREAL);
yproto_est = mxGetPr(plhs[1]);
for( i = 0 ; i < Nproto ; i++)
{
yproto_est[i] = yproto[i];
}
}
Np = mxMalloc(m*sizeof(int));
for (l = 0 ; l < m ; l++)
{
ind = labels[l];
Np[l] = 0;
for (i = 0 ; i < Nproto ; i++)
{
if(yproto_est[i] == ind)
{
Np[l]++;
}
}
if(Np[l] > Nproto_max)
{
Nproto_max = Np[l];
}
}
/* Input 3 Wproto */
if ((nrhs < 3) || mxIsEmpty(prhs[2]) )
{
mtemp = mxMalloc(d*m*sizeof(double));
stdtemp = mxMalloc(d*m*sizeof(double));
Nk = mxMalloc(m*sizeof(int));
for(i = 0 ; i < d*m ; i++)
{
mtemp[i] = 0.0;
stdtemp[i] = 0.0;
}
for (l = 0 ; l < m ; l++)
{
ind = labels[l];
ld = l*d;
Nk[l] = 0;
for (i = 0 ; i < Ntrain ; i++)
{
if(ytrain[i] == ind)
{
id = i*d;
for(j = 0 ; j < d ; j++)
{
mtemp[j + ld] += Xtrain[j + id];
}
Nk[l]++;
}
}
}
for (l = 0 ; l < m ; l++)
{
ld = l*d;
temp = 1.0/Nk[l];
for(j = 0 ; j < d ; j++)
{
mtemp[j + ld] *=temp;
}
}
for (l = 0 ; l < m ; l++)
{
ind = labels[l];
ld = l*d;
for (i = 0 ; i < Ntrain ; i++)
{
if(ytrain[i] == ind)
{
id = i*d;
for(j = 0 ; j < d ; j++)
{
temp = (Xtrain[j + id] - mtemp[j + ld]);
stdtemp[j + ld] += (temp*temp);
}
}
}
}
for (l = 0 ; l < m ; l++)
{
ld = l*d;
temp = 1.0/(Nk[l] - 1);
for(j = 0 ; j < d ; j++)
{
stdtemp[j + ld] = temp*sqrt(stdtemp[j + ld]);
}
}
plhs[0] = mxCreateDoubleMatrix(d , Nproto, mxREAL);
Wproto_est = mxGetPr(plhs[0]);
for(l = 0 ; l < Nproto ; l++)
{
ld = l*d;
for(i = 0 ; i < m ; i++)
{
if(labels[i] == yproto_est[l] )
{
indice = i*m;
}
}
for(i = 0 ; i < d ; i++)
{
Wproto_est[i + ld] = mtemp[i + indice] + stdtemp[i + indice]*randn();
}
}
}
else
{
Wproto = mxGetPr(prhs[2]);
if(mxGetNumberOfDimensions(prhs[2]) !=2 || mxGetM(prhs[2]) != d)
{
mexErrMsgTxt("Wproto must be (d x Nproto)");
}
Nproto = mxGetN(prhs[2]);
plhs[0] = mxCreateDoubleMatrix(d , Nproto, mxREAL);
Wproto_est = mxGetPr(plhs[0]);
for( i = 0 ; i < d*Nproto ; i++)
{
Wproto_est[i] = Wproto[i];
}
}
/* Input 5 lambda */
plhs[2] = mxCreateDoubleMatrix(d , 1 , mxREAL);
lambda_est = mxGetPr(plhs[2]);
if ((nrhs < 5) || mxIsEmpty(prhs[4]) )
{
lambda = (double *)mxMalloc(d*sizeof(double));
for (i = 0 ; i < d ; i++)
{
lambda[i] = 1.0/d;
lambda_est[i] = lambda[i] ;
}
}
else
{
lambda = mxGetPr(prhs[4]);
for (i = 0 ; i < d ; i++)
{
lambda_est[i] = lambda[i] ;
}
}
/* Input 6 option */
if ( (nrhs > 5) && !mxIsEmpty(prhs[5]) )
{
mxtemp = mxGetField(prhs[5] , 0, "epsilonk");
if(mxtemp != NULL)
{
tmp = mxGetPr(mxtemp);
options.epsilonk = tmp[0];
}
mxtemp = mxGetField(prhs[5] , 0, "epsilonl");
if(mxtemp != NULL)
{
tmp = mxGetPr(mxtemp);
options.epsilonl = tmp[0];
if (options.epsilonl>options.epsilonk)
{
mexErrMsgTxt("Epsilon_l < Epsilon_k");
}
}
mxtemp = mxGetField(prhs[5] , 0, "epsilonlambda");
if(mxtemp != NULL)
{
tmp = mxGetPr(mxtemp);
options.epsilonlambda = tmp[0];
if (options.epsilonlambda>options.epsilonl/10)
{
mexErrMsgTxt("Epsilon_lambda < Epsilon_l/10");
}
}
mxtemp = mxGetField(prhs[5] , 0, "sigmastart");
if(mxtemp != NULL)
{
tmp = mxGetPr(mxtemp);
options.sigmastart = tmp[0];
}
mxtemp = mxGetField(prhs[5] , 0, "sigmaend");
if(mxtemp != NULL)
{
tmp = mxGetPr(mxtemp);
options.sigmaend = tmp[0];
}
mxtemp = mxGetField(prhs[5] , 0, "sigmastretch");
if(mxtemp != NULL)
{
tmp = mxGetPr(mxtemp);
options.sigmastretch = tmp[0];
}
mxtemp = mxGetField(prhs[5] , 0, "threshold");
if(mxtemp != NULL)
{
tmp = mxGetPr(mxtemp);
options.threshold = tmp[0];
}
mxtemp = mxGetField(prhs[5] , 0, "xi");
if(mxtemp != NULL)
{
tmp = mxGetPr(mxtemp);
options.xi = tmp[0];
}
mxtemp = mxGetField(prhs[5] , 0, "nb_iterations");
if(mxtemp != NULL)
{
tmp = mxGetPr(mxtemp);
options.nb_iterations = (int) tmp[0];
}
mxtemp = mxGetField(prhs[5] , 0, "metric_method");
if(mxtemp != NULL)
{
tmp = mxGetPr(mxtemp);
options.metric_method = (int) tmp[0];
}
mxtemp = mxGetField(prhs[5] , 0, "shuffle");
if(mxtemp != NULL)
{
tmp = mxGetPr(mxtemp);
options.shuffle = (int) tmp[0];
}
mxtemp = mxGetField(prhs[5] , 0, "updatelambda");
if(mxtemp != NULL)
{
tmp = mxGetPr(mxtemp);
options.updatelambda = (int) tmp[0];
}
}
/*---------- Outputs --------*/
plhs[3] = mxCreateDoubleMatrix(1 , options.nb_iterations , mxREAL);
E_SRNG = mxGetPr(plhs[3]);
/*---------- Tempory matrices --------*/
temp_train = mxMalloc(Ntrain*sizeof(double));
index_train = mxMalloc(Ntrain*sizeof(int));
distj = mxMalloc(Nproto_max*sizeof(double));
hproto = mxMalloc(Nproto_max*sizeof(double));
sproto = mxMalloc(Nproto_max*sizeof(double));
rank_distj = mxMalloc(Nproto_max*sizeof(int));
index_distj = mxMalloc(Nproto_max*sizeof(int));
/* Main Call */
srng_model(Xtrain , ytrain , options , d , Ntrain , Nproto , m ,
Wproto_est , yproto_est , E_SRNG , lambda , lambda_est ,
labels , temp_train , index_train , distj , rank_distj , index_distj, hproto , sproto , Nproto_max);
/*---------- Free memory --------*/
mxFree(labels);
mxFree(ytrainsorted);
mxFree(temp_train);
mxFree(index_train);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -