📄 classifierevaluator.java
字号:
} int getRank(RankedClassification classification, String responseCategory) { for (int rank = 0; rank < classification.size(); ++rank) if (classification.category(rank).equals(responseCategory)) return rank; // default to putting it in last rank return mCategories.length-1; } /** * Returns a scored precision-recall evaluation of the * classification of the specified reference category versus all * other categories using the classification scores. * * @param refCategory Reference category. * @return The scored one-versus-all precision-recall evaluatuion. * @throws IllegalArgumentException If the specified category * is unknown. */ public ScoredPrecisionRecallEvaluation scoredOneVersusAll(String refCategory) { validateCategory(refCategory); return scoredOneVersusAll(mScoreOutcomeLists, categoryToIndex(refCategory)); } /** * Returns a scored precision-recall evaluation of the * classifcation of the specified reference category versus all * other categories using the conditional probability scores. * This method may only be called for evaluations that have * scores. * * @param refCategory Reference category. * @return The conditional one-versus-all precision-recall evaluatuion. * @throws IllegalArgumentException If the specified category * is unknown. */ public ScoredPrecisionRecallEvaluation conditionalOneVersusAll(String refCategory) { validateCategory(refCategory); return scoredOneVersusAll(mConditionalOutcomeLists, categoryToIndex(refCategory)); } /** * Returns the first-best one-versus-all precision-recall * evaluation of the classification of the specified reference * category versus all other categories. This method may be * called for any evaluation. * * @param refCategory Reference category. * @return The first-best one-versus-all precision-recall * evaluatuion. * @throws IllegalArgumentException If the specified category * is unknown. */ public PrecisionRecallEvaluation oneVersusAll(String refCategory) { validateCategory(refCategory); PrecisionRecallEvaluation prEval = new PrecisionRecallEvaluation(); int numCases = mReferenceCategories.size(); for (int i = 0; i < numCases; ++i) { Object caseRefCategory = mReferenceCategories.get(i); Classification response = (Classification) mClassifications.get(i); Object caseResponseCategory = response.bestCategory(); boolean inRef = caseRefCategory.equals(refCategory); boolean inResp = caseResponseCategory.equals(refCategory); prEval.addCase(inRef,inResp); } return prEval; } private ScoredPrecisionRecallEvaluation scoredOneVersusAll(ArrayList[] outcomeLists, int categoryIndex) { ScoredPrecisionRecallEvaluation eval = new ScoredPrecisionRecallEvaluation(); List responseList = outcomeLists[categoryIndex]; for (int i = 0; i < responseList.size(); ++i) { ScoreOutcome outcome = (ScoreOutcome) responseList.get(i); eval.addCase(outcome.mOutcome,outcome.mScore); } return eval; } /** * Returns a string-based representation of the classification * results. * * @return A string-based representation of the classification * results. */ public String toString() { StringBuffer sb = new StringBuffer(); sb.append("CLASSIFIER EVALUATION\n"); mConfusionMatrix.toStringGlobal(sb); if (mHasRanked) { sb.append("Average Reference Rank=" + averageRankReference() + "\n"); } if (mHasScored) { sb.append("Average Score Reference=" + averageScoreReference() + "\n"); } if (mHasConditional) { sb.append("Average Conditional Probability Reference=" + averageConditionalProbabilityReference() + "\n"); } if (mHasJoint) { sb.append("Average Log2 Joint Probability Reference=" + averageLog2JointProbabilityReference() + "\n"); } sb.append("ONE VERSUS ALL EVALUATIONS BY CATEGORY\n"); for (int i = 0; i < categories().length; ++i) { String category = categories()[i]; sb.append("\nCATEGORY[" + i + "]=" + category + "\n"); sb.append("First-Best Precision/Recall Evaluation\n"); sb.append(oneVersusAll(category)); sb.append("\n"); if (mHasRanked) { sb.append("Rank Histogram=\n"); appendCategoryLine(sb); for (int rank = 0; rank < numCategories(); ++rank) { if (rank > 0) sb.append(','); sb.append(mRankCounts[i][rank]); } sb.append("\n"); sb.append("Average Rank Histogram=\n"); appendCategoryLine(sb); for (int j = 0; j < numCategories(); ++j) { if (j > 0) sb.append(','); sb.append(averageRank(category,categories()[j])); } sb.append("\n"); } if (mHasScored) { sb.append("Scored One Versus All\n"); sb.append(scoredOneVersusAll(category).toString() + "\n"); sb.append("Average Score Histogram=\n"); appendCategoryLine(sb); for (int j = 0; j < numCategories(); ++j) { if (j > 0) sb.append(','); sb.append(averageScore(category,categories()[j])); } sb.append("\n"); } if (mHasConditional) { sb.append("Conditional One Versus All\n"); sb.append(conditionalOneVersusAll(category).toString() + "\n"); sb.append("Average Conditional Probability Histogram=\n"); appendCategoryLine(sb); for (int j = 0; j < numCategories(); ++j) { if (j > 0) sb.append(','); sb.append(averageConditionalProbability(category, categories()[j])); } sb.append("\n"); } if (mHasJoint) { sb.append("Average Joint Probability Histogram=\n"); appendCategoryLine(sb); for (int j = 0; j < numCategories(); ++j) { if (j > 0) sb.append(','); sb.append(averageLog2JointProbability(category, categories()[j])); } sb.append("\n"); } } return sb.toString(); } void appendCategoryLine(StringBuffer sb) { sb.append(" "); for (int i = 0; i < numCategories(); ++i) { if (i > 0) sb.append(','); sb.append(categories()[i]); } sb.append("\n "); } private void validateCategory(String category) { if (mCategorySet.contains(category)) return; String msg = "Unknown category=" + category; throw new IllegalArgumentException(msg); } /** * NEEDS HEADERS AND ESCAPES. */ void rankHistogramToCSV(StringBuffer sb) { for (int i = 0; i < numCategories(); ++i) { if (i > 0) sb.append('\n'); for (int rank = 0; rank < numCategories(); ++rank) { if (rank > 0) sb.append(','); sb.append(mRankCounts[i][rank]); } } } double averageRankReference(int i) { double sum = 0.0; int count = 0; for (int rank = 0; rank < numCategories(); ++rank) { int rankCount = mRankCounts[i][rank]; if (rankCount == 0) continue; count += rankCount; sum += rank * rankCount; } return sum / (double) count; } int categoryToIndex(String category) { int result = mConfusionMatrix.getIndex(category); if (result < 0) { String msg = "Unknown category=" + category; throw new IllegalArgumentException(msg); } return result; } int rankCount(int categoryIndex, int rank) { return mRankCounts[categoryIndex][rank]; } /** * Adds the specified classification as a response for the specified * reference category. * * @param referenceCategory Reference category for case. * @param classification Response classification for case. */ public void addClassification(String referenceCategory, Classification classification) { mConfusionMatrix.increment(referenceCategory, classification.bestCategory()); mReferenceCategories.add(referenceCategory); mClassifications.add(classification); ++mNumCases; if (classification instanceof RankedClassification) { mHasRanked = true; addRanking(referenceCategory, (RankedClassification) classification); } if (classification instanceof ScoredClassification) { mHasScored = true; addScoring(referenceCategory, (ScoredClassification) classification); } if (classification instanceof ConditionalClassification) { mHasConditional = true; addConditioning(referenceCategory, (ConditionalClassification) classification); } if (classification instanceof JointClassification) { mHasJoint = true; } } final int numCategories() { return mConfusionMatrix.numCategories(); } void addRanking(String refCategory, RankedClassification ranking) { updateRankHistogram(refCategory,ranking); } private void updateRankHistogram(String refCategory, RankedClassification ranking) { int refCategoryIndex = categoryToIndex(refCategory); if (ranking.size() < numCategories()) mDefectiveRanking = true; for (int rank = 0; rank < numCategories() && rank < ranking.size(); ++rank) { String category = ranking.category(rank); if (category.equals(refCategory)) { ++ mRankCounts[refCategoryIndex][rank]; return; } } // assume the reference has last rank ++mRankCounts[refCategoryIndex][mCategories.length-1]; } private void addScoring(String refCategory, ScoredClassification scoring) { // will this rank < scoring.size() mess up eval? if (scoring.size() < numCategories()) mDefectiveScoring = true; for (int rank = 0; rank < numCategories() && rank < scoring.size(); ++rank) { double score = scoring.score(rank); String category = scoring.category(rank); int categoryIndex = categoryToIndex(category); boolean match = category.equals(refCategory); ScoreOutcome outcome = new ScoreOutcome(score,match,rank==0); mScoreOutcomeLists[categoryIndex].add(outcome); } } private void addConditioning(String refCategory, ConditionalClassification scoring) { if (scoring.size() < numCategories()) mDefectiveConditioning = true; for (int rank = 0; rank < numCategories() && rank < scoring.size(); ++rank) { double score = scoring.conditionalProbability(rank); String category = scoring.category(rank); int categoryIndex = categoryToIndex(category); boolean match = category.equals(refCategory); ScoreOutcome outcome = new ScoreOutcome(score,match,rank==0); mConditionalOutcomeLists[categoryIndex].add(outcome); } } static class ScoreOutcome implements Scored { private final double mScore; private final boolean mOutcome; private final boolean mFirstBest; public ScoreOutcome(double score, boolean outcome, boolean firstBest) { mOutcome = outcome; mScore = score; mFirstBest = firstBest; } public double score() { return mScore; } public String toString() { return "(" + mScore + ": " + mOutcome + "firstBest=" + mFirstBest + ")"; } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -