⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 cross_val.m

📁 决策树算法的matlab实现
💻 M
字号:
function [percent,C,ave_planes,stdev] = cross_val(A,B,numgroups,maxdepth,tol,CF,min_points)
% CROSS_VAL determines the percentage of points correctly 
%	classified by the MSM-T algorithm (decision tree), using cross-
%	validation process, the confusion matrix, and the average number
%	of planes needed.
%
%	[percent,C,ave_planes,stdev] = cross_val(A,B,numgroups,maxdepth,tol,CF,min_points)
%
%	percent: percentage of points correctly classified by MSMT
%	C represents the confusion matrix as follows:
%
%				OBSERVED
%
%				A		B
%
%			A  # pts. of A	    # pts. of B	
%	PREDICTED	   properly 	    misclassified
%			   classified
%
%			B  # pts. of A	    # pts. of B
%			   misclassified    properly
%					    classified
%
%	(C is a 2 x 2 matrix)
%
%	aveplane is the average number of planes needed in seperation
%	aveplanes = numplanes/numgroups
%	
%	stdev: standard deviation of percentage correct for each group
%
%	A: m x n matrix representing the m n-dimensional points in set A
%	B: k x n matrix that takes care of the k n-dimensional points in set
%	B.
%	numgroups: number of groups to use for cross-validation
%	maxdepth: maximum depth in decision trees (if not given, entire tree
%		  will be created).
%	tol: given tolerance such that a given region will not be 
%	separated if either: 1. it contains only points of A or only points of
%			        B.
%			     2. the percentage error in that region is less 
%				than tol.
%
%	CF: certainty factor used in the C4.5 pruning algorithm (Quinlan)
%		if not given or CF = 0.0, the tree will not be pruned using
%		the C4.5 pruning algorithm
%	min_points: minimum number of points such that a node is pruned
%		if the number of misclassified points at that node is less
%		than min_points
%		if not given or min_points = 0, the tree will not be pruned
%
%	outline of cross-validation process:
%		1. shuffle the rows of A,B.
%		2. determine groups to be used in crossvalidation
%		3. for all the groups:
%			i. determine the correct test, train matrices
%			ii. train (build tree)
%			iii. test (obtain total number of errors)
%		4. determine percentage of correctly classified points.
%		   determine confusion matrix C.
%		   determine average number of planes needed.
global n;

if nargin < 7
	min_points = 0;
end
if nargin < 6
	CF = 0.0;
end
if nargin < 5
	tol = 0.0;
end
if nargin < 4 
	maxdepth = round(realmax);
end
n = size(A,2);
% shuffle the decks:
A = shuffle(A);
B = shuffle(B);
%
% think up the groups:
Agroups = cross_makegroups(size(A,1),numgroups);
Bgroups = cross_makegroups(size(B,1),numgroups);
%
% do it:
stanvec = [];
totmissA = 0;
totmissB = 0;
totcorrectA = 0;
totcorrectB = 0;
numplanes = 0;
% print out a header
disp(['Count' blanks(2) 'Total Correct' blanks(2) 'Total Incorrect' blanks(2) 'Percentage Correct'])
disp(blanks(1)')

for i = 1:numgroups

	[Atest,Btest,Atrain,Btrain] = cross_sets(A,B,Agroups,Bgroups,i);

% Train! (grow the Tree for these training points)

	T = [];
	T = msmt_tree(Atrain,Btrain,maxdepth,tol,CF,min_points);
	numplanes = numplanes + size(T,2); % number of planes is the 
					      % number of columns in Tree

% Test! (test with the test points)

	[missA,correctA,missB,correctB] = cross_classify(T,Atest,Btest);
	totmissA = totmissA + missA;
	totmissB = totmissB + missB;
	totcorrectA = totcorrectA + correctA;
	totcorrectB = totcorrectB + correctB;
	pc = 100*((correctA + correctB)/(missA + missB + correctA + correctB));
	stanvec(i) = pc;

% Print out some information regarding this test:
	disp([ num2str(i) blanks(7) sprintf('%7.3f ',correctA + correctB) blanks(6) sprintf('%7.3f ',missA + missB) blanks(9) sprintf('%7.3f ',pc)])
		
	Atest = A(Agroups(i):Agroups(i+1)-1,:);
	Btest = B(Bgroups(i):Bgroups(i+1)-1,:);
	Atrain = [ A(1:Agroups(i)-1,:);A(Agroups(i+1):size(A,1),:)];
	Btrain = [ B(1:Bgroups(i)-1,:);B(Bgroups(i+1):size(B,1),:)];
end % for

% compute percentage correctly classified:

percent = ((totcorrectA + totcorrectB)/(size(A,1)+size(B,1)))*100;
C = [ totcorrectA, totmissB;
      totmissA, totcorrectB ];
ave_planes = numplanes/numgroups;
stdev = std(stanvec);

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -