📄 .#mpckmeans.java.1.110
字号:
m_objNormalizerCurrPoint = -m_logTerms[centroidIdx]; } // only add the constraints if seedable or constrained if (m_Seedable || (m_Trainable != TRAINING_NONE)) { Object list = m_instanceConstraintHash.get(new Integer(instIdx)); if (list != null) { // there are constraints associated with this instance ArrayList constraintList = (ArrayList) list; for (int i = 0; i < constraintList.size(); i++) { InstancePair pair = (InstancePair) constraintList.get(i); int firstIdx = pair.first; int secondIdx = pair.second; Instance instance1 = m_Instances.instance(firstIdx); Instance instance2 = m_Instances.instance(secondIdx); int otherIdx = (firstIdx == instIdx) ? m_ClusterAssignments[secondIdx] : m_ClusterAssignments[firstIdx]; // check whether the constraint is violated if (otherIdx != -1 && otherIdx < m_NumClusters) { if (otherIdx != centroidIdx && pair.linkType == InstancePair.MUST_LINK) { // split penalty in half between the two involved clusters if (m_useMultipleMetrics) { double penalty1 = m_metrics[otherIdx].penaltySymmetric(instance1, instance2); double penalty2 = m_metrics[centroidIdx].penaltySymmetric(instance1, instance2); m_objMustLinksCurrPoint += 0.5 * m_MLweight * (penalty1 + penalty2); } else { double penalty = m_metric.penaltySymmetric(instance1, instance2); m_objMustLinksCurrPoint += m_MLweight * penalty; } } else if (otherIdx == centroidIdx && pair.linkType == InstancePair.CANNOT_LINK) { double penalty = m_metrics[centroidIdx].penaltySymmetric(instance1, instance2); m_objCannotLinksCurrPoint += m_CLweight * (m_maxCLPenalties[centroidIdx] - penalty); if (m_maxCLPenalties[centroidIdx] - penalty < 0) { System.out.println("***NEGATIVE*** penalty: " + penalty + " for CL constraint"); } } } } } } double total = m_objVarianceCurrPoint + m_objCannotLinksCurrPoint + m_objMustLinksCurrPoint + m_objNormalizerCurrPoint; if(m_verbose) { System.out.println("Final penalty for instance " + instIdx + " and centroid " + centroidIdx + " is: " + total); } return total; } /** M-step of the KMeans clustering algorithm -- updates cluster centroids */ protected void updateClusterCentroids() throws Exception { // System.out.println("CENTROIDS BEFORE: " + m_ClusterCentroids); // M-step: update cluster centroids Instances [] tempI = new Instances[m_NumClusters]; Instances tempCentroids = m_ClusterCentroids; Instances tempNewCentroids = new Instances(m_Instances, m_NumClusters); m_ClusterCentroids = new Instances(m_Instances, m_NumClusters); for (int i = 0; i < m_NumClusters; i++) { tempI[i] = new Instances(m_Instances, 0); // tempI[i] stores the cluster instances for cluster i } for (int i = 0; i < m_Instances.numInstances(); i++) { tempI[m_ClusterAssignments[i]].add(m_Instances.instance(i)); } // Calculates cluster centroids for (int i = 0; i < m_NumClusters; i++) { double [] values = new double[m_Instances.numAttributes()]; if (m_isSparseInstance) { values = ClusterUtils.meanOrMode(tempI[i]); // uses fast meanOrMode } else { for (int j = 0; j < m_Instances.numAttributes(); j++) { values[j] = tempI[i].meanOrMode(j); // uses usual meanOrMode } } // cluster centroids are dense in SPKMeans Instance centroid; if (m_isSparseInstance) { centroid = new SparseInstance(1.0, values); } else { // non-sparse centroid = new Instance(1.0, values); } // if we are using a smoothing metric, smooth the centroids if (m_metric instanceof SmoothingMetric && ((SmoothingMetric) m_metric).getUseSmoothing()) { SmoothingMetric smoothingMetric = (SmoothingMetric) m_metric; centroid = smoothingMetric.smoothInstance(centroid); System.out.println("Using Smoothing ..."); } m_ClusterCentroids.add(centroid); // tempNewCentroids.add(centroid);// m_ClusterCentroids.delete(); // for (int j = 0; j <= i; j++) {// m_ClusterCentroids.add(tempNewCentroids.instance(j));// }// for (int j = i+1; j < m_NumClusters; j++) {// m_ClusterCentroids.add(tempCentroids.instance(j));// } // double objBackup = m_Objective;// System.out.println(calculateObjectiveFunction(false));// m_Objective = objBackup; // in SPKMeans, cluster centroids need to be normalized if (m_metric.doesNormalizeData()) { m_metric.normalizeInstanceWeighted(m_ClusterCentroids.instance(i)); } } if (m_metric instanceof SmoothingMetric && ((SmoothingMetric) m_metric).getUseSmoothing()) updateSmoothingMetrics(); for (int i = 0; i < m_NumClusters; i++) tempI[i] = null; // free memory // System.out.println("CENTROIDS AFTER: " + m_ClusterCentroids); } /** M-step of the KMeans clustering algorithm -- updates metric * weights. Invoked only when we're using non-Potts model * and metric is trainable */ protected void updateMetricWeights() throws Exception { if (m_useMultipleMetrics) { for (int i = 0; i < m_NumClusters; i++) { m_metricLearners[i].trainMetric(i); } } else { m_metricLearner.trainMetric(-1); } InitNormalizerRegularizer(); } /** checks for convergence */ public boolean convergenceCheck(double oldObjective, double newObjective) throws Exception { boolean converged = false; // Convergence check if(Math.abs(oldObjective - newObjective) < m_ObjFunConvergenceDifference) { System.out.println("Final objective function is: " + newObjective); converged = true; } // number of iterations check if (m_numBlankIterations >= m_maxBlankIterations) { System.out.println("Max blank iterations reached ...\n"); converged = true; } if (m_Iterations >= m_maxIterations) { System.out.println("Max iterations reached ...\n"); converged = true; } return converged; } /** calculates objective function */ public double calculateObjectiveFunction(boolean isComplete) throws Exception { if (m_verbose) { System.out.println("Calculating objective function ..."); } // update the oldObjective only if previous estimate of m_Objective // was complete if (isComplete) { m_OldObjective = m_Objective; } m_Objective = 0; m_objVariance = 0; m_objMustLinks = 0; m_objCannotLinks = 0; m_objNormalizer = 0; // temporarily halve weights since every constraint is counted twice double tempML = m_MLweight; double tempCL = m_CLweight; m_MLweight = tempML/2; m_CLweight = tempCL/2; if (m_verbose) { System.out.println("Must link weight: " + m_MLweight); System.out.println("Cannot link weight: " + m_CLweight); } for (int i=0; i<m_Instances.numInstances(); i++) { if (m_isOfflineMetric) { double dist = m_metric.penalty(m_Instances.instance(i), m_ClusterCentroids.instance(m_ClusterAssignments[i])); m_Objective += dist; if (m_verbose) { System.out.println("Component for " + i + " = " + dist); } } else { m_Objective += penaltyForInstance(i, m_ClusterAssignments[i]); m_objVariance += m_objVarianceCurrPoint; m_objMustLinks += m_objMustLinksCurrPoint; m_objCannotLinks += m_objCannotLinksCurrPoint; m_objNormalizer += m_objNormalizerCurrPoint; } } m_Objective -= m_objRegularizer; m_MLweight = tempML; m_CLweight = tempCL; // reset the values of the constraint weights // Oscillation check if ((float)m_OldObjective < (float)m_Objective) { System.out.println("WHOA!!! Oscillations => bug in EM step?"); System.out.println("Old objective:" + (float)m_OldObjective + " < New objective: " + (float)m_Objective); }// // TEMPORARY BLAH// System.out.println("\tvar=" + ((float)m_objVariance)// + "\tC=" + ((float)m_objCannotLinks)// + "\tM=" + ((float)m_objMustLinks)// + "\tLOG=" + ((float)m_objNormalizer) // + "\tREG=" + ((float)m_objRegularizer)); return m_Objective; } /** Actual KMeans function */ protected void runKMeans() throws Exception { boolean converged = false; m_Iterations = 0; m_numBlankIterations = 0; m_Objective = Double.POSITIVE_INFINITY; if (!m_isOfflineMetric) { // initialize normalizer and regularizer m_metric.resetMetric(); // initialize max CL penalties if (m_ConstraintsHash.size() > 0) { m_maxCLPenalties = calculateMaxCLPenalties(); } } while (!converged) { m_OldObjective = m_Objective; // E-step int numMovedPoints = findBestAssignments(); m_numBlankIterations = (numMovedPoints == 0) ? m_numBlankIterations+1 : 0; {//if (m_verbose) { calculateObjectiveFunction(false); } System.out.println((float)m_Objective + " - Objective function after point assignment(CALC)"); System.out.println("\tvar=" + ((float)m_objVariance) + "\tC=" + ((float)m_objCannotLinks) + "\tM=" + ((float)m_objMustLinks) + "\tLOG=" + ((float)m_objNormalizer) + "\tREG=" + ((float)m_objRegularizer)); // M-step updateClusterCentroids(); System.out.println("\n" + m_Iterations + ". Objective function: " + ((float)m_Objective)); {//if (m_verbose) { calculateObjectiveFunction(true); } System.out.println((float)m_Objective + " - Objective function after centroid estimation"); System.out.println("\tvar=" + ((float)m_objVariance) + "\tC=" + ((float)m_objCannotLinks) + "\tM=" + ((float)m_objMustLinks) + "\tLOG=" + ((float)m_objNormalizer) + "\tREG=" + ((float)m_objRegularizer)); if (m_Trainable == TRAINING_INTERNAL && !m_isOfflineMetric) { updateMetricWeights(); {//if (m_verbose) { calculateObjectiveFunction(true); } System.out.println((float)m_Objective + " - Objective function after metric update"); System.out.println("\tvar=" + ((float)m_objVariance) + "\tC=" + ((float)m_objCannotLinks) + "\tM=" + ((float)m_objMustLinks) + "\tLOG=" + ((float)m_objNormalizer) + "\tREG=" + ((float)m_objRegularizer)); if (m_ConstraintsHash.size() > 0) { m_maxCLPenalties = calculateMaxCLPenalties(); } } converged = convergenceCheck(m_OldObjective, m_Objective); m_Iterations++; } if (m_verbose) { System.out.println("Done clustering; top cluster features: "); for (int i = 0; i < m_NumClusters; i++){ System.out.println("Centroid " + i); TreeMap map = new TreeMap(Collections.reverseOrder()); Instance centroid= m_ClusterCentroids.instance(i); for (int j = 0; j < centroid.numValues(); j++) { Attribute attr = centroid.attributeSparse(j); map.put(new Double(centroid.value(attr)), attr.name()); } Iterator it = map.entrySet().iterator(); for (int j=0; j < 5 && it.hasNext(); j++) { Map.Entry entry = (Map.Entry) it.next(); System.out.println("\t" + entry.getKey() + "\t" + entry.getValue()); } } } } /** reset the value of the objective function and all of its components */ public void resetObjective() { m_Objective = 0; m_objVariance = 0; m_objCannotLinks = 0; m_objMustLinks = 0; m_objNormalizer = 0; m_objRegularizer = 0; } /** Go through the cannot-link constraints and find the current maximum distance * @return an array of maximum weighted distances. If a single metric is used, maximum distance * is calculated over the entire dataset */ // TODO: non-datasetWide case is not debugged currently!!! protected double[] calculateMaxCLPenalties() throws Exception { double [] maxPenalties = null; double [][] minValues = null; double [][] maxValues = null; int[] attrIdxs = null; maxPenalties = new double[m_NumClusters]; m_maxCLPoints = new Instance[m_NumClusters][2]; m_maxCLDiffInstances = new Instance[m_NumClusters]; for (int i = 0; i < m_NumClusters; i++) { m_maxCLPoints[i][0] = new Instance(m_Instances.numAttributes()); m_maxCLPoints[i][1] = new Instance(m_Instances.numAttributes()); m_maxCLPoints[i][0].setDataset(m_Instances); m_maxCLPoints[i][1].setDataset(m_Instances); m_maxCLDiffInstances[i] = new Instance(m_Instances.numAttributes()); m_maxCLDiffInstances[i].setDataset(m_Instances); } // TEMPORARY PLUG: this was supposed to take care of WeightedDotp, // but it turns out that with weighting similarity can be > 1. // if (m_metric.m_fixedMaxDistance) {// for (int i = 0; i < m_NumClusters; i++) {// maxPenalties[i] = m_metric.getMaxDistance(); // }// return maxPenalties; // } minValues = new double[m_NumClusters][m_metrics[0].getNumAttributes()]; maxValues = new double[m_NumClusters][m_metrics[0].getNumAttributes()]; attrIdxs = m_metrics[0].getAttrIndxs(); // temporary plug: if this if the first iteration when no instances were assigned to clusters, // dataset-wide (not cluster-wide!) minimum and maximum are used even for the case with // multiple metrics boolean datasetWide = true; if (m_useMultipleMetrics && m_Iterations > 0) { datasetWide = false; } // TODO: Mahalanobis - check with getMaxPoints // go through all points // if (m_useMultipleMetrics) {// for (int i = 0; i < m_metrics.length; i++) { // double[][] maxPoints = ((WeightedMahalanobis)m_metrics[i]).getMaxPoints(m_ConstraintsHash, m_Instances);// minValues[i] = maxPoints[0];// maxValues[i] = maxPoints[1];// // System.out.println("Max points " + i);// // for (int j = 0; j < maxPoints[0].length; j++) { System.out.println(maxPoints[0][j] + " - " + maxPoints[1][j]);}// }// } else {
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -