📄 mobp_network01.m
字号:
function [deltaW1,deltaW2,deltab1,deltab2,Error] = MOBP_Network(W1,W2,b1,b2,deltaW1,deltaW2,deltab1,deltab2,sample_in,sample_out,study_eff,sample_num,mr)
error = 0;
delta_W1 = 0;
delta_W2 = 0;
delta_b1 = 0;
delta_b2 = 0;
[r,c] = size(W1);
for i = 1:1:sample_num
n1 = W1 * sample_in(i) + b1;
a1 = logsig(n1);
n2 = W2 * a1 + b2;
a2 = n2;
error = error + abs(sample_out(i) - a2);
s2 = -2 * (sample_out(i) - a2);
F = [(1 - a1(1)) .* a1(1)];
for j = 2:1:r
F(:,j) = 0;
F(j,:) = 0;
F(j,j) = (1 - a1(j)) .* a1(j);
end
s1 = F * W2'.* s2;
delta_W2 = delta_W2 + s2 * a1';
delta_b2 = delta_b2 + s2;
delta_W1 = delta_W1 + s1 * sample_in(i)';
delta_b1 = delta_b1 + s1;
end
W1_delta = delta_W1 / sample_num;
W2_delta = delta_W2 / sample_num;
b1_delta = delta_b1 / sample_num;
b2_delta = delta_b2 / sample_num;
Error = error / sample_num;
deltaW1 = mr * deltaW1 - study_eff * (1 - mr) * W1_delta;
deltaW2 = mr * deltaW2 - study_eff * (1 - mr) * W2_delta;
deltab1 = mr * deltab1 - study_eff * (1 - mr) * b1_delta;
deltab2 = mr * deltab2 - study_eff * (1 - mr) * b2_delta;
function [W1,W2,b1,b2,error] = MOBP_Train(training_error,study_eff,mr,hidden)
deltaW1 = 0;
deltaW2 = 0;
deltab1 = 0;
deltab2 = 0;
W1 = rand(hidden,1);
b1 = rand(hidden,1);
W2 = rand(1,hidden);
b2 = rand(1);
sample_in = linspace(-2,2,40);
sample_out = fun(sample_in);
if max(abs(sample_in)) > 1
sample_in = sample_in / max(abs(sample_in));
end
if max(abs(sample_out)) > 1
sample_out = sample_out / max(abs(sample_out));
end
sample_point_x = linspace(-2,2,10);
sample_point_y = fun(sample_point_x);
if max(abs(sample_point_x)) > 1
sample_point_x = sample_point_x / max(abs(sample_point_x));
end
if max(abs(sample_point_y)) > 1
sample_point_y = sample_point_y / max(abs(sample_point_y));
end
sample_num = length(sample_in);
while (true)
clf;
plot(sample_point_x,sample_point_y,'+');
hold on;
y = [];
for i = 1:1:length(sample_in)
y(i) = W2 * logsig(W1 .* sample_in(i) + b1) + b2;
end
plot(sample_in,y,'r');
pause(0.01);
[deltaW1,deltaW2,deltab1,deltab2,Error] = MOBP_Network(W1,W2,b1,b2,deltaW1,deltaW2,deltab1,deltab2,sample_in,sample_out,study_eff,sample_num,mr);
W1 = W1 + deltaW1;
W2 = W2 + deltaW2;
b1 = b1 + deltab1;
b2 = b2 + deltab2;
error = Error;
if (error < training_error)
break;
end
end
hold on;
sample_x = linspace(-2,2);
sample_y = fun(sample_x);
if max(abs(sample_x)) > 1
sample_x = sample_x / max(abs(sample_x));
end
if max(abs(sample_y)) > 1
sample_y = sample_y / max(abs(sample_y));
end
plot(sample_x,sample_y,'k');
function y = fun(X)
%y = sin(X(1,:) * pi);
%y = sin(X(1,:) * pi/4);
y = atan(X(1,:));
%y = exp(atan(X(1,:)));
%y = X(1,:).^2;
%y = X(1,:).^2 + sin(pi * X(1,:)) + sqrt(abs(X(1,:)));
%y = X(1,:).^3 + 2.^exp(X(1,:));
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -