⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 mpckmeans.java

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 JAVA
📖 第 1 页 / 共 5 页
字号:
    if (m_ClusterAssignments[instIdx] != bestCluster) {      if (m_ClusterAssignments[instIdx] >= 0 && m_ClusterAssignments[instIdx] < m_NumClusters) {	//if (m_verbose) {	System.out.println("Moving instance " + instIdx + " from cluster "			   + m_ClusterAssignments[instIdx] + " to cluster " + bestCluster			   + " penalty:" + ((float)penaltyForInstance(instIdx, m_ClusterAssignments[instIdx]))			   + "=>" + ((float)lowestPenalty));       }      moved = 1;      m_ClusterAssignments[instIdx] = bestCluster;     }    if (m_verbose) {      System.out.println("Assigning instance " + instIdx + " to cluster "			 + bestCluster);    }    return moved;  }  /** Delegate the distance calculation to the method appropriate for the current metric   */  public double penaltyForInstance(int instIdx, int centroidIdx) throws Exception {    m_objVarianceCurrPoint = 0;    m_objCannotLinksCurrPoint = 0;    m_objMustLinksCurrPoint = 0;    m_objNormalizerCurrPoint = 0;    int violatedConstraints = 0;     // variance contribution    Instance instance = m_Instances.instance(instIdx);    Instance centroid =  m_ClusterCentroids.instance(centroidIdx);    m_objVarianceCurrPoint = m_metrics[centroidIdx].penalty(instance, centroid);    // regularizer and normalizer contribution    if (m_Trainable == TRAINING_INTERNAL) {      m_objNormalizerCurrPoint = -m_logTerms[centroidIdx];     }    // only add the constraints if seedable or constrained    //    if (m_Seedable || (m_Trainable != TRAINING_NONE)) {       // Sugato: replacing, in order to be able to run MKMeans (no    // constraint violation, only metric learning)    if (m_Seedable) {      Object list =  m_instanceConstraintHash.get(new Integer(instIdx));      if (list != null) {   // there are constraints associated with this instance	ArrayList constraintList = (ArrayList) list;	for (int i = 0; i < constraintList.size(); i++) {	  InstancePair pair = (InstancePair) constraintList.get(i);	  int firstIdx = pair.first;	  int secondIdx = pair.second;	  Instance instance1 = m_Instances.instance(firstIdx);	  Instance instance2 = m_Instances.instance(secondIdx);	  int otherIdx = (firstIdx == instIdx) ? 	    m_ClusterAssignments[secondIdx] : m_ClusterAssignments[firstIdx];	  	  // check whether the constraint is violated	  if (otherIdx != -1 && otherIdx < m_NumClusters) { 	    if (otherIdx != centroidIdx && 		pair.linkType == InstancePair.MUST_LINK) {	      violatedConstraints++; 	      // split penalty in half between the two involved clusters	      if (m_useMultipleMetrics) {  		double penalty1 = m_metrics[otherIdx].penaltySymmetric(instance1, instance2);		double penalty2 = m_metrics[centroidIdx].penaltySymmetric(instance1, instance2);		m_objMustLinksCurrPoint += 0.5 * m_MLweight * (penalty1 + penalty2);	      } else {		double penalty = m_metric.penaltySymmetric(instance1, instance2);		m_objMustLinksCurrPoint += m_MLweight * penalty;	      }	    } else if (otherIdx == centroidIdx &&		       pair.linkType == InstancePair.CANNOT_LINK) {	      violatedConstraints++; 	      double penalty = m_metrics[centroidIdx].penaltySymmetric(instance1, instance2);	      m_objCannotLinksCurrPoint +=  m_CLweight *		(m_maxCLPenalties[centroidIdx] - penalty);	      if (m_maxCLPenalties[centroidIdx] - penalty < 0) {		System.out.println("***NEGATIVE*** penalty: " + penalty + " for CL constraint"); 	      }	    }	  }	}      }    }    double total = m_objVarianceCurrPoint  + m_objCannotLinksCurrPoint       + m_objMustLinksCurrPoint + m_objNormalizerCurrPoint;    if(m_verbose) {      System.out.println("Final penalty for instance " + instIdx + " and centroid "			 + centroidIdx + " is: " + total);    }    return total;  }    /** M-step of the KMeans clustering algorithm -- updates cluster centroids   */  protected void updateClusterCentroids() throws Exception {    Instances [] tempI = new Instances[m_NumClusters];    Instances tempCentroids = m_ClusterCentroids;    Instances tempNewCentroids = new Instances(m_Instances, m_NumClusters);     m_ClusterCentroids = new Instances(m_Instances, m_NumClusters);    // tempI[i] stores the cluster instances for cluster i    for (int i = 0; i < m_NumClusters; i++) {      tempI[i] = new Instances(m_Instances, 0);     }    for (int i = 0; i < m_Instances.numInstances(); i++) {      tempI[m_ClusterAssignments[i]].add(m_Instances.instance(i));    }    // Calculates cluster centroids    for (int i = 0; i < m_NumClusters; i++) {      double [] values = new double[m_Instances.numAttributes()];      Instance centroid = null;            if (m_isSparseInstance) { // uses fast meanOrMode	values = ClusterUtils.meanOrMode(tempI[i]);	centroid = new SparseInstance(1.0, values);      } else { // non-sparse, go through each attribute	for (int j = 0; j < m_Instances.numAttributes(); j++) {	  values[j] = tempI[i].meanOrMode(j); // uses usual meanOrMode	}	centroid = new Instance(1.0, values);      }      //        // debugging:  compare  previous centroid w/current://        double w = 0; //        for (int j = 0; j < m_Instances.numAttributes(); j++)  w += values[j] * values[j];//        double w1 = 0; //        for (int j = 0; j < m_Instances.numAttributes(); j++)  w1 += tempCentroids.instance(i).value(j) * tempCentroids.instance(i).value(j);     //        System.out.println("\tOldCentroid=" + w1);//        System.out.println("\tNewCentroid=" + w); //        double prevObj = 0, currObj = 0;//        for (int j = 0; j < tempI[i].numInstances(); j++) {//  	Instance instance = tempI[i].instance(j);//  	double prevPen = m_metrics[i].penalty(instance, tempCentroids.instance(i));//  	double currPen = m_metrics[i].penalty(instance, centroid);//  	prevObj += prevPen;//  	currObj += currPen; //  	//System.out.println("\t\t" + j + " " + prevPen + " -> " + currPen + "\t" + prevObj + " -> " + currObj); //        }//        // dump instances out if there is a problem.//        System.out.println("\t\t" + prevObj + " -> " + currObj); //        if (currObj > prevObj) {//  	PrintWriter out = new PrintWriter(new BufferedOutputStream(new FileOutputStream("/tmp/INST.arff")), true);//  	out.println(new Instances(tempI[i], 0));//  	out.println(centroid);//  	out.println(tempCentroids.instance(i)); //  	for (int j = 0; j < tempI[i].numInstances(); j++) {//  	  out.println(tempI[i].instance(j));//  	}//  	out.close();//  	System.out.println("  Updated cluster " + i + "("//  			   + tempI[i].numInstances());//  	System.exit(0); //        }             // if we are using a smoothing metric, smooth the centroids      if (m_metric instanceof SmoothingMetric &&	  ((SmoothingMetric) m_metric).getUseSmoothing()) {	System.out.println("\tSmoothing..."); 	SmoothingMetric smoothingMetric = (SmoothingMetric) m_metric;	centroid = smoothingMetric.smoothInstance(centroid);       }      //   DEBUGGING:  replaced line under with block below      m_ClusterCentroids.add(centroid); //        {//  	tempNewCentroids.add(centroid);//  	m_ClusterCentroids.delete(); //  	for (int j = 0; j <= i; j++) {//  	  m_ClusterCentroids.add(tempNewCentroids.instance(j));//  	}//  	for (int j = i+1; j < m_NumClusters; j++) {//  	  m_ClusterCentroids.add(tempCentroids.instance(j));//  	} //  	double objBackup = m_Objective;//  	System.out.println("  Updated cluster " + i + "("//  			   + tempI[i].numInstances() + "); obj=" +//  			   calculateObjectiveFunction(false));//  	m_Objective = objBackup;//        }            // in SPKMeans, cluster centroids need to be normalized      if (m_metric.doesNormalizeData()) {	m_metric.normalizeInstanceWeighted(m_ClusterCentroids.instance(i));      }    }    if (m_metric instanceof SmoothingMetric &&	((SmoothingMetric) m_metric).getUseSmoothing())      updateSmoothingMetrics();               for (int i = 0; i < m_NumClusters; i++)      tempI[i] = null; // free memory  }  /** M-step of the KMeans clustering algorithm -- updates metric   *  weights. Invoked only when we're using non-Potts model   *  and metric is trainable   */  protected void updateMetricWeights() throws Exception {    if (m_useMultipleMetrics) {      for (int i = 0; i < m_NumClusters; i++) {	m_metricLearners[i].trainMetric(i);      }     } else {      m_metricLearner.trainMetric(-1);    }     InitNormalizerRegularizer();  }   /** checks for convergence */  public boolean convergenceCheck(double oldObjective,				  double newObjective) throws Exception {    boolean converged = false;    // Convergence check    if(Math.abs(oldObjective - newObjective) < m_ObjFunConvergenceDifference) {      System.out.println("Final objective function is: " + newObjective);      converged = true;    }    // number of iterations check    if (m_numBlankIterations >= m_maxBlankIterations) {      System.out.println("Max blank iterations reached ...\n");      System.out.println("Final objective function is: " + newObjective);      converged = true;    }    if (m_Iterations >= m_maxIterations) {      System.out.println("Max iterations reached ...\n");      System.out.println("Final objective function is: " + newObjective);      converged = true;    }    return converged;  }  /** calculates objective function */  public double calculateObjectiveFunction(boolean isComplete) throws Exception {    System.out.println("\tCalculating objective function ...");    // update the oldObjective only if previous estimate of m_Objective    // was complete    if (isComplete) {      m_OldObjective = m_Objective;    }    m_Objective = 0;    m_objVariance = 0;    m_objMustLinks = 0;    m_objCannotLinks = 0;    m_objNormalizer = 0;    // Some debugging code:  tracking per-cluster objective    double[] objectives = new double[m_NumClusters];     // temporarily halve weights since every constraint is counted twice    double tempML = m_MLweight;    double tempCL = m_CLweight;    m_MLweight = tempML/2;    m_CLweight = tempCL/2;         if (m_verbose) {      System.out.println("Must link weight: " + m_MLweight);      System.out.println("Cannot link weight: " + m_CLweight);        }    for (int i=0; i<m_Instances.numInstances(); i++) {      if (m_isOfflineMetric) {	double dist = m_metric.penalty(m_Instances.instance(i),				       m_ClusterCentroids.instance(m_ClusterAssignments[i]));	m_Objective += dist;	if (m_verbose) {	  System.out.println("Component for " + i + " = " + dist);	}      }      else {	double penalty = penaltyForInstance(i, m_ClusterAssignments[i]);	objectives[m_ClusterAssignments[i]] += penalty;	m_Objective += penalty; 	m_objVariance += m_objVarianceCurrPoint;	m_objMustLinks += m_objMustLinksCurrPoint;	m_objCannotLinks += m_objCannotLinksCurrPoint;	m_objNormalizer += m_objNormalizerCurrPoint;      }    }    m_Objective -= m_objRegularizer;    m_MLweight = tempML;    m_CLweight = tempCL; // reset the values of the constraint weights    // debugging:  reporting per-cluster objectives    for (int i = 0; i < m_NumClusters; i++) {      System.out.println("\t\tCluster " + i + " obj=" + objectives[i]);     }    System.out.println("\tTotalObj=" + m_Objective);     // Oscillation check    if ((float)m_OldObjective < (float)m_Objective) {      System.out.println("WHOA!!!  Oscillations => bug in EM step?");      System.out.println("Old objective:" + (float)m_OldObjective			 + " < New objective: " + (float)m_Objective);     }//      // TEMPORARY BLAH//      System.out.println("\tvar=" + ((float)m_objVariance)//  			 + "\tC=" + ((float)m_objCannotLinks)//  			 + "\tM=" + ((float)m_objMustLinks)//  			 + "\tLOG=" + ((float)m_objNormalizer) //  			 + "\tREG=" + ((float)m_objRegularizer));    return m_Objective;  }    /** Actual KMeans function */  protected void runKMeans() throws Exception {    boolean converged = false;    m_Iterations = 0;    m_numBlankIterations = 0;     m_Objective = Double.POSITIVE_INFINITY;     if (!m_isOfflineMetric) {      if (m_useMultipleMetrics) {	for (int i = 0; i < m_metrics.length; i++) {	  m_metrics[i].resetMetric();	  m_metricLearners[i].resetLearner();	}       } else { 	m_metric.resetMetric();	m_metricLearner.resetLearner();      }      // initialize max CL penalties      if (m_ConstraintsHash.size() > 0) {	m_maxCLPenalties = calculateMaxCLPenalties();      }    }    // initialize m_ClusterAssignments    for (int i=0; i<m_NumClusters; i++) {      m_ClusterAssignments[i] = -1;    }    PrintStream fincoh = null;    if (m_ConstraintIncoherenceFile != null) {      fincoh = new PrintStream(new FileOutputStream(m_ConstraintIncoherenceFile));    }    while (!converged) {      System.out.println("\n" + m_Iterations + ". Objective function: " + ((float)m_Objective));      m_OldObjective = m_Objective;             // E-step      int numMovedPoints = findBestAssignments();      m_numBlankIterations = (numMovedPoints == 0) ? m_numBlankIterations+1 : 0;       //      calculateObjectiveFunction(false);      System.out.println((float)m_Objective + " - Objective function after point assignment(CALC)");      System.out.println("\tvar=" + ((float)m_objVariance) 			 + "\tC=" + ((float)m_objCannotLinks) 			 + "\tM=" + ((float)m_objMustLinks) 			 + "\tLOG=" + ((float)m_objNormalizer) 			 + "\tREG=" + ((float)m_objRegularizer));           // M-step      updateClusterCentroids();      //      calculateObjectiveFunction(false);      System.out.println((float)m_Objective + " - Objective function after centroid estimation");      System.out.println("\tvar=" + ((float)m_objVariance)			 + "\tC=" + ((float)m_objCannotLinks)			 + "\tM=" + ((float)m_objMustLinks)			 + "\tLOG=" + ((float)m_objNormalizer) 			 + "\tREG=" + ((float)m_objRegularizer));            if (m_Trainable == TRAINING_INTERNAL && !m_isOfflineMetric) {	updateMetricWeights();	if (m_verbose) {

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -