📄 mpckmeans.java
字号:
for (int i = 0; i < labeledPairs.size(); i++) { InstancePair pair = (InstancePair) labeledPairs.get(i); Integer firstInt = new Integer(pair.first); Integer secondInt = new Integer(pair.second); // for first point if(!m_SeedHash.contains(firstInt)) { // add instances with constraints to seedHash if (m_verbose) { System.out.println("Adding " + firstInt + " to seedHash"); } m_SeedHash.add(firstInt); } // for second point if(!m_SeedHash.contains(secondInt)) { m_SeedHash.add(secondInt); if (m_verbose) { System.out.println("Adding " + secondInt + " to seedHash"); } } if (pair.first >= pair.second) { throw new Exception("Ordering reversed - something wrong!!"); } else { InstancePair newPair = null; newPair = new InstancePair(pair.first, pair.second, InstancePair.DONT_CARE_LINK); m_ConstraintsHash.put(newPair, new Integer(pair.linkType)); // WLOG first < second if (m_verbose) { System.out.println("Adding constraint (" + pair.first +","+pair.second+"), " + pair.linkType); } // hash the constraints for the instances involved Object constraintList1 = m_instanceConstraintHash.get(firstInt); if (constraintList1 == null) { ArrayList constraintList = new ArrayList(); constraintList.add(pair); m_instanceConstraintHash.put(firstInt, constraintList); } else { ((ArrayList)constraintList1).add(pair); } Object constraintList2 = m_instanceConstraintHash.get(secondInt); if (constraintList2 == null) { ArrayList constraintList = new ArrayList(); constraintList.add(pair); m_instanceConstraintHash.put(secondInt, constraintList); } else { ((ArrayList)constraintList2).add(pair); } } } } m_StartingIndexOfTest = startingIndexOfTest; if (m_verbose) { System.out.println("Starting index of test: " + m_StartingIndexOfTest); } // learn metric using labeled data, // then cluster both the labeled and unlabeled data System.out.println("Initializing metric: " + m_metric); m_metric.buildMetric(unlabeledData); m_metricBuilt = true; m_metricLearner.setMetric(m_metric); m_metricLearner.setClusterer(this); // normalize all data for SPKMeans if (m_metric.doesNormalizeData()) { for (int i=0; i<unlabeledData.numInstances(); i++) { m_metric.normalizeInstanceWeighted(unlabeledData.instance(i)); } } // either create a new metric if multiple metrics, // or just point them all to m_metric m_metrics = new LearnableMetric[numClusters]; m_metricLearners = new MPCKMeansMetricLearner[numClusters]; for (int i = 0; i < m_metrics.length; i++) { if (m_useMultipleMetrics) { m_metrics[i] = (LearnableMetric) m_metric.clone(); m_metricLearners[i] = (MPCKMeansMetricLearner) m_metricLearner.clone(); m_metricLearners[i].setMetric(m_metrics[i]); m_metricLearners[i].setClusterer(this); } else { m_metrics[i] = m_metric; m_metricLearners[i] = m_metricLearner; } } buildClusterer(unlabeledData, numClusters); } /** * Generates a clusterer. Instances in data have to be * either all sparse or all non-sparse * * @param data set of instances serving as training data * @exception Exception if the clusterer has not been * generated successfully */ public void buildClusterer(Instances data) throws Exception { System.out.println("ML weight=" + m_MLweight); System.out.println("CL weight= " + m_CLweight); System.out.println("LOG term weight=" + m_logTermWeight); System.out.println("Regularizer weight= " + m_regularizerTermWeight); m_RandomNumberGenerator = new Random(m_RandomSeed); if (m_metric instanceof OfflineLearnableMetric) { m_isOfflineMetric = true; } else { m_isOfflineMetric = false; } // Don't rebuild the metric if it was already trained if (!m_metricBuilt) { m_metric.buildMetric(data); m_metricBuilt = true; m_metricLearner.setMetric(m_metric); m_metricLearner.setClusterer(this); m_metrics = new LearnableMetric[m_NumClusters]; m_metricLearners = new MPCKMeansMetricLearner[m_NumClusters]; for (int i = 0; i < m_metrics.length; i++) { if (m_useMultipleMetrics) { m_metrics[i] = (LearnableMetric) m_metric.clone(); m_metricLearners[i] = (MPCKMeansMetricLearner) m_metricLearner.clone(); m_metricLearners[i].setMetric(m_metrics[i]); m_metricLearners[i].setClusterer(this); } else { m_metrics[i] = m_metric; m_metricLearners[i] = m_metricLearner; } } } setInstances(data); m_ClusterCentroids = new Instances(m_Instances, m_NumClusters); m_ClusterAssignments = new int [m_Instances.numInstances()]; if (m_Instances.checkForNominalAttributes() && m_Instances.checkForStringAttributes()) { throw new UnsupportedAttributeTypeException("Cannot handle nominal attributes\n"); } m_ClusterCentroids = m_Initializer.initialize(); // if all instances are smoothed by the metric, the centroids // need to be smoothed too (note that this is independent of // centroid smoothing performed by K-Means) if (m_metric instanceof InstanceConverter) { System.out.println("Converting centroids..."); Instances convertedCentroids = new Instances(m_ClusterCentroids, m_NumClusters); for (int i = 0; i < m_ClusterCentroids.numInstances(); i++) { Instance centroid = m_ClusterCentroids.instance(i); convertedCentroids.add(((InstanceConverter)m_metric).convertInstance(centroid)); } m_ClusterCentroids.delete(); for (int i = 0; i < convertedCentroids.numInstances(); i++) { m_ClusterCentroids.add(convertedCentroids.instance(i)); } } System.out.println("Done initializing clustering ..."); getIndexClusters(); if (m_verbose && m_Seedable) { printIndexClusters(); for (int i=0; i<m_NumClusters; i++) { System.out.println("Centroid " + i + ": " + m_ClusterCentroids.instance(i)); } } // Some extra work for smoothing metrics if (m_metric instanceof SmoothingMetric && ((SmoothingMetric) m_metric).getUseSmoothing()) { SmoothingMetric smoothingMetric = (SmoothingMetric) m_metric; Instances smoothedCentroids = new Instances(m_Instances, m_NumClusters); for (int i = 0; i < m_ClusterCentroids.numInstances(); i++) { Instance smoothedCentroid = smoothingMetric.smoothInstance(m_ClusterCentroids.instance(i)); smoothedCentroids.add(smoothedCentroid); } m_ClusterCentroids = smoothedCentroids; updateSmoothingMetrics(); } runKMeans(); } protected void updateSmoothingMetrics() { if (m_useMultipleMetrics) { for (int i = 0; i < m_NumClusters; i++) { ((SmoothingMetric)m_metrics[i]).updateAlpha(); } } else { ((SmoothingMetric)m_metric).updateAlpha(); } } /** * Reset all values that have been learned */ public void resetClusterer() throws Exception{ m_metric.resetMetric(); if (m_useMultipleMetrics) { for (int i = 0; i < m_metrics.length; i++) { m_metrics[i].resetMetric(); } } m_SeedHash = null; m_ConstraintsHash = null; m_instanceConstraintHash = null; } /** Turn seeding on and off * @param seedable should seeding be done? */ public void setSeedable(boolean seedable) { m_Seedable = seedable; } /** Turn metric learning on and off * @param trainable should metric learning be done? */ public void setTrainable(SelectedTag trainable) { if (trainable.getTags() == TAGS_TRAINING) { if (m_verbose) { System.out.println("Trainable: " + trainable.getSelectedTag().getReadable()); } m_Trainable = trainable.getSelectedTag().getID(); } } /** Is seeding performed? * @return is seeding being done? */ public boolean getSeedable() { return m_Seedable; } /** Is metric learning performed? * @return is metric learning being done? */ public SelectedTag getTrainable() { return new SelectedTag(m_Trainable, TAGS_TRAINING); } /** * We can have clusterers that don't utilize seeding */ public boolean seedable() { return m_Seedable; } /** Outputs the current clustering * * @exception Exception if something goes wrong */ public void printIndexClusters() throws Exception { if (m_IndexClusters == null) throw new Exception ("Clusters were not created"); for (int i = 0; i < m_NumClusters; i++) { HashSet cluster = m_IndexClusters[i]; if (cluster == null) { System.out.println("Cluster " + i + " is null"); } else { System.out.println ("Cluster " + i + " consists of " + cluster.size() + " elements"); Iterator iter = cluster.iterator(); while(iter.hasNext()) { int idx = ((Integer) iter.next()).intValue(); Instance inst = m_TotalTrainWithLabels.instance(idx); if (m_TotalTrainWithLabels.classIndex() >= 0) { System.out.println("\t\t" + idx + ":" + inst.classAttribute().value((int) inst.classValue())); } } } } } /** E-step of the KMeans clustering algorithm -- find best cluster * assignments. Returns the number of points moved in this step */ protected int findBestAssignments() throws Exception { int moved = 0; double distance = 0; m_Objective = 0; m_objVariance = 0; m_objCannotLinks = 0; m_objMustLinks = 0; m_objNormalizer = 0; // Initialize the regularizer and normalizer hashes InitNormalizerRegularizer(); if (m_isOfflineMetric) { moved = assignAllInstancesToClusters(); } else { moved = assignPoints(); } if (m_verbose) { System.out.println(" " + moved + " points moved in this E-step"); } return moved; } /** Initialize m_logTerms and m_regularizerTerms */ protected void InitNormalizerRegularizer() { m_logTerms = new double[m_NumClusters]; m_objRegularizer = 0; if (m_useMultipleMetrics) { for (int i = 0; i < m_NumClusters; i++) { m_logTerms[i] = m_logTermWeight * m_metrics[i].getNormalizer(); if (m_regularize) { m_objRegularizer += m_regularizerTermWeight * m_metrics[i].regularizer(); } } } else { // we fill the logTerms with the log(det) of the only weight matrix m_logTerms[0] = m_logTermWeight * m_metric.getNormalizer(); for (int i = 1; i < m_logTerms.length; i++) { m_logTerms[i] = m_logTerms[0]; } if (m_regularize) { m_objRegularizer = m_regularizerTermWeight * m_metric.regularizer(); } } } /** Decides which assignment strategy to use based on argument passed in */ int assignPoints() throws Exception { int moved = 0; moved = m_Assigner.assign(); m_Objective = m_objVariance + m_objMustLinks + m_objCannotLinks + m_objNormalizer - m_objRegularizer; if (m_verbose) { System.out.println((float)m_Objective + " - Objective function (incomplete) after assignment"); System.out.println("\tvar=" + ((float)m_objVariance) + "\tC=" + ((float)m_objCannotLinks) + "\tM=" + ((float)m_objMustLinks) + "\tLOG=" + ((float)m_objNormalizer) + "\tREG=" + ((float)m_objRegularizer)); } // TODO: add a m_fast switch and put the following line inside it. // calculateObjectiveFunction(); return moved; } /** * Classifies the instance using the current clustering, considering constraints * * @param instance the instance to be assigned to a cluster * @return the number of the assigned cluster as an integer if the * class is enumerated, otherwise the predicted value * @exception Exception if instance could not be classified * successfully */ public int assignInstanceToClusterWithConstraints(int instIdx) throws Exception { int bestCluster = 0; double lowestPenalty = Double.MAX_VALUE; int moved = 0; // try each cluster and find one with lowest penalty for (int i = 0; i < m_NumClusters; i++) { double penalty = penaltyForInstance(instIdx, i); if (penalty < lowestPenalty) { lowestPenalty = penalty; bestCluster = i; m_objVarianceCurrPointBest = m_objVarianceCurrPoint; m_objNormalizerCurrPointBest = m_objNormalizerCurrPoint; m_objMustLinksCurrPointBest = m_objMustLinksCurrPoint; m_objCannotLinksCurrPointBest = m_objCannotLinksCurrPoint; } } m_objVariance += m_objVarianceCurrPointBest; m_objNormalizer += m_objNormalizerCurrPointBest; m_objMustLinks += m_objMustLinksCurrPointBest; m_objCannotLinks += m_objCannotLinksCurrPointBest;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -