📄 backpropagationxor.m
字号:
% M-file function, BackpropagationXOR.m
% Backpropagation算法_处理异或数据
% s1 第1组样本数
% s2 第2组样本数
% w1 w1初始值
% w2 w2初始值
% w3 w3初始值
% w4 w4初始值
% w5 w5初始值
% w6 w6初始值
% w7 w7初始值
% w8 w8初始值
% w9 w9初始值
% w10 w10初始值
% eta 学习速率
% alpha 动量
% iteraMax epoch的最大值
% correct 训练集的正确率
function [correct] = BackpropagationXOR(s1,s2,w1,w2,w3,w4,w5,w6,w7,w8,w9,w10,eta,alpha,iteraMax)
r0 = RandomXOR(s1,-0.5,-0.5,0,0);
r1 = RandomXOR(s2,-0.5,0.5,0,1);
r2 = RandomXOR(s2,0.5,-0.5,1,0);
r3 = RandomXOR(s1,0.5,0.5,1,1);
for k = 1:s1
I1(k) = r0(k,1);
I2(k) = r0(k,2);
I3(k) = r3(k,1);
I4(k) = r3(k,2);
d(k) = 0.99;
end
for k = 1:s2
I1(k + s1) = r1(k,1);
I2(k + s1) = r1(k,2);
I3(k + s1) = r2(k,1);
I4(k + s1) = r2(k,2);
d(k + s1) = 0.01;
end
for itera = 1:(s1+s2)*iteraMax
%向前传播输入
O1 = I1(mod(itera-1,s1+s2)+1);
O2 = I2(mod(itera-1,s1+s2)+1);
O3 = I3(mod(itera-1,s1+s2)+1);
O4 = I4(mod(itera-1,s1+s2)+1);
%隐藏层
I5 = w1 * O1 + w3 * O2 + w5 * O3 + w7 * O4;
O5 = 1 / (1 + exp(-I5));
I6 = w2 * O1 + w4 * O2 + w6 * O3 + w8 * O4;
O6 = 1 / (1 + exp(-I6));
%输出层
I7 = w9 * O5 + w10 * O6;
O7 = 1 / (1 + exp(-I7));
%反向传播误差
%输出层
T = d(mod(itera-1,s1+s2)+1);
delta7 = O7 * (1 - O7) * (T - O7);
%隐藏层
delta5 = O5 * (1 - O5) * w9 * delta7;
delta6 = O6 * (1 - O6) * w10 * delta7;
%更新权值
if (itera ~= 1)
deltaw1(itera) = eta * delta5 * O1 + alpha * deltaw1(itera-1);
deltaw2(itera) = eta * delta6 * O1 + alpha * deltaw2(itera-1);
deltaw3(itera) = eta * delta5 * O2 + alpha * deltaw3(itera-1);
deltaw4(itera) = eta * delta6 * O2 + alpha * deltaw4(itera-1);
deltaw5(itera) = eta * delta5 * O3 + alpha * deltaw5(itera-1);
deltaw6(itera) = eta * delta6 * O3 + alpha * deltaw6(itera-1);
deltaw7(itera) = eta * delta5 * O4 + alpha * deltaw7(itera-1);
deltaw8(itera) = eta * delta6 * O4 + alpha * deltaw8(itera-1);
deltaw9(itera) = eta * delta7 * O5 + alpha * deltaw9(itera-1);
deltaw10(itera) = eta * delta7 * O6 + alpha * deltaw10(itera-1);
else
deltaw1(itera) = eta * delta5 * O1;
deltaw2(itera) = eta * delta6 * O1;
deltaw3(itera) = eta * delta5 * O2;
deltaw4(itera) = eta * delta6 * O2;
deltaw5(itera) = eta * delta5 * O3;
deltaw6(itera) = eta * delta6 * O3;
deltaw7(itera) = eta * delta5 * O4;
deltaw8(itera) = eta * delta6 * O4;
deltaw9(itera) = eta * delta7 * O5;
deltaw10(itera) = eta * delta7 * O6;
end
w1 = w1 + deltaw1(itera);
w2 = w2 + deltaw2(itera);
w3 = w3 + deltaw3(itera);
w4 = w4 + deltaw4(itera);
w5 = w5 + deltaw5(itera);
w6 = w6 + deltaw6(itera);
w7 = w7 + deltaw7(itera);
w8 = w8 + deltaw8(itera);
w9 = w9 + deltaw9(itera);
w10 = w10 + deltaw10(itera);
end
correct = [0 0 0 0 0 0 0 0 0];
for itera = 1:s1
O1 = I1(itera);
O2 = I2(itera);
O3 = I3(itera);
O4 = I4(itera);
I5 = w1 * O1 + w3 * O2 + w5 * O3 + w7 * O4;
O5 = 1 / (1 + exp(-I5));
I6 = w2 * O1 + w4 * O2 + w6 * O3 + w8 * O4;
O6 = 1 / (1 + exp(-I6));
I7 = w9 * O5 + w10 * O6;
O7 = 1 / (1 + exp(-I7));
if (O7 >= 0.5)
correct(1) = correct(1) + 1;
end
end
for itera = 1:s2
O1 = I1(itera+s1);
O2 = I2(itera+s1);
O3 = I3(itera+s1);
O4 = I4(itera+s1);
I5 = w1 * O1 + w3 * O2 + w5 * O3 + w7 * O4;
O5 = 1 / (1 + exp(-I5));
I6 = w2 * O1 + w4 * O2 + w6 * O3 + w8 * O4;
O6 = 1 / (1 + exp(-I6));
I7 = w9 * O5 + w10 * O6;
O7 = 1 / (1 + exp(-I7));
if (O7 < 0.5)
correct(4) = correct(4) + 1;
end
end
correct(2) = s1;
correct(3) = correct(1) / correct(2);
correct(5) = s2;
correct(6) = correct(4) / correct(5);
correct(7) = correct(1) + correct(4);
correct(8) = correct(2) + correct(5);
correct(9) = correct(7) / correct(8);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -