📄 minimizedisagreementsclustering.java
字号:
/* Copyright (C) 2002 Dept. of Computer Science, Univ. of Massachusetts, Amherst This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This program toolkit free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. For more details see the GNU General Public License and the file README-LEGAL. You should have received a copy of the GNU General Public License along with this program; if not, write to the Free Software Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. *//** @author Ben Wellner */package edu.umass.cs.mallet.projects.seg_plus_coref.graphs;import edu.umass.cs.mallet.projects.seg_plus_coref.anaphora.*;import edu.umass.cs.mallet.projects.seg_plus_coref.clustering.*;import java.util.*;import java.io.*;import java.text.*;import java.lang.reflect.Array;import salvo.jesus.graph.*;import salvo.jesus.graph.visual.*;import salvo.jesus.graph.visual.layout.*;public class MinimizeDisagreementsClustering{ final static double THRESHOLD = 7.0; WeightedGraph origGraph; SortedSet origVertices; // set of vertices in graph - unmodifiable Set clusters; // set of sets of vertices Random randGen; // will use this quite a bit in this class int graphSize; Object []vertexArray; double delta; double removeDelta; double addDelta; int goodsCnt = 0; int removesCnt = 0; HashMap vertexMap; // maps graph copys vertices to original graphs vertices Comparator comparator; MappedGraph mappedGraph = null; public int getGoods () { return goodsCnt; } public int getRemoves () { return removesCnt; } public MinimizeDisagreementsClustering (MappedGraph g, double d) { this (g.getGraph(), d); mappedGraph = g; } public MinimizeDisagreementsClustering (WeightedGraph g, double d) { origGraph = g; comparator = getComparator(g); origVertices = sortSet(origGraph.getVertexSet(),comparator); randGen = new Random(); graphSize = origVertices.size(); delta = d; removeDelta = delta * 3; addDelta = delta * 7; } private Comparator getComparator (Graph g) { Iterator i = g.getVerticesIterator(); VertexImpl v = null; if (i.hasNext()) { v = (VertexImpl)i.next(); } else { System.out.println("Warning: No Comparator"); } if ((v != null) && (v.getObject() instanceof Mention)) return new CompareMentionVertices(); else return new CompareVertices(); } private HashMap createVertexMap (Set v1, Set v2) { HashMap map = new HashMap(); Object[] a1 = v1.toArray(); Object[] a2 = v2.toArray(); for (int i=0; i < Array.getLength(a1); i++) { map.put(a1[i],a2[i]); } return map; } public Object deepCopy(Object oldObj) throws Exception { ObjectOutputStream oos = null; ObjectInputStream ois = null; try { ByteArrayOutputStream bos = new ByteArrayOutputStream(); // A oos = new ObjectOutputStream(bos); // B // serialize and pass the object oos.writeObject(oldObj); // C oos.flush(); // D ByteArrayInputStream bin = new ByteArrayInputStream(bos.toByteArray()); // E ois = new ObjectInputStream(bin); // F // return the new object return ois.readObject(); // G } catch(Exception e) { System.out.println("Exception in ObjectCloner = " + e); throw(e); } finally { if(oos != null) oos.close(); if(ois != null) ois.close(); } } private boolean allIntegers (Set vertices) { Iterator i = vertices.iterator(); while (i.hasNext()) { VertexImpl v = (VertexImpl)i.next(); if (!(v.getObject() instanceof Integer)) return false; } return true; } /***** Greedy aglomerative clustering implemented here.... */ public Clustering getClusteringGreedily () { double cost = 10000.0; boolean notFinished = true; WeightedGraph curGraph = null; // make copy of original graph try { if (allIntegers (origGraph.getVertexSet())) System.out.println("All vertex objects are integers"); else System.out.println("At least one vertex object is a non-integer"); curGraph = (WeightedGraph)deepCopy(origGraph); // copy graph, since we muck with it } catch (Exception e) {e.printStackTrace();} vertexMap = createVertexMap(sortSet(curGraph.getVertexSet(), comparator), origVertices); // create map between copy and original Set curSet = curGraph.getVertexSet(); Iterator i = curSet.iterator(); Clustering curClustering = new Clustering(); while (i.hasNext()) { Cluster cl = new Cluster(); cl.add(i.next()); curClustering.add(cl); } System.out.println("Initial clustering: "); curClustering.print(); while (cost > THRESHOLD) { cost = nextBestClustering(curClustering, curGraph); System.out.println("GRAPH COST: " + cost); curClustering.print(); } if (mappedGraph != null) { System.out.println("Remapping clusters:"); return remapClusters(curClustering); } else return curClustering; } public double nextBestClustering (Set curClustering, WeightedGraph graph) { double bestScore = 0.0; Cluster best1 = null; Cluster best2 = null; Object cArray[] = curClustering.toArray(); for (int i=0; i<curClustering.size(); i++) { for (int j=0; j < i; j++) { double curScore = evaluatePair((Cluster)cArray[i],(Cluster)cArray[j], graph); if (curScore > bestScore) { bestScore = curScore; best1 = (Cluster)cArray[i]; best2 = (Cluster)cArray[j]; } } } curClustering.remove(best1); curClustering.remove(best2); if ((best1 != null) && (best2 != null)) System.out.println("Merging clusters: " + best1 + " and " + best2); curClustering.add(mergeClusters(best1,best2)); return bestScore; } public double evaluatePair (Cluster c1, Cluster c2, WeightedGraph graph) { double total = 0.0; int numEdges = 0; Iterator i1 = c1.iterator(); while (i1.hasNext()) { List edges = graph.getEdges((Vertex)i1.next()); Iterator e1 = edges.iterator(); while (e1.hasNext()) { WeightedEdge edge = (WeightedEdge)e1.next(); if ((c2.contains(edge.getVertexA())) || (c2.contains(edge.getVertexB()))) { numEdges++; total += edge.getWeight(); } } } return total/(double)numEdges; } public Cluster mergeClusters (Cluster c1, Cluster c2) { Cluster newCluster = new Cluster(); Iterator i1 = c1.iterator(); while (i1.hasNext()) { newCluster.add(i1.next()); } Iterator i2 = c1.iterator(); while (i2.hasNext()) { newCluster.add(i2.next()); } return newCluster; } // this is the old case (where the threshold wasn't a parameter) // default threshold to 0.0 here public Clustering getClustering (List selectVertices) { return getClustering (selectVertices, 0.0); } public Clustering getClustering (List selectVertices, double threshold) { Clustering clusters = new Clustering(); WeightedGraph curGraph = null; // make copy of original graph List newSelectVertices = new Stack(); try { if (allIntegers (origGraph.getVertexSet())) System.out.println("All vertex objects are integers"); else System.out.println("At least one vertex object is a non-integer"); curGraph = (WeightedGraph)deepCopy(origGraph); // copy graph, since we muck with it } catch (Exception e) {e.printStackTrace();} vertexMap = createVertexMap(sortSet(curGraph.getVertexSet(), comparator), origVertices); // create map between copy and original while (!(curGraph.getVertexSet().isEmpty())) { Vertex v = null; //Vertex v = selectRandomVertex(curGraph); if ((selectVertices != null)) { while ((!selectVertices.isEmpty()) && (!curGraph.getVertexSet().contains(v))) { v = (Vertex)selectVertices.remove(0); } if (v == null) v = selectHeaviestVertex(curGraph); } else { v = selectHeaviestVertex(curGraph); } newSelectVertices.add(v); Set nPlus = getNPlus(curGraph, v); nPlus.add(v); // initial cluster must include selected vertex (of course) Set cluster = sortSet(findOptimalCluster (curGraph, (Set)sortSet(nPlus, comparator), threshold), comparator); if (!(cluster.isEmpty())) { Iterator i = cluster.iterator(); while (i.hasNext()) { try { curGraph.remove((Vertex)i.next()); // remove created cluster from graph } catch (Exception e) {e.printStackTrace();} } Cluster toAdd = mapCluster(vertexMap, cluster); clusters.add(toAdd); } else { Iterator iter = curGraph.getVertexSet().iterator(); while (iter.hasNext()) { Cluster s = new Cluster(); s.add(iter.next()); clusters.add(s); } break; } } if (mappedGraph != null) { Clustering cl = remapClusters (clusters); cl.setSelectVertices (newSelectVertices); return cl; } else { clusters.setSelectVertices (newSelectVertices); return clusters; } } // this method remaps clusters back into the original objects associated with a MappedGraph // ... yes there are way too many hashtables in this object . . . private Clustering remapClusters (Clustering clusters) { Clustering set = new Clustering(); Iterator i = clusters.iterator(); while (i.hasNext()) { Cluster cluster = (Cluster)i.next(); LinkedHashSet s1 = new LinkedHashSet(); Iterator i1 = cluster.iterator(); while (i1.hasNext()) { Vertex v = (Vertex)i1.next(); s1.add(mappedGraph.getObjectFromVertex (v)); } set.add(s1); } return set; } private Cluster mapCluster (HashMap map, Set cluster) { Cluster realCluster = new Cluster (); Iterator i = cluster.iterator(); while (i.hasNext()) { realCluster.add(map.get(i.next())); } return realCluster; } // method will select vertex in Graph that has the highest sum of incident edges (> 0) // could also choose vertex that has the highest variance over incident edges // i.e. most of the values are close to 1 or -1 (not 0) - this is the vertex we // are most "confident about" private Vertex selectHeaviestVertex (Graph g) { double curHeaviestWeight = -1.0; Vertex curHeaviest = null; Set set = sortSet(g.getVertexSet(), comparator); Iterator i = set.iterator(); while (i.hasNext()) { Vertex v = (Vertex)i.next(); double vWeight = getVWeight (v, g); if (vWeight > curHeaviestWeight) { curHeaviestWeight = vWeight; curHeaviest = v; } } return curHeaviest; } public SortedSet sortSet (Set origSet, Comparator c) { TreeSet tSet = null; if (c != null) { tSet = new TreeSet(c); } else { tSet = new TreeSet(); // assume elements have a natural order } Iterator i = origSet.iterator();
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -