pcsoftkmeans.java
来自「wekaUT是 university texas austin 开发的基于wek」· Java 代码 · 共 1,952 行 · 第 1/5 页
JAVA
1,952 行
* * @param labeledData labeled instances to be used as seeds * @param unlabeledData unlabeled instances * @param classIndex attribute index in labeledData which holds class info * @param numClusters number of clusters * @exception Exception if something goes wrong. */ public void buildClusterer(Instances labeledData, Instances unlabeledData, int classIndex, int numClusters) throws Exception { /// ----- NOT USED FOR PCSoftKMeans!!! ----- /// if (m_Algorithm == ALGORITHM_SPHERICAL) { for (int i=0; i<labeledData.numInstances(); i++) { normalize(labeledData.instance(i)); } for (int i=0; i<unlabeledData.numInstances(); i++) { normalize(unlabeledData.instance(i)); } } // remove labels of labeledData before putting in seedHash Instances clusterData = new Instances(labeledData); clusterData.deleteClassAttribute(); // create seedHash from labeledData if (m_Seedable) { Seeder seeder = new Seeder(clusterData, labeledData); setSeedHash(seeder.getAllSeeds()); } // add unlabeled data to labeled data (labels removed), not the // other way around, so that the hash table entries are consistent // with the labeled data without labels for (int i=0; i<unlabeledData.numInstances(); i++) { clusterData.add(unlabeledData.instance(i)); } if (m_verbose) { System.out.println("combinedData has size: " + clusterData.numInstances() + "\n"); } // learn metric using labeled data, then cluster both the labeled and unlabeled data m_metric.buildMetric(labeledData); m_metricBuilt = true; buildClusterer(clusterData, numClusters); } /** * Reset all values that have been learned */ public void resetClusterer() throws Exception{ if (m_metric instanceof LearnableMetric) { ((LearnableMetric)m_metric).resetMetric(); } m_SeedHash = null; m_ConstraintsHash = null; } /** Set default perturbation value * @param p perturbation fraction */ public void setDefaultPerturb(double p) { m_DefaultPerturb = p; } /** Get default perturbation value * @return perturbation fraction */ public double getDefaultPerturb(){ return m_DefaultPerturb; } /** Turn seeding on and off * @param seedable should seeding be done? */ public void setSeedable(boolean seedable) { m_Seedable = seedable; } /** Is seeding performed? * @return is seeding being done? */ public boolean getSeedable() { return m_Seedable; } /** * We can have clusterers that don't utilize seeding */ public boolean seedable() { return m_Seedable; } /** Creates the global cluster centroid */ protected void createCentroids() throws Exception { // initialize using m_NumCurrentClusters neighborhoods (< m_NumClusters), make random for rest System.out.println("Creating centroids"); if (m_verbose) System.out.println("Current number of clusters: " + m_NumCurrentClusters); // compute centroids of all clusters m_ClusterCentroids = new Instances(m_Instances, m_NumClusters); for (int i=0; i<m_NumCurrentClusters; i++) { if (m_SumOfClusterInstances[i] != null) { if (m_verbose) { System.out.println("Normalizing cluster center " + i); } if (!m_objFunDecreasing) { normalize(m_SumOfClusterInstances[i]); } else { normalizeByWeight(m_SumOfClusterInstances[i]); } } m_SumOfClusterInstances[i].setDataset(m_Instances); m_ClusterCentroids.add(m_SumOfClusterInstances[i]); } // fill up remaining by randomPerturbInit if (m_NumCurrentClusters < m_NumClusters) { // find global centroid System.out.println("Creating global centroid"); double [] globalValues = new double[m_Instances.numAttributes()]; if (m_FastMode && m_isSparseInstance) { globalValues = meanOrMode(m_Instances); // uses fast meanOrMode } else { for (int j = 0; j < m_Instances.numAttributes(); j++) { globalValues[j] = m_Instances.meanOrMode(j); // uses usual meanOrMode } } // global centroid is dense in SPKMeans m_GlobalCentroid = new Instance(1.0, globalValues); m_GlobalCentroid.setDataset(m_Instances); Random random = new Random(m_RandomSeed); // normalize before random perturbation if (!m_objFunDecreasing) { normalizeInstance(m_GlobalCentroid); } System.out.println("Creating " + (m_NumClusters - m_NumCurrentClusters) + " random centroids"); for (int i=m_NumCurrentClusters; i<m_NumClusters; i++) { double [] values = new double[m_Instances.numAttributes()]; double normalizer = 0; for (int j = 0; j < m_Instances.numAttributes(); j++) { values[j] = m_GlobalCentroid.value(j) * (1 + m_DefaultPerturb * (random.nextFloat() - 0.5)); normalizer += values[j] * values[j]; } if (!m_objFunDecreasing) { normalizer = Math.sqrt(normalizer); for (int j = 0; j < m_Instances.numAttributes(); j++) { values[j] /= normalizer; } } // values suitably normalized at this point if required if (m_isSparseInstance) { m_ClusterCentroids.add(new SparseInstance(1.0, values)); // sparse for consistency with other cluster centroids } else { m_ClusterCentroids.add(new Instance(1.0, values)); } } } System.out.println("Finished creating centroids"); m_NumCurrentClusters = m_NumClusters; } /** adding other inferred ML and CL links to m_ConstraintsHash, from * m_NeighborSets */ protected void addMLAndCLTransitiveClosure(int[] indices) throws Exception { // add all ML links within neighborhoods selected as clusters if (m_verbose) { for (int j=0; j<m_NumCurrentClusters; j++) { int i = j; if (indices != null) { i = indices[j]; } System.out.println("Neighborhood list " + j + " is:"); System.out.println(m_NeighborSets[i]); } } for (int j=0; j<m_NumCurrentClusters; j++) { int i = j; if (indices != null) { i = indices[j]; } if (m_NeighborSets[i] != null) { Iterator iter1 = m_NeighborSets[i].iterator(); while (iter1.hasNext()) { int first = ((Integer) iter1.next()).intValue(); Iterator iter2 = m_NeighborSets[i].iterator(); while (iter2.hasNext()) { int second = ((Integer) iter2.next()).intValue(); if (first < second) { InstancePair pair = new InstancePair(first, second, InstancePair.DONT_CARE_LINK); if (!m_ConstraintsHash.containsKey(pair)) { m_ConstraintsHash.put(pair, new Integer(InstancePair.MUST_LINK)); if (m_verbose) { System.out.println("Adding inferred ML link: " + pair); } if (!m_SeedHash.contains(new Integer(first))) { m_SeedHash.add(new Integer(first)); } if (!m_SeedHash.contains(new Integer(second))) { m_SeedHash.add(new Integer(second)); } } } } } } } // add all CL links between clusters for (int ii=0; ii<m_NumCurrentClusters; ii++) { int i = ii; if (indices != null) { i = indices[ii]; } if (m_NeighborSets[i] != null) { Iterator iter1 = m_NeighborSets[i].iterator(); while (iter1.hasNext()) { int index1 = ((Integer) iter1.next()).intValue(); for (int jj=ii+1; jj<m_NumCurrentClusters; jj++) { int j = jj; if (indices != null) { j = indices[jj]; } if (m_NeighborSets[j] != null) { Iterator iter2 = m_NeighborSets[j].iterator(); while (iter2.hasNext()) { int index2 = ((Integer) iter2.next()).intValue(); int first = (index1 < index2)? index1:index2; int second = (index1 >= index2)? index1:index2; if (first == second) { throw new Exception(" Same instance " + first + " cannot be in cluster: " + i + " and cluster " + j); } InstancePair pair = new InstancePair(first, second, InstancePair.DONT_CARE_LINK); if (!m_ConstraintsHash.containsKey(pair)) { m_ConstraintsHash.put(pair, new Integer(InstancePair.CANNOT_LINK)); if (m_verbose) { System.out.println("Adding inferred CL link: " + pair); } if (!m_SeedHash.contains(new Integer(first))) { m_SeedHash.add(new Integer(first)); } if (!m_SeedHash.contains(new Integer(second))) { m_SeedHash.add(new Integer(second)); } } } } } } } } } /** Main Depth First Search routine */ protected void DFS() throws Exception { int [] vertexColor = new int[m_Instances.numInstances()]; m_NumCurrentClusters = 0; for(int u=0; u<m_Instances.numInstances(); u++){ vertexColor[u] = WHITE; } for(int u=0; u<m_Instances.numInstances(); u++){ if (m_AdjacencyList[u] != null && vertexColor[u] == WHITE) { m_NeighborSets[m_NumCurrentClusters] = new HashSet(); DFS_VISIT(u, vertexColor); // finds whole neighbourhood of u m_NumCurrentClusters++; } } } /** Recursive subroutine for DFS */ protected void DFS_VISIT(int u, int[] vertexColor) throws Exception { vertexColor[u] = GRAY; Iterator iter = null; if (m_AdjacencyList[u] != null) { iter = m_AdjacencyList[u].iterator(); while (iter.hasNext()) { int j = ((Integer) iter.next()).intValue(); if(vertexColor[j] == WHITE){ // if the vertex is still undiscovered DFS_VISIT(j, vertexColor); } } } // update stats for u m_ClusterAssignments[u] = m_NumCurrentClusters; m_NeighborSets[m_NumCurrentClusters].add(new Integer(u)); m_SumOfClusterInstances[m_NumCurrentClusters] = sumWithInstance(m_SumOfClusterInstances[m_NumCurrentClusters], m_Instances.instance(u)); vertexColor[u] = BLACK; } /** Initialization routine for non-active algorithm */ protected void nonActivePairwiseInit() throws Exception { m_NeighborSets = new HashSet[m_Instances.numInstances()]; m_SumOfClusterInstances = new Instance[m_Instances.numInstances()]; m_AdjacencyList = new HashSet[m_Instances.numInstances()]; for (int i=0; i<m_Instances.numInstances(); i++) { m_ClusterAssignments[i] = -1; } if (m_ConstraintsHash != null) { Set pointPairs = (Set) m_ConstraintsHash.keySet(); Iterator pairItr = pointPairs.iterator(); System.out.println("In non-active init"); // iterate over the pairs in ConstraintHash to create Adjacency List while( pairItr.hasNext() ){ InstancePair pair = (InstancePair) pairItr.next(); int linkType = ((Integer) m_ConstraintsHash.get(pair)).intValue(); if (m_verbose) System.out.println(pair + ": type = " + linkType); if( linkType == InstancePair.MUST_LINK ){ // concerned with MUST-LINK in Adjacency List if (m_AdjacencyList[pair.first] == null) { m_AdjacencyList[pair.first] = new HashSet(); } if (!m_AdjacencyList[pair.first].contains(new Integer(pair.second))) { m_AdjacencyList[pair.first].add(new Integer(pair.second)); } if (m_AdjacencyList[pair.second] == null) { m_AdjacencyList[pair.second] = new HashSet(); } if (!m_AdjacencyList[pair.second].contains(new Integer(pair.first))) { m_AdjacencyList[pair.second].add(new Integer(pair.first)); } } } // DFS for finding connected components in Adjacency List, updates required stats DFS(); } if (!m_Seedable) { // don't perform any seeding, initialize from random m_NumCurrentClusters = 0; } // if the required number of clusters has been obtained, wrap-up if( m_NumCurrentClusters >= m_NumClusters ){ if (m_verbose) { System.out.println("Got the required number of clusters ..."); System.out.println("num clusters: " + m_NumClusters + ", num current clusters: " + m_NumCurrentClusters); } int clusterSizes[] = new int[m_NumCurrentClusters]; for (int i=0; i<m_NumCurrentClusters; i++) { if (m_verbose) { System.out.println("Neighbor set: " + i + " has size: " + m_NeighborSets[i].size()); } clusterSizes[i] = -m_NeighborSets[i].size(); // for reverse sort (decreasing order) } int[] indices = Utils.sort(clusterSizes); Instance[] clusterCentroids = new Instance[m_NumClusters]; // compute centroids of m_NumClusters clusters m_ClusterCentroids = new Instances(m_Instances, m_NumClusters); for (int j=0; j < m_NumClusters; j++) { int i = indices[j]; if (m_SumOfClusterInstances[i] != null) { if (m_verbose) { System.out.println("Normalizing instance " + i); } if (!m_objFunDecreasing) { normalize(m_SumOfClusterInstances[i]); } else { normalizeByWeight(m_SumOfClusterInstances[i]); } } Iterator iter = m_NeighborSets[i].iterator(); while (iter.hasNext()) { // assign points of new cluster int instNumber = ((Integer) iter.next()).intValue(); if (m_verbose) { System.out.println("Assigning " + instNumber + " to cluster: " + j); } // have to re-assign after sorting
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?