📄 classify_nn.m
字号:
%%用最近邻方法分类 采用欧氏距离
function [class,mis_num,correct_rate]=classify_NN(data,class_num,train_num,dis_mark)
[ dim,total_num ] = size( data ) ;
class_data_num = total_num/class_num ; %得到每类中的样本数目
for class_mark = 1:class_num %for循环得到用来训练的数据
for class_data_mark = 1:train_num
train_data( :,( class_mark-1 )*train_num+class_data_mark ) = data( :,( class_mark-1 )*class_data_num+class_data_mark );
%test_data(:,())
end
end
test_num = class_data_num-train_num;
for class_mark = 1:class_num %for循环得到用来测试的数据
for class_data_mark = 1:test_num
test_data( :,( class_mark-1 )*test_num+class_data_mark ) = data( :,( class_mark-1 )*class_data_num+class_data_mark+train_num );
%test_data(:,())
end
end
train_data_num = train_num*class_num;
test_data_num = test_num*class_num;
for test_data_mark = 1:test_data_num %for循环求测试数据与每个训练数据的距离
for train_data_mark = 1:train_data_num
switch dis_mark
case 'L2'
dis(test_data_mark,train_data_mark) = norm(test_data(:,test_data_mark)-train_data(:,train_data_mark));
case 'cos'
dis(test_data_mark,train_data_mark) = acos((test_data(:,test_data_mark)'*train_data(:,train_data_mark))/(norm(test_data(:,test_data_mark))*norm(train_data(:,train_data_mark))));
case 'L1'
dis(test_data_mark,train_data_mark) = norm(test_data(:,test_data_mark)-train_data(:,train_data_mark),1);
otherwise
dis(test_data_mark,train_data_mark) = norm(test_data(:,test_data_mark)-train_data(:,train_data_mark));
end
end
end
[mindis,index]=min(dis');
class = fix((index+train_num-1)/train_num); %求测试数据所属的类别
mis_num = 0;
for test_data_mark = 1:test_data_num
if class(test_data_mark)~=fix((test_data_mark+test_num-1)/test_num) %若算得的类别与实际所属类别不一样则分类错误
mis_num = mis_num+1;
end
end
correct_rate =(test_data_num-mis_num)/test_data_num;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -