📄 .#mpckmeans.java.1.106
字号:
} } else if (otherIdx == centroidIdx && pair.linkType == InstancePair.CANNOT_LINK) { double penalty = m_metrics[centroidIdx].penaltySymmetric(instance1, instance2); m_objCannotLinksCurrPoint += m_CLweight * (m_maxCLPenalties[centroidIdx] - penalty); if (m_maxCLPenalties[centroidIdx] - penalty < 0) { System.out.println("***NEGATIVE*** penalty: " + penalty + " for CL constraint"); } } } } } } double total = m_objVarianceCurrPoint + m_objCannotLinksCurrPoint + m_objMustLinksCurrPoint + m_objNormalizerCurrPoint; if(m_verbose) { System.out.println("Final penalty for instance " + instIdx + " and centroid " + centroidIdx + " is: " + total); } return total; } /** M-step of the KMeans clustering algorithm -- updates cluster centroids */ protected void updateClusterCentroids() throws Exception { // M-step: update cluster centroids Instances [] tempI = new Instances[m_NumClusters]; m_ClusterCentroids = new Instances(m_Instances, m_NumClusters); for (int i = 0; i < m_NumClusters; i++) { tempI[i] = new Instances(m_Instances, 0); // tempI[i] stores the cluster instances for cluster i } for (int i = 0; i < m_Instances.numInstances(); i++) { tempI[m_ClusterAssignments[i]].add(m_Instances.instance(i)); } // Calculates cluster centroids for (int i = 0; i < m_NumClusters; i++) { double [] values = new double[m_Instances.numAttributes()]; if (m_isSparseInstance) { values = ClusterUtils.meanOrMode(tempI[i]); // uses fast meanOrMode } else { for (int j = 0; j < m_Instances.numAttributes(); j++) { values[j] = tempI[i].meanOrMode(j); // uses usual meanOrMode } } // cluster centroids are dense in SPKMeans Instance centroid = new Instance(1.0, values); // if we are using a smoothing metric, smooth the centroids if (m_metric instanceof SmoothingMetric && ((SmoothingMetric) m_metric).getUseSmoothing()) { SmoothingMetric smoothingMetric = (SmoothingMetric) m_metric; centroid = smoothingMetric.smoothInstance(centroid); } m_ClusterCentroids.add(centroid); // in SPKMeans, cluster centroids need to be normalized if (m_metric.doesNormalizeData()) { m_metric.normalizeInstanceWeighted(m_ClusterCentroids.instance(i)); } } if (m_metric instanceof SmoothingMetric && ((SmoothingMetric) m_metric).getUseSmoothing()) updateSmoothingMetrics(); for (int i = 0; i < m_NumClusters; i++) tempI[i] = null; // free memory } /** M-step of the KMeans clustering algorithm -- updates metric * weights. Invoked only when we're using non-Potts model * and metric is trainable */ protected void updateMetricWeights() throws Exception { if (m_useMultipleMetrics) { for (int i = 0; i < m_NumClusters; i++) { m_metricLearners[i].trainMetric(i); } } else { m_metricLearner.trainMetric(-1); } InitNormalizerRegularizer(); } /** checks for convergence */ public boolean convergenceCheck(double oldObjective, double newObjective) throws Exception { boolean converged = false; // Convergence check if(Math.abs(oldObjective - newObjective) < m_ObjFunConvergenceDifference) { System.out.println("Final objective function is: " + newObjective); converged = true; } // number of iterations check if (m_numBlankIterations >= m_maxBlankIterations) { System.out.println("Max blank iterations reached ...\n"); converged = true; } if (m_Iterations >= m_maxIterations) { System.out.println("Max iterations reached ...\n"); converged = true; } return converged; } /** calculates objective function */ public double calculateObjectiveFunction(boolean isComplete) throws Exception { if (m_verbose) { System.out.println("Calculating objective function ..."); } // update the oldObjective only if previous estimate of m_Objective // was complete if (isComplete) { m_OldObjective = m_Objective; } m_Objective = 0; m_objVariance = 0; m_objMustLinks = 0; m_objCannotLinks = 0; m_objNormalizer = 0; // temporarily halve weights since every constraint is counted twice double tempML = m_MLweight; double tempCL = m_CLweight; m_MLweight = tempML/2; m_CLweight = tempCL/2; if (m_verbose) { System.out.println("Must link weight: " + m_MLweight); System.out.println("Cannot link weight: " + m_CLweight); } for (int i=0; i<m_Instances.numInstances(); i++) { if (m_isOfflineMetric) { double dist = m_metric.penalty(m_Instances.instance(i), m_ClusterCentroids.instance(m_ClusterAssignments[i])); m_Objective += dist; if (m_verbose) { System.out.println("Component for " + i + " = " + dist); } } else { m_Objective += penaltyForInstance(i, m_ClusterAssignments[i]); m_objVariance += m_objVarianceCurrPoint; m_objMustLinks += m_objMustLinksCurrPoint; m_objCannotLinks += m_objCannotLinksCurrPoint; m_objNormalizer += m_objNormalizerCurrPoint; } } m_Objective -= m_objRegularizer; m_MLweight = tempML; m_CLweight = tempCL; // reset the values of the constraint weights // Oscillation check if ((float)m_OldObjective < (float)m_Objective) { System.out.println("WHOA!!! Oscillations => bug in EM step?"); System.out.println("Old objective:" + (float)m_OldObjective + " < New objective: " + (float)m_Objective); } return m_Objective; } /** Actual KMeans function */ protected void runKMeans() throws Exception { boolean converged = false; m_Iterations = 0; m_numBlankIterations = 0; m_Objective = Double.POSITIVE_INFINITY; if (!m_isOfflineMetric) { // initialize normalizer and regularizer m_metric.resetMetric(); // initialize max CL penalties if (m_ConstraintsHash.size() > 0) { m_maxCLPenalties = calculateMaxCLPenalties(); } } while (!converged) { m_OldObjective = m_Objective; // E-step int numMovedPoints = findBestAssignments(); m_numBlankIterations = (numMovedPoints == 0) ? m_numBlankIterations+1 : 0; {//if (m_verbose) { calculateObjectiveFunction(false); } System.out.println((float)m_Objective + " - Objective function after point assignment(CALC)"); System.out.println("\tvar=" + ((float)m_objVariance) + "\tC=" + ((float)m_objCannotLinks) + "\tM=" + ((float)m_objMustLinks) + "\tLOG=" + ((float)m_objNormalizer) + "\tREG=" + ((float)m_objRegularizer)); // M-step updateClusterCentroids(); System.out.println("\n" + m_Iterations + ". Objective function: " + ((float)m_Objective)); {//if (m_verbose) { calculateObjectiveFunction(true); } System.out.println((float)m_Objective + " - Objective function after centroid estimation"); System.out.println("\tvar=" + ((float)m_objVariance) + "\tC=" + ((float)m_objCannotLinks) + "\tM=" + ((float)m_objMustLinks) + "\tLOG=" + ((float)m_objNormalizer) + "\tREG=" + ((float)m_objRegularizer)); if (m_Trainable == TRAINING_INTERNAL && !m_isOfflineMetric) { updateMetricWeights(); {//if (m_verbose) { calculateObjectiveFunction(true); } System.out.println((float)m_Objective + " - Objective function after metric update"); System.out.println("\tvar=" + ((float)m_objVariance) + "\tC=" + ((float)m_objCannotLinks) + "\tM=" + ((float)m_objMustLinks) + "\tLOG=" + ((float)m_objNormalizer) + "\tREG=" + ((float)m_objRegularizer)); if (m_ConstraintsHash.size() > 0) { m_maxCLPenalties = calculateMaxCLPenalties(); } } converged = convergenceCheck(m_OldObjective, m_Objective); m_Iterations++; } if (m_verbose) { System.out.println("Done clustering; top cluster features: "); for (int i = 0; i < m_NumClusters; i++){ System.out.println("Centroid " + i); TreeMap map = new TreeMap(Collections.reverseOrder()); Instance centroid= m_ClusterCentroids.instance(i); for (int j = 0; j < centroid.numValues(); j++) { Attribute attr = centroid.attributeSparse(j); map.put(new Double(centroid.value(attr)), attr.name()); } Iterator it = map.entrySet().iterator(); for (int j=0; j < 5 && it.hasNext(); j++) { Map.Entry entry = (Map.Entry) it.next(); System.out.println("\t" + entry.getKey() + "\t" + entry.getValue()); } } } } /** reset the value of the objective function and all of its components */ public void resetObjective() { m_Objective = 0; m_objVariance = 0; m_objCannotLinks = 0; m_objMustLinks = 0; m_objNormalizer = 0; m_objRegularizer = 0; } /** Go through the cannot-link constraints and find the current maximum distance * @return an array of maximum weighted distances. If a single metric is used, maximum distance * is calculated over the entire dataset */ // TODO: non-datasetWide case is not debugged currently!!! protected double[] calculateMaxCLPenalties() throws Exception { double [] maxPenalties = null; double [][] minValues = null; double [][] maxValues = null; int[] attrIdxs = null; maxPenalties = new double[m_NumClusters]; m_maxCLPoints = new Instance[m_NumClusters][2]; m_maxCLDiffInstances = new Instance[m_NumClusters]; for (int i = 0; i < m_NumClusters; i++) { m_maxCLPoints[i][0] = new Instance(m_Instances.numAttributes()); m_maxCLPoints[i][1] = new Instance(m_Instances.numAttributes()); m_maxCLPoints[i][0].setDataset(m_Instances); m_maxCLPoints[i][1].setDataset(m_Instances); m_maxCLDiffInstances[i] = new Instance(m_Instances.numAttributes()); m_maxCLDiffInstances[i].setDataset(m_Instances); } // TEMPORARY PLUG: this was supposed to take care of WeightedDotp, // but it turns out that with weighting similarity can be > 1. // if (m_metric.m_fixedMaxDistance) {// for (int i = 0; i < m_NumClusters; i++) {// maxPenalties[i] = m_metric.getMaxDistance(); // }// return maxPenalties; // } minValues = new double[m_NumClusters][m_metrics[0].getNumAttributes()]; maxValues = new double[m_NumClusters][m_metrics[0].getNumAttributes()]; attrIdxs = m_metrics[0].getAttrIndxs(); // temporary plug: if this if the first iteration when no instances were assigned to clusters, // dataset-wide (not cluster-wide!) minimum and maximum are used even for the case with // multiple metrics boolean datasetWide = true; if (m_useMultipleMetrics && m_Iterations > 0) { datasetWide = false; } // TODO: Mahalanobis - check with getMaxPoints // go through all points // if (m_useMultipleMetrics) {// for (int i = 0; i < m_metrics.length; i++) { // double[][] maxPoints = ((WeightedMahalanobis)m_metrics[i]).getMaxPoints(m_ConstraintsHash, m_Instances);// minValues[i] = maxPoints[0];// maxValues[i] = maxPoints[1];// // System.out.println("Max points " + i);// // for (int j = 0; j < maxPoints[0].length; j++) { System.out.println(maxPoints[0][j] + " - " + maxPoints[1][j]);}// }// } else { // double[][] maxPoints = ((WeightedMahalanobis)m_metric).getMaxPoints(m_ConstraintsHash, m_Instances);// minValues[0] = maxPoints[0];// maxValues[0] = maxPoints[1];// for (int i = 0; i < m_metrics.length; i++) {// minValues[i] = maxPoints[0];// maxValues[i] = maxPoints[1];// }// // System.out.println("Max points:");// // for (int i = 0; i < maxPoints[0].length; i++) { System.out.println(maxPoints[0][i] + " - " + maxPoints[1][i]);}// }// } else { // find the enclosing hypercube for WeightedEuclidean etc. for (int i = 0; i < m_Instances.numInstances(); i++) { Instance instance = m_Instances.instance(i); for (int j = 0; j < attrIdxs.length; j++) { double val = instance.value(attrIdxs[j]); if (datasetWide) { if (val < minValues[0][j]) { minValues[0][j] = val; } if (val > maxValues[0][j]) { maxValues[0][j] = val; } } else { // cluster-specific min's and max's are needed if (val < minValues[m_ClusterAssignments[i]][j]) { minValues[m_ClusterAssignments[i]][j] = val; } if (val > maxValues[m_ClusterAssignments[i]][j]) { maxValues[m_ClusterAssignments[i]][j] = val; } } } } // get the max/min points if (datasetWide) { for (int i = 0; i < attrIdxs.length; i++) { m_maxCLPoints[0][0].setValue(attrIdxs[i], minValues[0][i]); m_maxCLPoints[0][1].setValue(attrIdxs[i], maxValues[0][i]); } // must copy these over all clusters - just for the first iteration for (int j = 1; j < m_NumClusters; j++) { for (int i = 0; i < attrIdxs.length; i++) { m_maxCLPoints[j][0].setValue(attrIdxs[i], minValues[0][i]); m_maxCLPoints[j][1].setValue(attrIdxs[i], maxValues[0][i]); } } } else { // cluster-specific for (int j = 0; j < m_NumClusters; j++) { for (int i = 0; i < attrIdxs.length; i++) { m_maxCLPoints[j][0].setValue(attrIdxs[i], minValues[j][i]); m_maxCLPoints[j][1].setValue(attrIdxs[i], maxValues[j][i]); } } } // calculate the distances if (datasetWide) { maxPenalties[0] = m_metrics[0].penaltySymmetric(m_maxCLPoints[0][0], m_maxCLPoints[0][1]); m_maxCLDiffInstances[0] = m_metrics[0].createDiffInstance(m_maxCLPoints[0][0], m_maxCLPoints[0][1]); for (int i = 1; i < maxPenalties.length; i++) { maxPenalties[i] = maxPenalties[0]; m_maxCLDiffInstances[i] = m_maxCLDiffInstances[0]; } } else { // cluster-specific - SHOULD BE FIXED!!!! for (int j = 0; j < m_NumClusters; j++) { for (int i = 0; i < attrIdxs.length; i++) { maxPenalties[j] += m_metrics[j].penaltySymmetric(m_maxCLPoints[j][0],
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -