📄 quadratic_search.m
字号:
function [s, net, tnet, status] = quadratic_search(... s0, net0, tnet0, basekls, restc, step, fs, tfs, data, ... params, t0, status, missing, notimedep),% QUADRATIC_SEARCH Line search by method of quadratic approximation% Copyright (C) 1999-2005 Antti Honkela, Harri Valpola,% Xavier Giannakopoulos and Matti Tornio.%% This package comes with ABSOLUTELY NO WARRANTY; for details% see License.txt in the program package. This is free software,% and you are welcome to redistribute it under certain conditions;% see License.txt for details.% Stopping criterion for step sizeepsilon = 1e-6;% Stopping criterion for cost functionepsilon2 = 1e-8;% Maximum number of line search iterationsmaxiters = 30;fs = [];tfs = [];% Make sure the initial step size is large enough to avoid numerical problemst0 = max(epsilon, t0);% Extract model parameterssdim = size(s0, 1);N = size(s0, 2);% Initial points for the line searchitercount0 = 3;t1 = 0;t2 = .5 * t0;t3 = t0;s1 = s0;net1 = net0;tnet1 = tnet0;[s2, net2, tnet2] = update_s_and_net(s0, net0, tnet0, t2, step);[s3, net3, tnet3] = update_s_and_net(s0, net0, tnet0, t3, step);c1 = basekls;c2 = restc + kldiv(tfs, fs, s2, data, net2, tnet2, params, ... missing, notimedep, status, 1);c3 = restc + kldiv(tfs, fs, s3, data, net3, tnet3, params, ... missing, notimedep, status, 1);if status.debug, fprintf('%.2f (%.4g) %.2f (%.4g) %.2f (%.4g)\n', c1, t1, c2, t2, c3, t3);endwhile ((c1 > c2) & (c2 > c3)), itercount0 = itercount0 + 1; c2 = c3; s2 = s3; net2 = net3; tnet2 = tnet3; t2 = t3; t3 = 2*t3; [s3, net3, tnet3] = update_s_and_net(s0, net0, tnet0, t3, step); c3 = restc + kldiv(tfs, fs, s3, data, net3, tnet3, params, ... missing, notimedep, status, 1); if status.debug, fprintf('%.2f (%.4g) %.2f (%.4g) %.2f (%.4g)\n', c1, t1, c2, t2, c3, t3); endendwhile (((c1 < c2) & (c2 < c3)) | ((c1 < c2) & (c2 > c3)) | ... ~isfinite(c2) | ~isfinite(c3)), itercount0 = itercount0 + 1; c3 = c2; s3 = s2; net3 = net2; tnet3 = tnet2; t3 = t2; t2 = .5*t2; [s2, net2, tnet2] = update_s_and_net(s0, net0, tnet0, t2, step); c2 = restc + kldiv(tfs, fs, s2, data, net2, tnet2, params, ... missing, notimedep, status, 1); if status.debug, fprintf('%.2f (%.4g) %.2f (%.4g) %.2f (%.4g)\n', c1, t1, c2, t2, c3, t3); endendif ((c1 < c2) & (c2 > c3)), warning('Non-convex point configuration for line search'); keyboard;end% Normal linesearch starts hereitercount = 0;while (((t3 - t1) > epsilon) & ... ((abs(c2-c3) + abs(c3-c1) + abs(c1-c2)) > epsilon2)), tnew = (t1.^2 .* (c2-c3) + t2.^2 .* (c3-c1) + t3.^2 .* (c1-c2)) ./ ... (2*(t1 .* (c2-c3) + t2 .* (c3-c1) + t3 .* (c1-c2))); if (tnew == t2), warning('NDFA:LINESEARCH:neweqold', ... 'LINESEARCH: proposed new point equal to old midpoint.'); if (t3 - t2) > (t2 - t1), tnew = t2 + .1 * (t3 - t2); else tnew = t2 - .1 * (t3 - t2); end end [snew, netnew, tnetnew] = update_s_and_net(s0, net0, tnet0, tnew, step); cnew = restc + kldiv(tfs, fs, snew, data, netnew, tnetnew, params, ... missing, notimedep, status, 1); if (tnew > t2), if (cnew > c2), t3 = tnew; s3 = snew; net3 = netnew; tnet3 = netnew; c3 = cnew; else t1 = t2; s1 = s2; net1 = net2; tnet1 = tnet2; c1 = c2; t2 = tnew; s2 = snew; net2 = netnew; tnet2 = tnetnew; c2 = cnew; end else % tnew < t2 if (cnew > c2), t1 = tnew; s1 = snew; net1 = netnew; tnet1 = tnetnew; c1 = cnew; else t3 = t2; s3 = s2; net3 = net2; tnet3 = tnet2; c3 = c2; t2 = tnew; s2 = snew; net2 = netnew; tnet2 = tnetnew; c2 = cnew; end end if status.debug, fprintf('%.2f (%.4g) %.2f (%.4g) %.2f (%.4g)\n', c1, t1, c2, t2, c3, t3); end itercount = itercount + 1; if itercount > maxiters, fprintf('Line search failed, trouble...\n'); break; endendNh = size(status.kls, 2);status.history.LS_iters(Nh) = itercount;status.history.LS_iters0(Nh) = itercount0;s = s2;net = net2;tnet = tnet2;status.t0 = t2;if status.debug, fprintf('Cost after mean update: c=%f\n', c2);end% z = 2.0*(x1 * (y2-y3) + x2 * (y3-y1) + x3 * (y1-y2))% if z == 0:% xnew = x2% else:% xnew = (x1**2 * (y2-y3) + x2**2 * (y3-y1) +% x3**2 * (y1-y2)) / z% return xnewfunction [smstep netmstep tnetmstep] = unitstep(smstep, netmstep, tnetmstep),% UNITSTEP Scale the search vector to unit vectorgn = sqrt(sum([smstep(:); netmstep.w1(:); netmstep.w2(:); ... netmstep.b1(:); netmstep.b2(:); ... tnetmstep.w1(:); tnetmstep.w2(:); ... tnetmstep.b1(:); tnetmstep.b2(:)] .^2));smstep = smstep / gn;netmstep.w1 = netmstep.w1 / gn;netmstep.w2 = netmstep.w2 / gn;netmstep.b1 = netmstep.b1 / gn;netmstep.b2 = netmstep.b2 / gn;tnetmstep.w1 = tnetmstep.w1 / gn;tnetmstep.w2 = tnetmstep.w2 / gn;tnetmstep.b1 = tnetmstep.b1 / gn;tnetmstep.b2 = tnetmstep.b2 / gn;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -