corefclusteradv.java
来自「mallet是自然语言处理、机器学习领域的一个开源项目。」· Java 代码 · 共 1,911 行 · 第 1/5 页
JAVA
1,911 行
Citation citation1 = (Citation)origNodePair.getObject1(); Citation citation2 = (Citation)origNodePair.getObject2(); Collection nBest1 = citation1.getNBest(); Collection nBest2 = citation2.getNBest(); String label; if (origNodePair.getIdRel()) label = "yes"; else label = "no"; Object name = origInstPair.getName(); Object source = origInstPair.getSource(); Pipe pipe = origInstPair.getPipe(); if (nBest1 == null || nBest2 == null) { //System.err.println("Did not find n-best, using original"); Instance instPair = new Instance (origNodePair, label, name, source, pipe); double score = computeScore(classifier, instPair); int i1 = citation1.getIndex(); int i2 = citation2.getIndex(); if (score < 0.0 && label == "yes") { System.out.println(i1 + " " + i2 + " " + score + " " + label); falseNegatives++; } else if (score > 0.0 && label == "no") { System.out.println(i1 + " " + i2 + " " + score + " " + label); falsePositives++; } return score; } List scores = new ArrayList(); int i = 0, j = 0; // make a new instance which is all pairs of two sets of n citations for (Iterator iterator = nBest1.iterator(); iterator.hasNext();) { j=0; Citation nbest_citation1 = (Citation) iterator.next(); for (Iterator iterator2 = nBest2.iterator(); iterator2.hasNext();) { Citation nbest_citation2 = (Citation) iterator2.next(); NodePair nodePair = new NodePair(nbest_citation1, nbest_citation2, origNodePair.getIdRel()); Instance instPair = new Instance (nodePair, label, name, null, pipe); // System.out.println(i + ", " + j); // System.out.println(nbest_citation1.rawstring); // System.out.println(nbest_citation2.rawstring); double score = computeScore(classifier, instPair); // System.out.println(score); // System.out.println(i + ", " + j + ": " + score); if (confidenceWeightedScores) { double weight = (nbest_citation1.getConfidenceScore() * nbest_citation2.getConfidenceScore()); scores.add(new Double(score*weight)); } else { scores.add(new Double(score)); } j++; } i++; } // this version actually returns the highest possible edge when they // _are_ coreferential and the lowest posslbe edge otherwise double optEdgeWeight; if (useOptimal) { if (origInstPair.getTarget().toString().equals("yes")) { optEdgeWeight = ((Double)Collections.max(scores)).doubleValue(); //System.out.println("HIGHEST edge weight " + optEdgeWeight + " selected from " // + scores); } else { optEdgeWeight = ((Double)Collections.min(scores)).doubleValue(); //System.out.println("LOWEST edge weight " + optEdgeWeight + " selected from " // + scores); } } else { // just average all the pair-wise scores for now // eventually do something sophisticated here optEdgeWeight = collectionAvg (scores); } return optEdgeWeight; } protected double collectionAvg (Collection collection) { double sum = 0.0; for (Iterator it = collection.iterator(); it.hasNext(); ) { sum += ((Double)it.next()).doubleValue(); } return sum/(double)collection.size(); } protected boolean hasNextIndexList(int[] indexList, int N) { for(int i=0; i<indexList.length; i++){ if(indexList[i] < N-1) return true; } return false; } protected int[] nextIndexList(int[] indexList, int N) { for(int i=indexList.length-1; i>=0; i--){ if(indexList[i] <= N-2){ indexList[i] ++; for(int j=i+1; j<=indexList.length-1;j++){ indexList[j] = 0; } break; } } return indexList; } protected int[] nextIndexListStochastic (int[] indexList, int N) { java.util.Random r = new java.util.Random(); int v = 0; for (int i=0; i < indexList.length; i++) { v = r.nextInt(N); indexList[i] = v; } return indexList; } public double weightOfConfig (int [] indexList, List instList) { double score = 0.0; for (int i=0; i<indexList.length; i++) { for (int j=i; j > 0; j--) { NodePair p = new NodePair((Citation)instList.get(i),(Citation)instList.get(j)); Instance inst = new Instance(p, "yes", null, p, this.pipe); score += computeScore_NBest(meClassifier, inst, indexList[i], indexList[j]); } } return score; } public void updateGraphNBest (WeightedGraph graph, int [] indexList, List instList, HashMap alreadyAdded) { for (int i=0; i<indexList.length; i++) { for (int j=i+1; j < indexList.length; j++) { Object c1 = instList.get(i); Object c2 = instList.get(j); NodePair p = new NodePair((Citation)c1,(Citation)c2); Instance inst = new Instance(p, "yes", null, p, this.pipe); constructEdgesUsingTrainedClusterer (graph, inst, alreadyAdded, new Double (computeScore_NBest (meClassifier, inst, indexList[i], indexList[j])) ); } } } public WeightedGraph constructOptimalEdgesUsingNBest (List mentions, int N) { WeightedGraph graph = new WeightedGraphImpl(); HashMap alreadyAddedVertices = new HashMap(); // keep track of HashMap bestCitationMap = new HashMap(); for (Iterator iter = keyPartitioning.iterator(); iter.hasNext();) { Collection cluster = (Collection)iter.next(); // get key cluster List instList = Collections.list(Collections.enumeration(cluster)); int [] indexList = new int[instList.size()]; for (int j=0; j < indexList.length; j++) indexList[j] = 0; // initialize int[] optimalIndexList = (int[])indexList.clone(); double highestWeight = weightOfConfig(indexList, instList); //System.out.println("Cluster size: " + cluster.size()); int numCombinations = (int)Math.pow((double)N,(double)cluster.size()); if (numCombinations > 16000) { for (int k=0; k < 4000; k++) { indexList = nextIndexList(indexList, N); double weight = weightOfConfig(indexList, instList); if( weight > highestWeight ){ highestWeight = weight; optimalIndexList = (int[])indexList.clone(); } } } else { while (hasNextIndexList(indexList, N)) { indexList = nextIndexList(indexList, N); double weight = weightOfConfig(indexList, instList); //printIList(indexList); //System.out.println(" -> " + weight); if( weight > highestWeight ){ highestWeight = weight; optimalIndexList = (int[])indexList.clone(); } } } for (int j=0; j < optimalIndexList.length; j++) { bestCitationMap.put(instList.get(j),new Integer(optimalIndexList[j])); } updateGraphNBest (graph, optimalIndexList, instList, alreadyAddedVertices); } addVerticesToGraph(graph, mentions, alreadyAddedVertices); //printGraph(graph); completeGraphNBest (graph, keyPartitioning, bestCitationMap); return graph; } private void printIList (int [] list) { for (int i=0; i < list.length; i++) { System.out.print(" " + list[i]); } } public void completeGraphNBest (WeightedGraph graph, Collection keyPartitioning, Map citMap) { HashMap m1 = new HashMap(); // map from objects to their vertices Set vs = graph.getVertexSet(); for (Iterator iter = vs.iterator(); iter.hasNext();) { VertexImpl v = (VertexImpl)iter.next(); if ((v.getObject() instanceof List) && ((List)v.getObject()).size() == 1) { Object o = ((List)v.getObject()).get(0); m1.put(o,v); } } List kList = Collections.list(Collections.enumeration(keyPartitioning)); for (int i=0; i < kList.size(); i++) { Collection c1 = (Collection)kList.get(i); for (int j=i+1; j < kList.size(); j++) { Collection c2 = (Collection)kList.get(j); for (Iterator i1 = c1.iterator(); i1.hasNext();) { Citation cit1 = (Citation)i1.next(); VertexImpl v1 = (VertexImpl)m1.get((Object)cit1); for (Iterator i2 = c2.iterator(); i2.hasNext();) { Citation cit2 = (Citation)i2.next(); VertexImpl v2 = (VertexImpl)m1.get((Object)cit2); NodePair np = new NodePair (cit1, cit2); Instance inst = new Instance (np, "no", null, np, this.pipe); double eval = computeScore_NBest(meClassifier, inst, ((Integer)citMap.get(cit1)).intValue(), ((Integer)citMap.get(cit2)).intValue()); try { graph.addEdge (v1, v2, eval); } catch (Exception e) {e.printStackTrace();} } } } } } protected void constructEdgesFromPseudoEdges (WeightedGraph graph, PseudoEdge pedge, HashMap alreadyAdded ) { PseudoVertex pv1 = pedge.getV1(); PseudoVertex pv2 = pedge.getV2(); Object node1 = pv1.getObject(); Object node2 = pv2.getObject(); VertexImpl v1 = (VertexImpl)alreadyAdded.get(pv1); VertexImpl v2 = (VertexImpl)alreadyAdded.get(pv2); if (v1 == null) { ArrayList a1 = new ArrayList(); a1.add(node1); v1 = new VertexImpl(a1); alreadyAdded.put(node1,v1); } if (v2 == null) { ArrayList a2 = new ArrayList(); a2.add(node2); v2 = new VertexImpl(a2); alreadyAdded.put(node2,v2); } try { graph.addEdge (v1, v2, pedge.getWeight()); } catch (Exception e) {e.printStackTrace();} } public void constructEdgesUsingTrainedClusterer (WeightedGraph graph, Instance instPair, HashMap alreadyAdded) { constructEdgesUsingTrainedClusterer (graph, instPair, alreadyAdded, null); } public void constructEdgesUsingTrainedClusterer (WeightedGraph graph, Instance instPair, HashMap alreadyAdded, Double edgeWeight) { constructEdgesUsingTrainedClusterer (graph, instPair, alreadyAdded, edgeWeight, this.meClassifier); } public void constructEdgesUsingTrainedClusterer (WeightedGraph graph, Instance instPair, HashMap alreadyAdded, Double edgeWeight, MaxEnt classifier) { NodePair mentionPair = (NodePair)instPair.getSource(); Object node1 = mentionPair.getObject1(); Object node2 = mentionPair.getObject2(); VertexImpl v1 = (VertexImpl)alreadyAdded.get(node1); VertexImpl v2 = (VertexImpl)alreadyAdded.get(node2); if (v1 == null) { ArrayList a1 = new ArrayList(); a1.add(node1); v1 = new VertexImpl(a1); alreadyAdded.put(node1,v1); } if (v2 == null) { ArrayList a2 = new ArrayList(); a2.add(node2); v2 = new VertexImpl(a2); alreadyAdded.put(node2,v2); } double edgeVal = 0.0; double edgeVal2; if (edgeWeight == null) { if (classifier != null) { if (useNBestInference) { Classification classification = (Classification)classifier.classify(instPair); Labeling labeling = classification.getLabeling(); edgeVal = computeScore_NBest(classifier, instPair); if (labeling.labelAtLocation(0).toString().equals("no")) { edgeVal2 = labeling.valueAtLocation(1)-labeling.valueAtLocation(0); } else { edgeVal2 = labeling.valueAtLocation(0)-labeling.valueAtLocation(1); } if ((edgeVal > 0 && edgeVal2 < 0) || (edgeVal < 0 && edgeVal2 > 0)) System.out.println(" " + edgeVal + " (" + edgeVal2 + ")"); } else { // Include the feature weights according to each label Classification classification = (Classification)classifier.classify(instPair); Labeling labeling = classification.getLabeling(); //classifier.getUnnormalizedClassificationScores(instPair, scores); if (labeling.labelAtLocation(0).toString().equals("no")) { edgeVal = labeling.valueAtLocation(1)-labeling.valueAtLocation(0); } else { edgeVal = labeling.valueAtLocation(0)-labeling.valueAtLocation(1); } } } else if (sgdParameters != null) { FeatureVector fv = (FeatureVector)instPair.getData(); double scores [] = new double[2]; getUnNormalizedScores (sgdParameters, fv, scores); edgeVal = scores[1] - scores[0]; } } else { edgeVal = edgeWeight.doubleValue(); } try { if (node1 != null && node2 != null) { graph.addEdge (v1, v2, edgeVal); } } catch (Exception e) {e.printStackTrace();} } public void getUnNormalizedScores (Matrix2 lambdas, FeatureVector fv, double[] scores) { for (int li = 0; li < 2; li++) { scores[li] = lambdas.value (li, numSGDFeatures) + lambdas.rowDotProduct (li, fv, numSGDFeatures,null); } } private double computeScore(MaxEnt classifier, Instance instPair) { // Include the feature weights according to each label Classification classification = (Classification)classifier.classify(instPair); Labeling labeling = classification.getLabeling(); /** NOTE: THIS ASSUMES THERE ARE JUST TWO LABELS - IE CLASSIFIER IS * BINARY */ double score = 0.0; if (labeling.labelAtLocation(0).toString().equals("no")) { score = labeling.valueAtLocation(1)-labeling.valueAtLocation(0); } else { score = labeling.valueAtLocation(0)-labeling.valueAtLocation(1); } return score; }}
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?