⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 learn_params.m

📁 贝叶斯算法(matlab编写) 安装,添加目录 /home/ai2/murphyk/matlab/FullBNT
💻 M
📖 第 1 页 / 共 2 页
字号:
function CPD = learn_params(CPD, fam, data, ns, cnodes, varargin)% LEARN_PARAMS Construct classification/regression tree given complete data% CPD = learn_params(CPD, fam, data, ns, cnodes)%% fam(i) is the node id of the i-th node in the family of nodes, self node is the last one% data(i,m) is the value of node i in case m (can be cell array).% ns(i) is the node size for the i-th node in the whold bnet% cnodes(i) is the node id for the i-th continuous node in the whole bnet%  % The following optional arguments can be specified in the form of name/value pairs:% stop_cases: for early stop (pruning). A node is not split if it has less than k cases. default is 0.% min_gain: for early stop (pruning). %     For discrete output: A node is not split when the gain of best split is less than min_gain. default is 0.  %     For continuous (cts) outpt: A node is not split when the gain of best split is less than min_gain*score(root) %                                 (we denote it cts_min_gain). default is 0.006% %%%%%%%%%%%%%%%%%%%Struction definition of dtree_CPD.tree%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% tree.num_node               the last position in tree.nodes array for adding new nodes,%                             it is not always same to number of nodes in a tree, because some position in the %                             tree.nodes array can be set to unused (e.g. in tree pruning)  % tree.nodes is the array of nodes in the tree plus some unused nodes.% tree.nodes(1) is the root for the tree.%% Below is the attributes for each node% tree.nodes(i).used;     % flag this node is used (0 means node not used, it can be removed from tree to save memory)% tree.nodes(i).is_leaf;  % if 1 means this node is a leaf, if 0 not a leaf.% tree.nodes(i).children; % children(i) is the node number in tree.nodes array for the i-th child node% tree.nodes(i).split_id; % the attribute id used to split this node% tree.nodes(i).split_threshhold; % the threshhold for continuous attribute to split this node% %%%%%attributes specially for classification tree (discrete output)% tree.nodes(i).probs     % probs(i) is the prob for i-th value of class node %                         % For three output class, the probs = [0.9 0.1 0.0] means the probability of %                         % class 1 is 0.9, for class 2 is 0.1, for class 3 is 0.0.% %%%%%attributes specially for regression tree (continuous output)                          % tree.nodes(i).mean      % mean output value for this node% tree.nodes(i).std       % standard deviation for output values in this node%% Author: yimin.zhang@intel.com% Last updated: Jan. 19, 2002% Want list:% (1) more efficient for cts attributes: get the values of cts attributes at first (the begining of build_tree function), then doing bi_search in finding threshhold% (2) pruning classification tree using Pessimistic Error Pruning% (3) bi_search for strings (used for transform data to BNT format)global tree %tree must be global so that it can be accessed in recursive slitting functionglobal cts_min_gaintree=[]; % clear the treetree.num_node=0;cts_min_gain=0;stop_cases=0;min_gain=0;args = varargin;nargs = length(args);if (nargs>0)  if isstr(args{1})    for i=1:2:nargs      switch args{i},        case 'stop_cases', stop_cases = args{i+1};           case 'min_gain', min_gain = args{i+1};      end    end  else    error(['error in input parameters']);  endendif iscell(data)  local_data = cell2num(data(fam,:));else  local_data = data(fam, :);end%counts = compute_counts(local_data, CPD.sizes);%CPD.CPT = mk_stochastic(counts + CPD.prior); % bug fix 11/5/01node_types = zeros(1,size(ns,2)); %all nodes are disretenode_types(cnodes)=1;%make the data be BNT compliant (values for discrete nodes are from 1-n, here n is the node size)%trans_data=transform_data(local_data,'tmp.dat',[]); %here no cts nodesbuild_dtree (CPD, local_data, ns(fam), node_types(fam),stop_cases,min_gain);%CPD.tree=copy_tree(tree);CPD.tree=tree; %copy the tree constructed to CPDfunction new_tree = copy_tree(tree)% copy the tree to new_treenew_tree.num_node=tree.num_node;new_tree.root = tree.root;for i=1:tree.num_node  new_tree.nodes(i)=tree.nodes(i);endfunction build_dtree (CPD, fam_ev, node_sizes, node_types,stop_cases,min_gain)global treeglobal cts_min_gaintree.num_node=0; %the current number of nodes in the treetree.root=1;T = 1:size(fam_ev,2) ; %all casescandidate_attrs = 1:(size(node_sizes,2)-1); %all attributesnode_id=1;  %the root nodelastnode=size(node_sizes,2); %the last element in all nodes is the dependent variable (category node)num_cat=node_sizes(lastnode);% get minimum gain for cts output (used in stop splitting)if (node_types(size(fam_ev,1))==1) %cts output  N = size(fam_ev,2);  output_id = size(fam_ev,1);  cases_T = fam_ev(output_id,:); %get all the output value for cases T  std_T = std(cases_T);  avg_y_T = mean(cases_T);  sqr_T = cases_T - avg_y_T;  cts_min_gain = min_gain*(sum(sqr_T.*sqr_T)/N);  % min_gain * (R(root) = 1/N * SUM(y-avg_y)^2)end  split_dtree (CPD, fam_ev, node_sizes, node_types, stop_cases,min_gain, T, candidate_attrs, num_cat);  % pruning method% (1) Restrictions on minimum node size: A node is not split if it has smaller than k cases.% (2) Threshholds on impurity: a threshhold is imposed on the splitting test score. Threshhold can be % imposed on local goodness measure (the gain_ratio of a node) or global goodness.% (3) Mininum Error Pruning (MEP), (no need pruning set)%     Prune if static error<=backed-up error%      Static error at node v: e(v) = (Nc + 1)/(N+k) (laplace estimate, prior for each class equal) %        here N is # of all examples, Nc is # of majority class examples, k is number of classes %      Backed-up error at node v: (Ti is the i-th subtree root)%         E(T) = Sum_1_to_n(pi*e(Ti))% (4) Pessimistic Error Pruning (PEP), used in Quilan C4.5 (no need pruning set, efficient because of pruning top-down)%       Probability of error (apparent error rate)%           q = (N-Nc+0.5)/N%         where N=#examples, Nc=#examples in majority class%     Error of a node v (if pruned)  q(v)= (Nv- Nc,v + 0.5)/Nv%     Error of a subtree   q(T)= Sum_of_l_leaves(Nl - Nc,l + 0.5)/Sum_of_l_leaves(Nl)%     Prune if q(v)<=q(T)% % Implementation statuts:% (1)(2) has been implemented as the input parameters of learn_params.% (4) is implemented in this functionfunction pruning(fam_ev,node_sizes,node_types)% PRUNING prune the constructed tree using PEP% pruning(fam_ev,node_sizes,node_types)%% fam_ev(i,j)  is the value of attribute i in j-th training cases (for whole tree), the last row is for the class label (self_ev)% node_sizes(i) is the node size for the i-th node in the family% node_types(i) is the node type for the i-th node in the family, 0 for disrete node, 1 for continous node% the global parameter 'tree' is for storing the input tree and the pruned treefunction split_T = split_cases(fam_ev,node_sizes,node_types,T,node_i, threshhold)% SPLIT_CASES split the cases T according to values of node_i in the family% split_T = split_cases(fam_ev,node_sizes,node_types,T,node_i)%% fam_ev(i,j)  is the value of attribute i in j-th training cases (for whole tree), the last row is for the class label (self_ev)% node_sizes(i) is the node size for the i-th node in the family% node_types(i) is the node type for the i-th node in the family, 0 for disrete node, 1 for continous node% node_i is the attribute we need to splitif (node_types(node_i)==0) %discrete attribute  %init the subsets of T  split_T = cell(1,node_sizes(node_i)); %T will be separated into |node_size of i| subsets according to different values of node i  for i=1:node_sizes(node_i)   % here we assume that the value of an attribute is 1:node_size    split_T{i}=zeros(1,0);  end  size_t = size(T,2);  for i=1:size_t    case_id = T(i);    %put this case into one subset of split_T according to its value for node_i    value = fam_ev(node_i,case_id);     pos = size(split_T{value},2)+1;    split_T{value}(pos)=case_id;  % here assumes the value of an attribute is 1:node_size   endelse %continuous attribute  %init the subsets of T  split_T = cell(1,2); %T will be separated into 2 subsets (<=threshhold) (>threshhold)  for i=1:2       split_T{i}=zeros(1,0);  end  size_t = size(T,2);  for i=1:size_t    case_id = T(i);    %put this case into one subset of split_T according to its value for node_i    value = fam_ev(node_i,case_id);     subset_num=1;    if (value>threshhold)      subset_num=2;    end      pos = size(split_T{subset_num},2)+1;    split_T{subset_num}(pos)=case_id;    endend  function new_node = split_dtree (CPD, fam_ev, node_sizes, node_types, stop_cases, min_gain, T, candidate_attrs, num_cat)% SPLIT_TREE Split the tree at node node_id with cases T (actually it is just indexes to family evidences).% new_node = split_dtree (fam_ev, node_sizes, node_types, T, node_id, num_cat, method)%% fam_ev(i,j)  is the value of attribute i in j-th training cases (for whole tree), the last row is for the class label (self_ev)% node_sizes{i} is the node size for the i-th node in the family% node_types{i} is the node type for the i-th node in the family, 0 for disrete node, 1 for continous node% stop_cases is the threshold of number of cases to stop slitting% min_gain is the minimum gain need to split a node% T(i) is the index of i-th cases in current decision tree node, we need split it further% candidate_attrs(i) the node id for the i-th attribute that still need to be considered as split attribute %%%%% node_id is the index of current node considered for a split% num_cat is the number of output categories for the decision tree% output:% new_node is the new node createdglobal treeglobal cts_min_gainsize_fam = size(fam_ev,1);            %number of family sizeoutput_type = node_types(size_fam);   %the type of output for the tree (0 is discrete, 1 is continuous)size_attrs = size(candidate_attrs,2); %number of candidate attributessize_t = size(T,2);                   %number of training cases in this tree node%(1)computeFrequenceyForEachClass(T)if (output_type==0) %discrete output  class_freqs = zeros(1,num_cat);  for i=1:size_t    case_id = T(i);    case_class = fam_ev(size_fam,case_id); %get the class label for this case    class_freqs(case_class)=class_freqs(case_class)+1;  endelse  %cts output  N = size(fam_ev,2);  cases_T = fam_ev(size(fam_ev,1),T); %get the output value for cases T  std_T = std(cases_T);end%(2) if OneClass (for discrete output) or same output value (for cts output) or Class With #examples < stop_cases%         return a leaf;%    create a decision node N;% get majority class in this nodeif (output_type == 0)  top1_class = 0;       %the class with the largest number of cases  top1_class_cases = 0; %the number of cases in top1_class  [top1_class_cases,top1_class]=max(class_freqs);end  if (size_t==0)     %impossble  new_node=-1;  fprintf('Fatal error: please contact the author. \n');  return;end% stop splitting if needed  %for discrete output: one class   %for cts output, all output value in cases are same  %cases too littleif ( (output_type==0 & top1_class_cases == size_t) | (output_type==1 & std_T == 0) | (size_t < stop_cases))               %create one new leaf node  tree.num_node=tree.num_node+1;  tree.nodes(tree.num_node).used=1; %flag this node is used (0 means node not used, it will be removed from tree at last to save memory)  tree.nodes(tree.num_node).is_leaf=1;  tree.nodes(tree.num_node).children=[];  tree.nodes(tree.num_node).split_id=0;  %the attribute(parent) id to split this tree node  tree.nodes(tree.num_node).split_threshhold=0;    if (output_type==0)    tree.nodes(tree.num_node).probs=class_freqs/size_t; %the prob for each value of class node     %  tree.nodes(tree.num_node).probs=zeros(1,num_cat); %the prob for each value of class node     %  tree.nodes(tree.num_node).probs(top1_class)=1; %use the majority class of parent node, like for binary class,                                                    %and majority is class 2, then the CPT is [0 1]                                                   %we may need to use prior to do smoothing, to get [0.001 0.999]    tree.nodes(tree.num_node).error.self_error=1-top1_class_cases/size_t; %the classfication error in this tree node when use default class    tree.nodes(tree.num_node).error.all_error=1-top1_class_cases/size_t;  %no total classfication error in this tree node and its subtree    tree.nodes(tree.num_node).error.all_error_num=size_t - top1_class_cases;    fprintf('Create leaf node(onecla) %d. Class %d Cases %d Error %d \n',tree.num_node, top1_class, size_t, size_t - top1_class_cases );  else    avg_y_T = mean(cases_T);    tree.nodes(tree.num_node).mean = avg_y_T;     tree.nodes(tree.num_node).std = std_T;    fprintf('Create leaf node(samevalue) %d. Mean %8.4f Std %8.4f Cases %d \n',tree.num_node, avg_y_T, std_T, size_t);  end    new_node = tree.num_node;  return;end    %create one new nodetree.num_node=tree.num_node+1;tree.nodes(tree.num_node).used=1; %flag this node is used (0 means node not used, it will be removed from tree at last to save memory)tree.nodes(tree.num_node).is_leaf=1;tree.nodes(tree.num_node).children=[];tree.nodes(tree.num_node).split_id=0;tree.nodes(tree.num_node).split_threshhold=0;  if (output_type==0)  tree.nodes(tree.num_node).error.self_error=1-top1_class_cases/size_t;   tree.nodes(tree.num_node).error.all_error=0;  tree.nodes(tree.num_node).error.all_error_num=0;else  avg_y_T = mean(cases_T);  tree.nodes(tree.num_node).mean = avg_y_T;   tree.nodes(tree.num_node).std = std_T;endnew_node = tree.num_node;%Stop splitting if no attributes left in this nodeif (size_attrs==0)   if (output_type==0)    tree.nodes(tree.num_node).probs=class_freqs/size_t; %the prob for each value of class node     tree.nodes(tree.num_node).error.all_error=1-top1_class_cases/size_t;      tree.nodes(tree.num_node).error.all_error_num=size_t - top1_class_cases;    fprintf('Create leaf node(noattr) %d. Class %d Cases %d Error %d \n',tree.num_node, top1_class, size_t, size_t - top1_class_cases );  else    fprintf('Create leaf node(noattr) %d. Mean %8.4f Std %8.4f Cases %d \n',tree.num_node, avg_y_T, std_T, size_t);  end  return;end        %(3) for each attribute A%        ComputeGain(A);max_gain=0;  %the max gain score (for discrete information gain or gain ration, for cts node the R(T))

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -