📄 update_everything.m
字号:
function [sources, net, oldgrads, status] = update_everything(... sources, net, dcp_dsm, dcp_dsv, fs, params, dcp_dnetm, dcp_dnetv, ... data, oldc, status, oldgrads)% UPDATE_EVERYTHING Perform a line search to find optimal step length% and update all parameters% Copyright (C) 1999-2004 Antti Honkela, Harri Valpola,% and Xavier Giannakopoulos.%% This package comes with ABSOLUTELY NO WARRANTY; for details% see License.txt in the program package. This is free software,% and you are welcome to redistribute it under certain conditions;% see License.txt for details.[sources, net, newc, restc, status] = update_variances(... sources, net, dcp_dsv, fs, params, dcp_dnetv, data, oldc, status);if status.updatesrcs < 0, dc_dsm = zeros(size(dcp_dsm));else dc_dsm = -dcp_dsm;endif status.updatenet < 0, dc_dnetm.w1 = zeros(size(dcp_dnetm.w1)); dc_dnetm.b1 = zeros(size(dcp_dnetm.b1)); dc_dnetm.w2 = zeros(size(dcp_dnetm.w2)); dc_dnetm.b2 = zeros(size(dcp_dnetm.b2));else dc_dnetm.w1 = -dcp_dnetm.w1; dc_dnetm.b1 = -dcp_dnetm.b1; dc_dnetm.w2 = -dcp_dnetm.w2; dc_dnetm.b2 = -dcp_dnetm.b2;endif strcmp(status.updatealg, 'conjgrad'), gn = sum(dc_dsm(:).^2) + ... sum(dcp_dnetm.w1(:).^2) + sum(dcp_dnetm.w2(:).^2) + ... sum(dcp_dnetm.b1(:).^2) + sum(dcp_dnetm.b2(:).^2); if (oldgrads.norm ~= 0) & all(size(oldgrads.s) == size(dc_dsm)) ... & all(size(oldgrads.net.w1) == size(dc_dnetm.w1)), beta = gn / oldgrads.norm; dc_dsm = dc_dsm + beta * oldgrads.s; dc_dnetm.w1 = dc_dnetm.w1 + beta * oldgrads.net.w1; dc_dnetm.w2 = dc_dnetm.w2 + beta * oldgrads.net.w2; dc_dnetm.b1 = dc_dnetm.b1 + beta * oldgrads.net.b1; dc_dnetm.b2 = dc_dnetm.b2 + beta * oldgrads.net.b2; end oldgrads.norm = gn; oldgrads.s = dc_dsm; oldgrads.net = dc_dnetm;end[sources, net, status.t0] = linesearch(... sources, net, newc, restc, dc_dsm, dc_dnetm, fs, data, params, ... status.t0, status);function [sources, net, newc, restc, status] = update_variances(... sources0, net0, dcp_dsv, fs, params, dcp_dnetv, ... data, oldcost, status)% UPDATE_VARIANCES Update the variances of the network and the sourcesepsilon=1e-6;sources = sources0;net = net0;if ~isfield(status, 'varalpha'), status.varalpha = 1;endnewc = inf;alpha = min([1, 2*sqrt(status.varalpha)]);if status.updatesrcvars >= 0, newsvar = .5 ./ max(dcp_dsv, .45 ./ sources.var);endif status.updatenet >= 0, neww1var = .5 ./ max(dcp_dnetv.w1, .45 ./ net.w1.var); neww2var = .5 ./ max(dcp_dnetv.w2, .45 ./ net.w2.var); newb1var = .5 ./ max(dcp_dnetv.b1, .45 ./ net.b1.var); newb2var = .5 ./ max(dcp_dnetv.b2, .45 ./ net.b2.var);endwhile newc > oldcost + epsilon, if status.updatesrcvars >= 0, sources.var = exp(alpha * log(newsvar) + (1-alpha) * log(sources0.var)); end if status.updatenet >= 0, net.w1.var = exp(alpha * log(neww1var) + (1-alpha) * log(net0.w1.var)); net.w2.var = exp(alpha * log(neww2var) + (1-alpha) * log(net0.w2.var)); net.b1.var = exp(alpha * log(newb1var) + (1-alpha) * log(net0.b1.var)); net.b2.var = exp(alpha * log(newb2var) + (1-alpha) * log(net0.b2.var)); end fs_tmp = feedfw(sources, net, status.approximation); [netc, restc] = kl_static_split(net, params); newc = kl_batch(fs_tmp{4}, sources, data, params) + netc + restc; if status.debug, fprintf('Variance update: alpha=%f, c=%f\n', alpha, newc); end alpha = .5 * alpha;endif (2*alpha < min([1, 2*status.varalpha])) && status.debug, fprintf('Decreased variance step...\n');endstatus.varalpha = 2*alpha;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -