📄 evalcost.m
字号:
function cost = evalcost(data, varargin),% EVALCOST Evaluate the cost of an NFA model%% cost = EVALCOST(data, otherargs...)% returns the cost function value for the given NLFA model.% The supported arguments are a subset of those accepted by NLFA,% those that are relevant to the problem at hand.%% One additional value for 'approximation' is supported, namely% 'mc' which uses Monte Carlo sampling to evaluate the cost.%% See also NLFA.% Copyright (C) 1999-2004 Antti Honkela, Harri Valpola,% and Xavier Giannakopoulos.%% This package comes with ABSOLUTELY NO WARRANTY; for details% see License.txt in the program package. This is free software,% and you are welcome to redistribute it under certain conditions;% see License.txt for details.% Read the argumentsif (length(varargin) == 1) & (isstruct(varargin{1})), args = varargin{1};elseif (mod(length(varargin), 2) ~= 0), if ~(isstruct(varargin{1})), error('Keyword arguments should appear in pairs'); else args = varargin{1}; for k=2:2:length(varargin), if ~ischar(varargin{k}) error('Keyword argument names must be strings.'); end eval(sprintf('args.%s = varargin{k+1};', varargin{k})); end endelse args = struct(varargin{:});end%argsif ~isfield(args, 'status'), [status.approximation, args] = getargdef(args, 'approximation', 'hermite');else status = args.status; args = rmfield(args, 'status'); if isfield(args, 'approximation'), status.approximation = args.approximation; args = rmfield(args, 'approximation'); endendif ~isfield(args, 'sources'), error('Sources must be set!');else sources = args.sources; args = rmfield(args, 'sources');endif ~isfield(args, 'net'), error('Net must be set!')else net = args.net; args = rmfield(args, 'net');endif ~isfield(args, 'params'), error('Params must be set!');else params = args.params; args = rmfield(args, 'params');endotherargs = fieldnames(args);if length(otherargs) > 0, fprintf('Warning: nlfa: unused arguments:\n'); fprintf(' %s\n', otherargs{:});end% Do the actual thing...cost = ... really_eval_cost(data, sources, net, params, status);function [val, args] = getargdef(args, name, default),if isfield(args, name), % val = args.(name); eval(sprintf('val = args.%s;', name)); args = rmfield(args, name);else val = default;endfunction cost = really_eval_cost(data, sources, net, params, status)nsampl = size(data, 2);fs = probdist(zeros(size(data)), ones(size(data)));% Do feedforward calculationsif strcmp(status.approximation, 'mc'), x = mc_feedfw(sources, net); fs = probdist(x.e, x.var);else x = feedfw(sources, net, status.approximation); fs = probdist(x{4}.e, x{4}.var);endcost = kl_static(net, params) + kl_batch(fs, sources, data, params);function x = mc_feedfw(s, net)% MC_FEEDFW Monte Carlo feedfwnpoints = 400;sp = repmat(s.e, [1, 1, npoints]) + ... repmat(sqrt(s.var), [1, 1, npoints]) .* ... randn(size(s.e, 1), size(s.e, 2), npoints);xp = mlpfw_mc(sp, net);x.e = mean(xp, 3);xd = (xp - repmat(x.e, [1, 1, npoints])).^2;x.var = mean(xd, 3);function x = mlpfw_mc(s, net),% MLPFW_MC Sample through an MLPfiller = ones(1, size(s, 2));x = zeros(size(net.w2.e, 1), size(s, 2), size(s, 3));for k = 1:size(s, 3), temp1 = (net.w1.e+randn(size(net.w1.e)).*sqrt(net.w1.var))*s(:, :, k) + ... (net.b1.e+randn(size(net.b1.e)).*sqrt(net.b1.var)) * filler; temp2 = feval(net.nonlin, temp1); x(:, :, k) = (net.w2.e+randn(size(net.w2.e)).*sqrt(net.w2.var))*temp2 + ... (net.b2.e+randn(size(net.b2.e)).*sqrt(net.b2.var)) * filler;end
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -