📄 mlp.m
字号:
function [u1,u2, v1, v2, b1, b2, times] = MLP(eps, initU1, initV1, initb1, initU2, initV2, initb2, alpha, input_P, target_T)
u1 = initU1;
v1 = initV1;
b1 = initb1;
u2 = initU2;
v2 = initV2;
b2 = initb2;
[len, dim] = size(input_P);
times=0;
index=1;
correct = 0;
lastcorrect=0;
while(true)
x0 = input_P(index,:)';
%先计算x1,x2,x3
x1 = sigmoid(u1*(x0.^2) + v1*x0 + b1);
x2 = sigmoid(u2*(x1.^2) + v2*x1 + b2);
t = target_T(index,:)';
%这里可能还有问题
result = sum((x2-t).^2);
if(result<eps)
correct = correct+1;
else
if (correct>=lastcorrect )
correct
times
lastcorrect=correct;
end
correct = 0;
end
if(correct>=len)
break;
end
%再计算s3,s2,s1
s2 = -2*difF(x2)*(t-x2);
s1 = difF(x1)*(2*u2*diag(x1)+v2)'*s2;
%再迭代更新权值和偏置值
u1 = u1-alpha*s1*(x0.^2)';
v1 = v1-alpha*s1*x0';
b1 = b1-alpha*s1;
u2 = u2-alpha*s2*(x1.^2)';
v2 = v2-alpha*s2*x1';
b2 = b2-alpha*s2;
times = times+1;
index = mod(index+1,len+1);
if(index==0)
index=1;
end
end
len
times = times-len;
function y = difF(x)
y=diag((1-x)'*diag(x));
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -