corefclusteradv.java
来自「mallet是自然语言处理、机器学习领域的一个开源项目。」· Java 代码 · 共 1,911 行 · 第 1/5 页
JAVA
1,911 行
PseudoVertex v22 = (PseudoVertex)i2.next(); Set s22 = v22.getCluster(); s22.addAll(s1); s22.addAll(s2); } } return newScore; } public double computeInitialTreeObjScore (Collection pvertices) { Collection c1 = (Collection)new ArrayList(); for (Iterator ii = pvertices.iterator(); ii.hasNext(); ) { PseudoVertex ppv = (PseudoVertex)ii.next(); Collection c2 = (Collection)new ArrayList(); c2.add((Citation)ppv.getObject()); c1.add(c2); } return treeModel.computeTreeObjFn(c1); } public Collection absoluteCluster (InstanceList ilist, List mentions) { List mCopy = new ArrayList(); boolean newCluster; java.util.Random r = new java.util.Random(); int numTries = 0; // get pseudo edges and sort em HashMap vsToPvs = new HashMap(); // vsToPvs set destructively Collection pvertices = createPseudoVertices (ilist, mentions, vsToPvs); for (Iterator i1 = pvertices.iterator(); i1.hasNext(); ) { PseudoVertex v = (PseudoVertex)i1.next(); mCopy.add(v); } List pedges = createPseudoEdges (ilist, vsToPvs); Collections.sort(pedges,new PseudoEdgeComparator()); double initialObjVal = computeInitialObjFnVal (pedges); System.out.println("initial obj fn: " + initialObjVal); double objFnVal = initialObjVal; double prevVal = -10000000000.0; int i = 0; while (true) { prevVal = objFnVal; PseudoEdge pedge = (PseudoEdge)pedges.get(i); // choosePseudoEdge (pedges, r); PseudoVertex v1 = (PseudoVertex)pedge.getV1(); PseudoVertex v2 = (PseudoVertex)pedge.getV2(); Set s1 = new LinkedHashSet(); Set s2 = new LinkedHashSet(); // make copies of the sets for (Iterator i1 = v1.getCluster().iterator(); i1.hasNext(); ) { s1.add(i1.next()); } for (Iterator i1 = v2.getCluster().iterator(); i1.hasNext(); ) { s2.add(i1.next()); } // this is the case for when the edge is now irrelevent through // transitive closure - just remove it and start at beginning if (s1.contains(v2) || s2.contains(v1)) { //pedges.remove(i); continue; } double[] n = new double[]{0.0}; objFnVal = updateScore (objFnVal, n , v1, v2, s1, s2, false); i++; if (objFnVal <= prevVal) { numTries++; objFnVal = prevVal; // reset it, since we don't commit to this } else { numTries = 0; } if (numTries > MAX_REDUCTIONS) break; //System.out.println("ObjFnVal: " + objFnVal); } // build a proper graph from the edges since // the evaluation code relies on this structure this.wgraph = buildGraphFromPseudoEdges (pedges, mentions); Collection citClustering = new ArrayList(); for (Iterator i1 = pvertices.iterator(); i1.hasNext(); ) { PseudoVertex v1 = (PseudoVertex)i1.next(); Collection cluster = v1.getCluster(); newCluster = true; if (citClustering.size() == 0) newCluster = true; for (Iterator i2 = citClustering.iterator(); i2.hasNext(); ) { Collection c2 = (Collection)i2.next(); if (c2.containsAll(cluster)) { newCluster = false; break; } } if (newCluster) { citClustering.add(cluster); } } Collection realClustering = new ArrayList(); for (Iterator i1 = citClustering.iterator(); i1.hasNext(); ) { Collection s1 = (Collection)i1.next(); Collection n1 = new ArrayList(); for (Iterator i2 = s1.iterator(); i2.hasNext(); ) { n1.add((Citation)((PseudoVertex)i2.next()).getObject()); } realClustering.add (n1); } return (Collection)realClustering; } protected Collection getPseudoClustering (Collection pvertices) { Collection citClustering = new ArrayList(); boolean newCluster; for (Iterator i1 = pvertices.iterator(); i1.hasNext(); ) { PseudoVertex v1 = (PseudoVertex)i1.next(); Collection cluster = v1.getCluster(); newCluster = true; if (citClustering.size() == 0) newCluster = true; for (Iterator i2 = citClustering.iterator(); i2.hasNext(); ) { Collection c2 = (Collection)i2.next(); if (c2.containsAll(cluster)) { newCluster = false; break; } } if (newCluster) { citClustering.add(cluster); } } return citClustering; } public Collection typicalClusterAdv (InstanceList ilist, List mentions) { List pedgesC = null; Collection pvertices = null; for (int j=0; j < MAX_ITERS; j++) { List mCopy = new ArrayList(); java.util.Random r = new java.util.Random(); int numTries = 0; // get pseudo edges and sort em HashMap vsToPvs = new HashMap(); // vsToPvs set destructively pvertices = createPseudoVertices (ilist, mentions, vsToPvs); for (Iterator i1 = pvertices.iterator(); i1.hasNext(); ) { PseudoVertex v = (PseudoVertex)i1.next(); mCopy.add(v); } List pedges = createPseudoEdges (ilist, vsToPvs); pedgesC = new ArrayList(); for (Iterator it = pedges.iterator(); it.hasNext();) { pedgesC.add((PseudoEdge)it.next()); } // build the actual graph for scoring puposees this.wgraph = buildGraphFromPseudoEdges (pedgesC, mentions); Collections.sort(pedges,new PseudoEdgeComparator()); for (int k=0; k < 50; k++) { PseudoEdge e1 = (PseudoEdge)pedges.get(k); } double initialObjVal = computeInitialObjFnVal (pedges); double initialTreeObjVal; if (treeModel != null) initialTreeObjVal = computeInitialTreeObjScore (pvertices); else initialTreeObjVal = 0.0; double[] treeObjVal; treeObjVal = new double[]{initialTreeObjVal}; double objFnVal = initialObjVal; double prevVal = -10000000000.0; int i = 0; int numClusters = pvertices.size(); while (true) { prevVal = objFnVal; int choice = r.nextInt(rBeamSize); // selection size if (choice > pedges.size()) break; PseudoEdge pedge = (PseudoEdge)pedges.get(choice); // choosePseudoEdge (pedges, r); PseudoVertex v1 = (PseudoVertex)pedge.getV1(); PseudoVertex v2 = (PseudoVertex)pedge.getV2(); Set s1 = new LinkedHashSet(); Set s2 = new LinkedHashSet(); // make copies of the sets of vertices represented by each pseudovertex for (Iterator i1 = v1.getCluster().iterator(); i1.hasNext(); ) { s1.add(i1.next()); } for (Iterator i1 = v2.getCluster().iterator(); i1.hasNext(); ) { s2.add(i1.next()); } // this is the case for when the edge is now irrelevent through // transitive closure - just remove it and start at beginning if (s1.contains(v2) || s2.contains(v1)) { pedges.remove(choice); continue; } numClusters--; if (trueNumStop) { objFnVal = updateScore (objFnVal, treeObjVal, v1, v2, s1, s2, true); //Collection cl1 = (Collection)getClusteringFromPseudo(pvertices); //System.out.println("+++++++++++++++++++++++++++"); //double treeV = treeModel.computeTreeObjFn(cl1, false); //System.out.println("+++++++++++++++++++++++++++"); //System.out.println("treeV: " + treeV + " updated: " + treeObjVal[0]); } else objFnVal = updateScore (objFnVal, treeObjVal, v1, v2, s1, s2, false); if (!trueNumStop && objFnVal <= prevVal) { numTries++; objFnVal = prevVal; // reset it, since we don't commit to this } else { //System.out.println(objFnVal + "," + treeObjVal[0]); pedges.remove(choice); numTries = 0; } if (trueNumStop && numClusters <= keyPartitioning.size()) { Collection cl1 = (Collection)getClusteringFromPseudo(pvertices); PairEvaluate pairEval = new PairEvaluate (keyPartitioning, cl1); pairEval.evaluate(); double curAgree = evaluatePartitioningAgree (cl1, this.wgraph); double curDisAgree = evaluatePartitioningDisAgree (cl1, this.wgraph); /* double treeV = treeModel.computeTreeObjFn(cl1, false); if (Math.abs(treeV - treeObjVal[0]) > 0.01) System.out.println("Tree values don't match: " + treeV + ":" + treeObjVal[0]);*/ int singles = numSingletons(cl1); System.out.println(objFnVal + "," + treeObjVal[0] + "," + curAgree + "," + curDisAgree + "," + singles + "," + pairEval.getF1()); break; } else if (numTries > MAX_REDUCTIONS) { Collection cl1 = (Collection)getClusteringFromPseudo(pvertices); PairEvaluate pairEval = new PairEvaluate (keyPartitioning, cl1); pairEval.evaluate(); System.out.println(objFnVal + "," + treeObjVal[0] + "," + pairEval.getF1()); break; } //System.out.println("ObjFnVal: " + objFnVal); } } // build a proper graph from the edges since // the evaluation code relies on this structure return (Collection)getClusteringFromPseudo(pvertices); } protected int numSingletons (Collection clustering) { int total = 0; for (Iterator it = clustering.iterator(); it.hasNext(); ) { if (((Collection)it.next()).size() == 1) total++; } return total; } protected Collection getClusteringFromPseudo (Collection pvertices) { Collection citClustering = getPseudoClustering (pvertices); Collection realClustering = new ArrayList(); for (Iterator i1 = citClustering.iterator(); i1.hasNext(); ) { Collection s1 = (Collection)i1.next(); Collection n1 = new ArrayList(); for (Iterator i2 = s1.iterator(); i2.hasNext(); ) { n1.add((Citation)((PseudoVertex)i2.next()).getObject()); } realClustering.add (n1); } return (Collection)realClustering; } protected WeightedGraph buildGraphFromPseudoEdges (List pedges, List mentions) { HashMap alreadyAdded = new HashMap(); WeightedGraph w = (WeightedGraph)new WeightedGraphImpl(); for (Iterator it = pedges.iterator(); it.hasNext(); ) { constructEdgesFromPseudoEdges (w, (PseudoEdge)it.next(), alreadyAdded); } addVerticesToGraph (w, mentions, alreadyAdded); return w; } public Collection typicalClusterPartition (WeightedGraph graph) { /* Iterator i0 = ((Set)graph.getVertexSet()).iterator(); while (i0.hasNext()) { VertexImpl v = (VertexImpl)i0.next(); System.out.println("Map: " + v.getObject() + " -> " + ((Citation)((Node)v.getObject()).getObject()).getBaseString() ); }*/ //System.out.println("Top Graph: " + graph); while (true) { double bestEdgeVal = -100000000; WeightedEdge bestEdge = null; //System.out.println("Top Graph: " + graph); Set edgeSet = graph.getEdgeSet(); Iterator i1 = edgeSet.iterator(); // get highest edge value in this loop while (i1.hasNext()) { WeightedEdge e1 = (WeightedEdge)i1.next(); if (e1.getWeight() > bestEdgeVal) { bestEdgeVal = e1.getWeight(); bestEdge = e1; } } System.err.println ("bestEdgeVal: " + bestEdgeVal + " threshold: " + threshold); if (bestEdgeVal < threshold) break; else { if (bestEdge != null) { VertexImpl v1 = (VertexImpl)bestEdge.getVertexA(); VertexImpl v2 = (VertexImpl)bestEdge.getVertexB(); /* System.out.println("Best edge val: " + bestEdgeVal); System.out.println("Merging vertices: " + v1.getObject() + " and " + v2.getObject()); */ mergeVertices(graph, v1, v2); } } } System.out.println("Final graph now has " + graph.getVertexSet().size() + " nodes"); return getCollectionOfOriginalObjects ((Collection)graph.getVertexSet()); } public Collection partitionGraph (WeightedGraph origGraph) { java.util.Random rand = new java.util.Random(); double bestCost = -100000000000.0; double curCost = bestCost; Collection bestPartitioning = null; // evalFreq is the frequency with which evaluations occur // in the early stages, it is silly to keep doing a complete // evaluation of the objective fn for (int i=0; i < MAX_ITERS; i++) { //System.out.println("Iteration " + i); double cost = -100000000.0; double bCost = cost; int evalFreq = 10; // this is a counter that will increment each time // a new edge is tried and the result is a graph // with a reduced total objective value int numReductions = 0; double treeCost = 0.0; int iter = 0; Collection curPartitioning = null; Collection localBestPartitioning = null; WeightedGraph graph = copyGraph(origGraph); WeightedGraph graph1 = copyGraph(graph); while (true) { Collection c0 = (Collection)graph.getEdgeSet(); List sortedEdges = Collections.list(Collections.enumeration(c0)); System.out.println("Size of sorted edges: " + sortedEdges.size()); if (sortedEdges.size() > 0) { EdgeComparator comp = new EdgeComparator(); Collections.sort(sortedEdges, comp); double minVal = ((WeightedEdge)sortedEdges.get(sortedEdges.size()-1)).getWeight(); double totalVal = 0.0; Iterator il = (Iterator)sortedEdges.iterator(); while (il.hasNext()) { totalVal += ((WeightedEdge)il.next()).getWeight(); } totalVal += sortedEdges.size()*(-minVal);
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?