📄 ndfa.m
字号:
if status.verbose, fprintf('Embedding data to continue a previous simulation.\n'); end data = embed(data, [], size(sources, 1)); end args = rmfield(args, 'sources');end% Modify sources to data length if necessary[Nx T] = size(data);[Ns Ts] = size(sources);if Ts > T, if status.verbose, fprintf('Trimming sources to data length.\n'); end sources = sources(:,1:T);elseif Ts < T, if status.verbose, fprintf('Padding sources to data length.\n'); end ns = ones(size(Ns, T - Ts)); newsources = acprobdist_alpha( ... repmat(sources(:,end).e, 1, T - Ts), ns * .0001); sources = [sources newsources];end% Initialise the clamped sourcesif isfield(args, 'initclamped'), control = args.initclamped; [Nu Tu] = size(control); % Remove old control values if isfield(status, 'controlchannels'), sources = sources(1:(Ns - Nu),:); Ns = Ns - Nu; end controls = Ns + (1:Nu); % Check that data and control lengths match if T ~= Tu, if status.embed.iters, control = [control nan*ones(Nu, T - Tu)]; Tu = T; else error('Data and control vectors must have equal number of columns!'); end end % Force control into acprobdist_alpha if ~isa(control, 'acprobdist_alpha'), control = acprobdist_alpha(control, .0001 * ones(Nu, T)); end % Find missing control values and set them to zeros missingcontrol = sparse(isnan(control.e)); clamped = sparse([false(Ns, T); ~missingcontrol]); control.e(find(missingcontrol)) = 0; % Add control variables to sources sources = [sources; control]; status.controlchannels = Nu; Ns = Ns + Nu; args = rmfield(args, 'initclamped'); if isfield(args, 'clamped'), args = rmfield(args, 'clamped'); endelse [clamped, args] = getargdef(args, 'clamped', sparse(false(Ns, T))); controls = []; if ~isfield(status, 'controlchannels'), status.controlchannels = 0; endend% Modify info on clamped sources to data lengthTms = size(clamped, 2);if Tms > T, clamped = clamped(1:Ns, 1:T);elseif Tms < T, clamped = [clamped repmat(clamped(:,end), 1, Tms - T)];end% Initialise ignored time dependenciesif isfield(args, 'notimedep'), if ~isempty(args.notimedep), % If the number of columns doesn't match sample count, use matrix as % indices if (size(args.notimedep, 2) ~= size(sources, 2) - 1) & T > 1, notimedep = sparse(zeros(1, T - 1)); notimedep(1,args.notimedep) = 1; else notimedep = args.notimedep; end % If only one row is provided, use it for all channels if size(notimedep, 1) == 1, notimedep = repmat(notimedep, Ns, 1); end else notimedep = []; end args = rmfield(args, 'notimedep');else notimedep = [];end% Initialise missing data valuesmissing = sparse(isnan(data));data(find(missing)) = 0;% Initialise constraints for data and controlif isfield(args, 'constraints'), if ~isempty(args.constraints), status.constraints = args.constraints; end args = rmfield(args, 'constraints');end% Initialise the observation networkif ~isfield(args, 'net'), if ~isfield(args, 'hidneurons'), error('Either net or hidneurons must be set!') end if status.verbose, fprintf('Initialising a new observation MLP network with %d hidden neurons.\n', args.hidneurons); end net = createnet_alpha(Ns, args.hidneurons, Nx, ... 'tanh', 1, 1, 1, 1, .01, .01); net.b2.e = net.b2.e + mean(data, 2); args = rmfield(args, 'hidneurons'); [net.identity, args] = getargdef(args, 'observationmapping', false);else if status.verbose, fprintf('Using a previously used observation MLP network with %d hidden neurons.\n', size(args.net.w1, 1)); end net = args.net; args = rmfield(args, 'net'); % Special observation network flags net.identity = getargdef(net, 'identity', false); [net.identity, args] = getargdef(args, 'observationmapping', net.identity);end% If identity mapping is used for observation mapping, we should% not waste time updating sources or the mappingif net.identity, status.updatesrcs = -inf; status.updatesrcvars = -inf; status.updatenet = -inf;end% Initialise the temporal networkif ~isfield(args, 'tnet'), if ~isfield(args, 'thidneurons'), error('Either tnet or thidneurons must be set!') end if status.verbose, fprintf('Initialising a new temporal MLP network with %d hidden neurons.\n', args.thidneurons); end tnet = createnet_alpha(size(sources, 1), args.thidneurons, size(sources, 1), ... 'tanh', 1, 1, .1, .01, .01, .01); args = rmfield(args, 'thidneurons');else if status.verbose, fprintf('Using a previously used temporal MLP network with %d hidden neurons.\n', size(args.tnet.w1, 1)); end tnet = args.tnet; args = rmfield(args, 'tnet');end% Initialise parameters and hyperparametersif ~isfield(args, 'params'), params.net.w2var = probdist(zeros(1, size(net.b1, 1)), ... .5 / size(data, 1) * ones(1, size(net.b1, 1))); params.tnet.w2var = probdist(zeros(1, size(tnet.b1, 1)), ... .5 / size(data, 1) * ones(1, size(tnet.b1, 1))); params.noise = probdist(.5 * log(.1) * ones(size(data, 1), 1), ... .5 / size(data, 2) * ones(size(data, 1), 1)); params.src = probdist(zeros(size(sources, 1), 1), ... .5 / size(data, 2) * ones(size(sources, 1), 1)); params.hyper.net.w2var = nlfa_inithyper(0, .1, 0, .1); params.hyper.net.b1 = nlfa_inithyper(0, .1, 0, .1); params.hyper.net.b2 = nlfa_inithyper(0, .1, 0, .1); params.hyper.tnet.w2var = nlfa_inithyper(0, .1, 0, .1); params.hyper.tnet.b1 = nlfa_inithyper(0, .1, 0, .1); params.hyper.tnet.b2 = nlfa_inithyper(0, .1, 0, .1); params.hyper.noise = nlfa_inithyper(0, .1, 0, .1); params.hyper.src = nlfa_inithyper(0, .1, 0, .1); params.prior.net.w2var = nlfa_initprior(0, log(100), 0, log(100)); params.prior.net.b1 = nlfa_initprior(0, log(100), 0, log(100)); params.prior.net.b2 = nlfa_initprior(0, log(100), 0, log(100)); params.prior.tnet.w2var = nlfa_initprior(0, log(100), 0, log(100)); params.prior.tnet.b1 = nlfa_initprior(0, log(100), 0, log(100)); params.prior.tnet.b2 = nlfa_initprior(0, log(100), 0, log(100)); params.prior.noise = nlfa_initprior(0, log(100), 0, log(100)); params.prior.src = nlfa_initprior(0, log(100), 0, log(100));else params = args.params; args = rmfield(args, 'params');end% No learning is done, the networks are not updated and the sources can be% updated from the beginningif isfield(args, 'nolearning'), status.updatenet = -inf; status.updatetnet = -inf; status.updateparams = 0; status.updatesrcs = 0; status.updatesrcvars = 0; args = rmfield(args, 'nolearning');end% Check that all the parameters were validotherargs = fieldnames(args);if length(otherargs) > 0, fprintf('Warning: ndfa: unused arguments:\n'); fprintf(' %s\n', otherargs{:});endif status.verbose == 1, status.verbose = 0;endstatus.cgreset = 0;% Do the actual iteration[sources, net, tnet, params, status, fs, tfs] = ...ndfa_iter(data, sources, net, tnet, params, status, missing, clamped, notimedep);% If only one return value is expected, pack everything to itif nargout == 1, val.sources = sources; val.clamped = clamped; val.notimedep = notimedep; val.net = net; val.tnet = tnet; val.params = params; val.status = status; sources = val;endfunction hyper = nlfa_inithyper(mm, mv, vm, vv)% Copyright (C) 2002 Harri Valpola and Antti Honkela.%% This package comes with ABSOLUTELY NO WARRANTY; for details% see License.txt in the program package. This is free software,% and you are welcome to redistribute it under certain conditions;% see License.txt for details.hyper.mean = probdist(mm, mv);hyper.var = probdist(vm, vv);function prior = nlfa_initprior(mm, mv, vm, vv)% Copyright (C) 2002 Harri Valpola and Antti Honkela.%% This package comes with ABSOLUTELY NO WARRANTY; for details% see License.txt in the program package. This is free software,% and you are welcome to redistribute it under certain conditions;% see License.txt for details.prior.mean.mean = probdist(mm, 0);prior.mean.var = probdist(mv, 0);prior.var.mean = probdist(vm, 0);prior.var.var = probdist(vv, 0);function status = convert_obsolote(status)% CONVERT_OBSOLOTE Convert a structure returned by a older version%% status = CONVERT_OBSOLOTE(status)%% Convert a result structure of an older version of NDFA to the% current format.%% Copyright (C) 2004-2005 Matti Tornio.%% This package comes with ABSOLUTELY NO WARRANTY; for details% see License.txt in the program package. This is free software,% and you are welcome to redistribute it under certain conditions;% see License.txt for details.status.controlchannels = getargdef(status, 'controlchannels', 0);status.embed = getargdef(status, 'embed', struct);status.embed.iters = getargdef(status.embed, 'iters', 0);status.prune = getargdef(status, 'prune', struct);status.prune.sources = getargdef(status.prune, 'sources', 0);status.prune.hidneurons = getargdef(status.prune, 'hidneurons', 0);status.prune.thidneurons = getargdef(status.prune, 'thidneurons', 0);status.prune.iters = getargdef(status.prune, 'iters', 0);status.approximation = getargdef(status, 'approximation', ... 'hermite');status.updatealg = getargdef(status, 'updatealg', 'conjgrad');status.epsilon = getargdef(status, 'epsilon', 1e-6);status.verbose = getargdef(status, 'verbose', 1);status.debug = getargdef(status, 'debug', 0);status.t0 = getargdef(status, 't0', 1);status.history = getargdef(status, 'history', {});
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -