pcsoftkmeans.java
来自「wekaUT是 university texas austin 开发的基于wek」· Java 代码 · 共 1,952 行 · 第 1/5 页
JAVA
1,952 行
m_ClusterAssignments[instNumber] = j; } m_SumOfClusterInstances[j].setDataset(m_Instances); // m_SumOfClusterInstances suitably normalized now m_ClusterCentroids.add(m_SumOfClusterInstances[i]); } for (int j=m_NumClusters; j < m_NumCurrentClusters; j++) { int i = indices[j]; Iterator iter = m_NeighborSets[i].iterator(); while (iter.hasNext()) { int instNumber = ((Integer) iter.next()).intValue(); if (m_verbose) { System.out.println("Assigning " + instNumber + " to cluster -1"); } m_ClusterAssignments[instNumber] = -1; } } m_NumCurrentClusters = m_NumClusters; // adding other inferred ML and CL links addMLAndCLTransitiveClosure(indices); return; } else if( m_NumCurrentClusters < m_NumClusters ){ createCentroids(); addMLAndCLTransitiveClosure(null); } } /** This function divides every attribute value in an instance by * the instance weight -- useful to find the mean of a cluster in * Euclidean space * @param inst Instance passed in for normalization (destructive update) */ protected void normalizeByWeight(Instance inst) { double weight = inst.weight(); if (m_verbose) { System.out.println("Before weight normalization: " + inst); } if (inst instanceof SparseInstance) { for (int i=0; i<inst.numValues(); i++) { inst.setValueSparse(i, inst.valueSparse(i)/weight); } } else if (!(inst instanceof SparseInstance)) { for (int i=0; i<inst.numAttributes(); i++) { inst.setValue(i, inst.value(i)/weight); } } if (m_verbose) { System.out.println("After weight normalization: " + inst); } } /** Finds the sum of instance sum with instance inst */ Instance sumWithInstance(Instance sum, Instance inst) throws Exception { Instance newSum; if (sum == null) { if (m_isSparseInstance) { newSum = new SparseInstance(inst); newSum.setDataset(m_Instances); } else { newSum = new Instance(inst); newSum.setDataset(m_Instances); } } else { newSum = sumInstances(sum, inst); } return newSum; } /** Finds sum of 2 instances (handles sparse and non-sparse) */ protected Instance sumInstances(Instance inst1, Instance inst2) throws Exception { int numAttributes = inst1.numAttributes(); if (inst2.numAttributes() != numAttributes) { throw new Exception ("Error!! inst1 and inst2 should have same number of attributes."); } // if (m_verbose) { // System.out.println("Instance 1 is: " + inst1 + ", instance 2 is: " + inst2); // } double weight1 = inst1.weight(), weight2 = inst2.weight(); double [] values = new double[numAttributes]; Instance newInst; for (int i=0; i<numAttributes; i++) { values[i] = 0; } if (inst1 instanceof SparseInstance && inst2 instanceof SparseInstance) { for (int i=0; i<inst1.numValues(); i++) { int indexOfIndex = inst1.index(i); values[indexOfIndex] = inst1.valueSparse(i); } for (int i=0; i<inst2.numValues(); i++) { int indexOfIndex = inst2.index(i); values[indexOfIndex] += inst2.valueSparse(i); } newInst = new SparseInstance(weight1+weight2, values); newInst.setDataset(m_Instances); } else if (!(inst1 instanceof SparseInstance) && !(inst2 instanceof SparseInstance)){ for (int i=0; i<numAttributes; i++) { values[i] = inst1.value(i) + inst2.value(i); } newInst = new Instance(weight1+weight2, values); newInst.setDataset(m_Instances); } else { throw new Exception ("Error!! inst1 and inst2 should be both of same type -- sparse or non-sparse"); } // if (m_verbose) { // System.out.println("Sum instance is: " + newInst); // } return newInst; } /** Outputs the current clustering * * @exception Exception if something goes wrong */ public void printIndexClusters() throws Exception { for (int j = 0; j < m_NumClusters; j++) { System.out.println("Cluster " + j); for (int i=0; i<m_Instances.numInstances(); i++) { System.out.println("Point: " + i + ", prob: " + m_ClusterDistribution[i][j]); } } } /** E-step of the KMeans clustering algorithm -- find new cluster assignments and new objective function */ protected double findAssignments() throws Exception{ double m_Objective = 0; for (int i=0; i<m_Instances.numInstances(); i++) { Instance inst = m_Instances.instance(i); try { // Update cluster assignment probs m_Objective += assignInstanceToClustersWithConstraints(i); } catch (Exception e) { System.out.println("Could not find distance. Exception: " + e); e.printStackTrace(); } } return m_Objective; } /** * Classifies the instance using the current clustering considering * constraints, updates cluster assignment probs * * @param instance the instance to be assigned to a cluster * @exception Exception if instance could not be assigned to clusters * successfully */ public double assignInstanceToClustersWithConstraints(int instIdx) throws Exception { double objectiveForInstIdx = 0; for (int j = 0; j < m_NumClusters; j++) { if (!m_objFunDecreasing) { double sim = similarityInPottsModel(instIdx, j); m_ClusterDistribution[instIdx][j] = Math.exp(m_Kappa*sim); // similarity // System.out.println("Weight value for sim between instance " + instIdx + " and cluster " + j + " = " + m_ClusterDistribution[instIdx][j] + ", sim = " + sim); } else { double dist = squareDistanceInPottsModel(instIdx, j); m_ClusterDistribution[instIdx][j] = Math.exp(-m_Kappa*dist); // distance // System.out.println("Weight value for dist between instance " + instIdx + " and cluster " + j + " = " + m_ClusterDistribution[instIdx][j] + ", dist = " + dist); } } // System.out.println(); if (!m_objFunDecreasing) { objectiveForInstIdx = Math.log(Utils.sum(m_ClusterDistribution[instIdx])); } else { objectiveForInstIdx = -Math.log(Utils.sum(m_ClusterDistribution[instIdx])); } // normalize to get posterior probs of cluster assignment Utils.normalize(m_ClusterDistribution[instIdx]); if (m_verbose) { System.out.println("Obj component is: " + objectiveForInstIdx); System.out.println("Posteriors for instance: " + instIdx); for (int j = 0; j < m_NumClusters; j++) { System.out.print(m_ClusterDistribution[instIdx][j] + " "); } System.out.println(); } return objectiveForInstIdx; } /** finds similarity between instance and centroid in Potts Model with Relaxation Labeling */ double similarityInPottsModel(int instIdx, int centroidIdx) throws Exception{ double sim = m_metric.similarity(m_Instances.instance(instIdx), m_ClusterCentroids.instance(centroidIdx)); if (false) { if (m_ConstraintsHash != null) { HashMap cluster = m_IndexClusters[centroidIdx]; if (cluster != null) { Iterator iter = cluster.entrySet().iterator(); while(iter.hasNext()) { int j = ((Integer) iter.next()).intValue(); int first = (j<instIdx)? j:instIdx; int second = (j>=instIdx)? j:instIdx; InstancePair pair = new InstancePair(first, second, InstancePair.DONT_CARE_LINK); if (m_ConstraintsHash.containsKey(pair)) { int linkType = ((Integer) m_ConstraintsHash.get(pair)).intValue(); if (linkType == InstancePair.MUST_LINK) { // count up number of must-links satisfied, instead of number of must-links violated. So, add to sim if (m_verbose) { System.out.println("Found satisfied must link between: " + first + " and " + second); } sim += m_MustLinkWeight; } else if (linkType == InstancePair.CANNOT_LINK) { if (m_verbose) { System.out.println("Found violated cannot link between: " + first + " and " + second); } sim -= m_CannotLinkWeight; } } } } // end while } } return sim; } /** finds squaredistance between instance and centroid in Potts Model with Relaxation Labeling */ double squareDistanceInPottsModel(int instIdx, int centroidIdx) throws Exception{ double dist = m_metric.distance(m_Instances.instance(instIdx), m_ClusterCentroids.instance(centroidIdx)); dist *= dist; // doing the squaring here itself if(m_verbose) { System.out.println("Unconstrained distance between instance " + instIdx + " and centroid " + centroidIdx + " is: " + dist); } if (false) { if (m_ConstraintsHash != null) { HashMap cluster = m_IndexClusters[centroidIdx]; if (cluster != null) { Iterator iter = cluster.entrySet().iterator(); while(iter.hasNext()) { int j = ((Integer) iter.next()).intValue(); int first = (j < instIdx)? j:instIdx; int second = (j >= instIdx)? j:instIdx; InstancePair pair = new InstancePair(first, second, InstancePair.DONT_CARE_LINK); if (m_ConstraintsHash.containsKey(pair)) { int linkType = ((Integer) m_ConstraintsHash.get(pair)).intValue(); if (linkType == InstancePair.MUST_LINK) { // count up number of must-links satisfied, instead of number of must-links violated. So, subtract from dist if (m_verbose) { System.out.println("Found satisfied must link between: " + first + " and " + second); } dist -= m_MustLinkWeight; } else if (linkType == InstancePair.CANNOT_LINK) { if (m_verbose) { System.out.println("Found violated cannot link between: " + first + " and " + second); } dist += m_CannotLinkWeight; } } } } } } if(m_verbose) { System.out.println("Final distance between instance " + instIdx + " and centroid " + centroidIdx + " is: " + dist); } return dist; } /** M-step of the KMeans clustering algorithm -- updates cluster centroids */ protected void updateClusterCentroids() throws Exception { // M-step: update cluster centroids Instances [] tempI = new Instances[m_NumClusters]; m_ClusterCentroids = new Instances(m_Instances, m_NumClusters); for (int j = 0; j < m_NumClusters; j++) { tempI[j] = new Instances(m_Instances, 0); // tempI[j] stores the cluster instances for cluster j } for (int i = 0; i < m_Instances.numInstances(); i++) { for (int j = 0; j < m_NumClusters; j++) { tempI[j].add(m_Instances.instance(i), m_ClusterDistribution[i][j]); // instance weight holds the posterior prob of the instance in the cluster } } // Calculates cluster centroids for (int j = 0; j < m_NumClusters; j++) { double [] values = new double[m_Instances.numAttributes()]; if (m_FastMode && m_isSparseInstance) { values = meanOrMode(tempI[j]); // uses fast meanOrMode } else { for (int k = 0; k < m_Instances.numAttributes(); k++) { values[k] = tempI[j].meanOrMode(k); // uses usual meanOrMode } } // cluster centroids are dense in SPKMeans m_ClusterCentroids.add(new Instance(1.0, values)); if (m_Algorithm == ALGORITHM_SPHERICAL) { try { normalize(m_ClusterCentroids.instance(j)); } catch (Exception e) { e.printStackTrace(); } } } for (int j = 0; j < m_NumClusters; j++) { tempI[j] = null; // free memory for garbage collector to pick up } } /** Actual KMeans function */ protected void runEM() throws Exception { boolean converged = false; m_Iterations = 0; m_ClusterDistribution = new double [m_Instances.numInstances()][m_NumClusters]; // initialize cluster distribution from the cluster assignments // after initial neighborhoods have been built for (int i=0; i<m_Instances.numInstances(); i++) { for (int j=0; j<m_NumClusters; j++) { m_ClusterDistribution[i][j] = 0; } if (m_ClusterAssignments[i] != -1 && m_ClusterAssignments[i] < m_NumClusters) { m_ClusterDistribution[i][m_ClusterAssignments[i]] = 1; } } double oldObjective = m_objFunDecreasing ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY; while (!converged) { // E-step: updates m_Objective if (m_verbose) { System.out.println("Doing E-step ..."); } m_Objective = findAssignments(); // finds assignments and calculates objective function // M-step if (m_verbose) { System.out.println("Doing M-step ..."); } updateClusterCentroids(); m_Iterations++; // anneal the value of kappa if (!m_objFunDecreasing) { if (m_Kappa < m_MaxKappaSim) { m_Kappa *= 2; } } else { if (m_Kappa < m_MaxKappaDist) { m_Kappa += 2; } } // Convergence check if(Math.abs(oldObjective - m_Objective) > m_ObjFunConvergenceDifference) { System.out.println("Objective function: " + m_Objective + ", numIterations = " + m_Iterations); converged = false; } else { converged = true; System.out.println("Final Objective function is: " + m_Objective + ", numIterations = " + m_Iterations); } if ((!m_objFunDecreasing && oldObjective > m_Objective) || (m_objFunDecreasing && oldObjective < m_Objective)) { // throw new Exception("Oscillations => bug in objective function/EM step!!"); } oldObjective = m_Objective; } }
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?