📄 rpem_adaptive.m
字号:
% Adaptive RPEM algorithm
function [new_betta, new_m, new_S, betta_trace, alpha_trace, m_trace, S_trace] = rpem_adaptive(x, k, lrate, betta, m, S)
total_epoch_no = input('Please input the number of epoches [default: 10]');
if isempty(total_epoch_no)
total_epoch_no = 10;
end
epoch_no = total_epoch_no;
[d, N] = size(x);
% Parameter
h = zeros(1, k);
alpha = zeros(1, k);
px = zeros(1, k);
new_betta = zeros(1, k);
new_m = zeros(d, k);
new_S = zeros(d,d,k);
z = zeros(d, k);
flag = 1;
epoch_count = 0;
colorset_m = zeros(d*k, 3);
mf = figure;
if d == 2
xf = figure;
end
% Record the parameters
alpha_trace(:,epoch_count+1) = vec(exp(betta) / sum(exp(betta)));
betta_trace(:, epoch_count+1) = vec(betta);
m_trace(:,epoch_count+1) = vec(m);
S_trace(:,epoch_count+1) = vec(S);
% Set the learning rate of betta
lrate_betta = 0.1 * lrate;
while (flag == 1)
for t = 1:N
% Perform Step 1
max_betta = max(betta);
if max_betta > 100
betta = betta - 100;
end
alpha = exp(betta)/sum(exp(betta));
for j = 1:k
z(:,j) = x(:,t)-m(:,j);
px(j) = exp(-0.5*z(:,j)'*S(:,:,j)*z(:,j))*sqrt(det(S(:,:,j)));
end
h = (alpha.*px)/sum(alpha.*px);
[max_h, c] = max(h);
% Perform Step 2
for j = 1:k
if j == c
temp_g = 2 - h(c);
else
temp_g = -h(j);
end
if j == c
new_betta(j) = betta(j) + lrate_betta*(2-h(c)-alpha(c));
else
new_betta(j) = betta(j) - lrate_betta*(h(j)+alpha(j));
end
new_m(:,j) = m(:,j) + lrate*temp_g*S(:,:,j)*z(:,j);
new_S(:,:,j) = (1+0.5*lrate*temp_g)*S(:,:,j) - 0.5*lrate*temp_g*S(:,:,j)*z(:,j)*z(:,j)'*S(:,:,j);
end
betta = new_betta;
m = new_m;
S = new_S;
end
epoch_no = epoch_no - 1;
epoch_count = epoch_count + 1;
fprintf('Epoch = %d \n', epoch_count);
% Record the parameters
alpha_trace(:,epoch_count+1) = vec(exp(betta) / sum(exp(betta)));
betta_trace(:, epoch_count+1) = vec(betta);
m_trace(:,epoch_count+1) = vec(m);
S_trace(:,epoch_count+1) = vec(S);
if epoch_no == 0
disp('The learned alpha are:');
alpha = exp(new_betta)/sum(exp(new_betta))
disp('The learned positions of m_js are:');
new_m
disp('The learned sigma are:');
for j = 1:k
inv(new_S(:,:,j))
end
figure(mf);
hold off;
for t=1:d*k
ploth = plot(0:epoch_count, m_trace(t,:));
set(ploth, 'Color', colorset_m(t,:));
hold on;
end
title('Learning Curve of Parameter m_js');
xlabel('No. of data points scanned (x 10^3)');
drawnow;
if d == 2
figure(xf);
hold off;
plot(x(1,:), x(2,:), '.');
hold on;
plot(m(1,:), m(2,:), 'black*');
title('Positions of Parameter m_js in Input Space');
drawnow;
end
flag_c = input('Do you want to continue to learning the parameters (yes:1, no:0) [default: 1]?');
if isempty(flag_c)
flag_c = 1;
end
if flag_c ~= 1
flag = 0;
else
epoch_no = input('Please input the number of epoches [default 10]');
if isempty(epoch_no)
epoch_no = 10;
end
total_epoch_no = total_epoch_no + epoch_no;
end
end
end
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -