⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 mpckmeans.java

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 JAVA
📖 第 1 页 / 共 5 页
字号:
      for (int i = 0; i < labeledPairs.size(); i++) {	InstancePair pair = (InstancePair) labeledPairs.get(i);		Integer firstInt = new Integer(pair.first);	Integer secondInt = new Integer(pair.second);	// for first point 	if(!m_SeedHash.contains(firstInt)) { // add instances with constraints to seedHash	  if (m_verbose) {	    System.out.println("Adding " + firstInt + " to seedHash");	  }	  m_SeedHash.add(firstInt);	}		// for second point 	if(!m_SeedHash.contains(secondInt)) {	  m_SeedHash.add(secondInt);	  if (m_verbose) {	    System.out.println("Adding " + secondInt + " to seedHash");	  }	}	if (pair.first >= pair.second) {	  throw new Exception("Ordering reversed - something wrong!!");	} else {	  InstancePair newPair = null;	  newPair = new InstancePair(pair.first, pair.second, InstancePair.DONT_CARE_LINK);	  m_ConstraintsHash.put(newPair, new Integer(pair.linkType)); // WLOG first < second	  if (m_verbose) {	    System.out.println("Adding constraint (" + pair.first +","+pair.second+"), " + pair.linkType);	  }	  	  // hash the constraints for the instances involved	  Object constraintList1 = m_instanceConstraintHash.get(firstInt);	  if (constraintList1 == null) {	    ArrayList constraintList = new ArrayList();	    constraintList.add(pair);	    m_instanceConstraintHash.put(firstInt, constraintList);	  } else {	    ((ArrayList)constraintList1).add(pair);	  }	  Object constraintList2 = m_instanceConstraintHash.get(secondInt);	  if (constraintList2 == null) {	    ArrayList constraintList = new ArrayList();	    constraintList.add(pair);	    m_instanceConstraintHash.put(secondInt, constraintList);	  } else {	    ((ArrayList)constraintList2).add(pair);	  }	}      }    }    m_StartingIndexOfTest = startingIndexOfTest;    if (m_verbose) {      System.out.println("Starting index of test: " + m_StartingIndexOfTest);    }    // learn metric using labeled data,    // then cluster both the labeled and unlabeled data    System.out.println("Initializing metric: " + m_metric);    m_metric.buildMetric(unlabeledData);    m_metricBuilt = true;    m_metricLearner.setMetric(m_metric);    m_metricLearner.setClusterer(this);    // normalize all data for SPKMeans    if (m_metric.doesNormalizeData()) {      for (int i=0; i<unlabeledData.numInstances(); i++) {	m_metric.normalizeInstanceWeighted(unlabeledData.instance(i));      }    }    // either create a new metric if multiple metrics,    // or just point them all to m_metric    m_metrics = new LearnableMetric[numClusters];    m_metricLearners = new MPCKMeansMetricLearner[numClusters];    for (int i = 0; i < m_metrics.length; i++) {      if (m_useMultipleMetrics) {	m_metrics[i] = (LearnableMetric) m_metric.clone();	m_metricLearners[i] = (MPCKMeansMetricLearner) m_metricLearner.clone();	m_metricLearners[i].setMetric(m_metrics[i]);	m_metricLearners[i].setClusterer(this);      } else { 	m_metrics[i] = m_metric;	m_metricLearners[i] = m_metricLearner;      }     }     buildClusterer(unlabeledData, numClusters);  }  /**   * Generates a clusterer. Instances in data have to be   * either all sparse or all non-sparse   *   * @param data set of instances serving as training data    * @exception Exception if the clusterer has not been    * generated successfully   */  public void buildClusterer(Instances data) throws Exception {    System.out.println("ML weight=" + m_MLweight);    System.out.println("CL weight= " + m_CLweight);    System.out.println("LOG term weight=" + m_logTermWeight);    System.out.println("Regularizer weight= " + m_regularizerTermWeight);    m_RandomNumberGenerator = new Random(m_RandomSeed);    if (m_metric instanceof OfflineLearnableMetric) {      m_isOfflineMetric = true;    } else {      m_isOfflineMetric = false;    }    // Don't rebuild the metric if it was already trained    if (!m_metricBuilt) {      m_metric.buildMetric(data);      m_metricBuilt = true;      m_metricLearner.setMetric(m_metric);      m_metricLearner.setClusterer(this);            m_metrics = new LearnableMetric[m_NumClusters];      m_metricLearners = new MPCKMeansMetricLearner[m_NumClusters];      for (int i = 0; i < m_metrics.length; i++) {	if (m_useMultipleMetrics) {	  m_metrics[i] = (LearnableMetric) m_metric.clone();	  m_metricLearners[i] = (MPCKMeansMetricLearner) m_metricLearner.clone();	  m_metricLearners[i].setMetric(m_metrics[i]);	  m_metricLearners[i].setClusterer(this); 	} else { 	  m_metrics[i] = m_metric;	  m_metricLearners[i] = m_metricLearner;	}       }     }    setInstances(data);    m_ClusterCentroids = new Instances(m_Instances, m_NumClusters);    m_ClusterAssignments = new int [m_Instances.numInstances()];    if (m_Instances.checkForNominalAttributes() &&	m_Instances.checkForStringAttributes()) {      throw new UnsupportedAttributeTypeException("Cannot handle nominal attributes\n");    }    m_ClusterCentroids = m_Initializer.initialize();    // if all instances are smoothed by the metric, the centroids    // need to be smoothed too (note that this is independent of    // centroid smoothing performed by K-Means)    if (m_metric instanceof InstanceConverter) {      System.out.println("Converting centroids...");      Instances convertedCentroids = new Instances(m_ClusterCentroids, m_NumClusters);      for (int i = 0; i < m_ClusterCentroids.numInstances(); i++) {	Instance centroid = m_ClusterCentroids.instance(i); 	convertedCentroids.add(((InstanceConverter)m_metric).convertInstance(centroid));      }      m_ClusterCentroids.delete();      for (int i = 0; i < convertedCentroids.numInstances(); i++) {	m_ClusterCentroids.add(convertedCentroids.instance(i));      }    }         System.out.println("Done initializing clustering ...");    getIndexClusters();    if (m_verbose && m_Seedable) {      printIndexClusters();      for (int i=0; i<m_NumClusters; i++) {	System.out.println("Centroid " + i + ": " + m_ClusterCentroids.instance(i));      }    }    // Some extra work for smoothing metrics    if (m_metric instanceof SmoothingMetric &&	((SmoothingMetric) m_metric).getUseSmoothing()) {       SmoothingMetric smoothingMetric = (SmoothingMetric) m_metric;      Instances smoothedCentroids = new Instances(m_Instances, m_NumClusters);            for (int i = 0; i < m_ClusterCentroids.numInstances(); i++) {	Instance smoothedCentroid =	  smoothingMetric.smoothInstance(m_ClusterCentroids.instance(i)); 	smoothedCentroids.add(smoothedCentroid);      }      m_ClusterCentroids = smoothedCentroids;      updateSmoothingMetrics();           }    runKMeans();  }  protected void updateSmoothingMetrics() {    if (m_useMultipleMetrics) {      for (int i = 0; i < m_NumClusters; i++) { 	((SmoothingMetric)m_metrics[i]).updateAlpha();      }    } else {      ((SmoothingMetric)m_metric).updateAlpha();    }  }   /**   * Reset all values that have been learned   */  public void resetClusterer()  throws Exception{    m_metric.resetMetric();    if (m_useMultipleMetrics) {      for (int i = 0; i < m_metrics.length; i++) {	m_metrics[i].resetMetric();      }    }        m_SeedHash = null;    m_ConstraintsHash = null;    m_instanceConstraintHash = null;  }  /** Turn seeding on and off   * @param seedable should seeding be done?   */  public void setSeedable(boolean seedable) {    m_Seedable = seedable;  }  /** Turn metric learning on and off   * @param trainable should metric learning be done?   */  public void setTrainable(SelectedTag trainable) {    if (trainable.getTags() == TAGS_TRAINING) {      if (m_verbose) {	System.out.println("Trainable: " + trainable.getSelectedTag().getReadable());      }      m_Trainable = trainable.getSelectedTag().getID();    }  }  /** Is seeding performed?   * @return is seeding being done?   */  public boolean getSeedable() {    return m_Seedable;  }  /** Is metric learning performed?   * @return is metric learning being done?   */  public SelectedTag getTrainable() {    return new SelectedTag(m_Trainable, TAGS_TRAINING);  }    /**   * We can have clusterers that don't utilize seeding   */  public boolean seedable() {    return m_Seedable;  }  /** Outputs the current clustering   *   * @exception Exception if something goes wrong   */  public void printIndexClusters() throws Exception {    if (m_IndexClusters == null)      throw new Exception ("Clusters were not created");    for (int i = 0; i < m_NumClusters; i++) {      HashSet cluster = m_IndexClusters[i];      if (cluster == null) {	System.out.println("Cluster " + i + " is null");      }      else {	System.out.println ("Cluster " + i + " consists of " + cluster.size() + " elements");	Iterator iter = cluster.iterator();	while(iter.hasNext()) {	  int idx = ((Integer) iter.next()).intValue();	  Instance inst = m_TotalTrainWithLabels.instance(idx);	  if (m_TotalTrainWithLabels.classIndex() >= 0) {	    System.out.println("\t\t" + idx + ":" + inst.classAttribute().value((int) inst.classValue()));	  }	}      }    }  }  /** E-step of the KMeans clustering algorithm -- find best cluster   * assignments. Returns the number of points moved in this step    */  protected int findBestAssignments() throws Exception {    int moved = 0;    double distance = 0;    m_Objective = 0;    m_objVariance = 0;    m_objCannotLinks = 0;    m_objMustLinks = 0;    m_objNormalizer = 0;    // Initialize the regularizer and normalizer hashes    InitNormalizerRegularizer();        if (m_isOfflineMetric) {      moved = assignAllInstancesToClusters();    } else {      moved = assignPoints();    }    if (m_verbose) {       System.out.println("  " + moved + " points moved in this E-step");    }    return moved;   }  /** Initialize m_logTerms and m_regularizerTerms */  protected void InitNormalizerRegularizer() {     m_logTerms = new double[m_NumClusters];    m_objRegularizer = 0;    if (m_useMultipleMetrics) {      for (int i = 0; i < m_NumClusters; i++) {	m_logTerms[i] = m_logTermWeight * m_metrics[i].getNormalizer(); 	if (m_regularize) {	  m_objRegularizer += m_regularizerTermWeight * m_metrics[i].regularizer(); 	}      }     } else {  // we fill the logTerms with the log(det) of the only weight matrix      m_logTerms[0] = m_logTermWeight * m_metric.getNormalizer();      for (int i = 1; i < m_logTerms.length; i++) {	m_logTerms[i] = m_logTerms[0];      }       if (m_regularize) {	m_objRegularizer = m_regularizerTermWeight * m_metric.regularizer();       }    }  }    /** Decides which assignment strategy to use based on argument passed in */  int assignPoints() throws Exception {    int moved = 0;    moved = m_Assigner.assign();    m_Objective = m_objVariance + m_objMustLinks      + m_objCannotLinks + m_objNormalizer - m_objRegularizer;    if (m_verbose) {      System.out.println((float)m_Objective + " - Objective function (incomplete) after assignment");      System.out.println("\tvar=" + ((float)m_objVariance)			 + "\tC=" + ((float)m_objCannotLinks)			 + "\tM=" + ((float)m_objMustLinks)			 + "\tLOG=" + ((float)m_objNormalizer) 			 + "\tREG=" + ((float)m_objRegularizer));    }    // TODO:  add a m_fast switch and put the following line inside it.    //    calculateObjectiveFunction();               return moved;  }  /**   * Classifies the instance using the current clustering, considering constraints   *   * @param instance the instance to be assigned to a cluster   * @return the number of the assigned cluster as an integer if the   * class is enumerated, otherwise the predicted value   * @exception Exception if instance could not be classified   * successfully    */  public int assignInstanceToClusterWithConstraints(int instIdx) throws Exception {    int bestCluster = 0;    double lowestPenalty = Double.MAX_VALUE;    int moved = 0;    // try each cluster and find one with lowest penalty    for (int i = 0; i < m_NumClusters; i++) {      double penalty = penaltyForInstance(instIdx, i);      if (penalty < lowestPenalty) {	lowestPenalty = penalty;	bestCluster = i;	m_objVarianceCurrPointBest = m_objVarianceCurrPoint;	m_objNormalizerCurrPointBest = m_objNormalizerCurrPoint;	m_objMustLinksCurrPointBest = m_objMustLinksCurrPoint;	m_objCannotLinksCurrPointBest = m_objCannotLinksCurrPoint;      }    }        m_objVariance += m_objVarianceCurrPointBest;    m_objNormalizer += m_objNormalizerCurrPointBest;    m_objMustLinks += m_objMustLinksCurrPointBest;    m_objCannotLinks += m_objCannotLinksCurrPointBest;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -