📄 confusionmatrix.java
字号:
*/ public String[] categories() { return mCategories; } /** * Returns the number of categories for this confusion matrix. * The underlying two-dimensional matrix of counts for this * confusion matrix has dimensions equal to the number of * categories. Note that <code>numCategories()</code> is * guaranteed to be the same as <code>categories().length</code> * and thus may be used to compute iteration bounds. * @return The number of categories for this confusion matrix. */ public int numCategories() { return categories().length; } /** * Return the index of the specified category in the list of * categories, or <code>-1</code> if it is not a category for this * confusion matrix. The index is the index in the array * returned by {@link #categories()}. * * @param category Category whose index is returned. * @return The index of the specified category in the list of * categories. * @see #categories() */ public int getIndex(String category) { Integer index = (Integer) mCategoryToIndex.get(category); if (index == null) return -1; return index.intValue(); } /** * Return the matrix values. All values will be non-negative. * * @return The matrix values. */ public int[][] matrix() { return mMatrix; } /** * Add one to the cell in the matrix for the specified reference * and response category indices. * * @param referenceCategoryIndex Index of reference category. * @param responseCategoryIndex Index of response category. * @throws IllegalArgumentException If either index is out of range. */ public void increment(int referenceCategoryIndex, int responseCategoryIndex) { checkIndex("reference",referenceCategoryIndex); checkIndex("response",responseCategoryIndex); ++mMatrix[referenceCategoryIndex][responseCategoryIndex]; } /** * Add n to the cell in the matrix for the specified reference * and response category indices. * * @param referenceCategoryIndex Index of reference category. * @param responseCategoryIndex Index of response category. * @param num Number of instances to increment by. * @throws IllegalArgumentException If either index is out of range. */ public void incrementByN(int referenceCategoryIndex, int responseCategoryIndex, int num ) { checkIndex("reference",referenceCategoryIndex); checkIndex("response",responseCategoryIndex); mMatrix[referenceCategoryIndex][responseCategoryIndex] += num; } /** * Add one to the cell in the matrix for the specified reference * and response categories. * * @param referenceCategory Name of reference category. * @param responseCategory Name of response category. * @throws IllegalArgumentException If either category is * not a category for this confusion matrix. */ public void increment(String referenceCategory, String responseCategory) { increment(getIndex(referenceCategory),getIndex(responseCategory)); } /** * Returns the value of the cell in the matrix for the specified * reference and response category indices. * * @param referenceCategoryIndex Index of reference category. * @param responseCategoryIndex Index of response category. * @return Value of specified cell in the matrix. * @throws IllegalArgumentException If either index is out of range. */ public int count(int referenceCategoryIndex, int responseCategoryIndex) { checkIndex("reference",referenceCategoryIndex); checkIndex("response",responseCategoryIndex); return mMatrix[referenceCategoryIndex][responseCategoryIndex]; } /** * Returns the total number of classifications. This is just * the sum of every cell in the matrix: * * <blockquote><code> * totalCount() * = <big><big><big>&Sigma</big></big></big><sub><sub>i</sub></sub> * <big><big><big>&Sigma</big></big></big><sub><sub>j</sub></sub> * count(i,j) * </code></blockquote> * * @return The sum of the counts of the entries in the matrix. */ public int totalCount() { int total = 0; int len = numCategories(); for (int i = 0; i < len; ++i) for (int j = 0; j < len; ++j) total += mMatrix[i][j]; return total; } /** * Returns the total number of responses that matched the * reference. This is the sum of counts on the diagonal of the * matrix: * * <blockquote><code> * totalCorrect() * = <big><big><big>&Sigma</big></big></big><sub><sub>i</sub></sub> * count(i,i) * </code></blockquote> * * The value is the same as that of the * <code>microAverage().correctResponse()</code>> * * @return The sum of the correct results. */ public int totalCorrect() { int total = 0; int len = numCategories(); for (int i = 0; i < len; ++i) total += mMatrix[i][i]; return total; } /** * Returns the percentage of response that are correct. * That is: * * <blockquote><code> * totalAccuracy() = totalCorrect() / totalCount() * </code></blockquote> * * Note that the classification error is just one minus the * accuracy, because each answer is either true or false. * * @return The percentage of responses that match the reference. */ public double totalAccuracy() { return ((double) totalCorrect()) / (double) totalCount(); } /** * Returns half the width of the 95% confidence interval for this * confusion matrix. Thus the confidence is 95% that the accuracy * is the total accuracy plus or minus the return value of this method. * * <P>Confidence is determined as described in {@link #confidence(double)} * with parameter <code>z=1.96</code>. * * @return Half of the width of the 95% confidence interval. */ public double confidence95() { return confidence(1.96); } /** * Returns half the width of the 99% confidence interval for this * confusion matrix. Thus the confidence is 99% that the accuracy * is the total accuracy plus or minus the return value of this method. * * <P>Confidence is determined as described in {@link #confidence(double)} * with parameter <code>z=2.58</code>. * * @return Half of the width of the 99% confidence interval. */ public double confidence99() { return confidence(2.58); } /** * Returns the normal approximation of half of the binomial * confidence interval for this confusion matrix for the specified * z-score. * <P>A z score represents the number of standard deviations from * the mean, with the following correspondence of z score and * percentage confidence intervals: * * <blockquote><table border='1' cellpadding='5'> * <tr><td><i>Z</i></td> <td><i>Confidence +/- Z</i></td></tr> * <tr><td>1.65</td> <td>90%</td></tr> * <tr><td>1.96</td> <td>95%</td></tr> * <tr><td>2.58</td> <td>99%</td></tr> * <tr><td>3.30</td> <td>99.9%</td></tr> * </table></blockquote> * * Thus the z-score for a 95% confidence interval is 1.96 standard * deviations. The confidence interval is just the accuracy plus or minus * the z score times the standard deviation. * To compute the normal approximation to the deviation of the * binomial distribution, assume * <code>p=totalAccuracy()</code> and <code>n=totalCount()</code>. * Then the confidence interval is defined in terms of the deviation of * <code>binomial(p,n)</code>, which is defined by first taking * the variance of the Bernoulli (one trial) distribution with * success rate <code>p</code>: * * <blockquote><pre> * variance(bernoulli(p)) = p * (1-p) * </code></blockquote> * * and then dividing by the number <code>n</code> of trials in the * binomial distribution to get the variance of the binomial * distribution: * * <blockquote><pre> * variance(binomial(p,n)) = p * (1-p) / n * </code></blockquote> * * and then taking the square root to get the deviation: * * <blockquote><pre> * dev(binomial(p,n)) = sqrt(p * (1-p) / n) * </code></blockquote> * * For instance, with <code>p=totalAccuracy()=.90</code>, and * <code>n=totalCount()=10000</code>: * * <blockquote><code> * dev(binomial(.9,10000)) = sqrt(0.9 * (1.0 - 0.9) / 10000) = 0.003 * </code></blockquote> * * Thus to determine the 95% confidence interval, we take * <code>z = 1.96</code> for a half-interval width of * <code>1.96 * 0.003 = 0.00588</code>. The * resulting interval is just <code>0.90 +/- 0.00588</code> * or roughly <code>(.894,.906)</code>. * * @param z The z score, or number of standard deviations. * @return Half the width of the confidence interval for the specified * number of deviations. */ public double confidence(double z) { double p = totalAccuracy(); double n = totalCount(); return z * java.lang.Math.sqrt(p * (1-p) / n); } /** * The entropy of the decision problem itself as defined by the * counts for the reference. The entropy of a distribution is the * average negative log probability of outcomes. For the * reference distribution, this is: * * <code></blockquote> * referenceEntropy() * <br> = * - <big><big><big>Σ</big></big></big><sub><sub>i</sub></sub> * referenceLikelihood(i) * * log<sub><sub>2</sub></sub> referenceLikelihood(i) * <br><br> * referenceLikelihood(i) = oneVsAll(i).referenceLikelihood() * </code></blockquote> * * @return The entropy of the reference distribution. */ public double referenceEntropy() { double sum = 0.0; for (int i = 0; i < numCategories(); ++i) { double prob = oneVsAll(i).referenceLikelihood(); sum += prob * Math.log2(prob); } return -sum; } /** * The entropy of the response distribution. The entropy of a * distribution is the average negative log probability of * outcomes. For the response distribution, this is: * * <blockquote><code> * responseEntropy() * <br> = * - <big><big><big>Σ</big></big></big><sub><sub>i</sub></sub> * responseLikelihood(i) * * log<sub><sub>2</sub></sub> responseLikelihood(i) * <br><br> * responseLikelihood(i) = oneVsAll(i).responseLikelihood() * </code></blockquote> * * @return The entropy of the response distribution. */ public double responseEntropy() { double sum = 0.0; for (int i = 0; i < numCategories(); ++i) { double prob = oneVsAll(i).responseLikelihood(); sum += prob * Math.log2(prob); } return -sum; } /** * The cross-entropy of the response distribution against the * reference distribution. The cross-entropy is defined by the * negative log probabilities of the response distribution * weighted by the reference distribution: * * <blockquote><code> * crossEntropy() * <br> = * - <big><big><big>Σ</big></big></big><sub><sub>i</sub></sub> * referenceLikelihood(i) * * log<sub><sub>2</sub></sub> responseLikelihood(i) * <br><br> * responseLikelihood(i) = oneVsAll(i).responseLikelihood() * <br> * referenceLikelihood(i) = oneVsAll(i).referenceLikelihood() * </code></blockquote> * * Note that <code>crossEntropy() >= referenceEntropy()</code>. * The entropy of a distribution is simply the cross-entropy of * the distribution with itself. * * <P>Low cross-entropy does not entail good classification, * though good classification entails low cross-entropy. * * @return The cross-entropy of the response distribution * against the reference distribution. */ public double crossEntropy() { double sum = 0.0; for (int i = 0; i < numCategories(); ++i) { PrecisionRecallEvaluation eval = oneVsAll(i); double referenceProb = eval.referenceLikelihood(); double responseProb = eval.responseLikelihood(); sum += referenceProb * Math.log2(responseProb); } return -sum; } /** * Returns the entropy of the joint reference and response * distribution as defined by the underlying matrix. Joint * entropy is derfined by: * * <blockquote><code> * jointEntropy() * <br> * = - <big><big>Σ</big></big><sub><sub>i</sub></sub> * <big><big>Σ</big></big><sub><sub>j</sub></sub> * P'(i,j) * log<sub><sub>2</sub></sub> P'(i,j) * </code></blockquote> * * <blockquote><code> * P'(i,j) = count(i,j) / totalCount() * </code></blockquote> * * and where by convention: * * <blockquote><code> * 0 log<sub><sub>2</sub></sub> 0 =<sub><sub>def</sub></sub> 0 * </code></blockquote> * * @return Joint entropy of this confusion matrix. */ public double jointEntropy() { double totalCount = totalCount(); double entropySum = 0.0; for (int i = 0; i < numCategories(); ++i) { for (int j = 0; j < numCategories(); ++j) { double prob = ((double)count(i,j)) / totalCount;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -