📄 rmnassigner.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * RMNAssigner.java * RMN assignment for K-Means * Copyright (C) 2004 Misha Bilenko, Sugato Basu * */package weka.clusterers.assigners; import java.io.*;import java.util.*;import weka.core.*;import weka.core.metrics.*;import weka.clusterers.*;import weka.clusterers.assigners.*;import rmn.*;public class RMNAssigner extends MPCKMeansAssigner { /** Inference can be single-pass approximate or multi-pass approximate */ protected boolean m_singlePass = true; /** scaling factor for exponent */ protected double m_expScalingFactor = 10; /** scaling factor for constraint weights */ protected double m_constraintWeight = 1000; public RMNAssigner() { super(); } public RMNAssigner(MPCKMeans clusterer) { super(clusterer); } /** This is a sequential assignment method */ public boolean isSequential() { return false; } /** small value to replace 0 in some places, to avoid numerical underflow */ public double m_epsilon; /** The main method * @return the number of points that changed assignment */ public int assign() throws Exception { int moved = 0; boolean verbose = m_clusterer.getVerbose(); m_epsilon = 1e-9; Metric metric = m_clusterer.getMetric(); LearnableMetric[] metrics = m_clusterer.getMetrics(); boolean useMultipleMetrics = m_clusterer.getUseMultipleMetrics(); Instances instances = m_clusterer.getInstances(); Instances centroids = m_clusterer.getClusterCentroids(); int numInstances = instances.numInstances(); int numClusters = m_clusterer.getNumClusters(); Random random = new Random(m_clusterer.getRandomSeed()); // initialize random number generator // create factor graph FactorGraph fg = new FactorGraph(); // create variable nodes Variable[] vars = new Variable[numInstances]; for (int i=0; i<numInstances; i++) { String name = "var:" + i; vars[i] = new Variable(name, numClusters); } // create centroid potential nodes double maxSim = 0; double[][] simMatrix = new double[numInstances][numClusters]; for (int i=0; i<numInstances; i++) { Instance instance = instances.instance(i); // fill up potential table using current distances for (int centroidIdx=0; centroidIdx<numClusters; centroidIdx++) { Instance centroid = centroids.instance(centroidIdx); if (!m_clusterer.isObjFunDecreasing()) { // increasing obj. function if (useMultipleMetrics) { // multiple metrics simMatrix[i][centroidIdx] = metrics[centroidIdx].similarity(instance, centroid); } else { simMatrix[i][centroidIdx] = metric.similarity(instance, centroid); } } else { // decreasing obj. function if (useMultipleMetrics) { // multiple metrics simMatrix[i][centroidIdx] = metrics[centroidIdx].distance(instance, centroid); } else { simMatrix[i][centroidIdx] = metric.distance(instance, centroid); } if (metric instanceof WeightedEuclidean || metric instanceof WeightedMahalanobis) { simMatrix[i][centroidIdx] *= simMatrix[i][centroidIdx]; } } if (maxSim < simMatrix[i][centroidIdx]) { maxSim = simMatrix[i][centroidIdx]; } if (verbose) { System.out.println("simMatrix[" + i + "," + centroidIdx + "]: " + simMatrix[i][centroidIdx] + ", MaxSim: " + maxSim); } } } m_expScalingFactor = maxSim; for (int i=0; i<numInstances; i++) { double[] weightVector = new double[numClusters]; // fill up potential table using current distances for (int centroidIdx=0; centroidIdx<numClusters; centroidIdx++) { if (!m_clusterer.isObjFunDecreasing()) { // increasing obj. function weightVector[centroidIdx] = Math.exp(simMatrix[i][centroidIdx]/m_expScalingFactor); if (weightVector[centroidIdx] < m_epsilon) { weightVector[centroidIdx] = m_epsilon; } } else { weightVector[centroidIdx] = Math.exp(-simMatrix[i][centroidIdx]/m_expScalingFactor); if (weightVector[centroidIdx] < m_epsilon) { weightVector[centroidIdx] = m_epsilon; } } if (verbose) { System.out.println("Centroid weight[" + centroidIdx + "] for instance: " + i + " = " + weightVector[centroidIdx]); } } // create centroid potential node PotentialFactory1 pf1 = new PotentialFactory1(weightVector); Potential pot = pf1.newInstance(); // add edges between potential and variable nodes Variable[] node = new Variable[1]; node[0] = vars[i]; fg.addEdges(pot, node); } // create ML and CL potential nodes HashMap constraintsHash = m_clusterer.getConstraintsHash(); if (constraintsHash != null) { System.out.println("Creating constraint potential nodes"); Set pointPairs = (Set) constraintsHash.keySet(); Iterator pairItr = pointPairs.iterator(); // iterate over the pairs in ConstraintHash while( pairItr.hasNext() ){ InstancePair pair = (InstancePair) pairItr.next(); Instance instance1 = instances.instance(pair.first); Instance instance2 = instances.instance(pair.second); int linkType = ((Integer) constraintsHash.get(pair)).intValue(); double cost = 0; if (linkType == InstancePair.MUST_LINK) { cost = m_clusterer.getMustLinkWeight(); } else if (linkType == InstancePair.CANNOT_LINK) { cost = m_clusterer.getCannotLinkWeight(); } if (verbose) { System.out.println(pair + ": type = " + linkType); } double[][] weightMatrix = new double[numClusters][numClusters]; if( linkType == InstancePair.MUST_LINK ){ // create ML potential node for (int centroidIdx1=0; centroidIdx1<numClusters; centroidIdx1++) { for (int centroidIdx2=0; centroidIdx2<numClusters; centroidIdx2++) { // fill up potential table using current distances if (centroidIdx1!=centroidIdx2) { double weight = 0; if (metric instanceof WeightedDotP) { if (useMultipleMetrics) { // split penalty in half between the two involved clusters double sim1 = metrics[centroidIdx1].similarity(instance1, instance2); weight -= 0.5 * cost * (1 - sim1); double sim2 = metrics[centroidIdx2].similarity(instance1, instance2); weight -= 0.5 * cost * (1 - sim2); } else { // single metric for all clusters double sim = metric.similarity(instance1, instance2); weight -= cost * (1 - sim); } weightMatrix[centroidIdx1][centroidIdx2] = Math.exp(m_constraintWeight*m_expScalingFactor*weight); weightMatrix[centroidIdx2][centroidIdx1] = Math.exp(m_constraintWeight*m_expScalingFactor*weight); } else if (metric instanceof KL) { if (useMultipleMetrics) { // split penalty in half between the two involved clusters double penalty1 = ((KL) metrics[centroidIdx1]).distanceJS(instance1, instance2); weight += 0.5 * cost * penalty1; double penalty2 = ((KL) metrics[centroidIdx2]).distanceJS(instance1, instance2); weight += 0.5 * cost * penalty2; } else { // single metric for all clusters double penalty = ((KL) metric).distanceJS(instance1, instance2); weight += cost * penalty; } weightMatrix[centroidIdx1][centroidIdx2] = Math.exp(-m_constraintWeight*weight/m_expScalingFactor); weightMatrix[centroidIdx2][centroidIdx1] = Math.exp(-m_constraintWeight*weight/m_expScalingFactor); } else if (metric instanceof WeightedEuclidean || metric instanceof WeightedMahalanobis) { if (useMultipleMetrics) { // split penalty in half between the two involved clusters double distance1 = metrics[centroidIdx1].distance(instance1, instance2); weight += 0.5 * cost * distance1 * distance1; double distance2 = metrics[centroidIdx2].distance(instance1, instance2); weight += 0.5 * cost * distance2 * distance2; } else { // single metric for all clusters double distance = metric.distance(instance1, instance2); weight += cost * distance * distance; } weightMatrix[centroidIdx1][centroidIdx2] = Math.exp(-m_constraintWeight*weight/m_expScalingFactor); weightMatrix[centroidIdx2][centroidIdx1] = Math.exp(-m_constraintWeight*weight/m_expScalingFactor); } } else { // no constraint violation weightMatrix[centroidIdx1][centroidIdx2] = 1; weightMatrix[centroidIdx2][centroidIdx1] = 1; } if (weightMatrix[centroidIdx1][centroidIdx2] < m_epsilon) { weightMatrix[centroidIdx1][centroidIdx2] = m_epsilon; } if (weightMatrix[centroidIdx2][centroidIdx1] < m_epsilon) { weightMatrix[centroidIdx2][centroidIdx1] = m_epsilon; } if (verbose) { System.out.println("Link weight[" + centroidIdx1 + "," + centroidIdx2 + "] for pair: (" + pair.first + "," + pair.second + "," + linkType + ") = " + weightMatrix[centroidIdx1][centroidIdx2]); } } } } else { // create CL potential node for (int centroidIdx1 = 0; centroidIdx1 < numClusters; centroidIdx1++) { for (int centroidIdx2 = 0; centroidIdx2 < numClusters; centroidIdx2++) { // fill up potential table using current distances if (centroidIdx1 == centroidIdx2) { double weight = 0; if (metric instanceof WeightedDotP) { if (useMultipleMetrics) { // centroidIdx1 == centroidIdx2 weight -= cost * metrics[centroidIdx1].similarity(instance1, instance2); } else { // single metric for all clusters weight -= cost * metric.similarity(instance1, instance2); } weightMatrix[centroidIdx1][centroidIdx2] = Math.exp(m_constraintWeight*m_expScalingFactor*weight); weightMatrix[centroidIdx2][centroidIdx1] = Math.exp(m_constraintWeight*m_expScalingFactor*weight); } else if (metric instanceof KL) {
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -