⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 bftree.java

📁 Weka
💻 JAVA
📖 第 1 页 / 共 5 页
字号:
    for (int i = 0; i < m_numFoldsPruning; i++) {      train[i] = cvData.trainCV(m_numFoldsPruning, i);      test[i] = cvData.testCV(m_numFoldsPruning, i);      parallelBFElements[i] = new FastVector();      m_roots[i] = new BFTree();      // calculate sorted indices, weights, initial class counts and total weights for each training data      totalWeight[i] = computeSortedInfo(train[i],sortedIndices[i], weights[i],	  classProbs[i]);      // compute information of the best split for this node (include split attribute,      // split value and gini gain (or information gain)) in this fold      nodeInfo[i] = computeSplitInfo(m_roots[i], train[i], sortedIndices[i],	  weights[i], dists[i], props[i], totalSubsetWeights[i], m_Heuristic, m_UseGini);      // compute information for root nodes      int attIndex = ((Attribute)nodeInfo[i].elementAt(1)).index();      m_roots[i].m_SortedIndices = new int[sortedIndices[i].length][0];      m_roots[i].m_Weights = new double[weights[i].length][0];      m_roots[i].m_Dists = new double[dists[i].length][0][0];      m_roots[i].m_ClassProbs = new double[classProbs[i].length];      m_roots[i].m_Distribution = new double[classProbs[i].length];      m_roots[i].m_Props = new double[2];      for (int j=0; j<m_roots[i].m_SortedIndices.length; j++) {	m_roots[i].m_SortedIndices[j] = sortedIndices[i][j];	m_roots[i].m_Weights[j] = weights[i][j];	m_roots[i].m_Dists[j] = dists[i][j];      }      System.arraycopy(classProbs[i], 0, m_roots[i].m_ClassProbs, 0,	  classProbs[i].length);      if (Utils.sum(m_roots[i].m_ClassProbs)!=0)	Utils.normalize(m_roots[i].m_ClassProbs);      System.arraycopy(classProbs[i], 0, m_roots[i].m_Distribution, 0,	  classProbs[i].length);      System.arraycopy(props[i][attIndex], 0, m_roots[i].m_Props, 0,	  props[i][attIndex].length);      m_roots[i].m_TotalWeight = totalWeight[i];      parallelBFElements[i].addElement(nodeInfo[i]);    }    // build a pre-pruned tree    if (m_PruningStrategy == PRUNING_PREPRUNING) {      double previousError = Double.MAX_VALUE;      double currentError = previousError;      double minError = Double.MAX_VALUE;      int minExpansion = 0;      FastVector errorList = new FastVector();      while(true) {	// compute average error	double expansionError = 0;	int count = 0;	for (int i=0; i<m_numFoldsPruning; i++) {	  Evaluation eval;	  // calculate error rate if only root node	  if (expansion==0) {	    m_roots[i].m_isLeaf = true;	    eval = new Evaluation(test[i]);	    eval.evaluateModel(m_roots[i], test[i]);	    if (m_UseErrorRate) expansionError += eval.errorRate();	    else expansionError += eval.rootMeanSquaredError();	    count ++;	  }	  // make tree - expand one node at a time	  else {	    if (m_roots[i] == null) continue; // if the tree cannot be expanded, go to next fold	    m_roots[i].m_isLeaf = false;	    BFTree nodeToSplit = (BFTree)	    (((FastVector)(parallelBFElements[i].elementAt(0))).elementAt(0));	    if (!m_roots[i].makeTree(parallelBFElements[i], m_roots[i], train[i],		nodeToSplit.m_SortedIndices, nodeToSplit.m_Weights,		nodeToSplit.m_Dists, nodeToSplit.m_ClassProbs,		nodeToSplit.m_TotalWeight, nodeToSplit.m_Props, m_minNumObj,		m_Heuristic, m_UseGini)) {	      m_roots[i] = null; // cannot be expanded	      continue;	    }	    eval = new Evaluation(test[i]);	    eval.evaluateModel(m_roots[i], test[i]);	    if (m_UseErrorRate) expansionError += eval.errorRate();	    else expansionError += eval.rootMeanSquaredError();	    count ++;	  }	}	// no tree can be expanded any more	if (count==0) break;	expansionError /=count;	errorList.addElement(new Double(expansionError));	currentError = expansionError;	if (!m_UseOneSE) {	  if (currentError>previousError)	    break;	}	else {	  if (expansionError < minError) {	    minError = expansionError;	    minExpansion = expansion;	  }	  if (currentError>previousError) {	    double oneSE = Math.sqrt(minError*(1-minError)/		data.numInstances());	    if (currentError > minError + oneSE) {	      break;	    }	  }	}	expansion ++;	previousError = currentError;      }      if (!m_UseOneSE) expansion = expansion - 1;      else {	double oneSE = Math.sqrt(minError*(1-minError)/data.numInstances());	for (int i=0; i<errorList.size(); i++) {	  double error = ((Double)(errorList.elementAt(i))).doubleValue();	  if (error<=minError + oneSE) { // && counts[i]>=m_numFoldsPruning/2) {	    expansion = i;	    break;	  }	}      }    }    // build a postpruned tree    else {      FastVector[] modelError = new FastVector[m_numFoldsPruning];      // calculate error of each expansion for each fold      for (int i = 0; i < m_numFoldsPruning; i++) {	modelError[i] = new FastVector();	m_roots[i].m_isLeaf = true;	Evaluation eval = new Evaluation(test[i]);	eval.evaluateModel(m_roots[i], test[i]);	double error;	if (m_UseErrorRate) error = eval.errorRate();	else error = eval.rootMeanSquaredError();	modelError[i].addElement(new Double(error));	m_roots[i].m_isLeaf = false;	BFTree nodeToSplit = (BFTree)	(((FastVector)(parallelBFElements[i].elementAt(0))).elementAt(0));	m_roots[i].makeTree(parallelBFElements[i], m_roots[i], train[i], test[i],	    modelError[i],nodeToSplit.m_SortedIndices, nodeToSplit.m_Weights,	    nodeToSplit.m_Dists, nodeToSplit.m_ClassProbs,	    nodeToSplit.m_TotalWeight, nodeToSplit.m_Props, m_minNumObj,	    m_Heuristic, m_UseGini, m_UseErrorRate);	m_roots[i] = null;      }      // find the expansion with minimal error rate      double minError = Double.MAX_VALUE;      int maxExpansion = modelError[0].size();      for (int i=1; i<modelError.length; i++) {	if (modelError[i].size()>maxExpansion)	  maxExpansion = modelError[i].size();      }      double[] error = new double[maxExpansion];      int[] counts = new int[maxExpansion];      for (int i=0; i<maxExpansion; i++) {	counts[i] = 0;	error[i] = 0;	for (int j=0; j<m_numFoldsPruning; j++) {	  if (i<modelError[j].size()) {	    error[i] += ((Double)modelError[j].elementAt(i)).doubleValue();	    counts[i]++;	  }	}	error[i] = error[i]/counts[i]; //average error for each expansion	if (error[i]<minError) {// && counts[i]>=m_numFoldsPruning/2) {	  minError = error[i];	  expansion = i;	}      }      // the 1 SE rule choosen      if (m_UseOneSE) {	double oneSE = Math.sqrt(minError*(1-minError)/	    data.numInstances());	for (int i=0; i<maxExpansion; i++) {	  if (error[i]<=minError + oneSE) { // && counts[i]>=m_numFoldsPruning/2) {	    expansion = i;	    break;	  }	}      }    }    // make tree on all data based on the expansion caculated    // from cross-validation    // calculate sorted indices, weights and initial class counts    int[][] prune_sortedIndices = new int[data.numAttributes()][0];    double[][] prune_weights = new double[data.numAttributes()][0];    double[] prune_classProbs = new double[data.numClasses()];    double prune_totalWeight = computeSortedInfo(data, prune_sortedIndices,	prune_weights, prune_classProbs);    // compute information of the best split for this node (include split attribute,    // split value and gini gain)    double[][][] prune_dists = new double[data.numAttributes()][2][data.numClasses()];    double[][] prune_props = new double[data.numAttributes()][2];    double[][] prune_totalSubsetWeights = new double[data.numAttributes()][2];    FastVector prune_nodeInfo = computeSplitInfo(this, data, prune_sortedIndices,	prune_weights, prune_dists, prune_props, prune_totalSubsetWeights, m_Heuristic,m_UseGini);    // add the root node (with its split info) to BestFirstElements    FastVector BestFirstElements = new FastVector();    BestFirstElements.addElement(prune_nodeInfo);    int attIndex = ((Attribute)prune_nodeInfo.elementAt(1)).index();    m_Expansion = 0;    makeTree(BestFirstElements, data, prune_sortedIndices, prune_weights, prune_dists,	prune_classProbs, prune_totalWeight, prune_props[attIndex] ,m_minNumObj,	m_Heuristic, m_UseGini, expansion);  }  /**   * Recursively build a best-first decision tree.   * Method for building a Best-First tree for a given number of expansions.   * preExpasion is -1 means that no expansion is specified (just for a   * tree without any pruning method). Pre-pruning and post-pruning methods also   * use this method to build the final tree on all training data based on the   * expansion calculated from internal cross-validation.   *   * @param BestFirstElements 	list to store BFTree nodes   * @param data 		training data   * @param sortedIndices 	sorted indices of the instances   * @param weights 		weights of the instances   * @param dists 		class distributions for each attribute   * @param classProbs 		class probabilities of this node   * @param totalWeight 	total weight of this node (note if the node    * 				can not split, this value is not calculated.)   * @param branchProps 	proportions of two subbranches   * @param minNumObj 		minimal number of instances at leaf nodes   * @param useHeuristic 	if use heuristic search for nominal attributes    * 				in multi-class problem   * @param useGini 		if use Gini index as splitting criterion   * @param preExpansion 	the number of expansions the tree to be expanded   * @throws Exception 		if something goes wrong   */  protected void makeTree(FastVector BestFirstElements,Instances data,      int[][] sortedIndices, double[][] weights, double[][][] dists,      double[] classProbs, double totalWeight, double[] branchProps,      int minNumObj, boolean useHeuristic, boolean useGini, int preExpansion)  	throws Exception {    if (BestFirstElements.size()==0) return;    ///////////////////////////////////////////////////////////////////////    // All information about the node to split (the first BestFirst object in    // BestFirstElements)    FastVector firstElement = (FastVector)BestFirstElements.elementAt(0);    // split attribute    Attribute att = (Attribute)firstElement.elementAt(1);    // info of split value or split string    double splitValue = Double.NaN;    String splitStr = null;    if (att.isNumeric())      splitValue = ((Double)firstElement.elementAt(2)).doubleValue();    else {      splitStr=((String)firstElement.elementAt(2)).toString();    }    // the best gini gain or information gain of this node    double gain = ((Double)firstElement.elementAt(3)).doubleValue();    ///////////////////////////////////////////////////////////////////////    if (m_ClassProbs==null) {      m_SortedIndices = new int[sortedIndices.length][0];      m_Weights = new double[weights.length][0];      m_Dists = new double[dists.length][0][0];      m_ClassProbs = new double[classProbs.length];      m_Distribution = new double[classProbs.length];      m_Props = new double[2];      for (int i=0; i<m_SortedIndices.length; i++) {	m_SortedIndices[i] = sortedIndices[i];	m_Weights[i] = weights[i];	m_Dists[i] = dists[i];      }      System.arraycopy(classProbs, 0, m_ClassProbs, 0, classProbs.length);      System.arraycopy(classProbs, 0, m_Distribution, 0, classProbs.length);      System.arraycopy(branchProps, 0, m_Props, 0, m_Props.length);      m_TotalWeight = totalWeight;      if (Utils.sum(m_ClassProbs)!=0) Utils.normalize(m_ClassProbs);    }    // If no enough data or this node can not be split, find next node to split.    if (totalWeight < 2*minNumObj || branchProps[0]==0	|| branchProps[1]==0) {      // remove the first element      BestFirstElements.removeElementAt(0);      makeLeaf(data);      if (BestFirstElements.size()!=0) {	FastVector nextSplitElement = (FastVector)BestFirstElements.elementAt(0);	BFTree nextSplitNode = (BFTree)nextSplitElement.elementAt(0);	nextSplitNode.makeTree(BestFirstElements,data,	    nextSplitNode.m_SortedIndices, nextSplitNode.m_Weights,	    nextSplitNode.m_Dists,	    nextSplitNode.m_ClassProbs, nextSplitNode.m_TotalWeight,	    nextSplitNode.m_Props, minNumObj, useHeuristic, useGini, preExpansion);      }      return;    }    // If gini gain or information gain is 0, make all nodes in the BestFirstElements leaf nodes    // because these nodes are sorted descendingly according to gini gain or information gain.    // (namely, gini gain or information gain of all nodes in BestFirstEelements is 0).    if (gain==0 || preExpansion==m_Expansion) {      for (int i=0; i<BestFirstElements.size(); i++) {	FastVector element = (FastVector)BestFirstElements.elementAt(i);	BFTree node = (BFTree)element.elementAt(0);	node.makeLeaf(data);      }      BestFirstElements.removeAllElements();    }    // gain is not 0    else {      // remove the first element      BestFirstElements.removeElementAt(0);      m_Attribute = att;      if (m_Attribute.isNumeric()) m_SplitValue = splitValue;      else m_SplitString = splitStr;      int[][][] subsetIndices = new int[2][data.numAttributes()][0];      double[][][] subsetWeights = new double[2][data.numAttributes()][0];      splitData(subsetIndices, subsetWeights, m_Attribute, m_SplitValue,	  m_SplitString, sortedIndices, weights, data);      // If split will generate node(s) which has total weights less than m_minNumObj,      // do not split.      int attIndex = att.index();      if (subsetIndices[0][attIndex].length<minNumObj ||	  subsetIndices[1][attIndex].length<minNumObj) {	makeLeaf(data);      }      // split the node      else {	m_isLeaf = false;	m_Attribute = att;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -