⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 simplecart.java

📁 代码是一个分类器的实现,其中使用了部分weka的源代码。可以将项目导入eclipse运行
💻 JAVA
📖 第 1 页 / 共 4 页
字号:
      }      //calculate the vector of mean class probability      for (int i=0; i<meanClass.length; i++) {	meanClass[i]/=numInstances;      }      // calculate the covariance matrix      double[][] covariance = new double[k][k];      for (int i1=0; i1<k; i1++) {	for (int i2=0; i2<k; i2++) {	  double element = 0;	  for (int j=0; j<n; j++) {	    element += (P[j][i2]-meanClass[i2])*(P[j][i1]-meanClass[i1])	    *numInstancesValue[j];	  }	  covariance[i1][i2] = element;	}      }      Matrix matrix = new Matrix(covariance);      weka.core.matrix.EigenvalueDecomposition eigen =	new weka.core.matrix.EigenvalueDecomposition(matrix);      double[] eigenValues = eigen.getRealEigenvalues();      // find index of the largest eigenvalue      int index=0;      double largest = eigenValues[0];      for (int i=1; i<eigenValues.length; i++) {	if (eigenValues[i]>largest) {	  index=i;	  largest = eigenValues[i];	}      }      // calculate the first principle component      double[] FPC = new double[k];      Matrix eigenVector = eigen.getV();      double[][] vectorArray = eigenVector.getArray();      for (int i=0; i<FPC.length; i++) {	FPC[i] = vectorArray[i][index];      }      // calculate the first principle component scores      //System.out.println("the first principle component scores: ");      double[] Sa = new double[n];      for (int i=0; i<Sa.length; i++) {	Sa[i]=0;	for (int j=0; j<k; j++) {	  Sa[i] += FPC[j]*P[i][j];	}      }      // sort category according to Sa(s)      double[] pCopy = new double[n];      System.arraycopy(Sa,0,pCopy,0,n);      String[] sortedValues = new String[n];      Arrays.sort(Sa);      for (int j=0; j<n; j++) {	sortedValues[j] = nonEmptyValues[Utils.minIndex(pCopy)];	pCopy[Utils.minIndex(pCopy)] = Double.MAX_VALUE;      }      // for the attribute values that class frequency is not 0      String tempStr = "";      for (int j=0; j<nonEmpty-1; j++) {	currDist = new double[2][numClasses];	if (tempStr=="") tempStr="(" + sortedValues[j] + ")";	else tempStr += "|"+ "(" + sortedValues[j] + ")";	for (int i=0; i<sortedIndices.length;i++) {	  Instance inst = data.instance(sortedIndices[i]);	  if (inst.isMissing(att)) {	    break;	  }	  if (tempStr.indexOf	      ("(" + att.value((int)inst.value(att)) + ")")!=-1) {	    currDist[0][(int)inst.classValue()] += weights[i];	  } else currDist[1][(int)inst.classValue()] += weights[i];	}	double[][] tempDist = new double[2][numClasses];	for (int kk=0; kk<2; kk++) {	  tempDist[kk] = currDist[kk];	}	double[] tempProps = new double[2];	for (int kk=0; kk<2; kk++) {	  tempProps[kk] = Utils.sum(tempDist[kk]);	}	if (Utils.sum(tempProps)!=0) Utils.normalize(tempProps);	// split missing values	int mstart = missingStart;	while (mstart < sortedIndices.length) {	  Instance insta = data.instance(sortedIndices[mstart]);	  for (int jj = 0; jj < 2; jj++) {	    tempDist[jj][(int)insta.classValue()] += tempProps[jj] * weights[mstart];	  }	  mstart++;	}	double currGiniGain = computeGiniGain(parentDist,tempDist);	if (currGiniGain>bestGiniGain) {	  bestGiniGain = currGiniGain;	  bestSplitString = tempStr;	  for (int jj = 0; jj < 2; jj++) {	    //dist[jj] = new double[currDist[jj].length];	    System.arraycopy(tempDist[jj], 0, dist[jj], 0,		dist[jj].length);	  }	}      }    }    // Compute weights    int attIndex = att.index();            props[attIndex] = new double[2];    for (int k = 0; k < 2; k++) {      props[attIndex][k] = Utils.sum(dist[k]);    }    if (!(Utils.sum(props[attIndex]) > 0)) {      for (int k = 0; k < props[attIndex].length; k++) {	props[attIndex][k] = 1.0 / (double)props[attIndex].length;      }    } else {      Utils.normalize(props[attIndex]);    }    // Compute subset weights    subsetWeights[attIndex] = new double[2];    for (int j = 0; j < 2; j++) {      subsetWeights[attIndex][j] += Utils.sum(dist[j]);    }    // Then, for the attribute values that class frequency is 0, split it into the    // most frequent branch    for (int j=0; j<empty; j++) {      if (props[attIndex][0]>=props[attIndex][1]) {	if (bestSplitString=="") bestSplitString = "(" + emptyValues[j] + ")";	else bestSplitString += "|" + "(" + emptyValues[j] + ")";      }    }    // clean Gini gain for the attribute    giniGains[attIndex] = Math.rint(bestGiniGain*10000000)/10000000.0;    dists[attIndex] = dist;    return bestSplitString;  }  /**   * Split data into two subsets and store sorted indices and weights for two   * successor nodes.   *    * @param subsetIndices 	sorted indecis of instances for each attribute    * 				for two successor node   * @param subsetWeights 	weights of instances for each attribute for    * 				two successor node   * @param att 		attribute the split based on   * @param splitPoint 		split point the split based on if att is numeric   * @param splitStr 		split subset the split based on if att is nominal   * @param sortedIndices 	sorted indices of the instances to be split   * @param weights 		weights of the instances to bes split   * @param data 		training data   * @throws Exception 		if something goes wrong     */  protected void splitData(int[][][] subsetIndices, double[][][] subsetWeights,      Attribute att, double splitPoint, String splitStr, int[][] sortedIndices,      double[][] weights, Instances data) throws Exception {    int j;    // For each attribute    for (int i = 0; i < data.numAttributes(); i++) {      if (i==data.classIndex()) continue;      int[] num = new int[2];      for (int k = 0; k < 2; k++) {	subsetIndices[k][i] = new int[sortedIndices[i].length];	subsetWeights[k][i] = new double[weights[i].length];      }      for (j = 0; j < sortedIndices[i].length; j++) {	Instance inst = data.instance(sortedIndices[i][j]);	if (inst.isMissing(att)) {	  // Split instance up	  for (int k = 0; k < 2; k++) {	    if (m_Props[k] > 0) {	      subsetIndices[k][i][num[k]] = sortedIndices[i][j];	      subsetWeights[k][i][num[k]] = m_Props[k] * weights[i][j];	      num[k]++;	    }	  }	} else {	  int subset;	  if (att.isNumeric())  {	    subset = (inst.value(att) < splitPoint) ? 0 : 1;	  } else { // nominal attribute	    if (splitStr.indexOf		("(" + att.value((int)inst.value(att.index()))+")")!=-1) {	      subset = 0;	    } else subset = 1;	  }	  subsetIndices[subset][i][num[subset]] = sortedIndices[i][j];	  subsetWeights[subset][i][num[subset]] = weights[i][j];	  num[subset]++;	}      }      // Trim arrays      for (int k = 0; k < 2; k++) {	int[] copy = new int[num[k]];	System.arraycopy(subsetIndices[k][i], 0, copy, 0, num[k]);	subsetIndices[k][i] = copy;	double[] copyWeights = new double[num[k]];	System.arraycopy(subsetWeights[k][i], 0 ,copyWeights, 0, num[k]);	subsetWeights[k][i] = copyWeights;      }    }  }  /**   * Updates the numIncorrectModel field for all nodes when subtree (to be    * pruned) is rooted. This is needed for calculating the alpha-values.   *    * @throws Exception 	if something goes wrong   */  public void modelErrors() throws Exception{    Evaluation eval = new Evaluation(m_train);    if (!m_isLeaf) {      m_isLeaf = true; //temporarily make leaf      // calculate distribution for evaluation      eval.evaluateModel(this, m_train);      m_numIncorrectModel = eval.incorrect();      m_isLeaf = false;      for (int i = 0; i < m_Successors.length; i++)	m_Successors[i].modelErrors();    } else {      eval.evaluateModel(this, m_train);      m_numIncorrectModel = eval.incorrect();    }         }  /**   * Updates the numIncorrectTree field for all nodes. This is needed for   * calculating the alpha-values.   *    * @throws Exception 	if something goes wrong   */  public void treeErrors() throws Exception {    if (m_isLeaf) {      m_numIncorrectTree = m_numIncorrectModel;    } else {      m_numIncorrectTree = 0;      for (int i = 0; i < m_Successors.length; i++) {	m_Successors[i].treeErrors();	m_numIncorrectTree += m_Successors[i].m_numIncorrectTree;      }    }  }  /**   * Updates the alpha field for all nodes.   *    * @throws Exception 	if something goes wrong   */  public void calculateAlphas() throws Exception {    if (!m_isLeaf) {      double errorDiff = m_numIncorrectModel - m_numIncorrectTree;      if (errorDiff <=0) {	//split increases training error (should not normally happen).	//prune it instantly.	makeLeaf(m_train);	m_Alpha = Double.MAX_VALUE;      } else {	//compute alpha	errorDiff /= m_totalTrainInstances;	m_Alpha = errorDiff / (double)(numLeaves() - 1);	long alphaLong = Math.round(m_Alpha*Math.pow(10,10));	m_Alpha = (double)alphaLong/Math.pow(10,10);	for (int i = 0; i < m_Successors.length; i++) {	  m_Successors[i].calculateAlphas();	}      }    } else {      //alpha = infinite for leaves (do not want to prune)      m_Alpha = Double.MAX_VALUE;    }  }  /**   * Find the node with minimal alpha value. If two nodes have the same alpha,    * choose the one with more leave nodes.   *    * @param nodeList 	list of inner nodes   * @return 		the node to be pruned   */  protected SimpleCart nodeToPrune(Vector nodeList) {    if (nodeList.size()==0) return null;    if (nodeList.size()==1) return (SimpleCart)nodeList.elementAt(0);    SimpleCart returnNode = (SimpleCart)nodeList.elementAt(0);    double baseAlpha = returnNode.m_Alpha;    for (int i=1; i<nodeList.size(); i++) {      SimpleCart node = (SimpleCart)nodeList.elementAt(i);      if (node.m_Alpha < baseAlpha) {	baseAlpha = node.m_Alpha;	returnNode = node;      } else if (node.m_Alpha == baseAlpha) { // break tie	if (node.numLeaves()>returnNode.numLeaves()) {	  returnNode = node;	}      }    }    return returnNode;  }  /**   * Compute sorted indices, weights and class probabilities for a given    * dataset. Return total weights of the data at the node.   *    * @param data 		training data   * @param sortedIndices 	sorted indices of instances at the node   * @param weights 		weights of instances at the node   * @param classProbs 		class probabilities at the node   * @return total 		weights of instances at the node   * @throws Exception 		if something goes wrong   */  protected double computeSortedInfo(Instances data, int[][] sortedIndices, double[][] weights,      double[] classProbs) throws Exception {    // Create array of sorted indices and weights    double[] vals = new double[data.numInstances()];    for (int j = 0; j < data.numAttributes(); j++) {      if (j==data.classIndex()) continue;      weights[j] = new double[data.numInstances()];      if (data.attribute(j).isNominal()) {	// Handling nominal attributes. Putting indices of	// instances with missing values at the end.	sortedIndices[j] = new int[data.numInstances()];	int count = 0;	for (int i = 0; i < data.numInstances(); i++) {	  Instance inst = data.instance(i);	  if (!inst.isMissing(j)) {	    sortedIndices[j][count] = i;	    weights[j][count] = inst.weight();	    count++;	  }	}	for (int i = 0; i < data.numInstances(); i++) {	  Instance inst = data.instance(i);	  if (inst.isMissing(j)) {	    sortedIndices[j][count] = i;	    weights[j][count] = inst.weight();	    count++;	  }	}      } else {	// Sorted indices are computed for numeric attributes	// missing values instances are put to end 	for (int i = 0; i < data.numInstances(); i++) {	  Instance inst = data.instance(i);	  vals[i] = inst.value(j);	}	sortedIndices[j] = Utils.sort(vals);	for (int i = 0; i < data.numInstances(); i++) {	  weights[j][i] = data.instance(sortedIndices[j][i]).weight();	}      }    }    // Compute initial class counts    double totalWeight = 0;    for (int i = 0; i < data.numInstances(); i++) {      Instance inst = data.instance(i);      classProbs[(int)inst.classValue()] += inst.weight();      totalWeight += inst.weight();    }    return totalWeight;  }  /**   * Compute and return gini gain for given distributions of a node and its    * successor nodes.   *    * @param parentDist 	class distributions of parent node   * @param childDist 	class distributions of successor nodes   * @return 		Gini gain computed   */  protected double computeGiniGain(double[] parentDist, double[][] childDist) {    double totalWeight = Utils.sum(parentDist);    if (totalWeight==0) return 0;    double leftWeight = Utils.sum(childDist[0]);    double rightWeight = Utils.sum(childDist[1]);    double parentGini = computeGini(parentDist, totalWeight);    double leftGini = computeGini(childDist[0],leftWeight);    double rightGini = computeGini(childDist[1], rightWeight);    return parentGini - leftWeight/totalWeight*leftGini -    rightWeight/totalWeight*rightGini;  }  /**   * Compute and return gini index for a given distribution of a node.   *    * @param dist 	class distributions   * @param total 	class distributions   * @return 		Gini index of the class distributions   */  protected double computeGini(double[] dist, double total) {    if (total==0) return 0;    double val = 0;    for (int i=0; i<dist.length; i++) {      val += (dist[i]/total)*(dist[i]/total);    }    return 1- val;  }  /**   * Computes class probabilities for instance using the decision tree.   *    * @param instance 	the instance for which class probabilities is to be computed   * @return 		the class probabilities for the given instance   * @throws Exception 	if something goes wrong   */  public double[] distributionForInstance(Instance instance)  throws Exception {    if (!m_isLeaf) {      // value of split attribute is missing      if (instance.isMissing(m_Attribute)) {	double[] returnedDist = new double[m_ClassProbs.length];	for (int i = 0; i < m_Successors.length; i++) {	  double[] help =	    m_Successors[i].distributionForInstance(instance);	  if (help != null) {	    for (int j = 0; j < help.length; j++) {	      returnedDist[j] += m_Props[i] * help[j];	    }	  }	}	return returnedDist;      }      // split attribute is nonimal      else if (m_Attribute.isNominal()) {	if (m_SplitString.indexOf("(" +	    m_Attribute.value((int)instance.value(m_Attribute)) + ")")!=-1)	  return  m_Successors[0].distributionForInstance(instance);	else return  m_Successors[1].distributionForInstance(instance);      }      // split attribute is numeric      else {	if (instance.value(m_Attribute) < m_SplitValue)	  return m_Successors[0].distributionForInstance(instance);	else	  return m_Successors[1].distributionForInstance(instance);      }    }    // leaf node    else return m_ClassProbs;  }  /**   * Make the node leaf node.   *    * @param data 	trainging data   */  protected void makeLeaf(Instances data) {    m_Attribute = null;    m_isLeaf = true;    m_ClassValue=Utils.maxIndex(m_ClassProbs);    m_ClassAttribute = data.classAttribute();  }  /**   * Prints the decision tree using the protected toString method from below.   *    * @return 		a textual description of the classifier   */  public String toString() {    if ((m_ClassProbs == null) && (m_Successors == null)) {      return "CART Tree: No model built yet.";    }

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -