📄 rmnassigner.java
字号:
if (useMultipleMetrics) { // centroidIdx1 == centroidIdx2 double penalty = 2.0 - ((KL) metrics[centroidIdx1]).distanceJS(instance1, instance2); weight += cost * penalty; } else { // single metric for all clusters double penalty = 2.0 - ((KL) metric).distanceJS(instance1, instance2); weight += cost * penalty; } weightMatrix[centroidIdx1][centroidIdx2] = Math.exp(-m_constraintWeight*weight/m_expScalingFactor); weightMatrix[centroidIdx2][centroidIdx1] = Math.exp(-m_constraintWeight*weight/m_expScalingFactor); } else if (metric instanceof WeightedEuclidean || metric instanceof WeightedMahalanobis) { if (useMultipleMetrics) { // centroidIdx1 == centroidIdx2 double maxDistance = metrics[centroidIdx1].distance(m_clusterer.m_maxCLPoints[centroidIdx1][0], m_clusterer.m_maxCLPoints[centroidIdx1][1]); double distance = metrics[centroidIdx1].distance(instance1, instance2); weight += cost * (maxDistance * maxDistance - distance * distance); } else { // single metric for all clusters double maxDistance = metric.distance(m_clusterer.m_maxCLPoints[0][0], m_clusterer.m_maxCLPoints[0][1]); double distance = metric.distance(instance1, instance2); weight += cost * (maxDistance * maxDistance - distance * distance); } weightMatrix[centroidIdx1][centroidIdx2] = Math.exp(-m_constraintWeight*weight/m_expScalingFactor); weightMatrix[centroidIdx2][centroidIdx1] = Math.exp(-m_constraintWeight*weight/m_expScalingFactor); } } else { // no constraint violation weightMatrix[centroidIdx1][centroidIdx2] = 1; weightMatrix[centroidIdx2][centroidIdx1] = 1; } if (weightMatrix[centroidIdx1][centroidIdx2] < m_epsilon) { weightMatrix[centroidIdx1][centroidIdx2] = m_epsilon; } if (weightMatrix[centroidIdx2][centroidIdx1] < m_epsilon) { weightMatrix[centroidIdx2][centroidIdx1] = m_epsilon; } if (verbose) { System.out.println("Link weight[" + centroidIdx1 + "," + centroidIdx2 + "] for pair: (" + pair.first + "," + pair.second + "," + linkType + ") = " + weightMatrix[centroidIdx1][centroidIdx2]); } } } } PotentialFactory2 pf2 = new PotentialFactory2(weightMatrix); Potential pot = pf2.newInstance(); Variable[] nodePair = new Variable[2]; // add edges between potential and variable nodes nodePair[0] = vars[pair.first]; nodePair[1] = vars[pair.second]; fg.addEdges(pot, nodePair); } } fg.allocateMessages(); System.out.println("Doing MPE inference"); if (m_singlePass) { fg.setMPE(); // Razvan's fast approximate computation } else { fg.setExactMPE(); // Kevin Murphy's exact computation } /****/ /**** Compare to default assigner */ /****/ // compare to default assigner SimpleAssigner simple = new SimpleAssigner(m_clusterer); int [] clusterAssignments = m_clusterer.getClusterAssignments(); int [] oldAssignments = new int[numInstances]; int [] simpleAssignments = new int[numInstances]; // backup assignments before E-step for (int i=0; i<numInstances; i++) { oldAssignments[i] = clusterAssignments[i]; } // get assignments with default E-step simple.assign(); for (int i=0; i<numInstances; i++) { simpleAssignments[i] = clusterAssignments[i]; // restore assignments to state before E-step clusterAssignments[i] = oldAssignments[i]; } // number of differences between default and RMN assignments int numDiff = 0; int numSame = 0; boolean invalidAssignments = false; double ratioMissassigned = 0; double ratioNonMissassigned = 0; // Make new cluster assignments using RMN inference for (int i = 0; i < numInstances; i++) { Variable var = vars[i]; int newAssignment = var.getInfValue(); if (verbose) { System.out.println("Variable " + i + " has MPE " + newAssignment); } if (clusterAssignments[i] != newAssignment) { if (verbose) { System.out.println("Moving instance " + i + " from cluster " + clusterAssignments[i] + " to cluster " + newAssignment); } clusterAssignments[i] = newAssignment; moved++; } if (clusterAssignments[i] == -1) { // current cluster assignment invalid invalidAssignments = true; break; // exit for loop } } // 0/NaN in RMN, fallback to SimpleAssigner if (invalidAssignments) { System.out.println("Instances not correctly assigned by RMN, backing off and assigning by SimpleAssigner"); for (int i=0; i<numInstances; i++) { clusterAssignments[i] = simpleAssignments[i]; } } else { // compare RMNAssignments to simpleAssignments for (int i = 0; i < numInstances; i++) { if (clusterAssignments[i] != simpleAssignments[i]) { numDiff++; // count number of constraint violations for this point HashMap instanceConstraintHash = m_clusterer.getInstanceConstraintsHash(); int numViolated = 0; int numTotal = 0; Object list = instanceConstraintHash.get(new Integer(i)); if (list != null) { // there are constraints associated with this instance ArrayList constraintList = (ArrayList) list; numTotal = constraintList.size(); for (int j = 0; j < constraintList.size(); j++) { InstancePair pair = (InstancePair) constraintList.get(j); int firstIdx = pair.first; int secondIdx = pair.second; int centroidIdx = (firstIdx == i) ? clusterAssignments[firstIdx] : clusterAssignments[secondIdx]; int otherIdx = (firstIdx == i) ? clusterAssignments[secondIdx] : clusterAssignments[firstIdx]; // check whether the constraint is violated if (otherIdx != -1 && otherIdx < numClusters) { if (otherIdx != centroidIdx && pair.linkType == InstancePair.MUST_LINK) { numViolated++; } else if (otherIdx == centroidIdx && pair.linkType == InstancePair.CANNOT_LINK) { numViolated++; } } } } // System.out.println("#constraints violated for point " + i + " = " + numViolated); // compare to simpleAssignments double ratio = (numTotal == 0) ? 0 : ((numViolated+0.0)/numTotal); System.out.print("Point: " + i + "..."); System.out.println("clusterAssignments: " + clusterAssignments[i] + ", simpleAssignments[i] = " + simpleAssignments[i]); System.out.println("numTotal: " + numTotal + ", ratio: " + ratio); if (numTotal > 0) { if (clusterAssignments[i] != simpleAssignments[i]) { numDiff++; System.out.println("MISSASSIGNED; violated/total = " + numViolated + "/" + numTotal + "\t=" + ((float) ratio)); ratioMissassigned += ratio; } else { System.out.println("NOT MISASSIGNED; violated/total = " + numViolated + "/" + numTotal + "\t=" + ((float) ratio)); ratioNonMissassigned += ratio; numSame++; } } } } } System.out.println("\tAVG for misassigned: " + ((float) (ratioMissassigned/numDiff)) + "\n\tAVG for non-misassigned: " + ((float) (ratioNonMissassigned/numSame))); System.out.println("Moved " + moved + " points in RMN inference E-step"); /****/ /**** End of comparing to default assigner */ /****/ return moved; } /** * Get/Set m_singlePass * @param b truth value */ public void setSinglePass (boolean b) { this.m_singlePass = b; } public boolean getSinglePass () { return m_singlePass; } /** * Get/Set m_expScalingFactor * @param s scaling factor */ public void setExpScalingFactor (double s) { this.m_expScalingFactor = s; System.out.println("Setting expScalingFactor to: " + m_expScalingFactor); } public double getExpScalingFactor () { return m_expScalingFactor; } /** * Get/Set m_constraintWeight * @param w weight */ public void setConstraintWeight (double w) { this.m_constraintWeight = w; System.out.println("Setting constraintWeight to: " + m_constraintWeight); } public double getConstraintWeight () { return m_constraintWeight; } public void setOptions (String[] options) throws Exception { // TODO } public Enumeration listOptions () { // TODO return null; } public String [] getOptions () { String[] options = new String[5]; int current = 0; if (m_singlePass) { options[current++] = "-singlePass"; }// options[current++] = "-expScale";// options[current++] = "" + m_expScalingFactor; options[current++] = "-constrWt"; options[current++] = "" + m_constraintWeight; while (current < options.length) { options[current++] = ""; } return options; }} // TODO: // 1. Add potential nodes for multiple-metric WeightedMahalanobis or WeightedEuclidean
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -