⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 simplecart.java

📁 代码是一个分类器的实现,其中使用了部分weka的源代码。可以将项目导入eclipse运行
💻 JAVA
📖 第 1 页 / 共 4 页
字号:
	nodeList = getInnerNodes();	prune = (nodeList.size() > 0);	continue;      }      preAlpha = nodeToPrune.m_Alpha;      //update tree errors and alphas      treeErrors();      calculateAlphas();      nodeList = getInnerNodes();      prune = (nodeList.size() > 0);    }  }  /**   * Method for performing one fold in the cross-validation of minimal    * cost-complexity pruning. Generates a sequence of alpha-values with error    * estimates for the corresponding (partially pruned) trees, given the test    * set of that fold.   *   * @param alphas 	array to hold the generated alpha-values   * @param errors 	array to hold the corresponding error estimates   * @param test 	test set of that fold (to obtain error estimates)   * @return 		the iteration of the pruning   * @throws Exception 	if something goes wrong   */  public int prune(double[] alphas, double[] errors, Instances test)     throws Exception {    Vector nodeList;    // determine training error of subtrees (both with and without replacing a subtree),     // and calculate alpha-values from them    modelErrors();    treeErrors();    calculateAlphas();    // get list of all inner nodes in the tree    nodeList = getInnerNodes();    boolean prune = (nodeList.size() > 0);    //alpha_0 is always zero (unpruned tree)    alphas[0] = 0;    Evaluation eval;    // error of unpruned tree    if (errors != null) {      eval = new Evaluation(test);      eval.evaluateModel(this, test);      errors[0] = eval.errorRate();    }    int iteration = 0;    double preAlpha = Double.MAX_VALUE;    while (prune) {      iteration++;      // get node with minimum alpha      SimpleCart nodeToPrune = nodeToPrune(nodeList);      // do not set m_sons null, want to unprune      nodeToPrune.m_isLeaf = true;      // normally would not happen      if (nodeToPrune.m_Alpha==preAlpha) {	iteration--;	treeErrors();	calculateAlphas();	nodeList = getInnerNodes();	prune = (nodeList.size() > 0);	continue;      }      // get alpha-value of node      alphas[iteration] = nodeToPrune.m_Alpha;      // log error      if (errors != null) {	eval = new Evaluation(test);	eval.evaluateModel(this, test);	errors[iteration] = eval.errorRate();      }      preAlpha = nodeToPrune.m_Alpha;      //update errors/alphas      treeErrors();      calculateAlphas();      nodeList = getInnerNodes();      prune = (nodeList.size() > 0);    }    //set last alpha 1 to indicate end    alphas[iteration + 1] = 1.0;    return iteration;  }  /**   * Method to "unprune" the CART tree. Sets all leaf-fields to false.   * Faster than re-growing the tree because CART do not have to be fit again.   */  protected void unprune() {    if (m_Successors != null) {      m_isLeaf = false;      for (int i = 0; i < m_Successors.length; i++) m_Successors[i].unprune();    }  }  /**   * Compute distributions, proportions and total weights of two successor    * nodes for a given numeric attribute.   *    * @param props 		proportions of each two branches for each attribute   * @param dists 		class distributions of two branches for each attribute   * @param att 		numeric att split on   * @param sortedIndices 	sorted indices of instances for the attirubte   * @param weights 		weights of instances for the attirbute   * @param subsetWeights 	total weight of two branches split based on the attribute   * @param giniGains 		Gini gains for each attribute    * @param data 		training instances   * @return 			Gini gain the given numeric attribute   * @throws Exception 		if something goes wrong   */  protected double numericDistribution(double[][] props, double[][][] dists,      Attribute att, int[] sortedIndices, double[] weights, double[][] subsetWeights,      double[] giniGains, Instances data)    throws Exception {    double splitPoint = Double.NaN;    double[][] dist = null;    int numClasses = data.numClasses();    int i; // differ instances with or without missing values    double[][] currDist = new double[2][numClasses];    dist = new double[2][numClasses];    // Move all instances without missing values into second subset    double[] parentDist = new double[numClasses];    int missingStart = 0;    for (int j = 0; j < sortedIndices.length; j++) {      Instance inst = data.instance(sortedIndices[j]);      if (!inst.isMissing(att)) {	missingStart ++;	currDist[1][(int)inst.classValue()] += weights[j];      }      parentDist[(int)inst.classValue()] += weights[j];    }    System.arraycopy(currDist[1], 0, dist[1], 0, dist[1].length);    // Try all possible split points    double currSplit = data.instance(sortedIndices[0]).value(att);    double currGiniGain;    double bestGiniGain = -Double.MAX_VALUE;    for (i = 0; i < sortedIndices.length; i++) {      Instance inst = data.instance(sortedIndices[i]);      if (inst.isMissing(att)) {	break;      }      if (inst.value(att) > currSplit) {	double[][] tempDist = new double[2][numClasses];	for (int k=0; k<2; k++) {	  //tempDist[k] = currDist[k];	  System.arraycopy(currDist[k], 0, tempDist[k], 0, tempDist[k].length);	}	double[] tempProps = new double[2];	for (int k=0; k<2; k++) {	  tempProps[k] = Utils.sum(tempDist[k]);	}	if (Utils.sum(tempProps) !=0) Utils.normalize(tempProps);	// split missing values	int index = missingStart;	while (index < sortedIndices.length) {	  Instance insta = data.instance(sortedIndices[index]);	  for (int j = 0; j < 2; j++) {	    tempDist[j][(int)insta.classValue()] += tempProps[j] * weights[index];	  }	  index++;	}	currGiniGain = computeGiniGain(parentDist,tempDist);	if (currGiniGain > bestGiniGain) {	  bestGiniGain = currGiniGain;	  // clean split point	  splitPoint = Math.rint((inst.value(att) + currSplit)/2.0*100000)/100000.0; 	  for (int j = 0; j < currDist.length; j++) {	    System.arraycopy(tempDist[j], 0, dist[j], 0,		dist[j].length);	  }	}      }      currSplit = inst.value(att);      currDist[0][(int)inst.classValue()] += weights[i];      currDist[1][(int)inst.classValue()] -= weights[i];    }    // Compute weights    int attIndex = att.index();    props[attIndex] = new double[2];    for (int k = 0; k < 2; k++) {      props[attIndex][k] = Utils.sum(dist[k]);    }    if (Utils.sum(props[attIndex]) != 0) Utils.normalize(props[attIndex]);    // Compute subset weights    subsetWeights[attIndex] = new double[2];    for (int j = 0; j < 2; j++) {      subsetWeights[attIndex][j] += Utils.sum(dist[j]);    }    // clean Gini gain    giniGains[attIndex] = Math.rint(bestGiniGain*10000000)/10000000.0;    dists[attIndex] = dist;    return splitPoint;  }  /**   * Compute distributions, proportions and total weights of two successor    * nodes for a given nominal attribute.   *    * @param props 		proportions of each two branches for each attribute   * @param dists 		class distributions of two branches for each attribute   * @param att 		numeric att split on   * @param sortedIndices 	sorted indices of instances for the attirubte   * @param weights 		weights of instances for the attirbute   * @param subsetWeights 	total weight of two branches split based on the attribute   * @param giniGains 		Gini gains for each attribute    * @param data 		training instances   * @param useHeuristic 	if use heuristic search   * @return 			Gini gain for the given nominal attribute   * @throws Exception 		if something goes wrong   */  protected String nominalDistribution(double[][] props, double[][][] dists,      Attribute att, int[] sortedIndices, double[] weights, double[][] subsetWeights,      double[] giniGains, Instances data, boolean useHeuristic)    throws Exception {    String[] values = new String[att.numValues()];    int numCat = values.length; // number of values of the attribute    int numClasses = data.numClasses();    String bestSplitString = "";    double bestGiniGain = -Double.MAX_VALUE;    // class frequency for each value    int[] classFreq = new int[numCat];    for (int j=0; j<numCat; j++) classFreq[j] = 0;    double[] parentDist = new double[numClasses];    double[][] currDist = new double[2][numClasses];    double[][] dist = new double[2][numClasses];    int missingStart = 0;    for (int i = 0; i < sortedIndices.length; i++) {      Instance inst = data.instance(sortedIndices[i]);      if (!inst.isMissing(att)) {	missingStart++;	classFreq[(int)inst.value(att)] ++;      }      parentDist[(int)inst.classValue()] += weights[i];    }    // count the number of values that class frequency is not 0    int nonEmpty = 0;    for (int j=0; j<numCat; j++) {      if (classFreq[j]!=0) nonEmpty ++;    }    // attribute values that class frequency is not 0    String[] nonEmptyValues = new String[nonEmpty];    int nonEmptyIndex = 0;    for (int j=0; j<numCat; j++) {      if (classFreq[j]!=0) {	nonEmptyValues[nonEmptyIndex] = att.value(j);	nonEmptyIndex ++;      }    }    // attribute values that class frequency is 0    int empty = numCat - nonEmpty;    String[] emptyValues = new String[empty];    int emptyIndex = 0;    for (int j=0; j<numCat; j++) {      if (classFreq[j]==0) {	emptyValues[emptyIndex] = att.value(j);	emptyIndex ++;      }    }    if (nonEmpty<=1) {      giniGains[att.index()] = 0;      return "";    }    // for tow-class probloms    if (data.numClasses()==2) {      //// Firstly, for attribute values which class frequency is not zero      // probability of class 0 for each attribute value      double[] pClass0 = new double[nonEmpty];      // class distribution for each attribute value      double[][] valDist = new double[nonEmpty][2];      for (int j=0; j<nonEmpty; j++) {	for (int k=0; k<2; k++) {	  valDist[j][k] = 0;	}      }      for (int i = 0; i < sortedIndices.length; i++) {	Instance inst = data.instance(sortedIndices[i]);	if (inst.isMissing(att)) {	  break;	}	for (int j=0; j<nonEmpty; j++) {	  if (att.value((int)inst.value(att)).compareTo(nonEmptyValues[j])==0) {	    valDist[j][(int)inst.classValue()] += inst.weight();	    break;	  }	}      }      for (int j=0; j<nonEmpty; j++) {	double distSum = Utils.sum(valDist[j]);	if (distSum==0) pClass0[j]=0;	else pClass0[j] = valDist[j][0]/distSum;      }      // sort category according to the probability of the first class      String[] sortedValues = new String[nonEmpty];      for (int j=0; j<nonEmpty; j++) {	sortedValues[j] = nonEmptyValues[Utils.minIndex(pClass0)];	pClass0[Utils.minIndex(pClass0)] = Double.MAX_VALUE;      }      // Find a subset of attribute values that maximize Gini decrease      // for the attribute values that class frequency is not 0      String tempStr = "";      for (int j=0; j<nonEmpty-1; j++) {	currDist = new double[2][numClasses];	if (tempStr=="") tempStr="(" + sortedValues[j] + ")";	else tempStr += "|"+ "(" + sortedValues[j] + ")";	for (int i=0; i<sortedIndices.length;i++) {	  Instance inst = data.instance(sortedIndices[i]);	  if (inst.isMissing(att)) {	    break;	  }	  if (tempStr.indexOf	      ("(" + att.value((int)inst.value(att)) + ")")!=-1) {	    currDist[0][(int)inst.classValue()] += weights[i];	  } else currDist[1][(int)inst.classValue()] += weights[i];	}	double[][] tempDist = new double[2][numClasses];	for (int kk=0; kk<2; kk++) {	  tempDist[kk] = currDist[kk];	}	double[] tempProps = new double[2];	for (int kk=0; kk<2; kk++) {	  tempProps[kk] = Utils.sum(tempDist[kk]);	}	if (Utils.sum(tempProps)!=0) Utils.normalize(tempProps);	// split missing values	int mstart = missingStart;	while (mstart < sortedIndices.length) {	  Instance insta = data.instance(sortedIndices[mstart]);	  for (int jj = 0; jj < 2; jj++) {	    tempDist[jj][(int)insta.classValue()] += tempProps[jj] * weights[mstart];	  }	  mstart++;	}	double currGiniGain = computeGiniGain(parentDist,tempDist);	if (currGiniGain>bestGiniGain) {	  bestGiniGain = currGiniGain;	  bestSplitString = tempStr;	  for (int jj = 0; jj < 2; jj++) {	    //dist[jj] = new double[currDist[jj].length];	    System.arraycopy(tempDist[jj], 0, dist[jj], 0,		dist[jj].length);	  }	}      }    }    // multi-class problems - exhaustive search    else if (!useHeuristic || nonEmpty<=4) {      // Firstly, for attribute values which class frequency is not zero      for (int i=0; i<(int)Math.pow(2,nonEmpty-1); i++) {	String tempStr="";	currDist = new double[2][numClasses];	int mod;	int bit10 = i;	for (int j=nonEmpty-1; j>=0; j--) {	  mod = bit10%2; // convert from 10bit to 2bit	  if (mod==1) {	    if (tempStr=="") tempStr = "("+nonEmptyValues[j]+")";	    else tempStr += "|" + "("+nonEmptyValues[j]+")";	  }	  bit10 = bit10/2;	}	for (int j=0; j<sortedIndices.length;j++) {	  Instance inst = data.instance(sortedIndices[j]);	  if (inst.isMissing(att)) {	    break;	  }	  if (tempStr.indexOf("("+att.value((int)inst.value(att))+")")!=-1) {	    currDist[0][(int)inst.classValue()] += weights[j];	  } else currDist[1][(int)inst.classValue()] += weights[j];	}	double[][] tempDist = new double[2][numClasses];	for (int k=0; k<2; k++) {	  tempDist[k] = currDist[k];	}	double[] tempProps = new double[2];	for (int k=0; k<2; k++) {	  tempProps[k] = Utils.sum(tempDist[k]);	}	if (Utils.sum(tempProps)!=0) Utils.normalize(tempProps);	// split missing values	int index = missingStart;	while (index < sortedIndices.length) {	  Instance insta = data.instance(sortedIndices[index]);	  for (int j = 0; j < 2; j++) {	    tempDist[j][(int)insta.classValue()] += tempProps[j] * weights[index];	  }	  index++;	}	double currGiniGain = computeGiniGain(parentDist,tempDist);	if (currGiniGain>bestGiniGain) {	  bestGiniGain = currGiniGain;	  bestSplitString = tempStr;	  for (int j = 0; j < 2; j++) {	    //dist[jj] = new double[currDist[jj].length];	    System.arraycopy(tempDist[j], 0, dist[j], 0,		dist[j].length);	  }	}      }    }    // huristic search to solve multi-classes problems    else {      // Firstly, for attribute values which class frequency is not zero      int n = nonEmpty;      int k = data.numClasses();  // number of classes of the data      double[][] P = new double[n][k];      // class probability matrix      int[] numInstancesValue = new int[n]; // number of instances for an attribute value      double[] meanClass = new double[k];   // vector of mean class probability      int numInstances = data.numInstances(); // total number of instances      // initialize the vector of mean class probability      for (int j=0; j<meanClass.length; j++) meanClass[j]=0;      for (int j=0; j<numInstances; j++) {	Instance inst = (Instance)data.instance(j);	int valueIndex = 0; // attribute value index in nonEmptyValues	for (int i=0; i<nonEmpty; i++) {	  if (att.value((int)inst.value(att)).compareToIgnoreCase(nonEmptyValues[i])==0){	    valueIndex = i;	    break;	  }	}	P[valueIndex][(int)inst.classValue()]++;	numInstancesValue[valueIndex]++;	meanClass[(int)inst.classValue()]++;      }      // calculate the class probability matrix      for (int i=0; i<P.length; i++) {	for (int j=0; j<P[0].length; j++) {	  if (numInstancesValue[i]==0) P[i][j]=0;	  else P[i][j]/=numInstancesValue[i];	}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -