📄 corefclusteradv.java
字号:
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).http://www.cs.umass.edu/~mccallum/malletThis software is provided under the terms of the Common Public License,version 1.0, as published by http://www.opensource.org. For furtherinformation, see the file `LICENSE' included with this distribution. *//** @author Ben Wellner */package edu.umass.cs.mallet.projects.seg_plus_coref.coreference;import edu.umass.cs.mallet.projects.seg_plus_coref.clustering.*;import edu.umass.cs.mallet.projects.seg_plus_coref.graphs.*;import salvo.jesus.graph.*;import salvo.jesus.graph.VertexImpl;import edu.umass.cs.mallet.base.types.*;import edu.umass.cs.mallet.base.classify.*;import edu.umass.cs.mallet.base.pipe.*;import edu.umass.cs.mallet.base.pipe.iterator.*;import edu.umass.cs.mallet.base.util.*;import java.util.*;import java.lang.*;import java.io.*;/*An object of this class will allow an InstanceList as well as a List ofmentions to be passed to a method that will return a list of listsrepresenting the partitioning of the mentions into clusters/equiv. classes.There should be exactly M choose 2 instances in the InstanceList where M isthe size of the List of mentions (assuming a complete graph).*/public class CorefClusterAdv { private double MAX_ITERS = 30; // max number of typical cluster // iterations private static boolean export_graph = true; private static int falsePositives = 0; private static int falseNegatives = 0; Collection keyPartitioning = null; // key partitioning - optional for // evaluation against test clusters // within search procedure private double MAX_REDUCTIONS = 10; // max number of reductions // allowed Pipe pipe; boolean trueNumStop = false; // if true, only stop once true number of // clusters achieved boolean useOptimal = false; // if true, cheat boolean useNBestInference = false; // true if we should use the Greedy // N-best method that Michael developed boolean fullPartition = false; boolean confidenceWeightedScores = false; int rBeamSize = 10; private final double NegativeInfinite = -1000000000; MaxEnt meClassifier = null; Matrix2 sgdParameters = null; int numSGDFeatures = 0; double threshold = 0.0; TreeModel treeModel = null; WeightedGraph wgraph = null; // pointer to graph accesible from outside public CorefClusterAdv () {} public CorefClusterAdv (TreeModel tm) { this.treeModel = tm; } public CorefClusterAdv (Pipe p) { this.pipe = p; } public CorefClusterAdv (Pipe p, TreeModel tm) { this.pipe = p; this.treeModel = tm; } public CorefClusterAdv (double threshold) { this.threshold = threshold; } public CorefClusterAdv (double threshold, MaxEnt classifier, Pipe p) { this.threshold = threshold; this.meClassifier = classifier; this.pipe = p; } // version that has a tree model public CorefClusterAdv (double threshold, MaxEnt classifier, TreeModel tm, Pipe p) { this.threshold = threshold; this.meClassifier = classifier; this.treeModel = tm; this.pipe = p; } // in this case we don't have a MaxEnt classifier, but a set of parameters // learned via Stochastic Gradient Descent (or are being learned) public CorefClusterAdv (double threshold, Matrix2 sgdParameters, int numSGDFeatures) { this.threshold = threshold; this.sgdParameters = sgdParameters; this.numSGDFeatures = numSGDFeatures; } public void setConfWeightedScores (boolean b) { this.confidenceWeightedScores = b; } public void setRBeamSize (int s) { this.rBeamSize = s; } public void setOptimality (boolean b) { this.useOptimal = b; } public void setNBestInference (boolean b) { this.useNBestInference = b; } public void setTrueNumStop (boolean b) { this.trueNumStop = b; } public void setSearchParams (int iters, int reductions) { MAX_ITERS = iters; MAX_REDUCTIONS = reductions; } public void setThreshold (double t) { this.threshold = t; } public void setKeyPartitioning (Collection keyP) { this.keyPartitioning = keyP; } public void setFullPartition (boolean f) { this.fullPartition = f; } /* Initialize a list of lists where each inner list is a list with a single element. */ public void loadME (String file) { MaxEnt me = null; try { ObjectInputStream ois = new ObjectInputStream(new FileInputStream( file )); me = (MaxEnt)ois.readObject(); ois.close(); } catch (Exception e) {e.printStackTrace();} this.meClassifier = me; } public void train (InstanceList ilist) { // just to plain MaxEnt training for now System.out.println("Training NOW: "); MaxEnt me = (MaxEnt)(new MaxEntTrainer().train (ilist, null, null, null, null)); Alphabet alpha = ilist.getDataAlphabet(); alpha.stopGrowth(); // hack to prevent alphabet from growing Trial t = new Trial(me, ilist); System.out.println("CorefClusterAdv -> Training F1 on \"yes\" is: " + t.labelF1("yes")); //me.write(new File("/tmp/MaxEnt_Output")); this.meClassifier = me; } public void testClassifier (InstanceList tlist) { Trial t = new Trial(meClassifier, tlist); System.out.println("test accuracy: " + t.labelF1("yes")); for (Iterator it = tlist.iterator(); it.hasNext();) { Instance inst = (Instance)it.next(); Classification classification = (Classification)meClassifier.classify(inst); Labeling l = classification.getLabeling(); //System.out.println("Best label: " + l.getBestLabel().toString() + " " // + inst.getTarget().toString()); if (!l.getBestLabel().toString().equals(inst.getTarget().toString())) { Citation c1 = (Citation)((NodePair)inst.getSource()).getObject1(); Citation c2 = (Citation)((NodePair)inst.getSource()).getObject2(); if (inst.getLabeling().getBestLabel().toString().equals("yes")) { System.out.println("FN: " + c1.print() + " " + c1.getString() + "\n " + c2.print() + " " + c2.getString()); System.out.println("Citation venue: " + c1.getField("venue") + " --> " + c2.getField("venue")); } else if (inst.getLabeling().getBestLabel().toString().equals("no")) { System.out.println("FP: " + c1.print() + " " + c1.getString() + "\n " + c2.print() + " " + c2.getString()); System.out.println("Citation venue: " + c1.getField("venue") + " --> " + c2.getField("venue")); } System.out.println(printParamDetails((FeatureVector)inst.getData(), classification)); } } } public String printParamDetails (FeatureVector vec, Classification classification) { Labeling l = classification.getLabeling(); Alphabet dictionary = vec.getAlphabet(); int [] indices = vec.getIndices(); double [] values = vec.getValues(); double [] params = meClassifier.getParameters(); int paramsLength = params.length; int indicesLength = indices.length; int numParams = paramsLength/2; //assert (paramsLength == 2*indicesLength); //Thread.currentThread().dumpStack(); StringBuffer sb = new StringBuffer (); int valuesLength = vec.numLocations(); for (int i = 0; i < valuesLength; i++) { if (dictionary == null) sb.append ("["+i+"]"); else sb.append (dictionary.lookupObject(indices == null ? i : indices[i]).toString()); sb.append ("(" + indices[i] +")"); //sb.append ("(" + i +")"); sb.append ("="); sb.append (values[i]); if (l.labelAtLocation(0).toString().equals("no")) sb.append (" (" + (params[indices[i]+numParams]-params[indices[i]]) + ")"); else sb.append (" (" + (params[indices[i]]-params[indices[i]+numParams]) + ")"); sb.append ("\n"); } return sb.toString(); } public void printParams (MaxEnt me) { double[] parameters = me.getParameters(); int numFeatures = parameters.length/2; Matrix2 matrix2 = new Matrix2(parameters,2,numFeatures); for (int i=0; i<2; i++) { System.out.print(i + ": "); for (int j=0; j<numFeatures; j++) { System.out.print(j + "=" + matrix2.value(new int[] {i,j}) + " "); } System.out.println(); } } public MaxEnt getClassifier() {return meClassifier;} public Collection clusterMentions (InstanceList ilist, List mentions) { return clusterMentions (ilist, mentions, -1, true); } /* performance time method */ public Collection clusterMentions (InstanceList ilist, List mentions, int optimalNBest, boolean stochastic) { if (meClassifier != null || sgdParameters != null) { if (optimalNBest > 0) { System.out.println("Computing \"optimal\" edge weights using N-best lists to " + optimalNBest); WeightedGraph g = constructOptimalEdgesUsingNBest (mentions, optimalNBest); wgraph = g; System.out.println("!!!! Constructed Graph !!!!!"); //return null; if (stochastic) return partitionGraph (g); else return absoluteCluster (ilist, mentions); //return typicalClusterPartition(g); } else { if (fullPartition) { wgraph = createGraph(ilist, mentions); return partitionGraph(wgraph); } else if (stochastic) { return typicalClusterAdv (ilist, mentions); //return absoluteCluster (ilist, mentions); //wgraph = createGraph(ilist, mentions); //return partitionGraph (wgraph); } else { wgraph = createGraph(ilist, mentions); //return absoluteCluster (ilist, mentions); return typicalClusterPartition (wgraph); } } } else { return null; } } public WeightedGraph createGraph (InstanceList ilist, List mentions) { WeightedGraph graph = new WeightedGraphImpl(); HashMap alreadyAddedVertices = new HashMap(); // keep track of for (int i=0; i < ilist.size(); i++) { constructEdgesUsingTrainedClusterer(graph, ilist.getInstance(i), alreadyAddedVertices); } System.out.println("Finished building graph"); addVerticesToGraph(graph, mentions, alreadyAddedVertices); return graph; } // this just writes out the edges of a graph with one edge per line public void exportGraph (String file) { Set edges = wgraph.getEdgeSet(); try { BufferedWriter out = new BufferedWriter(new FileWriter(file)); for (Iterator it = edges.iterator(); it.hasNext(); ) { WeightedEdge e = (WeightedEdge)it.next(); VertexImpl v1 = (VertexImpl)e.getVertexA(); VertexImpl v2 = (VertexImpl)e.getVertexB(); Citation c1 = (Citation)((List)v1.getObject()).get(0); Citation c2 = (Citation)((List)v2.getObject()).get(0); out.write(c1.print() + " " + c2.print() + " " + e.getWeight() + "\n"); } out.close(); } catch (IOException e){e.printStackTrace();}; } //shallow copy of graph - just need to keep the edges public WeightedGraph copyGraph (WeightedGraph graph) { WeightedGraph copy = new WeightedGraphImpl(); Set edgeSet = graph.getEdgeSet(); Iterator i1 = edgeSet.iterator(); HashMap map = new HashMap(); while (i1.hasNext()) { WeightedEdge e1 = (WeightedEdge)i1.next(); VertexImpl v1 = (VertexImpl)e1.getVertexA(); VertexImpl v2 = (VertexImpl)e1.getVertexB(); VertexImpl n1 = (VertexImpl)map.get(v1); VertexImpl n2 = (VertexImpl)map.get(v2); if (n1 == null) { Object o1 = v1.getObject(); ArrayList l1 = new ArrayList(); if (o1 instanceof List) for (int i=0; i < ((List)o1).size(); i++) l1.add(((List)o1).get(i)); else l1.add(o1); n1 = new VertexImpl(l1); map.put(v1,n1); } if (n2 == null) { Object o2 = v2.getObject(); ArrayList l2 = new ArrayList(); if (o2 instanceof List) for (int i=0; i < ((List)o2).size(); i++) l2.add(((List)o2).get(i)); else l2.add(o2); n2 = new VertexImpl(o2); map.put(v2,n2); } WeightedEdge ne = new WeightedEdgeImpl(n1, n2, e1.getWeight()); try { copy.addEdge(ne); } catch (Exception e) {e.printStackTrace();} } return copy; } public void addVerticesToGraph(WeightedGraph graph, List mentions, HashMap alreadyAddedVertices) { for (int i=0; i < mentions.size(); i++) { Object o = mentions.get(i); if (alreadyAddedVertices.get(o) == null) { // add only if it hasn't been // added List l = new ArrayList(); l.add(o); VertexImpl v = new VertexImpl(l); try { graph.add(v); // add the vertex } catch (Exception e) {e.printStackTrace();} } } } private WeightedEdge chooseEdge3 (List edges, double minVal, double total, java.util.Random rand) { if (edges.size() > 0) { return (WeightedEdge)edges.get(0); } else return null; } // simpler more heuristic-based approach private WeightedEdge chooseEdge2 (List edges, double minVal, double total, java.util.Random rand) { //return (WeightedEdge)edges.first(); if (edges.size() < 1) return null; int x = rand.nextInt(10); if (x > edges.size()) x = edges.size(); WeightedEdge e = null; Iterator i1 = edges.iterator(); int i=0; while (i1.hasNext() && i < x) { e = (WeightedEdge)i1.next(); i++; } if (e != null) return e; else return (WeightedEdge)edges.get(0); } /* Algorithm: Sort edges by magnitude. Scale so they're all positive. Choose a random number between 0 and the sum of all the magnitudes. Select an edge in this fashion. Merge the two vertices and */ private WeightedEdge chooseEdge (List edges, double minVal, double total, java.util.Random rand) { double x = rand.nextDouble() * total; // 0 < x < total double cur = 0.0; Iterator i1 = edges.iterator(); while (i1.hasNext()) { WeightedEdge e = (WeightedEdge)i1.next(); cur += (e.getWeight()-minVal); // SUBTRACT minVal if (cur > x) { return e; } } // this shouldn't really happend unless there is some kind if numerical // issues - default to the first edge return (WeightedEdge)edges.get(0); } private PseudoEdge choosePseudoEdge (List edges, java.util.Random rand) { if (edges.size() == 0) return null; double factor = Math.ceil(Math.log(edges.size()))*20; int x = rand.nextInt(10); if (x > edges.size()) x = edges.size(); PseudoEdge e = null; Iterator i1 = edges.iterator(); int i=0; while (i1.hasNext() && i < x) { e = (PseudoEdge)i1.next(); i++; } if (e != null) return e; else return (PseudoEdge)edges.get(0); } public double evaluatePartitioningExternal (InstanceList ilist, List mentions, Collection collection) {
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -