📄 corefclusteradv.java
字号:
HashMap hm = new HashMap(); while (i1.hasNext()) { WeightedEdge e = (WeightedEdge) i1.next(); if ((e.getVertexA() == v1) && (e.getVertexB() != v2)) hm.put((Object) e.getVertexB(), new Double(e.getWeight())); else if (e.getVertexA() != v2) hm.put((Object) e.getVertexA(), new Double(e.getWeight())); } try { g.remove(v1); // this also removes all edges incident with this vertex } catch (Exception ex) { ex.printStackTrace(); } // System.out.println("Hashmap for vertex: " + v1 + " is: " + hm); List edges2 = (List)g.getEdges(v2); Iterator i2 = edges2.iterator(); Vertex cv = null; if (edges2.size() > 0) { while (i2.hasNext()) { WeightedEdge e = (WeightedEdge)i2.next(); if (e.getVertexA() == v2) cv = e.getVertexB(); else cv = e.getVertexA(); double w2 = ((Double)hm.get(cv)).doubleValue(); double w1 = e.getWeight(); //double weight = (w1 > w2) ? w1 : w2; // max double weight = (w1 + w2)/2; // avg if (w1 == NegativeInfinite || w2 == NegativeInfinite) { weight = NegativeInfinite; // precaution: avoid creeping away from Infinite } WeightedEdge ne = new WeightedEdgeImpl(newVertex, cv, weight); try { g.addEdge(ne); } catch (Exception ex) { ex.printStackTrace(); } } } else { // in this case, no edges are left, just add new Vertex try { g.add(newVertex); } catch (Exception ex) { ex.printStackTrace(); } } try { g.remove(v2); } catch (Exception ex) { ex.printStackTrace(); } // System.out.println("After adding new edges: " + g); } public void printGraph (WeightedGraph g) { Set vs = g.getVertexSet(); Iterator i = vs.iterator(); System.out.println("Vertices: " + vs); while (i.hasNext()) { VertexImpl v = (VertexImpl)i.next(); printVObj(v.getObject()); System.out.println(" "); } Set es = g.getEdgeSet(); Iterator i2 = es.iterator(); System.out.println("Edges: "); while (i2.hasNext()) { WeightedEdge e = (WeightedEdge)i2.next(); VertexImpl v1 = (VertexImpl)e.getVertexA(); VertexImpl v2 = (VertexImpl)e.getVertexB(); printVObj(v1.getObject()); System.out.print(" <----> (" + e.getWeight() + ") "); printVObj(v2.getObject()); System.out.println(""); } } private double computeScore_NBest (MaxEnt classifier, Instance origInstPair, int ind1, int ind2) { NodePair origNodePair = (NodePair)origInstPair.getSource(); Citation citation1 = (Citation)origNodePair.getObject1(); Citation citation2 = (Citation)origNodePair.getObject2(); Citation nBest1 = citation1.getNthBest(ind1); Citation nBest2 = citation2.getNthBest(ind2); NodePair newPair = new NodePair (nBest1, nBest2); Pipe pipe = origInstPair.getPipe(); return computeScore(classifier, new Instance(newPair, "yes", null, newPair, pipe)); } /** * This method assumes that the instance is a pair of Citation and that the * Citation objects have an N-best list for the n-best segmentations. The * edge value returned is simply the MAX score of any pair of citations. * * @param classifier * @param origInstPair * @return */ private double computeScore_NBest(MaxEnt classifier, Instance origInstPair) { NodePair origNodePair = (NodePair)origInstPair.getSource(); Citation citation1 = (Citation)origNodePair.getObject1(); Citation citation2 = (Citation)origNodePair.getObject2(); Collection nBest1 = citation1.getNBest(); Collection nBest2 = citation2.getNBest(); String label; if (origNodePair.getIdRel()) label = "yes"; else label = "no"; Object name = origInstPair.getName(); Object source = origInstPair.getSource(); Pipe pipe = origInstPair.getPipe(); if (nBest1 == null || nBest2 == null) { //System.err.println("Did not find n-best, using original"); Instance instPair = new Instance (origNodePair, label, name, source, pipe); double score = computeScore(classifier, instPair); int i1 = citation1.getIndex(); int i2 = citation2.getIndex(); if (score < 0.0 && label == "yes") { System.out.println(i1 + " " + i2 + " " + score + " " + label); falseNegatives++; } else if (score > 0.0 && label == "no") { System.out.println(i1 + " " + i2 + " " + score + " " + label); falsePositives++; } return score; } List scores = new ArrayList(); int i = 0, j = 0; // make a new instance which is all pairs of two sets of n citations for (Iterator iterator = nBest1.iterator(); iterator.hasNext();) { j=0; Citation nbest_citation1 = (Citation) iterator.next(); for (Iterator iterator2 = nBest2.iterator(); iterator2.hasNext();) { Citation nbest_citation2 = (Citation) iterator2.next(); NodePair nodePair = new NodePair(nbest_citation1, nbest_citation2, origNodePair.getIdRel()); Instance instPair = new Instance (nodePair, label, name, null, pipe); // System.out.println(i + ", " + j); // System.out.println(nbest_citation1.rawstring); // System.out.println(nbest_citation2.rawstring); double score = computeScore(classifier, instPair); // System.out.println(score); // System.out.println(i + ", " + j + ": " + score); if (confidenceWeightedScores) { double weight = (nbest_citation1.getConfidenceScore() * nbest_citation2.getConfidenceScore()); scores.add(new Double(score*weight)); } else { scores.add(new Double(score)); } j++; } i++; } // this version actually returns the highest possible edge when they // _are_ coreferential and the lowest posslbe edge otherwise double optEdgeWeight; if (useOptimal) { if (origInstPair.getTarget().toString().equals("yes")) { optEdgeWeight = ((Double)Collections.max(scores)).doubleValue(); //System.out.println("HIGHEST edge weight " + optEdgeWeight + " selected from " // + scores); } else { optEdgeWeight = ((Double)Collections.min(scores)).doubleValue(); //System.out.println("LOWEST edge weight " + optEdgeWeight + " selected from " // + scores); } } else { // just average all the pair-wise scores for now // eventually do something sophisticated here optEdgeWeight = collectionAvg (scores); } return optEdgeWeight; } protected double collectionAvg (Collection collection) { double sum = 0.0; for (Iterator it = collection.iterator(); it.hasNext(); ) { sum += ((Double)it.next()).doubleValue(); } return sum/(double)collection.size(); } protected boolean hasNextIndexList(int[] indexList, int N) { for(int i=0; i<indexList.length; i++){ if(indexList[i] < N-1) return true; } return false; } protected int[] nextIndexList(int[] indexList, int N) { for(int i=indexList.length-1; i>=0; i--){ if(indexList[i] <= N-2){ indexList[i] ++; for(int j=i+1; j<=indexList.length-1;j++){ indexList[j] = 0; } break; } } return indexList; } protected int[] nextIndexListStochastic (int[] indexList, int N) { java.util.Random r = new java.util.Random(); int v = 0; for (int i=0; i < indexList.length; i++) { v = r.nextInt(N); indexList[i] = v; } return indexList; } public double weightOfConfig (int [] indexList, List instList) { double score = 0.0; for (int i=0; i<indexList.length; i++) { for (int j=i; j > 0; j--) { NodePair p = new NodePair((Citation)instList.get(i),(Citation)instList.get(j)); Instance inst = new Instance(p, "yes", null, p, this.pipe); score += computeScore_NBest(meClassifier, inst, indexList[i], indexList[j]); } } return score; } public void updateGraphNBest (WeightedGraph graph, int [] indexList, List instList, HashMap alreadyAdded) { for (int i=0; i<indexList.length; i++) { for (int j=i+1; j < indexList.length; j++) { Object c1 = instList.get(i); Object c2 = instList.get(j); NodePair p = new NodePair((Citation)c1,(Citation)c2); Instance inst = new Instance(p, "yes", null, p, this.pipe); constructEdgesUsingTrainedClusterer (graph, inst, alreadyAdded, new Double (computeScore_NBest (meClassifier, inst, indexList[i], indexList[j])) ); } } } public WeightedGraph constructOptimalEdgesUsingNBest (List mentions, int N) { WeightedGraph graph = new WeightedGraphImpl(); HashMap alreadyAddedVertices = new HashMap(); // keep track of HashMap bestCitationMap = new HashMap(); for (Iterator iter = keyPartitioning.iterator(); iter.hasNext();) { Collection cluster = (Collection)iter.next(); // get key cluster List instList = Collections.list(Collections.enumeration(cluster)); int [] indexList = new int[instList.size()]; for (int j=0; j < indexList.length; j++) indexList[j] = 0; // initialize int[] optimalIndexList = (int[])indexList.clone(); double highestWeight = weightOfConfig(indexList, instList); //System.out.println("Cluster size: " + cluster.size()); int numCombinations = (int)Math.pow((double)N,(double)cluster.size()); if (numCombinations > 16000) { for (int k=0; k < 4000; k++) { indexList = nextIndexList(indexList, N); double weight = weightOfConfig(indexList, instList); if( weight > highestWeight ){ highestWeight = weight; optimalIndexList = (int[])indexList.clone(); } } } else { while (hasNextIndexList(indexList, N)) { indexList = nextIndexList(indexList, N); double weight = weightOfConfig(indexList, instList); //printIList(indexList); //System.out.println(" -> " + weight); if( weight > highestWeight ){ highestWeight = weight; optimalIndexList = (int[])indexList.clone(); } } } for (int j=0; j < optimalIndexList.length; j++) { bestCitationMap.put(instList.get(j),new Integer(optimalIndexList[j])); } updateGraphNBest (graph, optimalIndexList, instList, alreadyAddedVertices); } addVerticesToGraph(graph, mentions, alreadyAddedVertices); //printGraph(graph); completeGraphNBest (graph, keyPartitioning, bestCitationMap); return graph; } private void printIList (int [] list) { for (int i=0; i < list.length; i++) { System.out.print(" " + list[i]); } } public void completeGraphNBest (WeightedGraph graph, Collection keyPartitioning, Map citMap) { HashMap m1 = new HashMap(); // map from objects to their vertices Set vs = graph.getVertexSet(); for (Iterator iter = vs.iterator(); iter.hasNext();) { VertexImpl v = (VertexImpl)iter.next(); if ((v.getObject() instanceof List) && ((List)v.getObject()).size() == 1) { Object o = ((List)v.getObject()).get(0); m1.put(o,v); } } List kList = Collections.list(Collections.enumeration(keyPartitioning)); for (int i=0; i < kList.size(); i++) { Collection c1 = (Collection)kList.get(i); for (int j=i+1; j < kList.size(); j++) { Collection c2 = (Collection)kList.get(j); for (Iterator i1 = c1.iterator(); i1.hasNext();) { Citation cit1 = (Citation)i1.next(); VertexImpl v1 = (VertexImpl)m1.get((Object)cit1); for (Iterator i2 = c2.iterator(); i2.hasNext();) { Citation cit2 = (Citation)i2.next(); VertexImpl v2 = (VertexImpl)m1.get((Object)cit2); NodePair np = new NodePair (cit1, cit2); Instance inst = new Instance (np, "no", null, np, this.pipe); double eval = computeScore_NBest(meClassifier, inst, ((Integer)citMap.get(cit1)).intValue(), ((Integer)citMap.get(cit2)).intValue()); try { graph.addEdge (v1, v2, eval); } catch (Exception e) {e.printStackTrace();} } } } } } protected void constructEdgesFromPseudoEdges (WeightedGraph graph, PseudoEdge pedge, HashMap alreadyAdded ) { PseudoVertex pv1 = pedge.getV1(); PseudoVertex pv2 = pedge.getV2(); Object node1 = pv1.getObject(); Object node2 = pv2.getObject(); VertexImpl v1 = (VertexImpl)alreadyAdded.get(pv1); VertexImpl v2 = (VertexImpl)alreadyAdded.get(pv2); if (v1 == null) { ArrayList a1 = new ArrayList(); a1.add(node1); v1 = new VertexImpl(a1); alreadyAdded.put(node1,v1); } if (v2 == null) { ArrayList a2 = new ArrayList(); a2.add(node2); v2 = new VertexImpl(a2); alreadyAdded.put(node2,v2); } try { graph.addEdge (v1, v2, pedge.getWeight()); } catch (Exception e) {e.printStackTrace();} } public void constructEdgesUsingTrainedClusterer (WeightedGraph graph, Instance instPair, HashMap alreadyAdded) { constructEdgesUsingTrainedClusterer (graph, instPair, alreadyAdded, null); } public void constructEdgesUsingTrainedClusterer (WeightedGraph graph, Instance instPair, HashMap alreadyAdded, Double edgeWeight) { NodePair mentionPair = (NodePair)instPair.getSource(); Object node1 = mentionPair.getObject1(); Object node2 = mentionPair.getObject2(); VertexImpl v1 = (VertexImpl)alreadyAdded.get(node1); VertexImpl v2 = (VertexImpl)alreadyAdded.get(node2); if (v1 == null) { ArrayList a1 = new ArrayList(); a1.add(node1); v1 = new VertexImpl(a1); alreadyAdded.put(node1,v1); } if (v2 == null) { ArrayList a2 = new ArrayList(); a2.add(node2); v2 = new VertexImpl(a2); alreadyAdded.put(node2,v2); } double edgeVal = 0.0; double edgeVal2; if (edgeWeight == null) { if (meClassifier != null) { if (useNBestInference) { Classification classification = (Classification)meClassifier.classify(instPair); Labeling labeling = classification.getLabeling(); edgeVal = computeScore_NBest(meClassifier, instPair); if (labeling.labelAtLocation(0).toString().equals("no")) { edgeVal2 = labeling.valueAtLocation(1)-labeling.valueAtLocation(0); } else { edgeVal2 = labeling.valueAtLocation(0)-labeling.valueAtLocation(1); } if ((edgeVal > 0 && edgeVal2 < 0) || (edgeVal < 0 && edgeVal2 > 0)) System.out.println(" " + edgeVal + " (" + edgeVal2 + ")"); } else { // Include the feature weights according to each label Classification classification = (Classification)meClassifier.classify(instPair); Labeling labeling = classification.getLabeling(); //classifier.getUnnormalizedClassificationScores(instPair, scores); if (labeling.labelAtLocation(0).toString().equals("no")) { edgeVal = labeling.valueAtLocation(1)-labeling.valueAtLocation(0); } else { edgeVal = labeling.valueAtLocation(0)-labeling.valueAtLocation(1); } } } else if (sgdParameters != null) { FeatureVector fv = (FeatureVector)instPair.getData(); double scores [] = new double[2]; getUnNormalizedScores (sgdParameters, fv, scores); edgeVal = scores[1] - scores[0]; } } else { edgeVal = edgeWeight.doubleValue(); } try { if (node1 != null && node2 != null) { graph.addEdge (v1, v2, edgeVal); } } catch (Exception e) {e.printStackTrace();} } public void getUnNormalizedScores (Matrix2 lambdas, FeatureVector fv, double[] scores) { for (int li = 0; li < 2; li++) { scores[li] = lambdas.value (li, numSGDFeatures) + lambdas.rowDotProduct (li, fv, numSGDFeatures,null); } } private double computeScore(MaxEnt classifier, Instance instPair) { // Include the feature weights according to each label Classification classification = (Classification)classifier.classify(instPair); Labeling labeling = classification.getLabeling(); /** NOTE: THIS ASSUMES THERE ARE JUST TWO LABELS - IE CLASSIFIER IS * BINARY */ double score = 0.0; if (labeling.labelAtLocation(0).toString().equals("no")) { score = labeling.valueAtLocation(1)-labeling.valueAtLocation(0); } else { score = labeling.valueAtLocation(0)-labeling.valueAtLocation(1); } return score; }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -