vote.java

来自「Weka」· Java 代码 · 共 661 行 · 第 1/2 页

JAVA
661
字号
    switch (m_CombinationRule) {      case AVERAGE_RULE:      case PRODUCT_RULE:      case MAJORITY_VOTING_RULE:      case MIN_RULE:      case MAX_RULE:	dist = distributionForInstance(instance);	if (instance.classAttribute().isNominal()) {	  index = Utils.maxIndex(dist);	  if (dist[index] == 0)	    result = Instance.missingValue();	  else	    result = index;	}	else if (instance.classAttribute().isNumeric()){	  result = dist[0];	}	else {	  result = Instance.missingValue();	}	break;      case MEDIAN_RULE:	result = classifyInstanceMedian(instance);	break;      default:	throw new IllegalStateException("Unknown combination rule '" + m_CombinationRule + "'!");    }        return result;  }  /**   * Classifies the given test instance, returning the median from all   * classifiers.   *   * @param instance the instance to be classified   * @return the predicted most likely class for the instance or    * Instance.missingValue() if no prediction is made   * @throws Exception if an error occurred during the prediction   */  protected double classifyInstanceMedian(Instance instance) throws Exception {    double[] results = new double[m_Classifiers.length];    double result;    for (int i = 0; i < results.length; i++)      results[i] = m_Classifiers[i].classifyInstance(instance);        if (results.length == 0)      result = 0;    else if (results.length == 1)      result = results[0];    else      result = Utils.kthSmallestValue(results, results.length / 2);        return result;  }  /**   * Classifies a given instance using the selected combination rule.   *   * @param instance the instance to be classified   * @return the distribution   * @throws Exception if instance could not be classified   * successfully   */  public double[] distributionForInstance(Instance instance) throws Exception {    double[] result = new double[instance.numClasses()];        switch (m_CombinationRule) {      case AVERAGE_RULE:	result = distributionForInstanceAverage(instance);	break;      case PRODUCT_RULE:	result = distributionForInstanceProduct(instance);	break;      case MAJORITY_VOTING_RULE:	result = distributionForInstanceMajorityVoting(instance);	break;      case MIN_RULE:	result = distributionForInstanceMin(instance);	break;      case MAX_RULE:	result = distributionForInstanceMax(instance);	break;      case MEDIAN_RULE:	result[0] = classifyInstance(instance);	break;      default:	throw new IllegalStateException("Unknown combination rule '" + m_CombinationRule + "'!");    }        if (!instance.classAttribute().isNumeric() && (Utils.sum(result) > 0))      Utils.normalize(result);        return result;  }    /**   * Classifies a given instance using the Average of Probabilities    * combination rule.   *   * @param instance the instance to be classified   * @return the distribution   * @throws Exception if instance could not be classified   * successfully   */  protected double[] distributionForInstanceAverage(Instance instance) throws Exception {    double[] probs = getClassifier(0).distributionForInstance(instance);    for (int i = 1; i < m_Classifiers.length; i++) {      double[] dist = getClassifier(i).distributionForInstance(instance);      for (int j = 0; j < dist.length; j++) {    	  probs[j] += dist[j];      }    }    for (int j = 0; j < probs.length; j++) {      probs[j] /= (double)m_Classifiers.length;    }    return probs;  }    /**   * Classifies a given instance using the Product of Probabilities    * combination rule.   *   * @param instance the instance to be classified   * @return the distribution   * @throws Exception if instance could not be classified   * successfully   */  protected double[] distributionForInstanceProduct(Instance instance) throws Exception {    double[] probs = getClassifier(0).distributionForInstance(instance);    for (int i = 1; i < m_Classifiers.length; i++) {      double[] dist = getClassifier(i).distributionForInstance(instance);      for (int j = 0; j < dist.length; j++) {    	  probs[j] *= dist[j];      }    }        return probs;  }    /**   * Classifies a given instance using the Majority Voting combination rule.   *   * @param instance the instance to be classified   * @return the distribution   * @throws Exception if instance could not be classified   * successfully   */  protected double[] distributionForInstanceMajorityVoting(Instance instance) throws Exception {    double[] probs = new double[instance.classAttribute().numValues()];    double[] votes = new double[probs.length];        for (int i = 0; i < m_Classifiers.length; i++) {      probs = getClassifier(i).distributionForInstance(instance);      int maxIndex = 0;      for(int j = 0; j<probs.length; j++) {          if(probs[j] > probs[maxIndex])        	  maxIndex = j;      }            // Consider the cases when multiple classes happen to have the same probability      for (int j=0; j<probs.length; j++) {	if (probs[j] == probs[maxIndex])	  votes[j]++;      }    }        int tmpMajorityIndex = 0;    for (int k = 1; k < votes.length; k++) {      if (votes[k] > votes[tmpMajorityIndex])	tmpMajorityIndex = k;    }        // Consider the cases when multiple classes receive the same amount of votes    Vector<Integer> majorityIndexes = new Vector<Integer>();    for (int k = 0; k < votes.length; k++) {      if (votes[k] == votes[tmpMajorityIndex])	majorityIndexes.add(k);     }    // Resolve the ties according to a uniform random distribution    int majorityIndex = majorityIndexes.get(m_Random.nextInt(majorityIndexes.size()));        //set probs to 0    for (int k = 0; k<probs.length; k++)      probs[k] = 0;    probs[majorityIndex] = 1; //the class that have been voted the most receives 1        return probs;  }    /**   * Classifies a given instance using the Maximum Probability combination rule.   *   * @param instance the instance to be classified   * @return the distribution   * @throws Exception if instance could not be classified   * successfully   */  protected double[] distributionForInstanceMax(Instance instance) throws Exception {    double[] max = getClassifier(0).distributionForInstance(instance);    for (int i = 1; i < m_Classifiers.length; i++) {      double[] dist = getClassifier(i).distributionForInstance(instance);      for (int j = 0; j < dist.length; j++) {    	  if(max[j]<dist[j])    		  max[j]=dist[j];      }    }        return max;  }    /**   * Classifies a given instance using the Minimum Probability combination rule.   *   * @param instance the instance to be classified   * @return the distribution   * @throws Exception if instance could not be classified   * successfully   */  protected double[] distributionForInstanceMin(Instance instance) throws Exception {    double[] min = getClassifier(0).distributionForInstance(instance);    for (int i = 1; i < m_Classifiers.length; i++) {      double[] dist = getClassifier(i).distributionForInstance(instance);      for (int j = 0; j < dist.length; j++) {    	  if(dist[j]<min[j])    		  min[j]=dist[j];      }    }        return min;  }     /**   * Returns the tip text for this property   *    * @return 		tip text for this property suitable for   * 			displaying in the explorer/experimenter gui   */  public String combinationRuleTipText() {    return "The combination rule used.";  }    /**   * Gets the combination rule used   *   * @return 		the combination rule used   */  public SelectedTag getCombinationRule() {    return new SelectedTag(m_CombinationRule, TAGS_RULES);  }  /**   * Sets the combination rule to use. Values other than   *   * @param newRule 	the combination rule method to use   */  public void setCombinationRule(SelectedTag newRule) {    if (newRule.getTags() == TAGS_RULES)      m_CombinationRule = newRule.getSelectedTag().getID();  }    /**   * Output a representation of this classifier   *    * @return a string representation of the classifier   */  public String toString() {    if (m_Classifiers == null) {      return "Vote: No model built yet.";    }    String result = "Vote combines";    result += " the probability distributions of these base learners:\n";    for (int i = 0; i < m_Classifiers.length; i++) {      result += '\t' + getClassifierSpec(i) + '\n';    }    result += "using the '";        switch (m_CombinationRule) {      case AVERAGE_RULE:	result += "Average of Probabilities";	break;	      case PRODUCT_RULE:	result += "Product of Probabilities";	break;	      case MAJORITY_VOTING_RULE:	result += "Majority Voting";	break;	      case MIN_RULE:	result += "Minimum Probability";	break;	      case MAX_RULE:	result += "Maximum Probability";	break;	      case MEDIAN_RULE:	result += "Median Probability";	break;	      default:	throw new IllegalStateException("Unknown combination rule '" + m_CombinationRule + "'!");    }        result += "' combination rule \n";    return result;  }  /**   * Main method for testing this class.   *   * @param argv should contain the following arguments:   * -t training file [-T test file] [-c class index]   */  public static void main(String [] argv) {    runClassifier(new Vote(), argv);  }}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?