📄 kldiv.m
字号:
function [v_total, v_data, v_param, v_dyn, v_nets] = ... kldiv(tfs, fs, s, x, net, tnet, params, missing, notimedep, status, ... no_static)% KLDIV calculate Kullback-Leibler divergence%% Usage:% [kl_total, kl_data, kl_param, kl_dyn] = ...% kldiv(tfs, fs, sources, data, net, tnet, params,% missing, notimedep, status)%% [kl_total, kl_data, kl_param, kl_dyn] = ...% kldiv(data, result)% Copyright (C) 2002-2005 Harri Valpola, Antti Honkela and Matti Tornio.%% This package comes with ABSOLUTELY NO WARRANTY; for details% see License.txt in the program package. This is free software,% and you are welcome to redistribute it under certain conditions;% see License.txt for details.%% NOTE: Assume prior variance of net.w1 to be unity% If no status is provided, assume default valuesif nargin < 10, status.approximation = 'hermite'; status.freeinitial = 1;end % If only data and result from a previous run are provided, extract all the% parametersif nargin == 2 | nargin == 3, if nargin == 3, if isa(s, 'probdist'), fs.sources.e = s.e; fs.sources.var = s.var; else fs.sources.e = s; end end status = fs.status; notimedep = fs.notimedep; missing = isnan(tfs); params = fs.params; tnet = fs.tnet; net = fs.net; x = tfs; s = fs.sources; fs = []; tfs = [];endif nargin < 11, no_static = false;endif isempty(fs), x_tmp = feedfw(s, net, status.approximation); fs = probdist(x_tmp{4}.e, x_tmp{4}.var);endif isempty(tfs), tx_tmp = acfeedfw(s, tnet, status.approximation); if (status.freeinitial) tfs = probdist([s.e(:,1) tx_tmp{4}.e], ... [zeros(size(s, 1), 1) tx_tmp{4}.var]); else tfs = probdist([zeros(size(s, 1), 1) tx_tmp{4}.e], ... [zeros(size(s, 1), 1) tx_tmp{4}.var]); endend% Calculate the cost for reconstruction errorif net.identity, v_data = 0;else v_data = kl_data(probdist(fs.e - x, fs.var), probdist(0), ... params.noise, 2, missing);end% Calculate the cost for prediction errorv_dyn = kl_acparam(s, tfs, params.src, 2, notimedep);% Calculate the cost for (hyper-) parametersif no_static, v_param = 0;else v_param = kl_static(params);end% Calculate the cost for networksv_nets = ... kl_param(net.w2, probdist(0), params.net.w2var, 1) + ... kl_param(net.b1, params.hyper.net.b1.mean, params.hyper.net.b1.var) + ... kl_param(net.b2, params.hyper.net.b2.mean, params.hyper.net.b2.var) + ... kl_param(net.w1, probdist(0), probdist(0)) + ... kl_param(tnet.w2, probdist(0), params.tnet.w2var, 1) + ... kl_param(tnet.b1, params.hyper.tnet.b1.mean, ... params.hyper.tnet.b1.var) + ... kl_param(tnet.b2, params.hyper.tnet.b2.mean, ... params.hyper.tnet.b2.var) + ... kl_param(tnet.w1, probdist(0), probdist(0));v_total = v_data + v_dyn + v_param + v_nets;% Support for legacy call formatif nargout >= 3 & nargout <= 4, v_param = v_param + v_nets;end
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -