📄 clusterscore.java
字号:
* * @return The set of true positives. */ public Set<Tuple<E>> truePositives() { Set<Tuple<E>> referenceEquivalences = toEquivalences(mReferencePartition); Set<Tuple<E>> responseEquivalences = toEquivalences(mResponsePartition); referenceEquivalences.retainAll(responseEquivalences); return referenceEquivalences; } /** * Returns the set of false positive relations for this scoring. * Each relation is an instance of {@link Tuple}. The false * positives will include both <code>(x,y)</code> and * <code>(y,x)</code> for a false positive relation between * <code>x</code> and <code>y</code>. * * @return The set of false positives. */ public Set<Tuple<E>> falsePositives() { Set<Tuple<E>> referenceEquivalences = toEquivalences(mReferencePartition); Set<Tuple<E>> responseEquivalences = toEquivalences(mResponsePartition); responseEquivalences.removeAll(referenceEquivalences); return responseEquivalences; } /** * Returns the set of false negative relations for this scoring. * Each relation is an instance of {@link Tuple}. The false * negative set will include both <code>(x,y)</code> and * <code>(y,x)</code> for a false negative relation between * <code>x</code> and <code>y</code>. * * @return The set of false negatives. */ public Set<Tuple<E>> falseNegatives() { Set<Tuple<E>> referenceEquivalences = toEquivalences(mReferencePartition); Set<Tuple<E>> responseEquivalences = toEquivalences(mResponsePartition); referenceEquivalences.removeAll(responseEquivalences); return referenceEquivalences; } private PrecisionRecallEvaluation calculateConfusionMatrix() { Set<Tuple<E>> referenceEquivalences = toEquivalences(mReferencePartition); Set<Tuple<E>> responseEquivalences = toEquivalences(mResponsePartition); Iterator<Tuple<E>> it = referenceEquivalences.iterator(); long tp = 0; long fn = 0; while (it.hasNext()) { if (responseEquivalences.remove(it.next())) ++tp; else ++fn; } long numElements = elementsOf(mReferencePartition).size(); long totalCount = numElements * numElements; long fp = responseEquivalences.size(); long tn = totalCount - tp - fn - fp; return new PrecisionRecallEvaluation(tp,fn,fp,tn); } /** * Returns a string representation of the statistics for this * score. The string includes the information in all of the * methods of this class: b3 scores by cluster and by element, * muc scores, and the precision-recall evaluation based on * equivalence. * * @return String-based representation of this score. */ public String toString() { StringBuffer sb = new StringBuffer(); sb.append("CLUSTER SCORE"); sb.append("\nEquivalence Evaluation\n"); sb.append(mPrEval.toString()); sb.append("\nMUC Evaluation"); sb.append("\n MUC Precision = " + mucPrecision()); sb.append("\n MUC Recall = " + mucRecall()); sb.append("\n MUC F(1) = " + mucF()); sb.append("\nB-Cubed Evaluation"); sb.append("\n B3 Cluster Averaged Precision = " + b3ClusterPrecision()); sb.append("\n B3 Cluster Averaged Recall = " + b3ClusterRecall()); sb.append("\n B3 Cluster Averaged F(1) = " + b3ClusterF()); sb.append("\n B3 Element Averaged Precision = " + b3ElementPrecision()); sb.append("\n B3 Element Averaged Recall = " + b3ElementRecall()); sb.append("\n B3 Element Averaged F(1) = " + b3ElementF()); return sb.toString(); } /** * Returns the within-cluster scatter measure for the specified * clustering with respect to the specified distance. The * within-cluster scatter is simply the sum of the scatters for * each set in the clustering; see {@link #scatter(Set,Distance)} * for a definition of scatter. * * <blockquote><pre> * withinClusterScatter(clusters,distance) * = <big>Σ</big><sub><sub>cluster in clusters</sub></sub> scatter(cluster,distance)</pre></blockquote> * * <p>As the number of clusters increases, the within-cluster * scatter decreases monotonically. Typically, this is used * to determine how many clusters to return, by inspecting * a plot of within-cluster scatter against number of clusters * and looking for a "knee" in the graph. * * @param clustering Clustering to evaluate. * @param distance Distance against which to evaluate. * @return The within-cluster scatter score. */ static public <E> double withinClusterScatter(Set<? extends Set<? extends E>> clustering, Distance<? super E> distance) { double scatter = 0.0; for (Set<? extends E> s : clustering) scatter += scatter(s,distance); return scatter; } /** * Returns the scatter for the specified cluster based on the * specified distance. The scatter is the sum of all of the * pairwise distances between elements, with each pair of elements * counted once. Abusing notation to use <code>xs[i]</code> for * the <code>i</code>th element returned by the set's iterator, ** scatter is defined by: * * <blockquote><pre> * scatter(xs,distance) * = <big>Σ</big><sub><sub>i</sub></sub> <big>Σ</big><sub><sub>j < i</sub></sub> distance(xs[i],xs[j])</pre></blockquote> * * Note that elements are not compared to themselves. This * presupposes a distance for which the distance of an element to * itself is zero and which is symmetric. * * @param cluster Cluster to evaluate. * @param distance Distance against which to evaluate. * @return The total scatter for the specified set. */ static public <E> double scatter(Set<? extends E> cluster, Distance<? super E> distance) { Object[] elements = cluster.toArray(); double scatter = 0.0; for (int i = 0; i < elements.length; ++i) for (int j = i+1; j < elements.length; ++j) scatter += distance.distance((E)elements[i],(E)elements[j]); return scatter; } // includes self-equivalences for completeness of counts Set<Tuple<E>> toEquivalences(Set<? extends Set<? extends E>> partition) { Set<Tuple<E>> equivalences = new HashSet<Tuple<E>>(); for (Set<? extends E> equivalenceClass : partition) { Object[] xs = new Object[equivalenceClass.size()]; equivalenceClass.toArray(xs); for (int i = 0; i < xs.length; ++i) for (int j = 0; j < xs.length; ++j) equivalences.add(Tuple.<E>create((E)xs[i],(E)xs[j])); } return equivalences; } private static double b3ElementRecall(Set referencePartition, Set responsePartition) { double score = 0.0; Set elementsOfReference = elementsOf(referencePartition); Iterator referenceEqClassIterator = referencePartition.iterator(); while (referenceEqClassIterator.hasNext()) { Set referenceEqClass = (Set) referenceEqClassIterator.next(); Iterator referenceEqClassEltIterator = referenceEqClass.iterator(); while (referenceEqClassEltIterator.hasNext()) { Object referenceEqClassElt = referenceEqClassEltIterator.next(); score += uniformElementWeight(elementsOfReference) * b3Recall(referenceEqClassElt, referenceEqClass,responsePartition); } } return score; } private static double uniformElementWeight(Set elements) { return 1.0 / (double) elements.size(); } private static double uniformClusterWeight(Set eqClass, Set partition) { return 1.0 / ((double) (eqClass.size() * partition.size())); } private static double b3ClusterRecall(Set referencePartition, Set responsePartition) { double score = 0.0; Iterator referenceEqClassIterator = referencePartition.iterator(); while (referenceEqClassIterator.hasNext()) { Set referenceEqClass = (Set) referenceEqClassIterator.next(); Iterator referenceEqClassEltIterator = referenceEqClass.iterator(); while (referenceEqClassEltIterator.hasNext()) { Object referenceEqClassElt = referenceEqClassEltIterator.next(); score += uniformClusterWeight(referenceEqClass,referencePartition) * b3Recall(referenceEqClassElt, referenceEqClass,responsePartition); } } return score; } private static double b3Recall(Object element, Set referenceEqClass, Set responsePartition) { Set responseClass = getEquivalenceClass(element,responsePartition); return recallSets(referenceEqClass,responseClass); } private static double recallSets(Set referenceSet, Set responseSet) { if (referenceSet.size() == 0) return 1.0; return ((double) intersectionSize(referenceSet,responseSet)) / (double) referenceSet.size(); } private static long intersectionSize(Set set1, Set set2) { long count = 0; Iterator it = set1.iterator(); while (it.hasNext()) { Object x = it.next(); if (set2.contains(x)) ++count; } return count; } private static void assertPartitionSameSets(Set set1, Set set2) { assertValidPartition(set1); assertValidPartition(set2); if (!elementsOf(set1).equals(elementsOf(set2))) { String msg = "Partitions must be of same sets."; throw new IllegalArgumentException(msg); } } private static void assertValidPartition(Set partition) { Iterator eqClasses = partition.iterator(); HashSet eltsSoFar = new HashSet(); while (eqClasses.hasNext()) { Set eqClass = (Set) eqClasses.next(); Iterator members = eqClass.iterator(); while (members.hasNext()) { if (!eltsSoFar.add(members.next())) throw new IllegalArgumentException( "Partitions must not contain overlapping members."); } } } private static Set toPartition(Set[] equivalences) { HashSet partition = new HashSet(); Collections.addAll(partition,equivalences); return partition; } private static Set getEquivalenceClass(Object element, Set partition) { Iterator it = partition.iterator(); while (it.hasNext()) { Set equivalenceClass = (Set) it.next(); if (equivalenceClass.contains(element)) return equivalenceClass; } throw new IllegalArgumentException( "Element must be in an equivalence class in partition."); } private static Set elementsOf(Set partition) { HashSet elements = new HashSet(); Iterator it = partition.iterator(); while (it.hasNext()) { elements.addAll((Set)it.next()); } return elements; } private static double f(double precision, double recall) { return 2.0 * precision * recall / (precision + recall); } private static double mucRecall(Set referencePartition, Set responsePartition) { long numerator = 0; long denominator = 0; Iterator referenceEqClassesIt = referencePartition.iterator(); while (referenceEqClassesIt.hasNext()) { Set referenceEqClass = (Set) referenceEqClassesIt.next(); long numPartitions = 0; Iterator responseEqClasss = responsePartition.iterator(); while (responseEqClasss.hasNext()) { Set responseEqClass = (Set) responseEqClasss.next(); if (Collections.intersects(referenceEqClass,responseEqClass)) ++numPartitions; } numerator += referenceEqClass.size() - numPartitions; denominator += referenceEqClass.size() - 1; } if (denominator == 0) return 1.0; return ((double) numerator) / (double) denominator; }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -