📄 feedback.m
字号:
function [dc_dsm, dc_dsv, dc_dnetm, dc_dnetv, dx] =... feedback(x, net, sources, data, noiseparam, status)% FEEDBACK Do feedback phase calculations% Copyright (C) 1999-2004 Antti Honkela, Harri Valpola,% and Xavier Giannakopoulos.%% This package comes with ABSOLUTELY NO WARRANTY; for details% see License.txt in the program package. This is free software,% and you are welcome to redistribute it under certain conditions;% see License.txt for details.noisevar = normalvar(noiseparam);nsampl = size(data, 2);nsources = size(sources, 1);datavars = noisevar * ones(1, nsampl);dx{4}.var = .5 ./ datavars;dx{4}.e = (x{4}.e - data(:,1:nsampl)) ./ datavars;dx{4}.extra = dx{4}.var;dx{4}.multi = repmat(shiftdim(sources.var, -1), ... [size(x{4}.multi, 1) 1 1])... .* repmat(reshape(dx{4}.var, [size(data, 1) 1 nsampl]), [1 nsources 1])... .* (2 * x{4}.multi);multivar = zeros(size(sources));% The first layer (linear)temp = (x{4}.multi).^2;% Somewhat more efficient way to calculate% multivar(:,i) = temp(:,:,i)' * multivar(:,i);for i=1:nsources multivar(i,:) = sum(reshape(temp(:,i,:), size(dx{4}.var)) .* dx{4}.var, 1);enddx{3}.var = net.w2.var' * dx{4}.extra;dx{3}.e = net.w2.e' * dx{4}.e + (2*net.w2.var' * dx{4}.extra) .* x{3}.e;dx{3}.extra = net.w2.e' .^2 * dx{4}.extra;%dx{3}.multi = zeros(size(x{3}.multi));% dx{3}.multi(:,:,i) = net.w2.e' * dx{4}.multi(:,:,i);d0 = size(net.w2, 2);[d1 d2 d3] = size(dx{4}.multi);dx{3}.multi = ... reshape(net.w2.e' * reshape(dx{4}.multi, [d1 d2*d3]), [d0 d2 d3]);[dc_dnetm.w2, dc_dnetv.w2, dc_dnetm.b2, dc_dnetv.b2] = ... netgrads(x{3}, dx{4}, net.w2, net.b2);% The second layer (nonlinear)if strcmp(status.approximation, 'hermite'), [dx{2}.e, dx{2}.var, dx{2}.multi, dx{2}.extra] = ... feedback_hermite(dx{3}.e, dx{3}.var, dx{3}.multi, dx{3}.extra, ... x{2}.e, x{2}.var, x{2}.multi, x{2}.extra, ... x{3}.e, x{3}.var, net.nonlin, x{5}, status);elseif strcmp(status.approximation, 'taylor'), [dx{2}.e, dx{2}.var, dx{2}.multi, dx{2}.extra] = ... feedback_taylor(dx{3}.e, dx{3}.var, dx{3}.multi, dx{3}.extra, ... x{2}.e, x{2}.var, x{2}.multi, x{2}.extra, ... net.nonlin);else error('Unsupported approximation')enddx{1}.e = net.w1.e' * dx{2}.e + ... (2 * net.w1.var' * (dx{2}.var + dx{2}.extra)) .* x{1}.e;dx{1}.var = (net.w1.e'.^2 + net.w1.var') * dx{2}.var + ... net.w1.var' * dx{2}.extra;[dc_dnetm.w1, dc_dnetv.w1, dc_dnetm.b1, dc_dnetv.b1] = ... netgradstop(x{1}, dx{2}, net.w1, net.b1);dc_dsm = dx{1}.e;dc_dsv = dx{1}.var + multivar;function [dm, dv, dmv, dev] = ... feedback_hermite(dgm0, dgv0, dgmv, dgev, m_in, v_in, mv_in, ev_in, ... m_out, v_out, nonlin, aux, status)% FEEDBACK_HERMITE Evaluate the gradients of Gauss-Hermite quadrature% approximation of nonlinearity% The order of approximation and related abscissas and weightsn = 3;xi = [0, sqrt(6)/2, -sqrt(6)/2];wi = [2/3, 1/6, 1/6];% Basis points with extravar as variance%ev_args = zeros([size(m_in), n]);%for k=1:length(xi),% ev_args(:, :, k) = m_in + xi(k) * sqrt(2 * ev_in);%end% Components of the sum to evaluate output mean with extravar as% input variance%ev_sum = repmat(reshape(wi, [1, 1, n]), [size(m_in), 1]) .* ...% feval(nonlin, ev_args);% Basis points with input var as variance%v_args = zeros([size(m_in), n]);%for k=1:length(xi),% v_args(:, :, k) = m_in + xi(k) * sqrt(2 * v_in);%end% Components of the sum to evaluate output mean with var as% input variance%v_sum = repmat(reshape(wi, [1, 1, n]), [size(m_in), 1]) .* ...% feval(nonlin, v_args);% Output mean (now given as input)% m_out = sum(v_sum, 3);% Normalised sum componentsv_sum0 = (repmat(reshape(wi, [1, 1, n]), [size(m_out), 1]) .* ... aux.f_vardevs);ev_sum0 = (repmat(reshape(wi, [1, 1, n]), [size(m_out), 1]) .* ... aux.f_evdevs);% Compensate the use of output mean in evaluation of var and extravar% dgm = dgm0 + 2 * (m_out - sum(ev_sum, 3)) .* dgev;dgm = dgm0 - 2 * sum(ev_sum0, 3) .* dgev;% Compensate the use of output variance in evaluation of multivar% This value cannot be used in gradients with respect to variance as% it breaks the fixed point update rule used for the variancesdgv_mv = dgv0 + 1 ./ (2 * sqrt(v_out .* v_in) + 1e-20) .* ... reshape(sum(dgmv .* mv_in, 2), size(dgv0));dgv = dgv0;% Easy case first: the multivarsd = sqrt(v_out ./ v_in);dmv = repmat(reshape(d, [size(m_in, 1), 1, size(dgmv, 3)]),... [1, size(dgmv, 2), 1]) .* dgmv;% Evaluate the derivative of the nonlinearity at basis pointsder_vargs = feval(['d3' nonlin], aux.v_args, aux.f_varvals);der_evargs = feval(['d3' nonlin], aux.ev_args, aux.f_evvals);% Partial derivative with respect to input extravardev = sum(ev_sum0 .* ... der_evargs .* ... repmat(reshape(xi, [1, 1, n]), [size(m_out), 1]), 3) .* ... (ev_in .^ -0.5) .* dgev;% Partial derivative with respect to input varsqrtvi = v_in .^ -0.5;temp = .5 * sum(repmat(reshape(wi .* xi, [1, 1, n]), [size(m_out), 1]) .* ... der_vargs, 3) .* sqrtvi .* dgm;if strcmp(status.updatealg, 'old'), dv = sum(v_sum0 .* ... der_vargs .* ... repmat(reshape(xi, [1, 1, n]), [size(m_out), 1]), 3) .* ... sqrtvi .* dgv + ... temp .* (temp > 0);else dv = sum(v_sum0 .* ... der_vargs .* ... repmat(reshape(xi, [1, 1, n]), [size(m_out), 1]), 3) .* ... sqrtvi .* dgv + temp;end% Partial derivative with respect to input meandm = sum(repmat(reshape(wi, [1, 1, n]), [size(m_out), 1]) .* ... der_vargs, 3) .* dgm + ... 2 * sum(v_sum0 .* ... der_vargs, 3) .* dgv_mv + ... 2 * sum(ev_sum0 .* ... der_evargs, 3) .* dgev;function [dm, dv, dmv, dev] = ... feedback_taylor(dgm0, dgv0, dgmv, dgev, m_in, v_in, mv_in, ev_in, nonlin)% FEEDBACK_TAYLOR Evaluate the gradients of Taylor% approximation of nonlinearity[der1, der2, der3] = feval(['d3' nonlin], m_in);temp = .5 * der2 .* dgm0;dv = temp .* (temp > 0) + (der1 .^ 2) .* dgv0;dm = dgm0 .* (der1 + .5*v_in .* der3 .* (temp > 0)) + ... 2 * dgv0 .* v_in .* der2 .* der1 + ... 2 * dgev .* ev_in .* der2 .* der1 + ... reshape(sum(dgmv .* mv_in, 2), size(der2)) .* der2;dev = (der1 .^ 2) .* dgev;dmv = repmat(reshape(der1, [size(m_in, 1) 1 size(dgmv, 3)]),... [1 size(dgmv, 2) 1]) .* dgmv;function [dcp_dwm, dcp_dwv, dcp_dbm, dcp_dbv] = netgrads(x, dx, w, b)% NETGRADS Calculate partial derivatives of kldiv with respect to% network weights% A more efficient way to calculate%temp = x.multivar;%for i=1:nsampl% bonus = bonus + dx.multi(:,:,i) * temp(:,:,i)';%endd0 = size(x.multi, 1);[d1 d2 d3] = size(dx.multi);bonus = reshape(dx.multi, [d1 d2*d3]) * reshape(x.multi, [d0 d2*d3])';dcp_dwm = dx.e * x.e' + ... 2 * (dx.extra * x.extra') .* w.e ... + bonus;dcp_dwv = dx.extra * (x.var + x.e .^ 2)';dcp_dbm = sum(dx.e, 2);dcp_dbv = sum(dx.extra, 2);function [dcp_dwm, dcp_dwv, dcp_dbm, dcp_dbv] = netgradstop(x, dx, w, b)% NETGRADSTOP Calculate partial derivatives of kldiv with respect to% network weightsdcp_dwm = dx.e * x.e' + ... 2 * (dx.var * x.var') .* w.e + ... sum(dx.multi, 3);dcp_dwv = (dx.extra + dx.var) * (x.var + x.e .^ 2)';dcp_dbm = sum(dx.e, 2);dcp_dbv = sum(dx.var + dx.extra, 2);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -