📄 bayesmean.m
字号:
function y = bayesmean(mu, sigma, mu0, sigma0, N, x)
% y = bayesmean(mu, sigma, mu0, sigma0, N, x)
%
% Bayesian learning of the mean of a Gaussian with known variance.
% N samples are drawn from the a Gaussian density with mean mu and
% standard deviation sigma.
%
% The mean of the samples is learned, one sample at a time, starting
% with a Gaussian prior, with mean mu0 and width sigma0. Graphs of the
% posterior densities are plotted after each new datum is observed.
%
% Defaults are:
% mu 1
% sigma 0.3
% mu0, 0.5
% sigma0 0.6
% N 50
%
% Posteriors are plotted at the locations in x (Default 200 points,
% 4*sigma on either side of mu).
% Copyright (c) Richard Everson, Exeter University 2000
% $Id: bayesmean.m,v 1.1 2000/11/27 22:48:41 reverson Exp $
if nargin < 1 | isempty(mu), mu = 1; end
if nargin < 2 | isempty(sigma), sigma = 0.3; end
if nargin < 3 | isempty(mu0), mu0 = 0.5; end
if nargin < 4 | isempty(sigma0), sigma0 = 0.3; end
if nargin < 5 | isempty(N), N = 50; end
if nargin < 6 | isempty(x),
x = linspace(floor(mu-4*sigma), ceil(mu+4*sigma), 200);
end
disp(sprintf('Actual mean %g', mu));
disp(sprintf('Actual standard deviation %g', sigma));
disp(sprintf('Prior mean %g', mu0));
disp(sprintf('Prior standard deviation %g', sigma0));
disp(' ')
disp(' ')
disp('First plot the prior');
disp(' ')
r = input('Press enter to continue: ');
X = sigma*randn(N, 1) + mu;
f1 = figure; pos = get(f1, 'Position'); pos(1) = 20; set(f1, 'Position', pos);
f2 = figure; pos = get(f2, 'Position'); pos(1) = 1000; set(f2, 'Position', pos);
prior = exp(-(x - mu0).^2/(2*sigma0^2))/sqrt(2*pi*sigma0^2);
figure(f1);
plot(x, prior, 'k', 'LineWidth', 2);
figure(f2);
S(1) = sigma0;
map(1) = mu0;
h = errorbar(0.0, mu0, sigma0);
axis([-1, N, min(x), max(x)]);
set(h, 'LineWidth', 3);
set(gca, 'FontSize', 18);
r = input('Press enter to incorporate each successive data point ');
for n = 1:N
av = mean(X(1:n)); % Sample mean
sigman2 = 1.0/sigma0^2 + n/sigma^2;
sigman2 = 1.0/sigman2; % Posterior variance
% Posterior mean
mun = (n*sigma0^2/(n*sigma0^2 + sigma^2))*av + ...
sigma^2*mu0/(n*sigma0^2 + sigma^2);
p = exp(-(x - mun).^2/(2*sigman2))/sqrt(2*pi*sigman2);
% Plot prior, posterior and previous posterior (= prior for this step).
figure(f1);
clf
plot(x, prior, 'k', 'LineWidth', 2);
hold on
for k = 1:n
plot(X(k), 0, 'ro', 'MarkerSize', 8, 'MarkerFaceColor', [1, 0, 0]);
end
if n >= 2
plot(x, pold, 'g','LineWidth', 2);
end
plot(x, p, 'b', 'LineWidth', 2)
set(gca, 'FontSize', 18);
% Width vs time plot
figure(f2)
map(n+1) = mun;
S(n+1) = sqrt(sigman2);
plot([0, 0], [min(x), max(x)], 'k');
hold on;
h = errorbar([0:n], map(1:n+1), S(1:n+1));
axis([-1, N, min(x), max(x)]);
set(h, 'LineWidth', 3);
set(gca, 'FontSize', 18);
plot(n, X(n), 'ro', 'MarkerSize', 5, 'MarkerFaceColor', [1, 0, 0]);
plot([-1, N], [mu, mu], 'k');
r = input(sprintf('%d > ', n));
pold = p;
end
return;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -