⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 reptree.java

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
	return size;      }    }    /**     * Splits instances into subsets.     */    protected void splitData(int[][][] subsetIndices, 			     double[][][] subsetWeights,			     int att, double splitPoint, 			     int[][] sortedIndices, double[][] weights, 			     Instances data) throws Exception {          int j;      int[] num;         // For each attribute      for (int i = 0; i < data.numAttributes(); i++) {	if (i != data.classIndex()) {	  if (data.attribute(att).isNominal()) {	    // For nominal attributes	    num = new int[data.attribute(att).numValues()];	    for (int k = 0; k < num.length; k++) {	      subsetIndices[k][i] = new int[sortedIndices[i].length];	      subsetWeights[k][i] = new double[sortedIndices[i].length];	    }	    for (j = 0; j < sortedIndices[i].length; j++) {	      Instance inst = data.instance(sortedIndices[i][j]);	      if (inst.isMissing(att)) {		// Split instance up		for (int k = 0; k < num.length; k++) {		  if (m_Prop[k] > 0) {		    subsetIndices[k][i][num[k]] = sortedIndices[i][j];		    subsetWeights[k][i][num[k]] = 		      m_Prop[k] * weights[i][j];		    num[k]++;		  }		}	      } else {		int subset = (int)inst.value(att);		subsetIndices[subset][i][num[subset]] = 		  sortedIndices[i][j];		subsetWeights[subset][i][num[subset]] = weights[i][j];		num[subset]++;	      }	    }	  } else {	    // For numeric attributes	    num = new int[2];	    for (int k = 0; k < 2; k++) {	      subsetIndices[k][i] = new int[sortedIndices[i].length];	      subsetWeights[k][i] = new double[weights[i].length];	    }	    for (j = 0; j < sortedIndices[i].length; j++) {	      Instance inst = data.instance(sortedIndices[i][j]);	      if (inst.isMissing(att)) {		// Split instance up		for (int k = 0; k < num.length; k++) {		  if (m_Prop[k] > 0) {		    subsetIndices[k][i][num[k]] = sortedIndices[i][j];		    subsetWeights[k][i][num[k]] = 		      m_Prop[k] * weights[i][j];		    num[k]++;		  }		}	      } else {		int subset = (inst.value(att) < splitPoint) ? 0 : 1;		subsetIndices[subset][i][num[subset]] = 		  sortedIndices[i][j];		subsetWeights[subset][i][num[subset]] = weights[i][j];		num[subset]++;	      } 	    }	  }		  // Trim arrays	  for (int k = 0; k < num.length; k++) {	    int[] copy = new int[num[k]];	    System.arraycopy(subsetIndices[k][i], 0, copy, 0, num[k]);	    subsetIndices[k][i] = copy;	    double[] copyWeights = new double[num[k]];	    System.arraycopy(subsetWeights[k][i], 0,			     copyWeights, 0, num[k]);	    subsetWeights[k][i] = copyWeights;	  }	}      }    }    /**     * Computes class distribution for an attribute.     */    protected double distribution(double[][] props,				  double[][][] dists, int att, 				  int[] sortedIndices,				  double[] weights, 				  double[][] subsetWeights, 				  Instances data)       throws Exception {      double splitPoint = Double.NaN;      Attribute attribute = data.attribute(att);      double[][] dist = null;      int i;      if (attribute.isNominal()) {	// For nominal attributes	dist = new double[attribute.numValues()][data.numClasses()];	for (i = 0; i < sortedIndices.length; i++) {	  Instance inst = data.instance(sortedIndices[i]);	  if (inst.isMissing(att)) {	    break;	  }	  dist[(int)inst.value(att)][(int)inst.classValue()] += weights[i];	}      } else {	// For numeric attributes	double[][] currDist = new double[2][data.numClasses()];	dist = new double[2][data.numClasses()];	// Move all instances into second subset	for (int j = 0; j < sortedIndices.length; j++) {	  Instance inst = data.instance(sortedIndices[j]);	  if (inst.isMissing(att)) {	    break;	  }	  currDist[1][(int)inst.classValue()] += weights[j];	}	double priorVal = priorVal(currDist);	System.arraycopy(currDist[1], 0, dist[1], 0, dist[1].length);	// Try all possible split points	double currSplit = data.instance(sortedIndices[0]).value(att);	double currVal, bestVal = -Double.MAX_VALUE;	for (i = 0; i < sortedIndices.length; i++) {	  Instance inst = data.instance(sortedIndices[i]);	  if (inst.isMissing(att)) {	    break;	  }	  if (inst.value(att) > currSplit) {	    currVal = gain(currDist, priorVal);	    if (currVal > bestVal) {	      bestVal = currVal;	      splitPoint = (inst.value(att) + currSplit) / 2.0;	      for (int j = 0; j < currDist.length; j++) {		System.arraycopy(currDist[j], 0, dist[j], 0, 				 dist[j].length);	      }	    } 	  } 	  currSplit = inst.value(att);	  currDist[0][(int)inst.classValue()] += weights[i];	  currDist[1][(int)inst.classValue()] -= weights[i];	}      }      // Compute weights      props[att] = new double[dist.length];      for (int k = 0; k < props[att].length; k++) {	props[att][k] = Utils.sum(dist[k]);      }      if (!(Utils.sum(props[att]) > 0)) {	for (int k = 0; k < props[att].length; k++) {	  props[att][k] = 1.0 / (double)props[att].length;	}      } else {	Utils.normalize(props[att]);      }          // Distribute counts      while (i < sortedIndices.length) {	Instance inst = data.instance(sortedIndices[i]);	for (int j = 0; j < dist.length; j++) {	  dist[j][(int)inst.classValue()] += props[att][j] * weights[i];	}	i++;      }      // Compute subset weights      subsetWeights[att] = new double[dist.length];      for (int j = 0; j < dist.length; j++) {	subsetWeights[att][j] += Utils.sum(dist[j]);      }      // Return distribution and split point      dists[att] = dist;      return splitPoint;    }          /**     * Computes class distribution for an attribute.     */    protected double numericDistribution(double[][] props, 					 double[][][] dists, int att, 					 int[] sortedIndices,					 double[] weights, 					 double[][] subsetWeights, 					 Instances data,					 double[] vals)       throws Exception {      double splitPoint = Double.NaN;      Attribute attribute = data.attribute(att);      double[][] dist = null;      double[] sums = null;      double[] sumSquared = null;      double[] sumOfWeights = null;      double totalSum = 0, totalSumSquared = 0, totalSumOfWeights = 0;      int i;      if (attribute.isNominal()) {	// For nominal attributes	sums = new double[attribute.numValues()];        sumSquared = new double[attribute.numValues()];	sumOfWeights = new double[attribute.numValues()];	int attVal;	for (i = 0; i < sortedIndices.length; i++) {	  Instance inst = data.instance(sortedIndices[i]);	  if (inst.isMissing(att)) {	    break;	  }	  attVal = (int)inst.value(att);	  sums[attVal] += inst.classValue() * weights[i];	  sumSquared[attVal] += 	    inst.classValue() * inst.classValue() * weights[i];	  sumOfWeights[attVal] += weights[i];	}	totalSum = Utils.sum(sums);	totalSumSquared = Utils.sum(sumSquared);	totalSumOfWeights = Utils.sum(sumOfWeights);      } else {	// For numeric attributes	sums = new double[2];        sumSquared = new double[2];	sumOfWeights = new double[2];	double[] currSums = new double[2];        double[] currSumSquared = new double[2];	double[] currSumOfWeights = new double[2];	// Move all instances into second subset	for (int j = 0; j < sortedIndices.length; j++) {	  Instance inst = data.instance(sortedIndices[j]);	  if (inst.isMissing(att)) {	    break;	  }	  currSums[1] += inst.classValue() * weights[j];	  currSumSquared[1] += 	    inst.classValue() * inst.classValue() * weights[j];	  currSumOfWeights[1] += weights[j];	  	}	totalSum = currSums[1];	totalSumSquared = currSumSquared[1];	totalSumOfWeights = currSumOfWeights[1];		sums[1] = currSums[1];	sumSquared[1] = currSumSquared[1];	sumOfWeights[1] = currSumOfWeights[1];	// Try all possible split points	double currSplit = data.instance(sortedIndices[0]).value(att);	double currVal, bestVal = Double.MAX_VALUE;	for (i = 0; i < sortedIndices.length; i++) {	  Instance inst = data.instance(sortedIndices[i]);	  if (inst.isMissing(att)) {	    break;	  }	  if (inst.value(att) > currSplit) {	    currVal = variance(currSums, currSumSquared, currSumOfWeights);	    if (currVal < bestVal) {	      bestVal = currVal;	      splitPoint = (inst.value(att) + currSplit) / 2.0;	      for (int j = 0; j < 2; j++) {		sums[j] = currSums[j];		sumSquared[j] = currSumSquared[j];		sumOfWeights[j] = currSumOfWeights[j];	      }	    } 	  } 	  currSplit = inst.value(att);	  double classVal = inst.classValue() * weights[i];	  double classValSquared = inst.classValue() * classVal;	  currSums[0] += classVal;	  currSumSquared[0] += classValSquared;	  currSumOfWeights[0] += weights[i];	  currSums[1] -= classVal;	  currSumSquared[1] -= classValSquared;	  currSumOfWeights[1] -= weights[i];	}      }      // Compute weights      props[att] = new double[sums.length];      for (int k = 0; k < props[att].length; k++) {	props[att][k] = sumOfWeights[k];      }      if (!(Utils.sum(props[att]) > 0)) {	for (int k = 0; k < props[att].length; k++) {	  props[att][k] = 1.0 / (double)props[att].length;	}      } else {	Utils.normalize(props[att]);      }    	      // Distribute counts for missing values      while (i < sortedIndices.length) {	Instance inst = data.instance(sortedIndices[i]);	for (int j = 0; j < sums.length; j++) {	  sums[j] += props[att][j] * inst.classValue() * weights[i];	  sumSquared[j] += props[att][j] * inst.classValue() * 	    inst.classValue() * weights[i];	  sumOfWeights[j] += props[att][j] * weights[i];	}	totalSum += inst.classValue() * weights[i];	totalSumSquared += 	  inst.classValue() * inst.classValue() * weights[i]; 	totalSumOfWeights += weights[i];	i++;      }      // Compute final distribution      dist = new double[sums.length][data.numClasses()];      for (int j = 0; j < sums.length; j++) {	if (sumOfWeights[j] > 0) {	  dist[j][0] = sums[j] / sumOfWeights[j];	} else {	  dist[j][0] = totalSum / totalSumOfWeights;	}      }            // Compute variance gain      double priorVar =	singleVariance(totalSum, totalSumSquared, totalSumOfWeights);      double var = variance(sums, sumSquared, sumOfWeights);      double gain = priorVar - var;            // Return distribution and split point      subsetWeights[att] = sumOfWeights;      dists[att] = dist;      vals[att] = gain;      return splitPoint;    }          /**     * Computes variance for subsets.     */    protected double variance(double[] s, double[] sS, 			    double[] sumOfWeights) {            double var = 0;            for (int i = 0; i < s.length; i++) {	if (sumOfWeights[i] > 0) {	  var += singleVariance(s[i], sS[i], sumOfWeights[i]);	}      }            return var;    }        /**      * Computes the variance for a single set     */    protected double singleVariance(double s, double sS, double weight) {            return sS - ((s * s) / weight);    }    /**     * Computes value of splitting criterion before split.     */    protected double priorVal(double[][] dist) {      return ContingencyTables.entropyOverColumns(dist);    }    /**     * Computes value of splitting criterion after split.     */    protected double gain(double[][] dist, double priorVal) {      return priorVal - ContingencyTables.entropyConditionedOnRows(dist);    }    /**     * Prunes the tree using the hold-out data (bottom-up).     */    protected double reducedErrorPrune() throws Exception {      // Is node leaf ?       if (m_Attribute == -1) {	return m_HoldOutError;      }      // Prune all sub trees      double errorTree = 0;      for (int i = 0; i < m_Successors.length; i++) {	errorTree += m_Successors[i].reducedErrorPrune();      }      // Replace sub tree with leaf if error doesn't get worse      if (errorTree >= m_HoldOutError) {	m_Attribute = -1;	m_Successors = null;	return m_HoldOutError;      } else {	return errorTree;      }    }    /**     * Inserts hold-out set into tree.     */    protected void insertHoldOutSet(Instances data) throws Exception {      for (int i = 0; i < data.numInstances(); i++) {	insertHoldOutInstance(data.instance(i), data.instance(i).weight(),			      this);      }    }    /**     * Inserts an instance from the hold-out set into the tree.     */    protected void insertHoldOutInstance(Instance inst, double weight, 					 Tree parent) throws Exception {          // Insert instance into hold-out class distribution      if (inst.classAttribute().isNominal()) {	// Nominal case	m_HoldOutDist[(int)inst.classValue()] += weight;	int predictedClass = 0;	if (m_ClassProbs == null) {	  predictedClass = Utils.maxIndex(parent.m_ClassProbs);	} else {	  predictedClass = Utils.maxIndex(m_ClassProbs);	}	if (predictedClass != (int)inst.classValue()) {	  m_HoldOutError += weight;	}      } else {	// Numeric case	m_HoldOutDist[0] += weight;	double diff = 0;	if (m_ClassProbs == null) {	  diff = parent.m_ClassProbs[0] - inst.classValue();	} else {	  diff =  m_ClassProbs[0] - inst.classValue();	}	m_HoldOutError += diff * diff * weight;      }	      // Th process is recursive      if (m_Attribute != -1) {      	// If node is not a leaf

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -