📄 mpckmeans.java
字号:
if (m_ClusterAssignments[instIdx] != bestCluster) { if (m_ClusterAssignments[instIdx] >= 0 && m_ClusterAssignments[instIdx] < m_NumClusters) { //if (m_verbose) { System.out.println("Moving instance " + instIdx + " from cluster " + m_ClusterAssignments[instIdx] + " to cluster " + bestCluster + " penalty:" + ((float)penaltyForInstance(instIdx, m_ClusterAssignments[instIdx])) + "=>" + ((float)lowestPenalty)); } moved = 1; m_ClusterAssignments[instIdx] = bestCluster; } if (m_verbose) { System.out.println("Assigning instance " + instIdx + " to cluster " + bestCluster); } return moved; } /** Delegate the distance calculation to the method appropriate for the current metric */ public double penaltyForInstance(int instIdx, int centroidIdx) throws Exception { m_objVarianceCurrPoint = 0; m_objCannotLinksCurrPoint = 0; m_objMustLinksCurrPoint = 0; m_objNormalizerCurrPoint = 0; int violatedConstraints = 0; // variance contribution Instance instance = m_Instances.instance(instIdx); Instance centroid = m_ClusterCentroids.instance(centroidIdx); m_objVarianceCurrPoint = m_metrics[centroidIdx].penalty(instance, centroid); // regularizer and normalizer contribution if (m_Trainable == TRAINING_INTERNAL) { m_objNormalizerCurrPoint = -m_logTerms[centroidIdx]; } // only add the constraints if seedable or constrained // if (m_Seedable || (m_Trainable != TRAINING_NONE)) { // Sugato: replacing, in order to be able to run MKMeans (no // constraint violation, only metric learning) if (m_Seedable) { Object list = m_instanceConstraintHash.get(new Integer(instIdx)); if (list != null) { // there are constraints associated with this instance ArrayList constraintList = (ArrayList) list; for (int i = 0; i < constraintList.size(); i++) { InstancePair pair = (InstancePair) constraintList.get(i); int firstIdx = pair.first; int secondIdx = pair.second; Instance instance1 = m_Instances.instance(firstIdx); Instance instance2 = m_Instances.instance(secondIdx); int otherIdx = (firstIdx == instIdx) ? m_ClusterAssignments[secondIdx] : m_ClusterAssignments[firstIdx]; // check whether the constraint is violated if (otherIdx != -1 && otherIdx < m_NumClusters) { if (otherIdx != centroidIdx && pair.linkType == InstancePair.MUST_LINK) { violatedConstraints++; // split penalty in half between the two involved clusters if (m_useMultipleMetrics) { double penalty1 = m_metrics[otherIdx].penaltySymmetric(instance1, instance2); double penalty2 = m_metrics[centroidIdx].penaltySymmetric(instance1, instance2); m_objMustLinksCurrPoint += 0.5 * m_MLweight * (penalty1 + penalty2); } else { double penalty = m_metric.penaltySymmetric(instance1, instance2); m_objMustLinksCurrPoint += m_MLweight * penalty; } } else if (otherIdx == centroidIdx && pair.linkType == InstancePair.CANNOT_LINK) { violatedConstraints++; double penalty = m_metrics[centroidIdx].penaltySymmetric(instance1, instance2); m_objCannotLinksCurrPoint += m_CLweight * (m_maxCLPenalties[centroidIdx] - penalty); if (m_maxCLPenalties[centroidIdx] - penalty < 0) { System.out.println("***NEGATIVE*** penalty: " + penalty + " for CL constraint"); } } } } } } double total = m_objVarianceCurrPoint + m_objCannotLinksCurrPoint + m_objMustLinksCurrPoint + m_objNormalizerCurrPoint; if(m_verbose) { System.out.println("Final penalty for instance " + instIdx + " and centroid " + centroidIdx + " is: " + total); } return total; } /** M-step of the KMeans clustering algorithm -- updates cluster centroids */ protected void updateClusterCentroids() throws Exception { Instances [] tempI = new Instances[m_NumClusters]; Instances tempCentroids = m_ClusterCentroids; Instances tempNewCentroids = new Instances(m_Instances, m_NumClusters); m_ClusterCentroids = new Instances(m_Instances, m_NumClusters); // tempI[i] stores the cluster instances for cluster i for (int i = 0; i < m_NumClusters; i++) { tempI[i] = new Instances(m_Instances, 0); } for (int i = 0; i < m_Instances.numInstances(); i++) { tempI[m_ClusterAssignments[i]].add(m_Instances.instance(i)); } // Calculates cluster centroids for (int i = 0; i < m_NumClusters; i++) { double [] values = new double[m_Instances.numAttributes()]; Instance centroid = null; if (m_isSparseInstance) { // uses fast meanOrMode values = ClusterUtils.meanOrMode(tempI[i]); centroid = new SparseInstance(1.0, values); } else { // non-sparse, go through each attribute for (int j = 0; j < m_Instances.numAttributes(); j++) { values[j] = tempI[i].meanOrMode(j); // uses usual meanOrMode } centroid = new Instance(1.0, values); } // // debugging: compare previous centroid w/current:// double w = 0; // for (int j = 0; j < m_Instances.numAttributes(); j++) w += values[j] * values[j];// double w1 = 0; // for (int j = 0; j < m_Instances.numAttributes(); j++) w1 += tempCentroids.instance(i).value(j) * tempCentroids.instance(i).value(j); // System.out.println("\tOldCentroid=" + w1);// System.out.println("\tNewCentroid=" + w); // double prevObj = 0, currObj = 0;// for (int j = 0; j < tempI[i].numInstances(); j++) {// Instance instance = tempI[i].instance(j);// double prevPen = m_metrics[i].penalty(instance, tempCentroids.instance(i));// double currPen = m_metrics[i].penalty(instance, centroid);// prevObj += prevPen;// currObj += currPen; // //System.out.println("\t\t" + j + " " + prevPen + " -> " + currPen + "\t" + prevObj + " -> " + currObj); // }// // dump instances out if there is a problem.// System.out.println("\t\t" + prevObj + " -> " + currObj); // if (currObj > prevObj) {// PrintWriter out = new PrintWriter(new BufferedOutputStream(new FileOutputStream("/tmp/INST.arff")), true);// out.println(new Instances(tempI[i], 0));// out.println(centroid);// out.println(tempCentroids.instance(i)); // for (int j = 0; j < tempI[i].numInstances(); j++) {// out.println(tempI[i].instance(j));// }// out.close();// System.out.println(" Updated cluster " + i + "("// + tempI[i].numInstances());// System.exit(0); // } // if we are using a smoothing metric, smooth the centroids if (m_metric instanceof SmoothingMetric && ((SmoothingMetric) m_metric).getUseSmoothing()) { System.out.println("\tSmoothing..."); SmoothingMetric smoothingMetric = (SmoothingMetric) m_metric; centroid = smoothingMetric.smoothInstance(centroid); } // DEBUGGING: replaced line under with block below m_ClusterCentroids.add(centroid); // {// tempNewCentroids.add(centroid);// m_ClusterCentroids.delete(); // for (int j = 0; j <= i; j++) {// m_ClusterCentroids.add(tempNewCentroids.instance(j));// }// for (int j = i+1; j < m_NumClusters; j++) {// m_ClusterCentroids.add(tempCentroids.instance(j));// } // double objBackup = m_Objective;// System.out.println(" Updated cluster " + i + "("// + tempI[i].numInstances() + "); obj=" +// calculateObjectiveFunction(false));// m_Objective = objBackup;// } // in SPKMeans, cluster centroids need to be normalized if (m_metric.doesNormalizeData()) { m_metric.normalizeInstanceWeighted(m_ClusterCentroids.instance(i)); } } if (m_metric instanceof SmoothingMetric && ((SmoothingMetric) m_metric).getUseSmoothing()) updateSmoothingMetrics(); for (int i = 0; i < m_NumClusters; i++) tempI[i] = null; // free memory } /** M-step of the KMeans clustering algorithm -- updates metric * weights. Invoked only when we're using non-Potts model * and metric is trainable */ protected void updateMetricWeights() throws Exception { if (m_useMultipleMetrics) { for (int i = 0; i < m_NumClusters; i++) { m_metricLearners[i].trainMetric(i); } } else { m_metricLearner.trainMetric(-1); } InitNormalizerRegularizer(); } /** checks for convergence */ public boolean convergenceCheck(double oldObjective, double newObjective) throws Exception { boolean converged = false; // Convergence check if(Math.abs(oldObjective - newObjective) < m_ObjFunConvergenceDifference) { System.out.println("Final objective function is: " + newObjective); converged = true; } // number of iterations check if (m_numBlankIterations >= m_maxBlankIterations) { System.out.println("Max blank iterations reached ...\n"); System.out.println("Final objective function is: " + newObjective); converged = true; } if (m_Iterations >= m_maxIterations) { System.out.println("Max iterations reached ...\n"); System.out.println("Final objective function is: " + newObjective); converged = true; } return converged; } /** calculates objective function */ public double calculateObjectiveFunction(boolean isComplete) throws Exception { System.out.println("\tCalculating objective function ..."); // update the oldObjective only if previous estimate of m_Objective // was complete if (isComplete) { m_OldObjective = m_Objective; } m_Objective = 0; m_objVariance = 0; m_objMustLinks = 0; m_objCannotLinks = 0; m_objNormalizer = 0; // Some debugging code: tracking per-cluster objective double[] objectives = new double[m_NumClusters]; // temporarily halve weights since every constraint is counted twice double tempML = m_MLweight; double tempCL = m_CLweight; m_MLweight = tempML/2; m_CLweight = tempCL/2; if (m_verbose) { System.out.println("Must link weight: " + m_MLweight); System.out.println("Cannot link weight: " + m_CLweight); } for (int i=0; i<m_Instances.numInstances(); i++) { if (m_isOfflineMetric) { double dist = m_metric.penalty(m_Instances.instance(i), m_ClusterCentroids.instance(m_ClusterAssignments[i])); m_Objective += dist; if (m_verbose) { System.out.println("Component for " + i + " = " + dist); } } else { double penalty = penaltyForInstance(i, m_ClusterAssignments[i]); objectives[m_ClusterAssignments[i]] += penalty; m_Objective += penalty; m_objVariance += m_objVarianceCurrPoint; m_objMustLinks += m_objMustLinksCurrPoint; m_objCannotLinks += m_objCannotLinksCurrPoint; m_objNormalizer += m_objNormalizerCurrPoint; } } m_Objective -= m_objRegularizer; m_MLweight = tempML; m_CLweight = tempCL; // reset the values of the constraint weights // debugging: reporting per-cluster objectives for (int i = 0; i < m_NumClusters; i++) { System.out.println("\t\tCluster " + i + " obj=" + objectives[i]); } System.out.println("\tTotalObj=" + m_Objective); // Oscillation check if ((float)m_OldObjective < (float)m_Objective) { System.out.println("WHOA!!! Oscillations => bug in EM step?"); System.out.println("Old objective:" + (float)m_OldObjective + " < New objective: " + (float)m_Objective); }// // TEMPORARY BLAH// System.out.println("\tvar=" + ((float)m_objVariance)// + "\tC=" + ((float)m_objCannotLinks)// + "\tM=" + ((float)m_objMustLinks)// + "\tLOG=" + ((float)m_objNormalizer) // + "\tREG=" + ((float)m_objRegularizer)); return m_Objective; } /** Actual KMeans function */ protected void runKMeans() throws Exception { boolean converged = false; m_Iterations = 0; m_numBlankIterations = 0; m_Objective = Double.POSITIVE_INFINITY; if (!m_isOfflineMetric) { if (m_useMultipleMetrics) { for (int i = 0; i < m_metrics.length; i++) { m_metrics[i].resetMetric(); m_metricLearners[i].resetLearner(); } } else { m_metric.resetMetric(); m_metricLearner.resetLearner(); } // initialize max CL penalties if (m_ConstraintsHash.size() > 0) { m_maxCLPenalties = calculateMaxCLPenalties(); } } // initialize m_ClusterAssignments for (int i=0; i<m_NumClusters; i++) { m_ClusterAssignments[i] = -1; } PrintStream fincoh = null; if (m_ConstraintIncoherenceFile != null) { fincoh = new PrintStream(new FileOutputStream(m_ConstraintIncoherenceFile)); } while (!converged) { System.out.println("\n" + m_Iterations + ". Objective function: " + ((float)m_Objective)); m_OldObjective = m_Objective; // E-step int numMovedPoints = findBestAssignments(); m_numBlankIterations = (numMovedPoints == 0) ? m_numBlankIterations+1 : 0; // calculateObjectiveFunction(false); System.out.println((float)m_Objective + " - Objective function after point assignment(CALC)"); System.out.println("\tvar=" + ((float)m_objVariance) + "\tC=" + ((float)m_objCannotLinks) + "\tM=" + ((float)m_objMustLinks) + "\tLOG=" + ((float)m_objNormalizer) + "\tREG=" + ((float)m_objRegularizer)); // M-step updateClusterCentroids(); // calculateObjectiveFunction(false); System.out.println((float)m_Objective + " - Objective function after centroid estimation"); System.out.println("\tvar=" + ((float)m_objVariance) + "\tC=" + ((float)m_objCannotLinks) + "\tM=" + ((float)m_objMustLinks) + "\tLOG=" + ((float)m_objNormalizer) + "\tREG=" + ((float)m_objRegularizer)); if (m_Trainable == TRAINING_INTERNAL && !m_isOfflineMetric) { updateMetricWeights(); if (m_verbose) {
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -