📄 abalone.asv
字号:
clear;
load data.mat
Total = length(data(:,1));
VaribleNumber = length(data(1,:));
ClassNumber = 12;
SexNumber = 3;
Index = [1:Total];
[TestMatrix,TrainMatrix,TestIndex,TrainIndex] = SplitMatrixRandom(data);
TrainNumber = length(TrainIndex);
TestNumber = Total - TrainNumber;
TrainAnswer = TrainMatrix(:,VaribleNumber);
TrainClassIndex_1 = find((TrainAnswer-5).*(TrainAnswer-1) <= 0);
TrainClassIndex_2 = find(TrainAnswer == 6);
TrainClassIndex_3 = find(TrainAnswer == 7);
TrainClassIndex_4 = find(TrainAnswer == 8);
TrainClassIndex_5 = find(TrainAnswer == 9);
TrainClassIndex_6 = find(TrainAnswer == 10);
TrainClassIndex_7 = find(TrainAnswer == 11);
TrainClassIndex_8 = find(TrainAnswer == 12);
TrainClassIndex_9 = find(TrainAnswer == 13);
TrainClassIndex_10 = find(TrainAnswer == 14);
TrainClassIndex_11 = find((TrainAnswer-19).*(TrainAnswer-15) <= 0);
TrainClassIndex_12 = find(TrainAnswer >= 20);
Ppostsex = zeros(ClassNumber,3);
Ppostsex(1,1) = sum(TrainMatrix(TrainClassIndex_1,1) == 0)/length(TrainClassIndex_1);
Ppostsex(1,2) = sum(TrainMatrix(TrainClassIndex_1,1) == 1)/length(TrainClassIndex_1);
Ppostsex(1,3) = sum(TrainMatrix(TrainClassIndex_1,1) == 2)/length(TrainClassIndex_1);
Ppostsex(2,1) = sum(TrainMatrix(TrainClassIndex_2,1) == 0)/length(TrainClassIndex_2);
Ppostsex(2,2) = sum(TrainMatrix(TrainClassIndex_2,1) == 1)/length(TrainClassIndex_2);
Ppostsex(2,3) = sum(TrainMatrix(TrainClassIndex_2,1) == 2)/length(TrainClassIndex_2);
Ppostsex(3,1) = sum(TrainMatrix(TrainClassIndex_3,1) == 0)/length(TrainClassIndex_3);
Ppostsex(3,2) = sum(TrainMatrix(TrainClassIndex_3,1) == 1)/length(TrainClassIndex_3);
Ppostsex(3,3) = sum(TrainMatrix(TrainClassIndex_3,1) == 2)/length(TrainClassIndex_3);
Ppostsex(4,1) = sum(TrainMatrix(TrainClassIndex_4,1) == 0)/length(TrainClassIndex_4);
Ppostsex(4,2) = sum(TrainMatrix(TrainClassIndex_4,1) == 1)/length(TrainClassIndex_4);
Ppostsex(4,3) = sum(TrainMatrix(TrainClassIndex_4,1) == 2)/length(TrainClassIndex_4);
Ppostsex(5,1) = sum(TrainMatrix(TrainClassIndex_5,1) == 0)/length(TrainClassIndex_5);
Ppostsex(5,2) = sum(TrainMatrix(TrainClassIndex_5,1) == 1)/length(TrainClassIndex_5);
Ppostsex(5,3) = sum(TrainMatrix(TrainClassIndex_5,1) == 2)/length(TrainClassIndex_5);
Ppostsex(6,1) = sum(TrainMatrix(TrainClassIndex_6,1) == 0)/length(TrainClassIndex_6);
Ppostsex(6,2) = sum(TrainMatrix(TrainClassIndex_6,1) == 1)/length(TrainClassIndex_6);
Ppostsex(6,3) = sum(TrainMatrix(TrainClassIndex_6,1) == 2)/length(TrainClassIndex_6);
Ppostsex(7,1) = sum(TrainMatrix(TrainClassIndex_7,1) == 0)/length(TrainClassIndex_7);
Ppostsex(7,2) = sum(TrainMatrix(TrainClassIndex_7,1) == 1)/length(TrainClassIndex_7);
Ppostsex(7,3) = sum(TrainMatrix(TrainClassIndex_7,1) == 2)/length(TrainClassIndex_7);
Ppostsex(8,1) = sum(TrainMatrix(TrainClassIndex_8,1) == 0)/length(TrainClassIndex_8);
Ppostsex(8,2) = sum(TrainMatrix(TrainClassIndex_8,1) == 1)/length(TrainClassIndex_8);
Ppostsex(8,3) = sum(TrainMatrix(TrainClassIndex_8,1) == 2)/length(TrainClassIndex_8);
Ppostsex(9,1) = sum(TrainMatrix(TrainClassIndex_9,1) == 0)/length(TrainClassIndex_9);
Ppostsex(9,2) = sum(TrainMatrix(TrainClassIndex_9,1) == 1)/length(TrainClassIndex_9);
Ppostsex(9,3) = sum(TrainMatrix(TrainClassIndex_9,1) == 2)/length(TrainClassIndex_9);
Ppostsex(10,1) = sum(TrainMatrix(TrainClassIndex_10,1) == 0)/length(TrainClassIndex_10);
Ppostsex(10,2) = sum(TrainMatrix(TrainClassIndex_10,1) == 1)/length(TrainClassIndex_10);
Ppostsex(10,3) = sum(TrainMatrix(TrainClassIndex_10,1) == 2)/length(TrainClassIndex_10);
Ppostsex(11,1) = sum(TrainMatrix(TrainClassIndex_11,1) == 0)/length(TrainClassIndex_11);
Ppostsex(11,2) = sum(TrainMatrix(TrainClassIndex_11,1) == 1)/length(TrainClassIndex_11);
Ppostsex(11,3) = sum(TrainMatrix(TrainClassIndex_11,1) == 2)/length(TrainClassIndex_11);
Ppostsex(12,1) = sum(TrainMatrix(TrainClassIndex_12,1) == 0)/length(TrainClassIndex_12);
Ppostsex(12,2) = sum(TrainMatrix(TrainClassIndex_12,1) == 1)/length(TrainClassIndex_12);
Ppostsex(12,3) = sum(TrainMatrix(TrainClassIndex_12,1) == 2)/length(TrainClassIndex_12);
Pw = zeros(ClassNumber,1);
Pw(1,1) = length(TrainClassIndex_1)/length(TrainIndex);
Pw(2,1) = length(TrainClassIndex_2)/length(TrainIndex);
Pw(12,1) = length(TrainClassIndex_12)/length(TrainIndex);
Pw(3,1) = length(TrainClassIndex_3)/length(TrainIndex);
Pw(4,1) = length(TrainClassIndex_4)/length(TrainIndex);
Pw(5,1) = length(TrainClassIndex_5)/length(TrainIndex);
Pw(6,1) = length(TrainClassIndex_6)/length(TrainIndex);
Pw(7,1) = length(TrainClassIndex_7)/length(TrainIndex);
Pw(8,1) = length(TrainClassIndex_8)/length(TrainIndex);
Pw(9,1) = length(TrainClassIndex_9)/length(TrainIndex);
Pw(10,1) = length(TrainClassIndex_10)/length(TrainIndex);
Pw(11,1) = length(TrainClassIndex_11)/length(TrainIndex);
MeanMatrix_M = zeros(ClassNumber,VaribleNumber-2);
MeanMatrix_F = zeros(ClassNumber,VaribleNumber-2);
MeanMatrix_I = zeros(ClassNumber,VaribleNumber-2);
ConvarianceMatrix_M = zeros(VaribleNumber-2,VaribleNumber-2,ClassNumber);
ConvarianceMatrix_F = zeros(VaribleNumber-2,VaribleNumber-2,ClassNumber);
ConvarianceMatrix_I = zeros(VaribleNumber-2,VaribleNumber-2,ClassNumber);
FailGroup = ones(ClassNumber,3);
SampleNumber = 10;
tmp = TrainMatrix(TrainClassIndex_1,1:end-1);
sextmp = tmp(find(tmp(:,1) == 0),2:end);
if length(sextmp) < SampleNumber
FailGroup(1,1) = 0;
end
MeanMatrix_M(1,:) = mean(sextmp);
ConvarianceMatrix_M(:,:,1) = cov(sextmp);
sextmp = tmp(find(tmp(:,1) == 1),2:end);
if length(sextmp) < SampleNumber
FailGroup(1,2) = 0;
end
MeanMatrix_F(1,:) = mean(sextmp);
ConvarianceMatrix_F(:,:,1) = cov(sextmp);
sextmp = tmp(find(tmp(:,1) == 2),2:end);
if length(sextmp) < SampleNumber
FailGroup(1,3) = 0;
end
MeanMatrix_I(1,:) = mean(sextmp);
ConvarianceMatrix_I(:,:,1) = cov(sextmp);
tmp = TrainMatrix(TrainClassIndex_2,1:end-1);
sextmp = tmp(find(tmp(:,1) == 0),2:end);
if length(sextmp) < SampleNumber
FailGroup(2,1) = 0;
end
MeanMatrix_M(2,:) = mean(sextmp);
ConvarianceMatrix_M(:,:,2) = cov(sextmp);
sextmp = tmp(find(tmp(:,1) == 1),2:end);
if length(sextmp) < SampleNumber
FailGroup(2,2) = 0;
end
MeanMatrix_F(2,:) = mean(sextmp);
ConvarianceMatrix_F(:,:,2) = cov(sextmp);
sextmp = tmp(find(tmp(:,1) == 2),2:end);
if length(sextmp) < SampleNumber
FailGroup(2,3) = 0;
end
MeanMatrix_I(2,:) = mean(sextmp);
ConvarianceMatrix_I(:,:,2) = cov(sextmp);
tmp = TrainMatrix(TrainClassIndex_3,1:end-1);
sextmp = tmp(find(tmp(:,1) == 0),2:end);
if length(sextmp) < SampleNumber
FailGroup(3,1) = 0;
end
MeanMatrix_M(3,:) = mean(sextmp);
ConvarianceMatrix_M(:,:,3) = cov(sextmp);
sextmp = tmp(find(tmp(:,1) == 1),2:end);
if length(sextmp) < SampleNumber
FailGroup(3,2) = 0;
end
MeanMatrix_F(3,:) = mean(sextmp);
ConvarianceMatrix_F(:,:,3) = cov(sextmp);
sextmp = tmp(find(tmp(:,1) == 2),2:end);
if length(sextmp) < SampleNumber
FailGroup(3,3) = 0;
end
MeanMatrix_I(3,:) = mean(sextmp);
ConvarianceMatrix_I(:,:,3) = cov(sextmp);
tmp = TrainMatrix(TrainClassIndex_4,1:end-1);
sextmp = tmp(find(tmp(:,1) == 0),2:end);
if length(sextmp) < SampleNumber
FailGroup(4,1) = 0;
end
MeanMatrix_M(4,:) = mean(sextmp);
ConvarianceMatrix_M(:,:,4) = cov(sextmp);
sextmp = tmp(find(tmp(:,1) == 1),2:end);
if length(sextmp) < SampleNumber
FailGroup(4,2) = 0;
end
MeanMatrix_F(4,:) = mean(sextmp);
ConvarianceMatrix_F(:,:,4) = cov(sextmp);
sextmp = tmp(find(tmp(:,1) == 2),2:end);
if length(sextmp) < SampleNumber
FailGroup(4,3) = 0;
end
MeanMatrix_I(4,:) = mean(sextmp);
ConvarianceMatrix_I(:,:,4) = cov(sextmp);
tmp = TrainMatrix(TrainClassIndex_5,1:end-1);
sextmp = tmp(find(tmp(:,1) == 0),2:end);
if length(sextmp) < SampleNumber
FailGroup(5,1) = 0;
end
MeanMatrix_M(5,:) = mean(sextmp);
ConvarianceMatrix_M(:,:,5) = cov(sextmp);
sextmp = tmp(find(tmp(:,1) == 1),2:end);
if length(sextmp) < 3
FailGroup(5,2) = 0;
end
MeanMatrix_F(5,:) = mean(sextmp);
ConvarianceMatrix_F(:,:,5) = cov(sextmp);
sextmp = tmp(find(tmp(:,1) == 2),2:end);
if length(sextmp) < 3
FailGroup(5,3) = 0;
end
MeanMatrix_I(5,:) = mean(sextmp);
ConvarianceMatrix_I(:,:,5) = cov(sextmp);
tmp = TrainMatrix(TrainClassIndex_6,1:end-1);
sextmp = tmp(find(tmp(:,1) == 0),2:end);
if length(sextmp) < 3
FailGroup(6,1) = 0;
end
MeanMatrix_M(6,:) = mean(sextmp);
ConvarianceMatrix_M(:,:,6) = cov(sextmp);
sextmp = tmp(find(tmp(:,1) == 1),2:end);
if length(sextmp) < 3
FailGroup(6,2) = 0;
end
MeanMatrix_F(6,:) = mean(sextmp);
ConvarianceMatrix_F(:,:,6) = cov(sextmp);
sextmp = tmp(find(tmp(:,1) == 2),2:end);
if length(sextmp) < 3
FailGroup(6,3) = 0;
end
MeanMatrix_I(6,:) = mean(sextmp);
ConvarianceMatrix_I(:,:,6) = cov(sextmp);
tmp = TrainMatrix(TrainClassIndex_7,1:end-1);
sextmp = tmp(find(tmp(:,1) == 0),2:end);
if length(sextmp) < 3
FailGroup(7,1) = 0;
end
MeanMatrix_M(7,:) = mean(sextmp);
ConvarianceMatrix_M(:,:,7) = cov(sextmp);
sextmp = tmp(find(tmp(:,1) == 1),2:end);
if length(sextmp) < 3
FailGroup(7,2) = 0;
end
MeanMatrix_F(7,:) = mean(sextmp);
ConvarianceMatrix_F(:,:,7) = cov(sextmp);
sextmp = tmp(find(tmp(:,1) == 2),2:end);
if length(sextmp) < 3
FailGroup(7,3) = 0;
end
MeanMatrix_I(7,:) = mean(sextmp);
ConvarianceMatrix_I(:,:,7) = cov(sextmp);
tmp = TrainMatrix(TrainClassIndex_8,1:end-1);
sextmp = tmp(find(tmp(:,1) == 0),2:end);
if length(sextmp) < 3
FailGroup(8,1) = 0;
end
MeanMatrix_M(8,:) = mean(sextmp);
ConvarianceMatrix_M(:,:,8) = cov(sextmp);
sextmp = tmp(find(tmp(:,1) == 1),2:end);
if length(sextmp) < 3
FailGroup(8,2) = 0;
end
MeanMatrix_F(8,:) = mean(sextmp);
ConvarianceMatrix_F(:,:,8) = cov(sextmp);
sextmp = tmp(find(tmp(:,1) == 2),2:end);
if length(sextmp) < 3
FailGroup(8,3) = 0;
end
MeanMatrix_I(8,:) = mean(sextmp);
ConvarianceMatrix_I(:,:,8) = cov(sextmp);
tmp = TrainMatrix(TrainClassIndex_9,1:end-1);
sextmp = tmp(find(tmp(:,1) == 0),2:end);
if length(sextmp) < 3
FailGroup(9,1) = 0;
end
MeanMatrix_M(9,:) = mean(sextmp);
ConvarianceMatrix_M(:,:,9) = cov(sextmp);
sextmp = tmp(find(tmp(:,1) == 1),2:end);
if length(sextmp) < 3
FailGroup(9,2) = 0;
end
MeanMatrix_F(9,:) = mean(sextmp);
ConvarianceMatrix_F(:,:,9) = cov(sextmp);
sextmp = tmp(find(tmp(:,1) == 2),2:end);
if length(sextmp) < 3
FailGroup(9,3) = 0;
end
MeanMatrix_I(9,:) = mean(sextmp);
ConvarianceMatrix_I(:,:,9) = cov(sextmp);
tmp = TrainMatrix(TrainClassIndex_10,1:end-1);
sextmp = tmp(find(tmp(:,1) == 0),2:end);
if length(sextmp) < 3
FailGroup(10,1) = 0;
end
MeanMatrix_M(10,:) = mean(sextmp);
ConvarianceMatrix_M(:,:,10) = cov(sextmp);
sextmp = tmp(find(tmp(:,1) == 1),2:end);
if length(sextmp) < 3
FailGroup(10,2) = 0;
end
MeanMatrix_F(10,:) = mean(sextmp);
ConvarianceMatrix_F(:,:,10) = cov(sextmp);
sextmp = tmp(find(tmp(:,1) == 2),2:end);
if length(sextmp) < 3
FailGroup(10,3) = 0;
end
MeanMatrix_I(10,:) = mean(sextmp);
ConvarianceMatrix_I(:,:,10) = cov(sextmp);
tmp = TrainMatrix(TrainClassIndex_11,1:end-1);
sextmp = tmp(find(tmp(:,1) == 0),2:end);
if length(sextmp) < 3
FailGroup(11,1) = 0;
end
MeanMatrix_M(11,:) = mean(sextmp);
ConvarianceMatrix_M(:,:,11) = cov(sextmp);
sextmp = tmp(find(tmp(:,1) == 1),2:end);
if length(sextmp) < 3
FailGroup(11,2) = 0;
end
MeanMatrix_F(11,:) = mean(sextmp);
ConvarianceMatrix_F(:,:,11) = cov(sextmp);
sextmp = tmp(find(tmp(:,1) == 2),2:end);
if length(sextmp) < 3
FailGroup(11,3) = 0;
end
MeanMatrix_I(11,:) = mean(sextmp);
ConvarianceMatrix_I(:,:,11) = cov(sextmp);
tmp = TrainMatrix(TrainClassIndex_12,1:end-1);
sextmp = tmp(find(tmp(:,1) == 0),2:end);
if length(sextmp) < 3
FailGroup(12,1) = 0;
end
MeanMatrix_M(12,:) = mean(sextmp);
ConvarianceMatrix_M(:,:,12) = cov(sextmp);
sextmp = tmp(find(tmp(:,1) == 1),2:end);
if length(sextmp) < 3
FailGroup(12,2) = 0;
end
MeanMatrix_F(12,:) = mean(sextmp);
ConvarianceMatrix_F(:,:,12) = cov(sextmp);
sextmp = tmp(find(tmp(:,1) == 2),2:end);
if length(sextmp) < 3
FailGroup(12,3) = 0;
end
MeanMatrix_I(12,:) = mean(sextmp);
ConvarianceMatrix_I(:,:,12) = cov(sextmp);
%zero-one loss function
TestCondition = zeros(TestNumber,ClassNumber);
inv_ConvarianceMatrix_M = zeros(VaribleNumber-2,VaribleNumber-2,ClassNumber);
det_ConvarianceMatrix_M = zeros(1,ClassNumber);
inv_ConvarianceMatrix_F = zeros(VaribleNumber-2,VaribleNumber-2,ClassNumber);
det_ConvarianceMatrix_F = zeros(1,ClassNumber);
inv_ConvarianceMatrix_I = zeros(VaribleNumber-2,VaribleNumber-2,ClassNumber);
det_ConvarianceMatrix_I = zeros(1,ClassNumber);
for i = 1:ClassNumber
inv_ConvarianceMatrix_M(:,:,i) = inv(ConvarianceMatrix_M(:,:,i));
det_ConvarianceMatrix_M(:,i) = sqrt(abs(det(ConvarianceMatrix_M(:,:,i))));
inv_ConvarianceMatrix_F(:,:,i) = inv(ConvarianceMatrix_F(:,:,i));
det_ConvarianceMatrix_F(:,i) = sqrt(abs(det(ConvarianceMatrix_F(:,:,i))));
inv_ConvarianceMatrix_I(:,:,i) = inv(ConvarianceMatrix_I(:,:,i));
det_ConvarianceMatrix_I(:,i) = sqrt(abs(det(ConvarianceMatrix_I(:,:,i))));
end
pi2 = (2*pi)^(VaribleNumber/2);
Final = zeros(3,1);
for i = 1:TestNumber
for j = 1:ClassNumber
if (FailGroup(j,1) == 0)
Final(1,1) = 0;
else
Final(1,1) = exp(-.5*(TestMatrix(i,2:end-1)-MeanMatrix_M(j,:))*inv_ConvarianceMatrix_M(:,:,j)*(TestMatrix(i,2:end-1)-MeanMatrix_M(j,:))')/det_ConvarianceMatrix_M(:,j)/pi2*Pw(j,1)*Ppostsex(j,1);
end
if (FailGroup(j,2) == 0)
Final(2,1) = 0;
else
Final(2,1) = exp(-.5*(TestMatrix(i,2:end-1)-MeanMatrix_F(j,:))*inv_ConvarianceMatrix_F(:,:,j)*(TestMatrix(i,2:end-1)-MeanMatrix_F(j,:))')/det_ConvarianceMatrix_F(:,j)/pi2*Pw(j,1)*Ppostsex(j,2);
end
if (FailGroup(j,3) == 0)
% Final(3,1) = 0;
else
Final(3,1) = exp(-.5*(TestMatrix(i,2:end-1)-MeanMatrix_I(j,:))*inv_ConvarianceMatrix_I(:,:,j)*(TestMatrix(i,2:end-1)-MeanMatrix_I(j,:))')/det_ConvarianceMatrix_I(:,j)/pi2*Pw(j,1)*Ppostsex(j,3);
end
TestCondition(i,j) = Final(1,1)+Final(2,1)+Final(3,1);
end
end
TestAnswer = data(TestIndex,VaribleNumber);
TestAnswer(find((TestAnswer-5).*(TestAnswer-1) <= 0)) = 1;
TestAnswer(find(TestAnswer == 6)) = 2;
TestAnswer(find(TestAnswer == 7)) = 3;
TestAnswer(find(TestAnswer == 8)) = 4;
TestAnswer(find(TestAnswer == 9)) = 5;
TestAnswer(find(TestAnswer == 10)) = 6;
TestAnswer(find(TestAnswer == 11)) = 7;
TestAnswer(find(TestAnswer == 12)) = 8;
TestAnswer(find(TestAnswer == 13)) = 9;
TestAnswer(find(TestAnswer == 14)) = 10;
TestAnswer(find((TestAnswer-19).*(TestAnswer-15) <= 0)) = 11;
TestAnswer(find(TestAnswer >= 20)) = 12;
%TestCondition(find(TestCondition == inf)) = 0;
%TestCondition(find(TestCondition == nan)) = 0;
[max, Result] = max(TestCondition');
Result = Result';
ConfusionMatrix =zeros(ClassNumber,ClassNumber);
for i = 1:ClassNumber
for j = 1:ClassNumber
ConfusionMatrix(i,j) = sum( and((TestAnswer == i) ,(Result == j)));
end
end
ConfusionMatrix
performance = sum((Result == TestAnswer)) / TestNumber
%%%%%%%%%%%%%%%%%%%%
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -