📄 vote.java
字号:
case MAX_RULE: result = Utils.maxIndex(distributionForInstance(instance)); if (result == 0) result = Instance.missingValue(); break; case MEDIAN_RULE: result = classifyInstanceMedian(instance); break; default: throw new IllegalStateException("Unknown combination rule '" + m_CombinationRule + "'!"); } return result; } /** * Classifies the given test instance, returning the median from all * classifiers. * * @param instance the instance to be classified * @return the predicted most likely class for the instance or * Instance.missingValue() if no prediction is made * @throws Exception if an error occurred during the prediction */ protected double classifyInstanceMedian(Instance instance) throws Exception { double[] results = new double[m_Classifiers.length]; double result; for (int i = 0; i < results.length; i++) results[i] = m_Classifiers[i].classifyInstance(instance); if (results.length == 0) result = 0; else if (results.length == 1) result = results[0]; else result = Utils.kthSmallestValue(results, results.length / 2); return result; } /** * Classifies a given instance using the selected combination rule. * * @param instance the instance to be classified * @return the distribution * @throws Exception if instance could not be classified * successfully */ public double[] distributionForInstance(Instance instance) throws Exception { double[] result = new double[instance.numClasses()]; switch (m_CombinationRule) { case AVERAGE_RULE: result = distributionForInstanceAverage(instance); break; case PRODUCT_RULE: result = distributionForInstanceProduct(instance); break; case MAJORITY_VOTING_RULE: result = distributionForInstanceMajorityVoting(instance); break; case MIN_RULE: result = distributionForInstanceMin(instance); break; case MAX_RULE: result = distributionForInstanceMax(instance); break; case MEDIAN_RULE: result[0] = classifyInstance(instance); break; default: throw new IllegalStateException("Unknown combination rule '" + m_CombinationRule + "'!"); } if (!instance.classAttribute().isNumeric()) Utils.normalize(result); return result; } /** * Classifies a given instance using the Average of Probabilities * combination rule. * * @param instance the instance to be classified * @return the distribution * @throws Exception if instance could not be classified * successfully */ protected double[] distributionForInstanceAverage(Instance instance) throws Exception { double[] probs = getClassifier(0).distributionForInstance(instance); for (int i = 1; i < m_Classifiers.length; i++) { double[] dist = getClassifier(i).distributionForInstance(instance); for (int j = 0; j < dist.length; j++) { probs[j] += dist[j]; } } for (int j = 0; j < probs.length; j++) { probs[j] /= (double)m_Classifiers.length; } return probs; } /** * Classifies a given instance using the Product of Probabilities * combination rule. * * @param instance the instance to be classified * @return the distribution * @throws Exception if instance could not be classified * successfully */ protected double[] distributionForInstanceProduct(Instance instance) throws Exception { double[] probs = getClassifier(0).distributionForInstance(instance); for (int i = 1; i < m_Classifiers.length; i++) { double[] dist = getClassifier(i).distributionForInstance(instance); for (int j = 0; j < dist.length; j++) { probs[j] *= dist[j]; } } return probs; } /** * Classifies a given instance using the Majority Voting combination rule. * * @param instance the instance to be classified * @return the distribution * @throws Exception if instance could not be classified * successfully */ protected double[] distributionForInstanceMajorityVoting(Instance instance) throws Exception { double[] probs = getClassifier(0).distributionForInstance(instance); // If it was possible to get the number of classes without classifying // double probs = new double[getClassifier(0).numOfClasses()]; double[] votes = new double[probs.length]; for (int i = 0; i < m_Classifiers.length; i++) { probs = getClassifier(i).distributionForInstance(instance); int maxIndex = 0; for(int j=0; j<probs.length; j++) { if(probs[j] > probs[maxIndex]) maxIndex = j; } // Consider the cases when multiple classes happen to have the same probability for(int j=0; j<probs.length; j++) { if(probs[j] == probs[maxIndex]) votes[j]++; } } int tmpMajorityIndex = 0; for (int k = 1; k < votes.length; k++) { if(votes[k] > votes[tmpMajorityIndex]) tmpMajorityIndex = k; } // Consider the cases when multiple classes receive the same amount of votes Vector majorityIndexes = new Vector(); for (int k = 0; k < votes.length; k++) { if(votes[k] == votes[tmpMajorityIndex]) majorityIndexes.add(new Integer(k)); } // Resolve the ties according to a uniform random distribution int majorityIndex = ((Integer)majorityIndexes.get((int)(Math.random()/(1/majorityIndexes.size())))).intValue(); //set probs to 0 for(int k=0; k<probs.length; k++) probs[k]=0; probs[majorityIndex]=1; //the class that have been voted the most receives 1 return probs; } /** * Classifies a given instance using the Maximum Probability combination rule. * * @param instance the instance to be classified * @return the distribution * @throws Exception if instance could not be classified * successfully */ protected double[] distributionForInstanceMax(Instance instance) throws Exception { double[] max = getClassifier(0).distributionForInstance(instance); for (int i = 1; i < m_Classifiers.length; i++) { double[] dist = getClassifier(i).distributionForInstance(instance); for (int j = 0; j < dist.length; j++) { if(max[j]<dist[j]) max[j]=dist[j]; } } return max; } /** * Classifies a given instance using the Minimum Probability combination rule. * * @param instance the instance to be classified * @return the distribution * @throws Exception if instance could not be classified * successfully */ protected double[] distributionForInstanceMin(Instance instance) throws Exception { double[] min = getClassifier(0).distributionForInstance(instance); for (int i = 1; i < m_Classifiers.length; i++) { double[] dist = getClassifier(i).distributionForInstance(instance); for (int j = 0; j < dist.length; j++) { if(dist[j]<min[j]) min[j]=dist[j]; } } return min; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String combinationRuleTipText() { return "The combination rule used."; } /** * Gets the combination rule used * * @return the combination rule used */ public SelectedTag getCombinationRule() { return new SelectedTag(m_CombinationRule, TAGS_RULES); } /** * Sets the combination rule to use. Values other than * * @param newRule the combination rule method to use */ public void setCombinationRule(SelectedTag newRule) { if (newRule.getTags() == TAGS_RULES) m_CombinationRule = newRule.getSelectedTag().getID(); } /** * Output a representation of this classifier * * @return a string representation of the classifier */ public String toString() { if (m_Classifiers == null) { return "Vote: No model built yet."; } String result = "Vote combines"; result += " the probability distributions of these base learners:\n"; for (int i = 0; i < m_Classifiers.length; i++) { result += '\t' + getClassifierSpec(i) + '\n'; } result += "using the '"; switch (m_CombinationRule) { case AVERAGE_RULE: result += "Average of Probabilities"; break; case PRODUCT_RULE: result += "Product of Probabilities"; break; case MAJORITY_VOTING_RULE: result += "Majority Voting"; break; case MIN_RULE: result += "Minimum Probability"; break; case MAX_RULE: result += "Maximum Probability"; break; case MEDIAN_RULE: result += "Median Probability"; break; default: throw new IllegalStateException("Unknown combination rule '" + m_CombinationRule + "'!"); } result += "' combination rule \n"; return result; } /** * Main method for testing this class. * * @param argv should contain the following arguments: * -t training file [-T test file] [-c class index] */ public static void main(String [] argv) { runClassifier(new Vote(), argv); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -