📄 trainsm.c
字号:
* max_nb = max(max(m));
*/
mlfAssign(
&max_nb,
mlfMax(NULL, mlfMax(NULL, mclVa(m, "m"), NULL, NULL), NULL, NULL));
/*
*
* % SIZES
* [R,Q] = size(p);
*/
mlfSize(mlfVarargout(&R, &Q, NULL), mclVa(p, "p"), NULL);
/*
* [S,R] = size(w);
*/
mlfSize(mlfVarargout(&S, &R, NULL), mclVa(w, "w"), NULL);
/*
*
* % PLOTTING
* message = sprintf('TRAINSM: %%g/%g epochs, neighborhood = %%g, lr = %%g.\n',max_pres);
*/
mlfAssign(
&message,
mlfSprintf(NULL, _mxarray9_, mclVv(max_pres, "max_pres"), NULL));
/*
* fprintf(message,0,max_nb,init_lr)
*/
mclPrintAns(
&ans,
mlfNFprintf(
0,
mclVv(message, "message"),
_mxarray11_,
mclVv(max_nb, "max_nb"),
mclVv(init_lr, "init_lr"),
NULL));
/*
* if R > 1 & S < 100
*/
{
mxArray * a_ = mclInitialize(mclGt(mclVv(R, "R"), _mxarray13_));
if (mlfTobool(a_)
&& mlfTobool(mclAnd(a_, mclLt(mclVv(S, "S"), _mxarray12_)))) {
mxDestroyArray(a_);
/*
* plotsm(w,m);
*/
mlfPlotsm(mclVa(w, "w"), mclVa(m, "m"));
} else {
mxDestroyArray(a_);
}
/*
* end
*/
}
/*
*
* lr_x = (1/max_pres)^(1/sqrt(max_pres*1000));
*/
mlfAssign(
&lr_x,
mclMpower(
mclMrdivide(_mxarray13_, mclVv(max_pres, "max_pres")),
mclMrdivide(
_mxarray13_,
mlfSqrt(mclMtimes(mclVv(max_pres, "max_pres"), _mxarray14_)))));
/*
* base_lr = init_lr / (1+lr_x);
*/
mlfAssign(
&base_lr,
mclMrdivide(
mclVv(init_lr, "init_lr"), mclPlus(_mxarray13_, mclVv(lr_x, "lr_x"))));
/*
* nb_x = (1/max_nb)^(10/max_pres);
*/
mlfAssign(
&nb_x,
mclMpower(
mclMrdivide(_mxarray13_, mclVv(max_nb, "max_nb")),
mclMrdivide(_mxarray15_, mclVv(max_pres, "max_pres"))));
/*
* z = ones(S,1)*(1/S);
*/
mlfAssign(
&z,
mclMtimes(
mlfOnes(mclVv(S, "S"), _mxarray13_, NULL),
mclMrdivide(_mxarray13_, mclVv(S, "S"))));
/*
* for i=1:max_pres
*/
{
int v_ = mclForIntStart(1);
int e_ = mclForIntEnd(mclVv(max_pres, "max_pres"));
if (v_ > e_) {
mlfAssign(&i, _mxarray6_);
} else {
/*
*
* % TRAINING PARAMETER UPDATE
* nb = max(1,max_nb*nb_x^i);
* lr = base_lr*(5/(4+i) + lr_x^i);
*
* % PRESENTATION PHASE
* j = fix(rand*Q) + 1;
* P = p(:,j);
* a = simusm(P,w,m,nb);
*
* % LEARNING PHASE
* dw = learnis(w,P,a,lr);
* w = w + dw;
*
* % PLOTTING
* if rem(i,df) == 0
* fprintf(message,i,nb,lr)
* if R > 1 & S < 100
* plotsm(w,m);
* end
* end
* end
*/
for (; ; ) {
mlfAssign(
&nb,
mlfMax(
NULL,
_mxarray13_,
mclMtimes(
mclVv(max_nb, "max_nb"),
mclMpower(mclVv(nb_x, "nb_x"), mlfScalar(v_))),
NULL));
mlfAssign(
&lr,
mclMtimes(
mclVv(base_lr, "base_lr"),
mclPlus(
mlfScalar(svDoubleScalarRdivide(5.0, (double) (4 + v_))),
mclMpower(mclVv(lr_x, "lr_x"), mlfScalar(v_)))));
mlfAssign(
&j,
mclPlus(
mlfFix(mclMtimes(mlfNRand(1, NULL), mclVv(Q, "Q"))),
_mxarray13_));
mlfAssign(
&P,
mclArrayRef2(
mclVa(p, "p"), mlfCreateColonIndex(), mclVv(j, "j")));
mlfAssign(
&a,
mlfSimusm(
mclVv(P, "P"),
mclVa(w, "w"),
mclVa(m, "m"),
mclVv(nb, "nb")));
mlfAssign(
&dw,
mlfLearnis(
NULL,
mclVa(w, "w"),
mclVv(P, "P"),
mclVv(a, "a"),
mclVv(lr, "lr"),
NULL,
NULL,
NULL,
NULL,
NULL,
NULL,
NULL,
NULL));
mlfAssign(&w, mclPlus(mclVa(w, "w"), mclVv(dw, "dw")));
if (mclEqBool(
mlfRem(mlfScalar(v_), mclVv(df, "df")), _mxarray11_)) {
mclPrintAns(
&ans,
mlfNFprintf(
0,
mclVv(message, "message"),
mlfScalar(v_),
mclVv(nb, "nb"),
mclVv(lr, "lr"),
NULL));
{
mxArray * a_
= mclInitialize(mclGt(mclVv(R, "R"), _mxarray13_));
if (mlfTobool(a_)
&& mlfTobool(
mclAnd(
a_, mclLt(mclVv(S, "S"), _mxarray12_)))) {
mxDestroyArray(a_);
mlfPlotsm(mclVa(w, "w"), mclVa(m, "m"));
} else {
mxDestroyArray(a_);
}
}
}
if (v_ == e_) {
break;
}
++v_;
}
mlfAssign(&i, mlfScalar(v_));
}
}
/*
*
* if rem(i,df) ~= 0
*/
if (mclNeBool(mlfRem(mclVv(i, "i"), mclVv(df, "df")), _mxarray11_)) {
/*
* fprintf(message,i,nb,lr)
*/
mclPrintAns(
&ans,
mlfNFprintf(
0,
mclVv(message, "message"),
mclVv(i, "i"),
mclVv(nb, "nb"),
mclVv(lr, "lr"),
NULL));
/*
* end
*/
}
mclValidateOutput(w, 1, nargout_, "w", "trainsm");
mxDestroyArray(ans);
mxDestroyArray(df);
mxDestroyArray(max_pres);
mxDestroyArray(init_lr);
mxDestroyArray(max_nb);
mxDestroyArray(R);
mxDestroyArray(Q);
mxDestroyArray(S);
mxDestroyArray(message);
mxDestroyArray(lr_x);
mxDestroyArray(base_lr);
mxDestroyArray(nb_x);
mxDestroyArray(z);
mxDestroyArray(i);
mxDestroyArray(nb);
mxDestroyArray(lr);
mxDestroyArray(j);
mxDestroyArray(P);
mxDestroyArray(a);
mxDestroyArray(dw);
mxDestroyArray(tp);
mxDestroyArray(p);
mxDestroyArray(m);
mclSetCurrentLocalFunctionTable(save_local_function_table_);
return w;
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -