⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 backpropagationplot.m

📁 人工神经网络(感知器模型和BP算法) 模式识别课程实验
💻 M
字号:
% M-file function, BackpropagationPlot.m 

% Backpropagation算法_带画图功能

% g1 第1组实际序号
% g2 第2组实际序号
% s1 第1组样本数
% s2 第2组样本数
% w1 w1初始值
% w2 w2初始值
% w3 w3初始值
% w4 w4初始值
% w5 w5初始值
% w6 w6初始值
% w7 w7初始值
% w8 w8初始值
% w9 w9初始值
% w10 w10初始值
% eta 学习速率
% alpha 动量
% iteraMax epoch的最大值

% correct 训练集的正确率
% correctAll 全部样本集的正确率

function [correct,correctAll] = BackpropagationPlot(g1,g2,s1,s2,w1,w2,w3,w4,w5,w6,w7,w8,w9,w10,eta,alpha,iteraMax)

load iris.dat;
r1 = Random(s1);
r2 = Random(s2);
for k = 1:s1
    I1(k) = iris(r1(k)+50*(g1-1),1);
    I2(k) = iris(r1(k)+50*(g1-1),2);
    I3(k) = iris(r1(k)+50*(g1-1),3);
    I4(k) = iris(r1(k)+50*(g1-1),4);
    d(k) = 0.99;
end
for k = 1:s2
    I1(k + s1) = iris(r2(k)+50*(g2-1),1);
    I2(k + s1) = iris(r2(k)+50*(g2-1),2);
    I3(k + s1) = iris(r2(k)+50*(g2-1),3);
    I4(k + s1) = iris(r2(k)+50*(g2-1),4);
    d(k + s1) = 0.01;
end
hold off;
for epoch = 1 : iteraMax 
    J(epoch) = 0;
    for itera = 1:s1+s2            
        %向前传播输入
        O1 = I1(itera);
        O2 = I2(itera);
        O3 = I3(itera);
        O4 = I4(itera);
        %隐藏层
        I5 = w1 * O1 + w3 * O2 + w5 * O3 + w7 * O4;
        O5 = 1 / (1 + exp(-I5));
        I6 = w2 * O1 + w4 * O2 + w6 * O3 + w8 * O4;
        O6 = 1 / (1 + exp(-I6));
        %输出层
        I7 = w9 * O5 + w10 * O6;
        O7 = 1 / (1 + exp(-I7));
        %反向传播误差
        %输出层
        T = d(itera);
        delta7 = O7 * (1 - O7) * (T - O7);
        J(epoch) = J(epoch) + (T - O7) ^ 2 / 2;
        %隐藏层
        delta5 = O5 * (1 - O5) * w9 * delta7;
        delta6 = O6 * (1 - O6) * w10 * delta7;
        %更新权值
        if (itera ~= 1)
            deltaw1(itera) = eta * delta5 * O1 + alpha * deltaw1(itera-1);
            deltaw2(itera) = eta * delta6 * O1 + alpha * deltaw2(itera-1);
            deltaw3(itera) = eta * delta5 * O2 + alpha * deltaw3(itera-1);
            deltaw4(itera) = eta * delta6 * O2 + alpha * deltaw4(itera-1);
            deltaw5(itera) = eta * delta5 * O3 + alpha * deltaw5(itera-1);
            deltaw6(itera) = eta * delta6 * O3 + alpha * deltaw6(itera-1);
            deltaw7(itera) = eta * delta5 * O4 + alpha * deltaw7(itera-1);
            deltaw8(itera) = eta * delta6 * O4 + alpha * deltaw8(itera-1);
            deltaw9(itera) = eta * delta7 * O5 + alpha * deltaw9(itera-1);
            deltaw10(itera) = eta * delta7 * O6 + alpha * deltaw10(itera-1);
        else
            deltaw1(itera) = eta * delta5 * O1;
            deltaw2(itera) = eta * delta6 * O1;
            deltaw3(itera) = eta * delta5 * O2;
            deltaw4(itera) = eta * delta6 * O2;
            deltaw5(itera) = eta * delta5 * O3;
            deltaw6(itera) = eta * delta6 * O3;
            deltaw7(itera) = eta * delta5 * O4;
            deltaw8(itera) = eta * delta6 * O4;
            deltaw9(itera) = eta * delta7 * O5;
            deltaw10(itera) = eta * delta7 * O6;
        end      
        w1 = w1 + deltaw1(itera);
        w2 = w2 + deltaw2(itera);
        w3 = w3 + deltaw3(itera);
        w4 = w4 + deltaw4(itera);
        w5 = w5 + deltaw5(itera);
        w6 = w6 + deltaw6(itera);
        w7 = w7 + deltaw7(itera);
        w8 = w8 + deltaw8(itera);
        w9 = w9 + deltaw9(itera);
        w10 = w10 + deltaw10(itera);
    end    
    JTest(epoch) = 0;
    for itera = 1:50        
        T = 0.99;
        O1 = iris(itera+50*(g1-1),1);
        O2 = iris(itera+50*(g1-1),2);
        O3 = iris(itera+50*(g1-1),3);
        O4 = iris(itera+50*(g1-1),4);
        I5 = w1 * O1 + w3 * O2 + w5 * O3 + w7 * O4;
        O5 = 1 / (1 + exp(-I5));
        I6 = w2 * O1 + w4 * O2 + w6 * O3 + w8 * O4;
        O6 = 1 / (1 + exp(-I6));
        I7 = w9 * O5 + w10 * O6;
        O7 = 1 / (1 + exp(-I7));
        JTest(epoch) = JTest(epoch) + (T - O7) ^ 2 / 2;
        T = 0.01;
        O1 = iris(itera+50*(g2-1),1);
        O2 = iris(itera+50*(g2-1),2);
        O3 = iris(itera+50*(g2-1),3);
        O4 = iris(itera+50*(g2-1),4);
        I5 = w1 * O1 + w3 * O2 + w5 * O3 + w7 * O4;
        O5 = 1 / (1 + exp(-I5));
        I6 = w2 * O1 + w4 * O2 + w6 * O3 + w8 * O4;
        O6 = 1 / (1 + exp(-I6));
        I7 = w9 * O5 + w10 * O6;
        O7 = 1 / (1 + exp(-I7));
        JTest(epoch) = JTest(epoch) + (T - O7) ^ 2 / 2;
    end
    axis([0,epoch,0,30]);
    plot(epoch, J(epoch),'ob');
    plot(epoch, JTest(epoch),'*r');
    hold on;
end
xlabel('Number of weight updates');
ylabel('Error');
title('Error versus weight updates');
text(2 * epoch / 3, 25,'  Training set error    o');
text(2 * epoch / 3, 20,'Validation set error    *');
correct = [0 0 0 0 0 0 0 0 0];
correctAll = [0 0 0 0 0 0 0 0 0];
for itera = 1:s1
    O1 = I1(itera);
    O2 = I2(itera);
    O3 = I3(itera);
    O4 = I4(itera);
    I5 = w1 * O1 + w3 * O2 + w5 * O3 + w7 * O4;
    O5 = 1 / (1 + exp(-I5));
    I6 = w2 * O1 + w4 * O2 + w6 * O3 + w8 * O4;
    O6 = 1 / (1 + exp(-I6));
    I7 = w9 * O5 + w10 * O6;
    O7 = 1 / (1 + exp(-I7));
    if (O7 >= 0.5)
        correct(1) = correct(1) + 1;
    end
end
for itera = 1:s2
    O1 = I1(itera+s1);
    O2 = I2(itera+s1);
    O3 = I3(itera+s1);
    O4 = I4(itera+s1);
    I5 = w1 * O1 + w3 * O2 + w5 * O3 + w7 * O4;
    O5 = 1 / (1 + exp(-I5));
    I6 = w2 * O1 + w4 * O2 + w6 * O3 + w8 * O4;
    O6 = 1 / (1 + exp(-I6));
    I7 = w9 * O5 + w10 * O6;
    O7 = 1 / (1 + exp(-I7));
    if (O7 < 0.5)
        correct(4) = correct(4) + 1;
    end
end
correct(2) = s1;
correct(3) = correct(1) / correct(2);
correct(5) = s2;
correct(6) = correct(4) / correct(5);
correct(7) = correct(1) + correct(4);
correct(8) = correct(2) + correct(5);
correct(9) = correct(7) / correct(8);
for itera = 1:50
    O1 = iris(itera+50*(g1-1),1);
    O2 = iris(itera+50*(g1-1),2);
    O3 = iris(itera+50*(g1-1),3);
    O4 = iris(itera+50*(g1-1),4);
    I5 = w1 * O1 + w3 * O2 + w5 * O3 + w7 * O4;
    O5 = 1 / (1 + exp(-I5));
    I6 = w2 * O1 + w4 * O2 + w6 * O3 + w8 * O4;
    O6 = 1 / (1 + exp(-I6));
    I7 = w9 * O5 + w10 * O6;
    O7 = 1 / (1 + exp(-I7));
    if (O7 >= 0.5)
        correctAll(1) = correctAll(1) + 1;
    end
    O1 = iris(itera+50*(g2-1),1);
    O2 = iris(itera+50*(g2-1),2);
    O3 = iris(itera+50*(g2-1),3);
    O4 = iris(itera+50*(g2-1),4);
    I5 = w1 * O1 + w3 * O2 + w5 * O3 + w7 * O4;
    O5 = 1 / (1 + exp(-I5));
    I6 = w2 * O1 + w4 * O2 + w6 * O3 + w8 * O4;
    O6 = 1 / (1 + exp(-I6));
    I7 = w9 * O5 + w10 * O6;
    O7 = 1 / (1 + exp(-I7));
    if (O7 < 0.5)
        correctAll(4) = correctAll(4) + 1;
    end
end
correctAll(2) = 50;
correctAll(3) = correctAll(1) / correctAll(2);
correctAll(5) = 50;
correctAll(6) = correctAll(4) / correctAll(5);
correctAll(7) = correctAll(1) + correctAll(4);
correctAll(8) = correctAll(2) + correctAll(5);
correctAll(9) = correctAll(7) / correctAll(8);

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -