⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 bftree.java

📁 Weka
💻 JAVA
📖 第 1 页 / 共 5 页
字号:
	// If gini gain is less than that of last node in FastVector	if (gGain<((Double)(lastNode.elementAt(3))).doubleValue()) {	  BestFirstElements.insertElementAt(splitInfo, vectorSize);	} else {	  for (int j=0; j<vectorSize; j++) {	    FastVector node = (FastVector)BestFirstElements.elementAt(j);	    double nodeGain = ((Double)(node.elementAt(3))).doubleValue();	    if (gGain>=nodeGain) {	      BestFirstElements.insertElementAt(splitInfo, j);	      break;	    }	  }	}      }    }  }  /**   * Compute sorted indices, weights and class probabilities for a given    * dataset. Return total weights of the data at the node.   *    * @param data 		training data   * @param sortedIndices 	sorted indices of instances at the node   * @param weights 		weights of instances at the node   * @param classProbs 		class probabilities at the node   * @return 			total weights of instances at the node   * @throws Exception 		if something goes wrong   */  protected double computeSortedInfo(Instances data, int[][] sortedIndices, double[][] weights,      double[] classProbs) throws Exception {    // Create array of sorted indices and weights    double[] vals = new double[data.numInstances()];    for (int j = 0; j < data.numAttributes(); j++) {      if (j==data.classIndex()) continue;      weights[j] = new double[data.numInstances()];      if (data.attribute(j).isNominal()) {	// Handling nominal attributes. Putting indices of	// instances with missing values at the end.	sortedIndices[j] = new int[data.numInstances()];	int count = 0;	for (int i = 0; i < data.numInstances(); i++) {	  Instance inst = data.instance(i);	  if (!inst.isMissing(j)) {	    sortedIndices[j][count] = i;	    weights[j][count] = inst.weight();	    count++;	  }	}	for (int i = 0; i < data.numInstances(); i++) {	  Instance inst = data.instance(i);	  if (inst.isMissing(j)) {	    sortedIndices[j][count] = i;	    weights[j][count] = inst.weight();	    count++;	  }	}      } else {	// Sorted indices are computed for numeric attributes	// missing values instances are put to end (through Utils.sort() method)	for (int i = 0; i < data.numInstances(); i++) {	  Instance inst = data.instance(i);	  vals[i] = inst.value(j);	}	sortedIndices[j] = Utils.sort(vals);	for (int i = 0; i < data.numInstances(); i++) {	  weights[j][i] = data.instance(sortedIndices[j][i]).weight();	}      }    }    // Compute initial class counts and total weight    double totalWeight = 0;    for (int i = 0; i < data.numInstances(); i++) {      Instance inst = data.instance(i);      classProbs[(int)inst.classValue()] += inst.weight();      totalWeight += inst.weight();    }    return totalWeight;  }  /**   * Compute the best splitting attribute, split point or subset and the best   * gini gain or iformation gain for a given dataset.   *   * @param node 		node to be split   * @param data 		training data   * @param sortedIndices 	sorted indices of the instances   * @param weights 		weights of the instances   * @param dists 		class distributions for each attribute   * @param props 		proportions of two branches   * @param totalSubsetWeights 	total weight of two subsets   * @param useHeuristic 	if use heuristic search for nominal attributes    * 				in multi-class problem   * @param useGini 		if use Gini index as splitting criterion   * @return 			split information about the node   * @throws Exception 		if something is wrong   */  protected FastVector computeSplitInfo(BFTree node, Instances data, int[][] sortedIndices,      double[][] weights, double[][][] dists, double[][] props,      double[][] totalSubsetWeights, boolean useHeuristic, boolean useGini) throws Exception {    double[] splits = new double[data.numAttributes()];    String[] splitString = new String[data.numAttributes()];    double[] gains = new double[data.numAttributes()];    for (int i = 0; i < data.numAttributes(); i++) {      if (i==data.classIndex()) continue;      Attribute att = data.attribute(i);      if (att.isNumeric()) {	// numeric attribute	splits[i] = numericDistribution(props, dists, att, sortedIndices[i],	    weights[i], totalSubsetWeights, gains, data, useGini);      } else {	// nominal attribute	splitString[i] = nominalDistribution(props, dists, att, sortedIndices[i],	    weights[i], totalSubsetWeights, gains, data, useHeuristic, useGini);      }    }    int index = Utils.maxIndex(gains);    double mBestGain = gains[index];    Attribute att = data.attribute(index);    double mValue =Double.NaN;    String mString = null;    if (att.isNumeric())  mValue= splits[index];    else {      mString = splitString[index];      if (mString==null) mString = "";    }    // split information    FastVector splitInfo = new FastVector();    splitInfo.addElement(node);    splitInfo.addElement(att);    if (att.isNumeric()) splitInfo.addElement(new Double(mValue));    else splitInfo.addElement(mString);    splitInfo.addElement(new Double(mBestGain));    return splitInfo;  }  /**   * Compute distributions, proportions and total weights of two successor nodes for    * a given numeric attribute.   *   * @param props 		proportions of each two branches for each attribute   * @param dists 		class distributions of two branches for each attribute   * @param att 		numeric att split on   * @param sortedIndices 	sorted indices of instances for the attirubte   * @param weights 		weights of instances for the attirbute   * @param subsetWeights 	total weight of two branches split based on the attribute   * @param gains 		Gini gains or information gains for each attribute    * @param data 		training instances   * @param useGini 		if use Gini index as splitting criterion   * @return 			Gini gain or information gain for the given attribute   * @throws Exception 		if something goes wrong   */  protected double numericDistribution(double[][] props, double[][][] dists,      Attribute att, int[] sortedIndices, double[] weights, double[][] subsetWeights,      double[] gains, Instances data, boolean useGini)  throws Exception {    double splitPoint = Double.NaN;    double[][] dist = null;    int numClasses = data.numClasses();    int i; // differ instances with or without missing values    double[][] currDist = new double[2][numClasses];    dist = new double[2][numClasses];    // Move all instances without missing values into second subset    double[] parentDist = new double[numClasses];    int missingStart = 0;    for (int j = 0; j < sortedIndices.length; j++) {      Instance inst = data.instance(sortedIndices[j]);      if (!inst.isMissing(att)) {	missingStart ++;	currDist[1][(int)inst.classValue()] += weights[j];      }      parentDist[(int)inst.classValue()] += weights[j];    }    System.arraycopy(currDist[1], 0, dist[1], 0, dist[1].length);    // Try all possible split points    double currSplit = data.instance(sortedIndices[0]).value(att);    double currGain;    double bestGain = -Double.MAX_VALUE;    for (i = 0; i < sortedIndices.length; i++) {      Instance inst = data.instance(sortedIndices[i]);      if (inst.isMissing(att)) {	break;      }      if (inst.value(att) > currSplit) {	double[][] tempDist = new double[2][numClasses];	for (int k=0; k<2; k++) {	  //tempDist[k] = currDist[k];	  System.arraycopy(currDist[k], 0, tempDist[k], 0, tempDist[k].length);	}	double[] tempProps = new double[2];	for (int k=0; k<2; k++) {	  tempProps[k] = Utils.sum(tempDist[k]);	}	if (Utils.sum(tempProps) !=0) Utils.normalize(tempProps);	// split missing values	int index = missingStart;	while (index < sortedIndices.length) {	  Instance insta = data.instance(sortedIndices[index]);	  for (int j = 0; j < 2; j++) {	    tempDist[j][(int)insta.classValue()] += tempProps[j] * weights[index];	  }	  index++;	}	if (useGini) currGain = computeGiniGain(parentDist,tempDist);	else currGain = computeInfoGain(parentDist,tempDist);	if (currGain > bestGain) {	  bestGain = currGain;	  // clean split point	  splitPoint = Math.rint((inst.value(att) + currSplit)/2.0*100000)/100000.0;	  for (int j = 0; j < currDist.length; j++) {	    System.arraycopy(tempDist[j], 0, dist[j], 0,		dist[j].length);	  }	}      }      currSplit = inst.value(att);      currDist[0][(int)inst.classValue()] += weights[i];      currDist[1][(int)inst.classValue()] -= weights[i];    }    // Compute weights    int attIndex = att.index();    props[attIndex] = new double[2];    for (int k = 0; k < 2; k++) {      props[attIndex][k] = Utils.sum(dist[k]);    }    if (Utils.sum(props[attIndex]) != 0) Utils.normalize(props[attIndex]);    // Compute subset weights    subsetWeights[attIndex] = new double[2];    for (int j = 0; j < 2; j++) {      subsetWeights[attIndex][j] += Utils.sum(dist[j]);    }    // clean gain    gains[attIndex] = Math.rint(bestGain*10000000)/10000000.0;    dists[attIndex] = dist;    return splitPoint;  }  /**   * Compute distributions, proportions and total weights of two successor    * nodes for a given nominal attribute.   *   * @param props 		proportions of each two branches for each attribute   * @param dists 		class distributions of two branches for each attribute   * @param att 		numeric att split on   * @param sortedIndices 	sorted indices of instances for the attirubte   * @param weights 		weights of instances for the attirbute   * @param subsetWeights 	total weight of two branches split based on the attribute   * @param gains 		Gini gains for each attribute    * @param data 		training instances   * @param useHeuristic 	if use heuristic search   * @param useGini 		if use Gini index as splitting criterion   * @return 			Gini gain for the given attribute   * @throws Exception 		if something goes wrong   */  protected String nominalDistribution(double[][] props, double[][][] dists,      Attribute att, int[] sortedIndices, double[] weights, double[][] subsetWeights,      double[] gains, Instances data, boolean useHeuristic, boolean useGini)  throws Exception {    String[] values = new String[att.numValues()];    int numCat = values.length; // number of values of the attribute    int numClasses = data.numClasses();    String bestSplitString = "";    double bestGain = -Double.MAX_VALUE;    // class frequency for each value    int[] classFreq = new int[numCat];    for (int j=0; j<numCat; j++) classFreq[j] = 0;    double[] parentDist = new double[numClasses];    double[][] currDist = new double[2][numClasses];    double[][] dist = new double[2][numClasses];    int missingStart = 0;    for (int i = 0; i < sortedIndices.length; i++) {      Instance inst = data.instance(sortedIndices[i]);      if (!inst.isMissing(att)) {	missingStart++;	classFreq[(int)inst.value(att)] ++;      }      parentDist[(int)inst.classValue()] += weights[i];    }    // count the number of values that class frequency is not 0    int nonEmpty = 0;    for (int j=0; j<numCat; j++) {      if (classFreq[j]!=0) nonEmpty ++;    }    // attribute values which class frequency is not 0    String[] nonEmptyValues = new String[nonEmpty];    int nonEmptyIndex = 0;    for (int j=0; j<numCat; j++) {      if (classFreq[j]!=0) {	nonEmptyValues[nonEmptyIndex] = att.value(j);	nonEmptyIndex ++;      }    }    // attribute values which class frequency is 0    int empty = numCat - nonEmpty;    String[] emptyValues = new String[empty];    int emptyIndex = 0;    for (int j=0; j<numCat; j++) {      if (classFreq[j]==0) {	emptyValues[emptyIndex] = att.value(j);	emptyIndex ++;      }    }    if (nonEmpty<=1) {      gains[att.index()] = 0;      return "";    }    // for tow-class probloms    if (data.numClasses()==2) {      //// Firstly, for attribute values which class frequency is not zero      // probability of class 0 for each attribute value      double[] pClass0 = new double[nonEmpty];      // class distribution for each attribute value      double[][] valDist = new double[nonEmpty][2];      for (int j=0; j<nonEmpty; j++) {	for (int k=0; k<2; k++) {	  valDist[j][k] = 0;	}      }      for (int i = 0; i < sortedIndices.length; i++) {	Instance inst = data.instance(sortedIndices[i]);	if (inst.isMissing(att)) {	  break;	}	for (int j=0; j<nonEmpty; j++) {	  if (att.value((int)inst.value(att)).compareTo(nonEmptyValues[j])==0) {	    valDist[j][(int)inst.classValue()] += inst.weight();	    break;	  }	}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -