📄 update_everything.m
字号:
function [sources, net, tnet, oldgrads, status] = update_everything(... sources, net, tnet, dcp_dsm, dcp_dsvn, fs, tfs, params, ... dcp_dnetm, dcp_dnetv, dcp_dtnetm, dcp_dtnetv, newac, ... data, oldc, status, oldgrads, clamped, missing, notimedep)% UPDATE_EVERYTHING Perform a line search to find optimal step length% and update all parameters% Copyright (C) 1999-2005 Antti Honkela, Harri Valpola,% Xavier Giannakopoulos and Matti Tornio%% This package comes with ABSOLUTELY NO WARRANTY; for details% see License.txt in the program package. This is free software,% and you are welcome to redistribute it under certain conditions;% see License.txt for details.[sources, net, tnet, newc, restc, status] = update_variances(... sources, net, tnet, dcp_dsvn, fs, tfs, params, dcp_dnetv, dcp_dtnetv, ... newac, data, oldc, status, clamped, missing, notimedep);if status.updatesrcs < 0, dc_dsm = zeros(size(dcp_dsm));else dc_dsm = -dcp_dsm; % Set gradient for the clamped sources to zeros dc_dsm(find(clamped)) = 0;endif status.updatenet < 0, dc_dnetm.w1 = zeros(size(dcp_dnetm.w1)); dc_dnetm.b1 = zeros(size(dcp_dnetm.b1)); dc_dnetm.w2 = zeros(size(dcp_dnetm.w2)); dc_dnetm.b2 = zeros(size(dcp_dnetm.b2));else dc_dnetm.w1 = -dcp_dnetm.w1; dc_dnetm.b1 = -dcp_dnetm.b1; dc_dnetm.w2 = -dcp_dnetm.w2; dc_dnetm.b2 = -dcp_dnetm.b2;endif status.updatetnet < 0, dc_dtnetm.w1 = zeros(size(dcp_dtnetm.w1)); dc_dtnetm.b1 = zeros(size(dcp_dtnetm.b1)); dc_dtnetm.w2 = zeros(size(dcp_dtnetm.w2)); dc_dtnetm.b2 = zeros(size(dcp_dtnetm.b2));else dc_dtnetm.w1 = -dcp_dtnetm.w1; dc_dtnetm.b1 = -dcp_dtnetm.b1; dc_dtnetm.w2 = -dcp_dtnetm.w2; dc_dtnetm.b2 = -dcp_dtnetm.b2;end% CG for networks and CG or Kalman for sourcesif strcmp(status.updatealg, 'conjgrad'), grad = [dc_dsm(:); ... dc_dnetm.w1(:); dc_dnetm.w2(:); ... dc_dnetm.b1(:); dc_dnetm.b2(:); ... dc_dtnetm.w1(:); dc_dtnetm.w2(:); ... dc_dtnetm.b1(:); dc_dtnetm.b2(:)]; gn = sum(grad.^2); if (oldgrads.norm ~= 0) & isfield(oldgrads, 'grad') & ... all(size(oldgrads.s) == size(dc_dsm)) & ... all(size(oldgrads.net.w1) == size(dc_dnetm.w1)) & ... all(size(oldgrads.tnet.w1) == size(dc_dtnetm.w1)) & ... all(size(oldgrads.grad) == size(grad)), % Powell-Beale restarts if oldgrads.grad'*grad >= .2*gn, fprintf('Resetting CG (Powell-Beale)\n'); oldgrads.s = zeros(size(sources)); oldgrads.net = netgrad_zeros(net); oldgrads.tnet = netgrad_zeros(tnet); oldgrads.norm = 0; oldgrads.grad = zeros(size(grad)); beta = 0; else % Fletcher-Reeves formula (from Antti's version) %beta = gn / oldgrads.norm; % Polak-Ribi閞e formula %beta = grad' * (grad - oldgrads.grad) / oldgrads.norm; %beta = max([0 beta]); % Hestenes-Stiefel formula beta = (gn - grad'*oldgrads.grad) / (oldgrads.p' * (grad - oldgrads.grad)); end if status.debug >= 2, Nh = size(status.kls, 2); status.history.beta(Nh) = beta; end dc_dsm = dc_dsm + beta * oldgrads.s; dc_dnetm.w1 = dc_dnetm.w1 + beta * oldgrads.net.w1; dc_dnetm.w2 = dc_dnetm.w2 + beta * oldgrads.net.w2; dc_dnetm.b1 = dc_dnetm.b1 + beta * oldgrads.net.b1; dc_dnetm.b2 = dc_dnetm.b2 + beta * oldgrads.net.b2; dc_dtnetm.w1 = dc_dtnetm.w1 + beta * oldgrads.tnet.w1; dc_dtnetm.w2 = dc_dtnetm.w2 + beta * oldgrads.tnet.w2; dc_dtnetm.b1 = dc_dtnetm.b1 + beta * oldgrads.tnet.b1; dc_dtnetm.b2 = dc_dtnetm.b2 + beta * oldgrads.tnet.b2; end oldgrads.norm = gn; oldgrads.s = dc_dsm; oldgrads.net = dc_dnetm; oldgrads.tnet = dc_dtnetm; oldgrads.p = -[dc_dsm(:); ... dc_dnetm.w1(:); dc_dnetm.w2(:); ... dc_dnetm.b1(:); dc_dnetm.b2(:); ... dc_dtnetm.w1(:); dc_dtnetm.w2(:); ... dc_dtnetm.b1(:); dc_dtnetm.b2(:)]; oldgrads.grad = grad;endstep.s = dc_dsm;step.net = dc_dnetm;step.tnet = dc_dtnetm;[sources, net, tnet, status] = cubic_search(... sources, net, tnet, newc, restc, step, ... fs, tfs, data, params, status.t0, status, missing, notimedep);if status.debug >= 2, Nh = size(status.kls, 2); status.history.t0(Nh) = status.t0; status.history.linestep(Nh) = sqrt(oldgrads.norm) * status.t0;endfunction [sources, net, tnet, newc, restc, status] = update_variances(... sources0, net0, tnet0, dcp_dsvn, fs, tfs, params, dcp_dnetv, dcp_dtnetv, ... newac, data, oldcost, status, clamped, missing, notimedep)% UPDATE_VARIANCES Update the variances of the network and the sourcesepsilon = 1e-6;epsilon2 = 1e-10;sources = sources0;net = net0;tnet = tnet0;if ~isfield(status, 'varalpha'), status.varalpha = 1;endnewc = inf;alpha = min([1, 10*sqrt(status.varalpha)]);if status.updatesrcvars >= 0, step.s = .5 ./ max(dcp_dsvn, .45 ./ sources.nvar);endstep.ac = newac;if status.updatenet >= 0, step.net.w1 = .5 ./ max(dcp_dnetv.w1, .45 ./ net.w1.var); step.net.w2 = .5 ./ max(dcp_dnetv.w2, .45 ./ net.w2.var); step.net.b1 = .5 ./ max(dcp_dnetv.b1, .45 ./ net.b1.var); step.net.b2 = .5 ./ max(dcp_dnetv.b2, .45 ./ net.b2.var);endif status.updatetnet >= 0, step.tnet.w1 = .5 ./ max(dcp_dtnetv.w1, .45 ./ tnet.w1.var); step.tnet.w2 = .5 ./ max(dcp_dtnetv.w2, .45 ./ tnet.w2.var); step.tnet.b1 = .5 ./ max(dcp_dtnetv.b1, .45 ./ tnet.b1.var); step.tnet.b2 = .5 ./ max(dcp_dtnetv.b2, .45 ./ tnet.b2.var);enditercount = 0;hcost(1) = oldcost;halpha(1) = 0;% Variation of backtracking linesearch (with quadratic interpolation)while newc > (oldcost + epsilon), itercount = itercount + 1; if alpha < epsilon2, warning('Variance update failed'); sources = sources0; net = net0; tnet = tnet0; newc = oldcost; break; end [sources, net, tnet] = update_var(... sources0, net0, tnet0, alpha, step, status); [newc, datac, restc, dync, netc] = kldiv(... [], [], sources, data, net, tnet, params, missing, notimedep, status); hcost(itercount + 1) = newc; halpha(itercount + 1) = alpha; if status.debug, fprintf('Variance update: c=%.16g, alpha=%f\n', newc, alpha); end if itercount > 1, % After first step use quadratic interpolation x1 = halpha(1); x2 = halpha(itercount); x3 = halpha(itercount + 1); c1 = hcost(1); c2 = hcost(itercount); c3 = hcost(itercount + 1); newalpha = (x1.^2 .* (c2-c3) + x2.^2 .* (c3-c1) + x3.^2 .* (c1-c2)) ./ ... (2*(x1 .* (c2-c3) + x2 .* (c3-c1) + x3 .* (c1-c2))); alpha = max([min([newalpha .5 * alpha]) .1 * alpha]); else alpha = .1 * alpha; endendstatus.varalpha = min([10*alpha 1]);function [A C] = pop(A, B),% POP Implements stack style pop operation with reshaping%% [A, C] = pop(A, B)%% Pop elements from A to create a size B matrix Cn = prod(size(B));C = reshape(A(1:n), size(B));A = A(n+1:end);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -