📄 grlvq_model.c
字号:
{
if (currentlabel != ytrainsorted[i])
{
labels = (double *)mxRealloc(labels , (m+2)*sizeof(double));
labels[++m] = ytrainsorted[i];
currentlabel = ytrainsorted[i];
}
}
m++;
/* Input 4 yproto */
if ((nrhs < 4) || mxIsEmpty(prhs[3]) )
{
plhs[1] = mxCreateDoubleMatrix(1 , Nproto, mxREAL);
yproto_est = mxGetPr(plhs[1]);
co = 0;
Nprotom = ceil((double)Nproto/(double)m);
for(i = 0 ; i < m-1 ; i++)
{
ind = labels[i];
for(j = 0 ; j < Nprotom ; j++)
{
yproto_est[co] = labels[i];
co++;
}
}
ind = labels[m-1];
for(j = (m-1)*Nprotom ; j < Nproto ; j++)
{
yproto_est[co] = ind;
co++;
}
}
else
{
yproto = mxGetPr(prhs[3]);
if(mxGetNumberOfDimensions(prhs[3]) !=2)
{
mexErrMsgTxt("yproto must be (1 x Nproto)");
}
Nproto = mxGetN(prhs[3]);
plhs[1] = mxCreateDoubleMatrix(1 , Nproto, mxREAL);
yproto_est = mxGetPr(plhs[1]);
for( i = 0 ; i < Nproto ; i++)
{
yproto_est[i] = yproto[i];
}
}
/* Input 3 Wproto */
if ((nrhs < 3) || mxIsEmpty(prhs[2]) )
{
mtemp = mxMalloc(d*m*sizeof(double));
stdtemp = mxMalloc(d*m*sizeof(double));
Nk = mxMalloc(m*sizeof(int));
for(i = 0 ; i < d*m ; i++)
{
mtemp[i] = 0.0;
stdtemp[i] = 0.0;
}
for (l = 0 ; l < m ; l++)
{
ind = labels[l];
ld = l*d;
Nk[l] = 0;
for (i = 0 ; i < Ntrain ; i++)
{
if(ytrain[i] == ind)
{
id = i*d;
for(j = 0 ; j < d ; j++)
{
mtemp[j + ld] += Xtrain[j + id];
}
Nk[l]++;
}
}
}
for (l = 0 ; l < m ; l++)
{
ld = l*d;
temp = 1.0/Nk[l];
for(j = 0 ; j < d ; j++)
{
mtemp[j + ld] *=temp;
}
}
for (l = 0 ; l < m ; l++)
{
ind = labels[l];
ld = l*d;
for (i = 0 ; i < Ntrain ; i++)
{
if(ytrain[i] == ind)
{
id = i*d;
for(j = 0 ; j < d ; j++)
{
temp = (Xtrain[j + id] - mtemp[j + ld]);
stdtemp[j + ld] += (temp*temp);
}
}
}
}
for (l = 0 ; l < m ; l++)
{
ld = l*d;
temp = 1.0/(Nk[l] - 1);
for(j = 0 ; j < d ; j++)
{
stdtemp[j + ld] = temp*sqrt(stdtemp[j + ld]);
}
}
plhs[0] = mxCreateDoubleMatrix(d , Nproto, mxREAL);
Wproto_est = mxGetPr(plhs[0]);
for(l = 0 ; l < Nproto ; l++)
{
ld = l*d;
for(i = 0 ; i < m ; i++)
{
if(labels[i] == yproto_est[l] )
{
indice = i*m;
}
}
for(i = 0 ; i < d ; i++)
{
Wproto_est[i + ld] = mtemp[i + indice] + stdtemp[i + indice]*randn();
}
}
}
else
{
Wproto = mxGetPr(prhs[2]);
if(mxGetNumberOfDimensions(prhs[2]) !=2 || mxGetM(prhs[2]) != d)
{
mexErrMsgTxt("Wproto must be (d x Nproto)");
}
Nproto = mxGetN(prhs[2]);
plhs[0] = mxCreateDoubleMatrix(d , Nproto, mxREAL);
Wproto_est = mxGetPr(plhs[0]);
for( i = 0 ; i < d*Nproto ; i++)
{
Wproto_est[i] = Wproto[i];
}
}
/* Input 5 lambda */
plhs[2] = mxCreateDoubleMatrix(d , 1 , mxREAL);
lambda_est = mxGetPr(plhs[2]);
if ((nrhs < 5) || mxIsEmpty(prhs[4]) )
{
lambda = (double *)mxMalloc(d*sizeof(double));
for (i = 0 ; i < d ; i++)
{
lambda[i] = 1.0/d;
}
for (i = 0 ; i < d ; i++)
{
lambda_est[i] = lambda[i] ;
}
}
else
{
lambda = mxGetPr(prhs[4]);
for (i = 0 ; i < d ; i++)
{
lambda_est[i] = lambda[i] ;
}
}
/* Input 6 option */
if ( (nrhs > 5) && !mxIsEmpty(prhs[5]) )
{
mxtemp = mxGetField(prhs[5] , 0, "epsilonk");
if(mxtemp != NULL)
{
tmp = mxGetPr(mxtemp);
options.epsilonk = tmp[0];
}
mxtemp = mxGetField(prhs[5] , 0, "epsilonl");
if(mxtemp != NULL)
{
tmp = mxGetPr(mxtemp);
options.epsilonl = tmp[0];
if (options.epsilonl>options.epsilonk)
{
mexErrMsgTxt("Epsilon_l < Epsilon_k");
}
}
mxtemp = mxGetField(prhs[5] , 0, "epsilonlambda");
if(mxtemp != NULL)
{
tmp = mxGetPr(mxtemp);
options.epsilonlambda = tmp[0];
if (options.epsilonlambda>options.epsilonl/10)
{
mexErrMsgTxt("Epsilon_lambda < Epsilon_l/10");
}
}
mxtemp = mxGetField(prhs[5] , 0, "xi");
if(mxtemp != NULL)
{
tmp = mxGetPr(mxtemp);
options.xi = tmp[0];
}
mxtemp = mxGetField(prhs[5] , 0, "nb_iterations");
if(mxtemp != NULL)
{
tmp = mxGetPr(mxtemp);
options.nb_iterations = (int) tmp[0];
}
mxtemp = mxGetField(prhs[5] , 0, "metric_method");
if(mxtemp != NULL)
{
tmp = mxGetPr(mxtemp);
options.metric_method = (int) tmp[0];
}
mxtemp = mxGetField(prhs[5] , 0, "shuffle");
if(mxtemp != NULL)
{
tmp = mxGetPr(mxtemp);
options.shuffle = (int) tmp[0];
}
mxtemp = mxGetField(prhs[5] , 0, "updatelambda");
if(mxtemp != NULL)
{
tmp = mxGetPr(mxtemp);
options.updatelambda = (int) tmp[0];
}
}
/*---------- Outputs --------*/
plhs[3] = mxCreateDoubleMatrix(1 , options.nb_iterations , mxREAL);
E_GRLVQ = mxGetPr(plhs[3]);
/*---------- Tempory matrices --------*/
temp_train = mxMalloc(Ntrain*sizeof(double));
index_train = mxMalloc(Ntrain*sizeof(int));
/* Main Call */
grlvq_model(Xtrain , ytrain , options , d , Ntrain , Nproto , m ,
Wproto_est , yproto_est , E_GRLVQ , lambda , lambda_est ,
labels , temp_train , index_train);
/*---------- Free memory --------*/
mxFree(labels);
mxFree(ytrainsorted);
mxFree(temp_train);
mxFree(index_train);
if ((nrhs < 3) || mxIsEmpty(prhs[2]) )
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -