⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 rpem_adaptive.m

📁 EM算法
💻 M
字号:
% Adaptive RPEM algorithm

function [new_betta, new_m, new_S, betta_trace, alpha_trace, m_trace, S_trace] = rpem_adaptive(x, k, lrate, betta, m, S)
total_epoch_no = input('Please input the number of epoches [default: 10]');
if isempty(total_epoch_no)
        total_epoch_no = 10;
end
epoch_no = total_epoch_no;

[d, N] = size(x);

% Parameter 
h = zeros(1, k);
alpha = zeros(1, k);
px = zeros(1, k);
new_betta = zeros(1, k);
new_m = zeros(d, k);
new_S = zeros(d,d,k);
z = zeros(d, k);

flag = 1;
epoch_count = 0;
colorset_m = zeros(d*k, 3);

mf = figure;
if d == 2
    xf = figure;
end

 % Record the parameters
 alpha_trace(:,epoch_count+1) = vec(exp(betta) / sum(exp(betta)));
 betta_trace(:, epoch_count+1) = vec(betta);
 m_trace(:,epoch_count+1) = vec(m);
 S_trace(:,epoch_count+1) = vec(S);
 
% Set the learning rate of betta
lrate_betta = 0.1 * lrate;

while (flag == 1)
    for t = 1:N
        % Perform Step 1
        max_betta = max(betta);
        if max_betta > 100
           betta = betta - 100;
       end
        alpha = exp(betta)/sum(exp(betta));


        for j = 1:k
            z(:,j) = x(:,t)-m(:,j);
            px(j) = exp(-0.5*z(:,j)'*S(:,:,j)*z(:,j))*sqrt(det(S(:,:,j)));
        end
        h = (alpha.*px)/sum(alpha.*px);
        [max_h, c] = max(h);
        
        % Perform Step 2
        for j = 1:k
                if j == c
                    temp_g = 2 - h(c);
                else
                    temp_g = -h(j);
                end

                if j == c
                          new_betta(j) = betta(j) + lrate_betta*(2-h(c)-alpha(c));                            
                else
                          new_betta(j) = betta(j) - lrate_betta*(h(j)+alpha(j));                            
                end                
                new_m(:,j) = m(:,j) +  lrate*temp_g*S(:,:,j)*z(:,j);
			    new_S(:,:,j) = (1+0.5*lrate*temp_g)*S(:,:,j) - 0.5*lrate*temp_g*S(:,:,j)*z(:,j)*z(:,j)'*S(:,:,j);
       end  
       betta = new_betta;
       m = new_m;
       S = new_S;
  end
 epoch_no = epoch_no - 1;
 epoch_count = epoch_count + 1;
 fprintf('Epoch = %d \n', epoch_count);

 % Record the parameters
 alpha_trace(:,epoch_count+1) = vec(exp(betta) / sum(exp(betta)));
 betta_trace(:, epoch_count+1) = vec(betta);
 m_trace(:,epoch_count+1) = vec(m);
 S_trace(:,epoch_count+1) = vec(S);
 
  if epoch_no == 0
            disp('The learned alpha are:');
            alpha = exp(new_betta)/sum(exp(new_betta))
            disp('The learned positions of m_js are:');
             new_m
            disp('The learned sigma are:');
            for j = 1:k
               inv(new_S(:,:,j))
            end
                 
        	figure(mf);
        	hold off;
        	for t=1:d*k
        		ploth = plot(0:epoch_count, m_trace(t,:));
        		set(ploth, 'Color', colorset_m(t,:));
        		hold on;
        	end
        	title('Learning Curve of Parameter m_js');
        	xlabel('No. of data points scanned (x 10^3)');
        	drawnow;
                  
            if d == 2
        	    figure(xf);
        	    hold off;
        	    plot(x(1,:), x(2,:), '.');
                hold on;
                plot(m(1,:), m(2,:), 'black*');
        	    title('Positions of Parameter m_js in Input Space');
        	    drawnow;
            end
            flag_c = input('Do you want to continue to learning the parameters (yes:1, no:0) [default: 1]?');
            if isempty(flag_c)
                    flag_c = 1;
           end
           if flag_c ~= 1
               flag = 0;
           else
               epoch_no = input('Please input the number of epoches [default 10]');
               if isempty(epoch_no)
                    epoch_no = 10;
               end
               total_epoch_no = total_epoch_no + epoch_no;
           end
       end
   end
   

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -