corefclusteradv.java
来自「mallet是自然语言处理、机器学习领域的一个开源项目。」· Java 代码 · 共 1,911 行 · 第 1/5 页
JAVA
1,911 行
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).http://www.cs.umass.edu/~mccallum/malletThis software is provided under the terms of the Common Public License,version 1.0, as published by http://www.opensource.org. For furtherinformation, see the file `LICENSE' included with this distribution. *//** @author Ben Wellner */package edu.umass.cs.mallet.projects.seg_plus_coref.coreference;import edu.umass.cs.mallet.projects.seg_plus_coref.clustering.*;import edu.umass.cs.mallet.projects.seg_plus_coref.graphs.*;import salvo.jesus.graph.*;import salvo.jesus.graph.VertexImpl;import edu.umass.cs.mallet.base.types.*;import edu.umass.cs.mallet.base.classify.*;import edu.umass.cs.mallet.base.pipe.*;import edu.umass.cs.mallet.base.pipe.iterator.*;import edu.umass.cs.mallet.base.util.*;import java.util.*;import java.lang.*;import java.io.*;/*An object of this class will allow an InstanceList as well as a List ofmentions to be passed to a method that will return a list of listsrepresenting the partitioning of the mentions into clusters/equiv. classes.There should be exactly M choose 2 instances in the InstanceList where M isthe size of the List of mentions (assuming a complete graph).*/public class CorefClusterAdv { double MAX_ITERS = 30; // max number of typical cluster // iterations private static boolean export_graph = true; private static int falsePositives = 0; private static int falseNegatives = 0; Collection keyPartitioning = null; // key partitioning - optional for // evaluation against test clusters // within search procedure double MAX_REDUCTIONS = 10; // max number of reductions // allowed Pipe pipe; boolean trueNumStop = false; // if true, only stop once true number of // clusters achieved boolean useOptimal = false; // if true, cheat boolean useNBestInference = false; // true if we should use the Greedy // N-best method that Michael developed boolean fullPartition = false; boolean confidenceWeightedScores = false; int rBeamSize = 10; final double NegativeInfinite = -1000000000; MaxEnt meClassifier = null; Matrix2 sgdParameters = null; int numSGDFeatures = 0; double threshold = 0.0; TreeModel treeModel = null; WeightedGraph wgraph = null; // pointer to graph accesible from outside public CorefClusterAdv () {} public CorefClusterAdv (TreeModel tm) { this.treeModel = tm; } public CorefClusterAdv (Pipe p) { this.pipe = p; } public CorefClusterAdv (Pipe p, TreeModel tm) { this.pipe = p; this.treeModel = tm; } public CorefClusterAdv (double threshold) { this.threshold = threshold; } public CorefClusterAdv (double threshold, MaxEnt classifier, Pipe p) { this.threshold = threshold; this.meClassifier = classifier; this.pipe = p; } // version that has a tree model public CorefClusterAdv (double threshold, MaxEnt classifier, TreeModel tm, Pipe p) { this.threshold = threshold; this.meClassifier = classifier; this.treeModel = tm; this.pipe = p; } // in this case we don't have a MaxEnt classifier, but a set of parameters // learned via Stochastic Gradient Descent (or are being learned) public CorefClusterAdv (double threshold, Matrix2 sgdParameters, int numSGDFeatures) { this.threshold = threshold; this.sgdParameters = sgdParameters; this.numSGDFeatures = numSGDFeatures; } public void setConfWeightedScores (boolean b) { this.confidenceWeightedScores = b; } public void setRBeamSize (int s) { this.rBeamSize = s; } public void setOptimality (boolean b) { this.useOptimal = b; } public void setNBestInference (boolean b) { this.useNBestInference = b; } public void setTrueNumStop (boolean b) { this.trueNumStop = b; } public void setSearchParams (int iters, int reductions) { MAX_ITERS = iters; MAX_REDUCTIONS = reductions; } public void setThreshold (double t) { this.threshold = t; } public void setKeyPartitioning (Collection keyP) { this.keyPartitioning = keyP; } public void setFullPartition (boolean f) { this.fullPartition = f; } /* Initialize a list of lists where each inner list is a list with a single element. */ public void loadME (String file) { MaxEnt me = null; try { ObjectInputStream ois = new ObjectInputStream(new FileInputStream( file )); me = (MaxEnt)ois.readObject(); ois.close(); } catch (Exception e) {e.printStackTrace();} this.meClassifier = me; } public void train (InstanceList ilist) { this.meClassifier = trainClassifier (ilist); } public MaxEnt trainClassifier (InstanceList ilist) { // just to plain MaxEnt training for now System.out.println("Training NOW: "); MaxEnt me = (MaxEnt)(new MaxEntTrainer().train (ilist, null, null, null, null)); Alphabet alpha = ilist.getDataAlphabet(); alpha.stopGrowth(); // hack to prevent alphabet from growing Trial t = new Trial(me, ilist); System.out.println("CorefClusterAdv -> Training F1 on \"yes\" is: " + t.labelF1("yes")); //me.write(new File("/tmp/MaxEnt_Output")); return me; } public void testClassifier (InstanceList tlist) { testClassifier (tlist, meClassifier); } public void testClassifier (InstanceList tlist, MaxEnt classifier) { Trial t = new Trial(classifier, tlist); System.out.println("test accuracy: " + t.labelF1("yes")); for (Iterator it = tlist.iterator(); it.hasNext();) { Instance inst = (Instance)it.next(); Classification classification = (Classification)classifier.classify(inst); Labeling l = classification.getLabeling(); //System.out.println("Best label: " + l.getBestLabel().toString() + " " // + inst.getTarget().toString()); if (!l.getBestLabel().toString().equals(inst.getTarget().toString())) { Citation c1 = (Citation)((NodePair)inst.getSource()).getObject1(); Citation c2 = (Citation)((NodePair)inst.getSource()).getObject2(); if (inst.getLabeling().getBestLabel().toString().equals("yes")) { System.out.println("FN: " + c1.print() + " " + c1.getString() + "\n " + c2.print() + " " + c2.getString()); System.out.println("Citation venue: " + c1.getField("venue") + " --> " + c2.getField("venue")); } else if (inst.getLabeling().getBestLabel().toString().equals("no")) { System.out.println("FP: " + c1.print() + " " + c1.getString() + "\n " + c2.print() + " " + c2.getString()); System.out.println("Citation venue: " + c1.getField("venue") + " --> " + c2.getField("venue")); } System.out.println(printParamDetails((FeatureVector)inst.getData(), classification, classifier)); } } } public String printParamDetails (FeatureVector vec, Classification classification, MaxEnt classifier) { Labeling l = classification.getLabeling(); Alphabet dictionary = vec.getAlphabet(); int [] indices = vec.getIndices(); double [] values = vec.getValues(); double [] params = classifier.getParameters(); int paramsLength = params.length; int indicesLength = indices.length; int numParams = paramsLength/2; //assert (paramsLength == 2*indicesLength); //Thread.currentThread().dumpStack(); StringBuffer sb = new StringBuffer (); int valuesLength = vec.numLocations(); for (int i = 0; i < valuesLength; i++) { if (dictionary == null) sb.append ("["+i+"]"); else sb.append (dictionary.lookupObject(indices == null ? i : indices[i]).toString()); sb.append ("(" + indices[i] +")"); //sb.append ("(" + i +")"); sb.append ("="); sb.append (values[i]); if (l.labelAtLocation(0).toString().equals("no")) sb.append (" (" + (params[indices[i]+numParams]-params[indices[i]]) + ")"); else sb.append (" (" + (params[indices[i]]-params[indices[i]+numParams]) + ")"); sb.append ("\n"); } return sb.toString(); } public void printParams (MaxEnt me) { double[] parameters = me.getParameters(); int numFeatures = parameters.length/2; Matrix2 matrix2 = new Matrix2(parameters,2,numFeatures); for (int i=0; i<2; i++) { System.out.print(i + ": "); for (int j=0; j<numFeatures; j++) { System.out.print(j + "=" + matrix2.value(new int[] {i,j}) + " "); } System.out.println(); } } public MaxEnt getClassifier() {return meClassifier;} public Collection clusterMentions (InstanceList ilist, List mentions) { return clusterMentions (ilist, mentions, -1, true); } /* performance time method */ public Collection clusterMentions (InstanceList ilist, List mentions, int optimalNBest, boolean stochastic) { if (meClassifier != null || sgdParameters != null) { if (optimalNBest > 0) { System.out.println("Computing \"optimal\" edge weights using N-best lists to " + optimalNBest); WeightedGraph g = constructOptimalEdgesUsingNBest (mentions, optimalNBest); wgraph = g; System.out.println("!!!! Constructed Graph !!!!!"); //return null; if (stochastic) return partitionGraph (g); else return absoluteCluster (ilist, mentions); //return typicalClusterPartition(g); } else { if (fullPartition) { wgraph = createGraph(ilist, mentions); return partitionGraph(wgraph); } else if (stochastic) { return typicalClusterAdv (ilist, mentions); //return absoluteCluster (ilist, mentions); //wgraph = createGraph(ilist, mentions); //return partitionGraph (wgraph); } else { wgraph = createGraph(ilist, mentions); //return absoluteCluster (ilist, mentions); return typicalClusterPartition (wgraph); } } } else { return null; } } public WeightedGraph createGraph (InstanceList ilist, List mentions) { return createGraph (ilist, mentions, new WeightedGraphImpl()); } public WeightedGraph createGraph (InstanceList ilist, List mentions, WeightedGraph graph) { return createGraph (ilist, mentions, graph, this.meClassifier); } public WeightedGraph createGraph (InstanceList ilist, List mentions, WeightedGraph graph, MaxEnt classifier) { HashMap alreadyAddedVertices = new HashMap(); // keep track of for (int i=0; i < ilist.size(); i++) { constructEdgesUsingTrainedClusterer(graph, ilist.getInstance(i), alreadyAddedVertices, null, classifier); } System.out.println("Finished building graph"); addVerticesToGraph(graph, mentions, alreadyAddedVertices); return graph; } // this just writes out the edges of a graph with one edge per line public void exportGraph (String file) { Set edges = wgraph.getEdgeSet(); try { BufferedWriter out = new BufferedWriter(new FileWriter(file)); for (Iterator it = edges.iterator(); it.hasNext(); ) { WeightedEdge e = (WeightedEdge)it.next(); VertexImpl v1 = (VertexImpl)e.getVertexA(); VertexImpl v2 = (VertexImpl)e.getVertexB(); Citation c1 = (Citation)((List)v1.getObject()).get(0); Citation c2 = (Citation)((List)v2.getObject()).get(0); out.write(c1.print() + " " + c2.print() + " " + e.getWeight() + "\n"); } out.close(); } catch (IOException e){e.printStackTrace();}; } //shallow copy of graph - just need to keep the edges public WeightedGraph copyGraph (WeightedGraph graph) { WeightedGraph copy = new WeightedGraphImpl(); Set edgeSet = graph.getEdgeSet(); Iterator i1 = edgeSet.iterator(); HashMap map = new HashMap(); while (i1.hasNext()) { WeightedEdge e1 = (WeightedEdge)i1.next(); VertexImpl v1 = (VertexImpl)e1.getVertexA(); VertexImpl v2 = (VertexImpl)e1.getVertexB(); VertexImpl n1 = (VertexImpl)map.get(v1); VertexImpl n2 = (VertexImpl)map.get(v2); if (n1 == null) { Object o1 = v1.getObject(); ArrayList l1 = new ArrayList(); if (o1 instanceof List) for (int i=0; i < ((List)o1).size(); i++) l1.add(((List)o1).get(i)); else l1.add(o1); n1 = new VertexImpl(l1); map.put(v1,n1); } if (n2 == null) { Object o2 = v2.getObject(); ArrayList l2 = new ArrayList(); if (o2 instanceof List) for (int i=0; i < ((List)o2).size(); i++) l2.add(((List)o2).get(i)); else l2.add(o2); n2 = new VertexImpl(o2); map.put(v2,n2); }
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?