⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 main.m

📁 Matlab implementation of ID3 and NaiveBayes classifier. It also includes example dataset as well.
💻 M
字号:
%
% main.m
%
% Sunghwan Yoo

%% Step 1. Load file

dataset = dlmread('house-votes-84.ndata', ',');
% the original dataset has been modified for efficiency

%% Step 2. Load data into m file and adjust it.
% Replaces n => 0, y => 1, ? => 2, democrat => 0, republican => 1

[nrows, ncols] = size(dataset);

%for i = 1 : nrows
%	for j = 1 : ncols
%		if( dataset(i,j) == 'y' )
%			dataset(i,j) = 1;
%		elseif( dataset(i,j) == 'n' )
%			dataset(i,j) = 0;
%		elseif( dataset(i,j) == 'democrat' )
%			dataset(i,j) = 0;
%		elseif( dataset(i,j) == 'republican' )
%			dataset(i,j) = 1;
%		else
%			dataset(i,j) = 2;
%		end
%	end
%end


%% Step 3. Learn with id3 ( as of [10, 20, 40, 80, 120] learning set ), test with id3

nsamples = [5 10 20 40 80 120 200];
psamples = [0.05, 0.1, 0.9];
nrepeat = 10;

for i = 1 : length(nsamples)
	id3_training_acc_sum = 0;
	id3_test_acc_sum = 0;
	id3p_training_acc_sum = zeros(length(psamples));
	id3p_test_acc_sum = zeros(length(psamples));
	nb_training_acc_sum = 0;
	nb_test_acc_sum = 0;

	for j = 1 : nrepeat
		index = randperm(nrows);

		for k = 1 : nrows
			random_dataset(k,:) = dataset(index(k),:);
		end

		clear training_set;
		clear test_set;

		training_set = random_dataset(1:nsamples(i), :);
		test_set = random_dataset(201:400, :);

		[tree, id3_training_acc, id3_test_acc] = id3(training_set, test_set, 0, 1);		% set 1 means we will not use p-value prepruning
		id3_training_acc_sum = id3_training_acc_sum + id3_training_acc;
		id3_test_acc_sum = id3_test_acc_sum + id3_test_acc;

		for p = 1 : length(psamples)
			[tree, id3p_training_acc(p), id3p_test_acc(p)] = id3(training_set, test_set, 0, psamples(p));
			id3p_training_acc_sum(p) = id3p_training_acc_sum(p) + id3p_training_acc(p);
			id3p_test_acc_sum(p) = id3p_test_acc_sum(p) + id3p_test_acc(p);
		end

		[param, nb_training_acc, nb_test_acc] = naive_bayes(training_set, test_set);
		nb_training_acc_sum = nb_training_acc_sum + nb_training_acc;
		nb_test_acc_sum = nb_test_acc_sum + nb_test_acc;

	end
	txt = sprintf('ID3 : Sample = %d | training_acc = %f | test_acc = %f', nsamples(i), id3_training_acc_sum/nrepeat, id3_test_acc_sum/nrepeat);
	disp(txt);

	for p = 1 : length(psamples)
		txt = sprintf('ID3P: Sample = %d | training_acc = %f | test_acc = %f (p = %f)', nsamples(i), id3p_training_acc_sum(p)/nrepeat, id3p_test_acc_sum(p)/nrepeat, psamples(p));
		id3table_train(i,p) = id3p_training_acc_sum(p)/nrepeat;
		id3table_test(i,p) = id3p_test_acc_sum(p)/nrepeat;
		disp(txt);
	end

	txt = sprintf('NBC : Sample = %d | training_acc = %f | test_acc = %f\n', nsamples(i), nb_training_acc_sum/nrepeat, nb_test_acc_sum/nrepeat);
	nbctable_train(i) = nb_training_acc_sum/nrepeat;
	nbctable_test(i) = nb_test_acc_sum/nrepeat;
	disp(txt);

end

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -