📄 update_natgrad.m
字号:
function [sources, net, tnet, oldgrads, status] = update_natgrad(... sources, net, tnet, dcp_dsm, dcp_dsvn, fs, tfs, params, ... dcp_dnetm, dcp_dnetv, dcp_dtnetm, dcp_dtnetv, newac, ... data, oldc, status, oldgrads, clamped, missing, notimedep)% UPDATE_NATGRAD Perform a line search to find optimal step length% and update all parameters based on natural conjugate gradient% Copyright (C) 1999-2005 Antti Honkela, Harri Valpola,% Xavier Giannakopoulos and Matti Tornio%% This package comes with ABSOLUTELY NO WARRANTY; for details% see License.txt in the program package. This is free software,% and you are welcome to redistribute it under certain conditions;% see License.txt for details.[sources, net, tnet, newc, restc, status] = update_variances(... sources, net, tnet, dcp_dsvn, fs, tfs, params, dcp_dnetv, dcp_dtnetv, ... newac, data, oldc, status, clamped, missing, notimedep);if status.debug > 4, keyboard;endif status.updatesrcs < 0, dc_dsm = zeros(size(dcp_dsm));else dc_dsm = -dcp_dsm; % Set gradient for the clamped sources to zeros dc_dsm(find(clamped)) = 0;endif status.updatenet < 0, dc_dnetm.w1 = zeros(size(dcp_dnetm.w1)); dc_dnetm.b1 = zeros(size(dcp_dnetm.b1)); dc_dnetm.w2 = zeros(size(dcp_dnetm.w2)); dc_dnetm.b2 = zeros(size(dcp_dnetm.b2));else dc_dnetm.w1 = -dcp_dnetm.w1; dc_dnetm.b1 = -dcp_dnetm.b1; dc_dnetm.w2 = -dcp_dnetm.w2; dc_dnetm.b2 = -dcp_dnetm.b2;endif status.updatetnet < 0, dc_dtnetm.w1 = zeros(size(dcp_dtnetm.w1)); dc_dtnetm.b1 = zeros(size(dcp_dtnetm.b1)); dc_dtnetm.w2 = zeros(size(dcp_dtnetm.w2)); dc_dtnetm.b2 = zeros(size(dcp_dtnetm.b2));else dc_dtnetm.w1 = -dcp_dtnetm.w1; dc_dtnetm.b1 = -dcp_dtnetm.b1; dc_dtnetm.w2 = -dcp_dtnetm.w2; dc_dtnetm.b2 = -dcp_dtnetm.b2;end% Collect the gradients and variances of the sources and weightsgrad = [dc_dsm(:); dc_dnetm.w1(:); dc_dnetm.w2(:); ... dc_dnetm.b1(:); dc_dnetm.b2(:); ... dc_dtnetm.w1(:); dc_dtnetm.w2(:); ... dc_dtnetm.b1(:); dc_dtnetm.b2(:)];var = [sources.var(:); ... net.w1.var(:); net.w2.var(:); ... net.b1.var(:); net.b2.var(:); ... tnet.w1.var(:); tnet.w2.var(:); ... tnet.b1.var(:); tnet.b2.var(:)];% Norm in Riemannian spacegn = sum(var.^(-1).*grad.^2);% Natural gradientf = var .* grad;fn = sum(var.^(-1).*f.^2);% Extract the individual components of the gradient[f dc_dsm] = pop(f, dc_dsm);[f dc_dnetm.w1] = pop(f, dc_dnetm.w1);[f dc_dnetm.w2] = pop(f, dc_dnetm.w2);[f dc_dnetm.b1] = pop(f, dc_dnetm.b1);[f dc_dnetm.b2] = pop(f, dc_dnetm.b2);[f dc_dtnetm.w1] = pop(f, dc_dtnetm.w1);[f dc_dtnetm.w2] = pop(f, dc_dtnetm.w2);[f dc_dtnetm.b1] = pop(f, dc_dtnetm.b1);[f dc_dtnetm.b2] = pop(f, dc_dtnetm.b2);f = var .* grad;if (oldgrads.norm ~= 0) & isfield(oldgrads, 'grad') & ... isfield(oldgrads, 'f') & ... all(size(oldgrads.s) == size(dc_dsm)) & ... all(size(oldgrads.net.w1) == size(dc_dnetm.w1)) & ... all(size(oldgrads.tnet.w1) == size(dc_dtnetm.w1)) & ... all(size(oldgrads.grad) == size(grad)), % Powell-Beale restarts (rude approximation in R-S, fix!) % if abs(oldgrads.grad'*(var.^(-1).*grad)) >= .5*gn && ... if abs((oldgrads.f'*(var.^(-1).*f))) >= .2*fn && ... status.cgreset > 10, fprintf('Resetting CG (Powell-Beale restart)\n'); oldgrads.s = zeros(size(sources)); oldgrads.net = netgrad_zeros(net); oldgrads.tnet = netgrad_zeros(tnet); oldgrads.norm = 0; oldgrads.grad = zeros(size(grad)); beta = 0; status.cgreset = 0; oldgrads.cgreset(size(status.kls, 2)) = true; else % Fletcher-Reeves formula (gives inferior results in most cases) %beta = fn / oldgrads.fnorm; % Polak-Ribi閞e formula beta = f' * (var.^(-1).*(f-oldgrads.f)) / oldgrads.fnorm; beta = max([0 beta]); if status.debug, fprintf('Beta=%f ratio=%f\n', beta, abs((oldgrads.f'*(var.^(-1).*f))) / fn); end status.cgreset = status.cgreset + 1; end dc_dsm = dc_dsm + beta * oldgrads.s; dc_dnetm.w1 = dc_dnetm.w1 + beta * oldgrads.net.w1; dc_dnetm.w2 = dc_dnetm.w2 + beta * oldgrads.net.w2; dc_dnetm.b1 = dc_dnetm.b1 + beta * oldgrads.net.b1; dc_dnetm.b2 = dc_dnetm.b2 + beta * oldgrads.net.b2; dc_dtnetm.w1 = dc_dtnetm.w1 + beta * oldgrads.tnet.w1; dc_dtnetm.w2 = dc_dtnetm.w2 + beta * oldgrads.tnet.w2; dc_dtnetm.b1 = dc_dtnetm.b1 + beta * oldgrads.tnet.b1; dc_dtnetm.b2 = dc_dtnetm.b2 + beta * oldgrads.tnet.b2;end% Store the gradient and state information for the next iterationoldgrads.norm = gn;oldgrads.fnorm = fn;oldgrads.s = dc_dsm;oldgrads.net = dc_dnetm;oldgrads.tnet = dc_dtnetm;oldgrads.grad = grad;oldgrads.var = var;oldgrads.f = f;step.s = dc_dsm;step.net = dc_dnetm;step.tnet = dc_dtnetm;[sources, net, tnet, status] = cubic_search(... sources, net, tnet, newc, restc, step, ... fs, tfs, data, params, status.t0, status, missing, notimedep);function [sources, net, tnet, newc, restc, status] = update_variances(... sources0, net0, tnet0, dcp_dsvn, fs, tfs, params, dcp_dnetv, dcp_dtnetv, ... newac, data, oldcost, status, clamped, missing, notimedep)% UPDATE_VARIANCES Update the variances of the network and the sourcesepsilon = 1e-6;epsilon2 = 1e-10;sources = sources0;net = net0;tnet = tnet0;if ~isfield(status, 'varalpha'), status.varalpha = 1;endnewc = inf;alpha = min([1, 10*sqrt(status.varalpha)]);if status.updatesrcvars >= 0, step.s = .5 ./ max(dcp_dsvn, .45 ./ sources.nvar); step.ac = newac;endif status.updatenet >= 0, step.net.w1 = .5 ./ max(dcp_dnetv.w1, .45 ./ net.w1.var); step.net.w2 = .5 ./ max(dcp_dnetv.w2, .45 ./ net.w2.var); step.net.b1 = .5 ./ max(dcp_dnetv.b1, .45 ./ net.b1.var); step.net.b2 = .5 ./ max(dcp_dnetv.b2, .45 ./ net.b2.var);endif status.updatetnet >= 0, step.tnet.w1 = .5 ./ max(dcp_dtnetv.w1, .45 ./ tnet.w1.var); step.tnet.w2 = .5 ./ max(dcp_dtnetv.w2, .45 ./ tnet.w2.var); step.tnet.b1 = .5 ./ max(dcp_dtnetv.b1, .45 ./ tnet.b1.var); step.tnet.b2 = .5 ./ max(dcp_dtnetv.b2, .45 ./ tnet.b2.var);enditercount = 0;hcost(1) = oldcost;halpha(1) = 0;t1 = 0;c1 = oldcost;%T = 50;%for i=1:T,% fprintf('%d ', i);% [sources, net, tnet] = update_var(...% sources0, net0, tnet0, alpha*(i-11)/(T-20), step, status);% c(i) = kldiv(...% [], [], sources, data, net, tnet, params, missing, notimedep, status);%end%%fprintf('\n');%plot(((1:T)-11)/(T-20), c);%pause% Variation of backtracking linesearch (with quadratic interpolation)while newc > (oldcost + epsilon), itercount = itercount + 1; if alpha < epsilon2, warning('Variance update failed'); sources = sources0; net = net0; tnet = tnet0; newc = oldcost; break; end [sources, net, tnet] = update_var(... sources0, net0, tnet0, alpha, step, status); [newc, datac, restc, dync, netc] = kldiv(... [], [], sources, data, net, tnet, params, missing, notimedep, status); t1 = alpha; t2 = t1; t3 = t2; c1 = newc; c2 = c1; c3 = c2; hcost(itercount + 1) = newc; halpha(itercount + 1) = alpha; if status.debug, fprintf('Variance update: c=%.16g, alpha=%f\n', newc, alpha); end if itercount > 1, % After first step use quadratic interpolation x1 = halpha(1); x2 = halpha(itercount); x3 = halpha(itercount + 1); c1 = hcost(1); c2 = hcost(itercount); c3 = hcost(itercount + 1); newalpha = (x1.^2 .* (c2-c3) + x2.^2 .* (c3-c1) + x3.^2 .* (c1-c2)) ./ ... (2*(x1 .* (c2-c3) + x2 .* (c3-c1) + x3 .* (c1-c2))); alpha = max([min([newalpha .5 * alpha]) .1 * alpha]); else alpha = .1 * alpha; endendstatus.varalpha = min([10*alpha 1]);function [A C] = pop(A, B),% POP Implements stack style pop operation with reshaping%% Returns matrix C with shape and size of matrix B from the vector% stack A.%% [A, C] = pop(A, B)%% Pop elements from vector A to create a size B matrix Cn = prod(size(B));C = reshape(A(1:n), size(B));A = A(n+1:end);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -