📄 tlm3.c
字号:
* j = [j1, ext_d1', j2, ext_d2', j3, ext_d3'];
*
* % CHECK MAGNITUDE OF GRADIENT
* je = j' * e(:);
* grad = norm(je);
* if grad < grad_min, i=i-1; break, end
*
* % INNER LOOP, INCREASE MU UNTIL THE ERRORS ARE REDUCED
* jj = j'*j;
*
* while (mu <= mu_max)
* dx = -(jj+ii*mu) \ je;
* dw1(:) = dx(w1_ind); db1 = dx(b1_ind);
* dw2(:) = dx(w2_ind); db2 = dx(b2_ind);
* dw3(:) = dx(w3_ind); db3 = dx(b3_ind);
* new_w1 = w1 + dw1; new_b1 = b1 + db1;
* new_w2 = w2 + dw2; new_b2 = b2 + db2;
* new_w3 = w3 + dw3; new_b3 = b3 + db3;
*
* % EVALUATE NEW NETWORK
* [a1,a2,a3] = ...
* simuff(p,new_w1,new_b1,f1,new_w2,new_b2,f2,new_w3,new_b3,f3);
* new_e = t-a3;
* new_SSE = sumsqr(new_e);
*
* if (new_SSE < SSE), break, end
* mu = mu * mu_inc;
* end
* if (mu > mu_max), i = i-1; break, end
* mu = mu * mu_dec;
*
* % UPDATE NETWORK
* w1 = new_w1; b1 = new_b1;
* w2 = new_w2; b2 = new_b2;
* w3 = new_w3; b3 = new_b3;
* e = new_e; SSE = new_SSE;
*
* % TRAINING RECORD
* tr(i+1) = SSE;
*
* % PLOTTING
* if rem(i,df) == 0
* fprintf(message,i,mu,SSE)
* if plottype
* delete(h); h = plot(p,a2,'m'); drawnow;
* else
* h = ploterr(tr(1:(i+1)),eg,h);
* end
* end
* end
*/
for (; ; ) {
mlfAssign(i, mlfScalar(v_));
if (mclLtBool(mclVv(SSE, "SSE"), mclVv(eg, "eg"))) {
mlfAssign(i, mclMinus(mclVv(*i, "i"), _mxarray11_));
break;
}
mlfAssign(&ext_a1, mlfNncpyi(mclVv(a1, "a1"), mclVv(s3, "s3")));
mlfAssign(&ext_a2, mlfNncpyi(mclVv(a2, "a2"), mclVv(s3, "s3")));
mlfAssign(
&d3,
mlfFeval(
mclValueVarargout(),
mclVv(df3, "df3"),
mclVv(a3, "a3"),
NULL));
mlfAssign(&ext_d3, mclUminus(mlfNncpyd(mclVv(d3, "d3"))));
mlfAssign(
&ext_d2,
mlfFeval(
mclValueVarargout(),
mclVv(df2, "df2"),
mclVv(ext_a2, "ext_a2"),
mclVv(ext_d3, "ext_d3"),
mclVa(*w3, "w3"),
NULL));
mlfAssign(
&ext_d1,
mlfFeval(
mclValueVarargout(),
mclVv(df1, "df1"),
mclVv(ext_a1, "ext_a1"),
mclVv(ext_d2, "ext_d2"),
mclVa(*w2, "w2"),
NULL));
mlfAssign(
&j1,
mlfLearnlm(mclVv(ext_p, "ext_p"), mclVv(ext_d1, "ext_d1")));
mlfAssign(
&j2,
mlfLearnlm(mclVv(ext_a1, "ext_a1"), mclVv(ext_d2, "ext_d2")));
mlfAssign(
&j3,
mlfLearnlm(mclVv(ext_a2, "ext_a2"), mclVv(ext_d3, "ext_d3")));
mlfAssign(
&j,
mlfHorzcat(
mclVv(j1, "j1"),
mlfCtranspose(mclVv(ext_d1, "ext_d1")),
mclVv(j2, "j2"),
mlfCtranspose(mclVv(ext_d2, "ext_d2")),
mclVv(j3, "j3"),
mlfCtranspose(mclVv(ext_d3, "ext_d3")),
NULL));
mlfAssign(
&je,
mlf_times_transpose(
mclVv(j, "j"),
mclArrayRef1(mclVv(e, "e"), mlfCreateColonIndex()),
_mxarray15_));
mlfAssign(&grad, mlfNorm(mclVv(je, "je"), NULL));
if (mclLtBool(
mclVv(grad, "grad"), mclVv(grad_min, "grad_min"))) {
mlfAssign(i, mclMinus(mclVv(*i, "i"), _mxarray11_));
break;
}
mlfAssign(
&jj,
mlf_times_transpose(
mclVv(j, "j"), mclVv(j, "j"), _mxarray15_));
while (mclLeBool(mclVv(mu, "mu"), mclVv(mu_max, "mu_max"))) {
mlfAssign(
&dx,
mlfMldivide(
mclUminus(
mclPlus(
mclVv(jj, "jj"),
mclMtimes(mclVv(ii, "ii"), mclVv(mu, "mu")))),
mclVv(je, "je")));
mclArrayAssign1(
&dw1,
mclArrayRef1(mclVv(dx, "dx"), mclVv(w1_ind, "w1_ind")),
mlfCreateColonIndex());
mlfAssign(
&db1,
mclArrayRef1(mclVv(dx, "dx"), mclVv(b1_ind, "b1_ind")));
mclArrayAssign1(
&dw2,
mclArrayRef1(mclVv(dx, "dx"), mclVv(w2_ind, "w2_ind")),
mlfCreateColonIndex());
mlfAssign(
&db2,
mclArrayRef1(mclVv(dx, "dx"), mclVv(b2_ind, "b2_ind")));
mclArrayAssign1(
&dw3,
mclArrayRef1(mclVv(dx, "dx"), mclVv(w3_ind, "w3_ind")),
mlfCreateColonIndex());
mlfAssign(
&db3,
mclArrayRef1(mclVv(dx, "dx"), mclVv(b3_ind, "b3_ind")));
mlfAssign(
&new_w1, mclPlus(mclVa(w1, "w1"), mclVv(dw1, "dw1")));
mlfAssign(
&new_b1, mclPlus(mclVa(*b1, "b1"), mclVv(db1, "db1")));
mlfAssign(
&new_w2, mclPlus(mclVa(*w2, "w2"), mclVv(dw2, "dw2")));
mlfAssign(
&new_b2, mclPlus(mclVa(*b2, "b2"), mclVv(db2, "db2")));
mlfAssign(
&new_w3, mclPlus(mclVa(*w3, "w3"), mclVv(dw3, "dw3")));
mlfAssign(
&new_b3, mclPlus(mclVa(*b3, "b3"), mclVv(db3, "db3")));
mlfAssign(
&a1,
mlfNSimuff(
3,
&a2,
&a3,
mclVa(p, "p"),
mclVv(new_w1, "new_w1"),
mclVv(new_b1, "new_b1"),
mclVa(f1, "f1"),
mclVv(new_w2, "new_w2"),
mclVv(new_b2, "new_b2"),
mclVa(f2, "f2"),
mclVv(new_w3, "new_w3"),
mclVv(new_b3, "new_b3"),
mclVa(f3, "f3")));
mlfAssign(&new_e, mclMinus(mclVa(t, "t"), mclVv(a3, "a3")));
mlfAssign(&new_SSE, mlfSumsqr(mclVv(new_e, "new_e")));
if (mclLtBool(
mclVv(new_SSE, "new_SSE"), mclVv(SSE, "SSE"))) {
break;
}
mlfAssign(
&mu, mclMtimes(mclVv(mu, "mu"), mclVv(mu_inc, "mu_inc")));
}
if (mclGtBool(mclVv(mu, "mu"), mclVv(mu_max, "mu_max"))) {
mlfAssign(i, mclMinus(mclVv(*i, "i"), _mxarray11_));
break;
}
mlfAssign(
&mu, mclMtimes(mclVv(mu, "mu"), mclVv(mu_dec, "mu_dec")));
mlfAssign(&w1, mclVv(new_w1, "new_w1"));
mlfAssign(b1, mclVv(new_b1, "new_b1"));
mlfAssign(w2, mclVv(new_w2, "new_w2"));
mlfAssign(b2, mclVv(new_b2, "new_b2"));
mlfAssign(w3, mclVv(new_w3, "new_w3"));
mlfAssign(b3, mclVv(new_b3, "new_b3"));
mlfAssign(&e, mclVv(new_e, "new_e"));
mlfAssign(&SSE, mclVv(new_SSE, "new_SSE"));
mclArrayAssign1(
tr, mclVv(SSE, "SSE"), mclPlus(mclVv(*i, "i"), _mxarray11_));
if (mclEqBool(
mlfRem(mclVv(*i, "i"), mclVv(df, "df")), _mxarray14_)) {
mclPrintAns(
&ans,
mlfNFprintf(
0,
mclVv(message, "message"),
mclVv(*i, "i"),
mclVv(mu, "mu"),
mclVv(SSE, "SSE"),
NULL));
if (mlfTobool(mclVv(plottype, "plottype"))) {
mlfDelete(mclVv(h, "h"), NULL);
mlfAssign(
&h,
mlfNPlot(
1,
mclVa(p, "p"), mclVv(a2, "a2"), _mxarray16_, NULL));
mlfDrawnow(NULL);
} else {
mlfAssign(
&h,
mlfNPloterr(
1,
mclArrayRef1(
mclVv(*tr, "tr"),
mlfColon(
_mxarray11_,
mclPlus(mclVv(*i, "i"), _mxarray11_),
NULL)),
mclVv(eg, "eg"),
mclVv(h, "h")));
}
}
if (v_ == e_) {
break;
}
++v_;
}
}
}
/*
*
* % TRAINING RECORD
* tr = tr(1:(i+1));
*/
mlfAssign(
tr,
mclArrayRef1(
mclVv(*tr, "tr"),
mlfColon(_mxarray11_, mclPlus(mclVv(*i, "i"), _mxarray11_), NULL)));
/*
*
* % PLOTTING
* if rem(i,df) ~= 0
*/
if (mclNeBool(mlfRem(mclVv(*i, "i"), mclVv(df, "df")), _mxarray14_)) {
/*
* fprintf(message,i,mu,SSE)
*/
mclPrintAns(
&ans,
mlfNFprintf(
0,
mclVv(message, "message"),
mclVv(*i, "i"),
mclVv(mu, "mu"),
mclVv(SSE, "SSE"),
NULL));
/*
* if plottype
*/
if (mlfTobool(mclVv(plottype, "plottype"))) {
/*
* delete(h);
*/
mlfDelete(mclVv(h, "h"), NULL);
/*
* plot(p,a2,'m');
*/
mclAssignAns(
&ans,
mlfNPlot(0, mclVa(p, "p"), mclVv(a2, "a2"), _mxarray16_, NULL));
/*
* drawnow;
*/
mlfDrawnow(NULL);
/*
* else
*/
} else {
/*
* ploterr(tr,eg,h);
*/
mclAssignAns(
&ans,
mlfNPloterr(0, mclVv(*tr, "tr"), mclVv(eg, "eg"), mclVv(h, "h")));
/*
* end
*/
}
/*
* end
*/
}
/*
*
* % WARNINGS
* if SSE > eg
*/
if (mclGtBool(mclVv(SSE, "SSE"), mclVv(eg, "eg"))) {
/*
* disp(' ')
*/
mlfDisp(_mxarray18_);
/*
* if (mu > mu_max)
*/
if (mclGtBool(mclVv(mu, "mu"), mclVv(mu_max, "mu_max"))) {
/*
* disp('TRAINLM: Error gradient is too small to continue learning.')
*/
mlfDisp(_mxarray20_);
/*
* else
*/
} else {
/*
* disp('TRAINLM: Network error did not reach the error goal.')
*/
mlfDisp(_mxarray22_);
/*
* end
*/
}
/*
* disp(' Further training may be necessary, or try different')
*/
mlfDisp(_mxarray24_);
/*
* disp(' initial weights and biases and/or more hidden neurons.')
*/
mlfDisp(_mxarray26_);
/*
* disp(' ')
*/
mlfDisp(_mxarray18_);
/*
* end
*/
}
mclValidateOutput(w1, 1, nargout_, "w1", "tlm3");
mclValidateOutput(*b1, 2, nargout_, "b1", "tlm3");
mclValidateOutput(*w2, 3, nargout_, "w2", "tlm3");
mclValidateOutput(*b2, 4, nargout_, "b2", "tlm3");
mclValidateOutput(*w3, 5, nargout_, "w3", "tlm3");
mclValidateOutput(*b3, 6, nargout_, "b3", "tlm3");
mclValidateOutput(*i, 7, nargout_, "i", "tlm3");
mclValidateOutput(*tr, 8, nargout_, "tr", "tlm3");
mxDestroyArray(ans);
mxDestroyArray(df);
mxDestroyArray(me);
mxDestroyArray(eg);
mxDestroyArray(grad_min);
mxDestroyArray(mu_init);
mxDestroyArray(mu_inc);
mxDestroyArray(mu_dec);
mxDestroyArray(mu_max);
mxDestroyArray(df1);
mxDestroyArray(df2);
mxDestroyArray(df3);
mxDestroyArray(s1);
mxDestroyArray(r);
mxDestroyArray(s2);
mxDestroyArray(s3);
mxDestroyArray(w1_ind);
mxDestroyArray(b1_ind);
mxDestroyArray(w2_ind);
mxDestroyArray(b2_ind);
mxDestroyArray(w3_ind);
mxDestroyArray(b3_ind);
mxDestroyArray(ii);
mxDestroyArray(dw1);
mxDestroyArray(db1);
mxDestroyArray(dw2);
mxDestroyArray(db2);
mxDestroyArray(dw3);
mxDestroyArray(db3);
mxDestroyArray(ext_p);
mxDestroyArray(a1);
mxDestroyArray(a2);
mxDestroyArray(a3);
mxDestroyArray(e);
mxDestroyArray(SSE);
mxDestroyArray(plottype);
mxDestroyArray(message);
mxDestroyArray(h);
mxDestroyArray(mu);
mxDestroyArray(ext_a1);
mxDestroyArray(ext_a2);
mxDestroyArray(d3);
mxDestroyArray(ext_d3);
mxDestroyArray(ext_d2);
mxDestroyArray(ext_d1);
mxDestroyArray(j1);
mxDestroyArray(j2);
mxDestroyArray(j3);
mxDestroyArray(j);
mxDestroyArray(je);
mxDestroyArray(grad);
mxDestroyArray(jj);
mxDestroyArray(dx);
mxDestroyArray(new_w1);
mxDestroyArray(new_b1);
mxDestroyArray(new_w2);
mxDestroyArray(new_b2);
mxDestroyArray(new_w3);
mxDestroyArray(new_b3);
mxDestroyArray(new_e);
mxDestroyArray(new_SSE);
mxDestroyArray(tp);
mxDestroyArray(t);
mxDestroyArray(p);
mxDestroyArray(f3);
mxDestroyArray(f2);
mxDestroyArray(f1);
mclSetCurrentLocalFunctionTable(save_local_function_table_);
return w1;
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -