⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 mobp_network01.m

📁 MOBP的代码。3层网络
💻 M
字号:
function [deltaW1,deltaW2,deltab1,deltab2,Error] = MOBP_Network(W1,W2,b1,b2,deltaW1,deltaW2,deltab1,deltab2,sample_in,sample_out,study_eff,sample_num,mr)
error = 0;
delta_W1 = 0;
delta_W2 = 0;
delta_b1 = 0;
delta_b2 = 0;
[r,c] = size(W1);
for i = 1:1:sample_num
n1 = W1 * sample_in(i) + b1;
a1 = logsig(n1);
n2 = W2 * a1 + b2;
a2 = n2;
error = error + abs(sample_out(i) - a2);
s2 = -2 * (sample_out(i) - a2);
F = [(1 - a1(1)) .* a1(1)];
for j = 2:1:r
F(:,j) = 0;
F(j,:) = 0;
F(j,j) = (1 - a1(j)) .* a1(j); 
end
s1 = F * W2'.* s2;

delta_W2 = delta_W2 + s2 * a1';
delta_b2 = delta_b2 + s2;
delta_W1 = delta_W1 + s1 * sample_in(i)';
delta_b1 = delta_b1 + s1;
end

W1_delta = delta_W1 / sample_num;
W2_delta = delta_W2 / sample_num;
b1_delta = delta_b1 / sample_num;
b2_delta = delta_b2 / sample_num;
Error = error / sample_num;

deltaW1 = mr * deltaW1 - study_eff * (1 - mr) * W1_delta;
deltaW2 = mr * deltaW2 - study_eff * (1 - mr) * W2_delta;
deltab1 = mr * deltab1 - study_eff * (1 - mr) * b1_delta; 
deltab2 = mr * deltab2 - study_eff * (1 - mr) * b2_delta;
function [W1,W2,b1,b2,error] = MOBP_Train(training_error,study_eff,mr,hidden)
deltaW1 = 0;
deltaW2 = 0;
deltab1 = 0;
deltab2 = 0;

W1 = rand(hidden,1);
b1 = rand(hidden,1);
W2 = rand(1,hidden);
b2 = rand(1);

sample_in = linspace(-2,2,40);
sample_out = fun(sample_in);
if max(abs(sample_in)) > 1
sample_in = sample_in / max(abs(sample_in));
end
if max(abs(sample_out)) > 1
sample_out = sample_out / max(abs(sample_out));
end

sample_point_x = linspace(-2,2,10);
sample_point_y = fun(sample_point_x);
if max(abs(sample_point_x)) > 1
sample_point_x = sample_point_x / max(abs(sample_point_x));
end
if max(abs(sample_point_y)) > 1
sample_point_y = sample_point_y / max(abs(sample_point_y));
end

sample_num = length(sample_in);
while (true) 
clf;
plot(sample_point_x,sample_point_y,'+');
hold on;
y = [];
for i = 1:1:length(sample_in)
y(i) = W2 * logsig(W1 .* sample_in(i) + b1) + b2;
end
plot(sample_in,y,'r');
pause(0.01);
[deltaW1,deltaW2,deltab1,deltab2,Error] = MOBP_Network(W1,W2,b1,b2,deltaW1,deltaW2,deltab1,deltab2,sample_in,sample_out,study_eff,sample_num,mr);
W1 = W1 + deltaW1;
W2 = W2 + deltaW2;
b1 = b1 + deltab1;
b2 = b2 + deltab2;
error = Error;
if (error < training_error)
break;
end
end

hold on;
sample_x = linspace(-2,2);
sample_y = fun(sample_x);
if max(abs(sample_x)) > 1
sample_x = sample_x / max(abs(sample_x));
end
if max(abs(sample_y)) > 1
sample_y = sample_y / max(abs(sample_y));
end
plot(sample_x,sample_y,'k');

function y = fun(X)
%y = sin(X(1,:) * pi);
%y = sin(X(1,:) * pi/4);
y = atan(X(1,:));
%y = exp(atan(X(1,:)));
%y = X(1,:).^2;
%y = X(1,:).^2 + sin(pi * X(1,:)) + sqrt(abs(X(1,:)));
%y = X(1,:).^3 + 2.^exp(X(1,:)); 

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -