📄 feedback_hermite.m
字号:
function [dm, dv, dmv, dev] = ... feedback_hermite(dgm0, dgv0, dgmv, dgev, m_in, v_in, mv_in, ev_in, ... m_out, v_out, nonlin, aux, status)% FEEDBACK_HERMITE Evaluate the gradients of Gauss-Hermite quadrature% approximation of nonlinearity% The order of approximation and related abscissas and weightsn = 3;xi = [0, sqrt(6)/2, -sqrt(6)/2];wi = [2/3, 1/6, 1/6];% Basis points with extravar as variance%ev_args = zeros([size(m_in), n]);%for k=1:length(xi),% ev_args(:, :, k) = m_in + xi(k) * sqrt(2 * ev_in);%end% Components of the sum to evaluate output mean with extravar as% input variance%ev_sum = repmat(reshape(wi, [1, 1, n]), [size(m_in), 1]) .* ...% feval(nonlin, ev_args);% Basis points with input var as variance%v_args = zeros([size(m_in), n]);%for k=1:length(xi),% v_args(:, :, k) = m_in + xi(k) * sqrt(2 * v_in);%end% Components of the sum to evaluate output mean with var as% input variance%v_sum = repmat(reshape(wi, [1, 1, n]), [size(m_in), 1]) .* ...% feval(nonlin, v_args);% Output mean (now given as input)% m_out = sum(v_sum, 3);% Normalised sum componentsv_sum0 = (repmat(reshape(wi, [1, 1, n]), [size(m_out), 1]) .* ... aux.f_vardevs);ev_sum0 = (repmat(reshape(wi, [1, 1, n]), [size(m_out), 1]) .* ... aux.f_evdevs);% Compensate the use of output mean in evaluation of var and extravar% dgm = dgm0 + 2 * (m_out - sum(ev_sum, 3)) .* dgev;dgm = dgm0 - 2 * sum(ev_sum0, 3) .* dgev;% Compensate the use of output variance in evaluation of multivar% This value cannot be used in gradients with respect to variance as% it breaks the fixed point update rule used for the variancesdgv_mv = dgv0 + 1 ./ (2 * sqrt(v_out .* v_in) + 1e-20) .* ... reshape(sum(dgmv .* mv_in, 2), size(dgv0));dgv = dgv0;% Easy case first: the multivarsd = sqrt(v_out ./ v_in);dmv = repmat(reshape(d, [size(m_in, 1), 1, size(dgmv, 3)]),... [1, size(dgmv, 2), 1]) .* dgmv;% Evaluate the derivative of the nonlinearity at basis pointsder_vargs = feval(['d3' nonlin], aux.v_args, aux.f_varvals);der_evargs = feval(['d3' nonlin], aux.ev_args, aux.f_evvals);% Partial derivative with respect to input extravardev = sum(ev_sum0 .* ... der_evargs .* ... repmat(reshape(xi, [1, 1, n]), [size(m_out), 1]), 3) .* ... (ev_in .^ -0.5) .* dgev; %if strcmp(status.updatealg, 'old'),% temp = ev_sum0 .* ...% der_evargs .* ...% repmat(reshape(xi, [1, 1, n]), [size(m_out), 1]);% if (any(any(any(temp < 0)))),% warning('Negative values for extravar, compensating');% end% temp = temp .* (temp > 0);% dev = sum(temp, 3) .* (ev_in .^ -0.5) .* dgev; % Partial derivative with respect to input varsqrtvi = v_in .^ -0.5;temp = .5 * sum(repmat(reshape(wi .* xi, [1, 1, n]), [size(m_out), 1]) .* ... der_vargs, 3) .* sqrtvi .* dgm;if strcmp(status.updatealg, 'old'), dv = sum(v_sum0 .* ... der_vargs .* ... repmat(reshape(xi, [1, 1, n]), [size(m_out), 1]), 3) .* ... sqrtvi .* dgv + ... temp .* (temp > 0);else dv = sum(v_sum0 .* ... der_vargs .* ... repmat(reshape(xi, [1, 1, n]), [size(m_out), 1]), 3) .* ... sqrtvi .* dgv + temp;end% Partial derivative with respect to input meandm = sum(repmat(reshape(wi, [1, 1, n]), [size(m_out), 1]) .* ... der_vargs, 3) .* dgm + ... 2 * sum(v_sum0 .* ... der_vargs, 3) .* dgv_mv + ... 2 * sum(ev_sum0 .* ... der_evargs, 3) .* dgev;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -