pcsoftkmeans.java

来自「wekaUT是 university texas austin 开发的基于wek」· Java 代码 · 共 1,952 行 · 第 1/5 页

JAVA
1,952
字号
	  m_ClusterAssignments[instNumber] = j;	}	m_SumOfClusterInstances[j].setDataset(m_Instances);	// m_SumOfClusterInstances suitably normalized now	m_ClusterCentroids.add(m_SumOfClusterInstances[i]);      }      for (int j=m_NumClusters; j < m_NumCurrentClusters; j++) {	int i = indices[j];	Iterator iter = m_NeighborSets[i].iterator();	while (iter.hasNext()) {	  int instNumber = ((Integer) iter.next()).intValue();	  if (m_verbose) {	    System.out.println("Assigning " + instNumber + " to cluster -1");	  }	  m_ClusterAssignments[instNumber] = -1;	}      }      m_NumCurrentClusters = m_NumClusters;      // adding other inferred ML and CL links      addMLAndCLTransitiveClosure(indices);      return;    }    else if( m_NumCurrentClusters < m_NumClusters ){      createCentroids();      addMLAndCLTransitiveClosure(null);    }  }  /** This function divides every attribute value in an instance by   *  the instance weight -- useful to find the mean of a cluster in   *  Euclidean space    *  @param inst Instance passed in for normalization (destructive update)   */  protected void normalizeByWeight(Instance inst) {    double weight = inst.weight();    if (m_verbose) {      System.out.println("Before weight normalization: " + inst);    }    if (inst instanceof SparseInstance) {       for (int i=0; i<inst.numValues(); i++) {	inst.setValueSparse(i, inst.valueSparse(i)/weight);      }    }    else if (!(inst instanceof SparseInstance)) {      for (int i=0; i<inst.numAttributes(); i++) {	inst.setValue(i, inst.value(i)/weight);      }    }    if (m_verbose) {      System.out.println("After weight normalization: " + inst);    }  }  /** Finds the sum of instance sum with instance inst    */  Instance sumWithInstance(Instance sum, Instance inst) throws Exception {    Instance newSum;    if (sum == null) {      if (m_isSparseInstance) {	newSum = new SparseInstance(inst);	newSum.setDataset(m_Instances);      }      else {	newSum = new Instance(inst);	newSum.setDataset(m_Instances);      }    }    else {      newSum = sumInstances(sum, inst);    }    return newSum;  }  /** Finds sum of 2 instances (handles sparse and non-sparse)   */  protected Instance sumInstances(Instance inst1, Instance inst2) throws Exception {    int numAttributes = inst1.numAttributes();    if (inst2.numAttributes() != numAttributes) {      throw new Exception ("Error!! inst1 and inst2 should have same number of attributes.");    }    //      if (m_verbose) {    //        System.out.println("Instance 1 is: " + inst1 + ", instance 2 is: " + inst2);    //      }    double weight1 = inst1.weight(), weight2 = inst2.weight();    double [] values = new double[numAttributes];    Instance newInst;        for (int i=0; i<numAttributes; i++) {      values[i] = 0;    }        if (inst1 instanceof SparseInstance && inst2 instanceof SparseInstance) {      for (int i=0; i<inst1.numValues(); i++) {	int indexOfIndex = inst1.index(i);	values[indexOfIndex] = inst1.valueSparse(i);      }      for (int i=0; i<inst2.numValues(); i++) {	int indexOfIndex = inst2.index(i);	values[indexOfIndex] += inst2.valueSparse(i);      }      newInst = new SparseInstance(weight1+weight2, values);      newInst.setDataset(m_Instances);    }    else if (!(inst1 instanceof SparseInstance) && !(inst2 instanceof SparseInstance)){      for (int i=0; i<numAttributes; i++) {	values[i] = inst1.value(i) + inst2.value(i);      }      newInst = new Instance(weight1+weight2, values);      newInst.setDataset(m_Instances);    }    else {      throw new Exception ("Error!! inst1 and inst2 should be both of same type -- sparse or non-sparse");    }    //      if (m_verbose) {    //        System.out.println("Sum instance is: " + newInst);    //      }    return newInst;  }  /** Outputs the current clustering   *   * @exception Exception if something goes wrong   */  public void printIndexClusters() throws Exception {    for (int j = 0; j < m_NumClusters; j++) {      System.out.println("Cluster " + j);      for (int i=0; i<m_Instances.numInstances(); i++) {	System.out.println("Point: " + i + ", prob: " + m_ClusterDistribution[i][j]);      }    }  }  /** E-step of the KMeans clustering algorithm -- find new cluster   assignments and new objective function   */  protected double findAssignments() throws Exception{    double m_Objective = 0;    for (int i=0; i<m_Instances.numInstances(); i++) {      Instance inst = m_Instances.instance(i);            try {	// Update cluster assignment probs	m_Objective += assignInstanceToClustersWithConstraints(i);      }      catch (Exception e) {	System.out.println("Could not find distance. Exception: " + e);	e.printStackTrace();      }    }    return m_Objective;  }  /**   * Classifies the instance using the current clustering considering   * constraints, updates cluster assignment probs   *   * @param instance the instance to be assigned to a cluster   * @exception Exception if instance could not be assigned to clusters   * successfully */  public double assignInstanceToClustersWithConstraints(int instIdx) throws Exception {    double objectiveForInstIdx = 0;    for (int j = 0; j < m_NumClusters; j++) {      if (!m_objFunDecreasing) {	double sim = similarityInPottsModel(instIdx, j);	m_ClusterDistribution[instIdx][j] = Math.exp(m_Kappa*sim); // similarity		//	System.out.println("Weight value for sim between instance " + instIdx + " and cluster " + j + " = " + m_ClusterDistribution[instIdx][j] + ", sim = " + sim);      } else {	double dist = squareDistanceInPottsModel(instIdx, j);	m_ClusterDistribution[instIdx][j] = Math.exp(-m_Kappa*dist); // distance	//	System.out.println("Weight value for dist between instance " + instIdx + " and cluster " + j + " = " + m_ClusterDistribution[instIdx][j] + ", dist = " + dist);      }    }    //    System.out.println();        if (!m_objFunDecreasing) {      objectiveForInstIdx = Math.log(Utils.sum(m_ClusterDistribution[instIdx]));    } else {      objectiveForInstIdx = -Math.log(Utils.sum(m_ClusterDistribution[instIdx]));    }    // normalize to get posterior probs of cluster assignment    Utils.normalize(m_ClusterDistribution[instIdx]);    if (m_verbose) {      System.out.println("Obj component is: " + objectiveForInstIdx);      System.out.println("Posteriors for instance: " + instIdx);      for (int j = 0; j < m_NumClusters; j++) {	System.out.print(m_ClusterDistribution[instIdx][j] + " ");      }      System.out.println();    }    return objectiveForInstIdx;  }  /** finds similarity between instance and centroid in Potts Model with Relaxation Labeling   */  double similarityInPottsModel(int instIdx, int centroidIdx) throws Exception{    double sim = m_metric.similarity(m_Instances.instance(instIdx), m_ClusterCentroids.instance(centroidIdx));    if (false) {    if (m_ConstraintsHash != null) {      HashMap cluster = m_IndexClusters[centroidIdx];      if (cluster != null) {	Iterator iter = cluster.entrySet().iterator();	while(iter.hasNext()) {	  int j = ((Integer) iter.next()).intValue();	  int first = (j<instIdx)? j:instIdx;	  int second = (j>=instIdx)? j:instIdx;	  InstancePair pair = new InstancePair(first, second, InstancePair.DONT_CARE_LINK);	  if (m_ConstraintsHash.containsKey(pair)) {	    int linkType = ((Integer) m_ConstraintsHash.get(pair)).intValue();	    if (linkType == InstancePair.MUST_LINK) {  // count up number of must-links satisfied, instead of number of must-links violated. So, add to sim	      if (m_verbose) {		System.out.println("Found satisfied must link between: " + first + " and " + second);	      }	      sim += m_MustLinkWeight;	    }	    else if (linkType == InstancePair.CANNOT_LINK) {	      if (m_verbose) {		System.out.println("Found violated cannot link between: " + first + " and " + second);	      }	      sim -= m_CannotLinkWeight;	    }	  }	}      } // end while    }    }    return sim;  }  /** finds squaredistance between instance and centroid in Potts Model with Relaxation Labeling   */  double squareDistanceInPottsModel(int instIdx, int centroidIdx) throws Exception{    double dist = m_metric.distance(m_Instances.instance(instIdx), m_ClusterCentroids.instance(centroidIdx));    dist *= dist; // doing the squaring here itself    if(m_verbose) {      System.out.println("Unconstrained distance between instance " + instIdx + " and centroid " + centroidIdx + " is: " + dist);    }    if (false) {    if (m_ConstraintsHash != null) {      HashMap cluster = m_IndexClusters[centroidIdx];      if (cluster != null) {	Iterator iter = cluster.entrySet().iterator();	while(iter.hasNext()) {	  int j = ((Integer) iter.next()).intValue();	  int first = (j < instIdx)? j:instIdx;	  int second = (j >= instIdx)? j:instIdx;	  InstancePair pair = new InstancePair(first, second, InstancePair.DONT_CARE_LINK);	  if (m_ConstraintsHash.containsKey(pair)) {	    int linkType = ((Integer) m_ConstraintsHash.get(pair)).intValue();	    if (linkType == InstancePair.MUST_LINK) { // count up number of must-links satisfied, instead of number of must-links violated. So, subtract from dist	      if (m_verbose) {		System.out.println("Found satisfied must link between: " + first + " and " + second);	      }	      dist -= m_MustLinkWeight;	    }	    else if (linkType == InstancePair.CANNOT_LINK) {	      if (m_verbose) {		System.out.println("Found violated cannot link between: " + first + " and " + second);	      }	      dist += m_CannotLinkWeight;	    }	  }	}      }    }    }    if(m_verbose) {      System.out.println("Final distance between instance " + instIdx + " and centroid " + centroidIdx + " is: " + dist);    }    return dist;  }  /** M-step of the KMeans clustering algorithm -- updates cluster centroids   */  protected void updateClusterCentroids() throws Exception {    // M-step: update cluster centroids    Instances [] tempI = new Instances[m_NumClusters];    m_ClusterCentroids = new Instances(m_Instances, m_NumClusters);        for (int j = 0; j < m_NumClusters; j++) {      tempI[j] = new Instances(m_Instances, 0); // tempI[j] stores the cluster instances for cluster j    }    for (int i = 0; i < m_Instances.numInstances(); i++) {      for (int j = 0; j < m_NumClusters; j++) {	tempI[j].add(m_Instances.instance(i), m_ClusterDistribution[i][j]); // instance weight holds the posterior prob of the instance in the cluster      }    }        // Calculates cluster centroids    for (int j = 0; j < m_NumClusters; j++) {      double [] values = new double[m_Instances.numAttributes()];      if (m_FastMode && m_isSparseInstance) {	values = meanOrMode(tempI[j]); // uses fast meanOrMode      }      else {	for (int k = 0; k < m_Instances.numAttributes(); k++) {	  values[k] = tempI[j].meanOrMode(k); // uses usual meanOrMode	}      }            // cluster centroids are dense in SPKMeans      m_ClusterCentroids.add(new Instance(1.0, values));      if (m_Algorithm == ALGORITHM_SPHERICAL) {	try {	  normalize(m_ClusterCentroids.instance(j));	}	catch (Exception e) {	  e.printStackTrace();	}      }    }    for (int j = 0; j < m_NumClusters; j++) {      tempI[j] = null; // free memory for garbage collector to pick up    }  }      /** Actual KMeans function */  protected void runEM() throws Exception {    boolean converged = false;    m_Iterations = 0;    m_ClusterDistribution = new double [m_Instances.numInstances()][m_NumClusters];        // initialize cluster distribution from the cluster assignments    // after initial neighborhoods have been built    for (int i=0; i<m_Instances.numInstances(); i++) {      for (int j=0; j<m_NumClusters; j++) {	m_ClusterDistribution[i][j] = 0;      }      if (m_ClusterAssignments[i] != -1 && m_ClusterAssignments[i] < m_NumClusters) {	m_ClusterDistribution[i][m_ClusterAssignments[i]] = 1;      }    }    double oldObjective = m_objFunDecreasing ? Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;    while (!converged) {      // E-step: updates m_Objective      if (m_verbose) {	System.out.println("Doing E-step ...");      }      m_Objective = findAssignments(); // finds assignments and calculates objective function      // M-step      if (m_verbose) {	System.out.println("Doing M-step ...");      }      updateClusterCentroids();       m_Iterations++;      // anneal the value of kappa      if (!m_objFunDecreasing) {	if (m_Kappa < m_MaxKappaSim) {	  m_Kappa *= 2;	}      } else {	if (m_Kappa < m_MaxKappaDist) {	  m_Kappa += 2;	}      }            // Convergence check      if(Math.abs(oldObjective - m_Objective) > m_ObjFunConvergenceDifference) {	System.out.println("Objective function: " + m_Objective + ", numIterations = " + m_Iterations);	converged = false;      }      else {	converged = true;	System.out.println("Final Objective function is: " + m_Objective + ", numIterations = " + m_Iterations);      }      if ((!m_objFunDecreasing && oldObjective > m_Objective) ||  	  (m_objFunDecreasing && oldObjective < m_Objective)) {	//	throw new Exception("Oscillations => bug in objective function/EM step!!");      }      oldObjective = m_Objective;    }  }

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?