⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 em.java

📁 一个数据挖掘软件ALPHAMINERR的整个过程的JAVA版源代码
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    EM.java
 *    Copyright (C) 1999 Mark Hall
 *
 */

package  weka.clusterers;

import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;

import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.estimators.DiscreteEstimator;
import weka.estimators.Estimator;

/**
 * Simple EM (expectation maximisation) class. <p>
 * 
 * EM assigns a probability distribution to each instance which
 * indicates the probability of it belonging to each of the clusters.
 * EM can decide how many clusters to create by cross validation, or you
 * may specify apriori how many clusters to generate. <p>
 * <br>
 * The cross validation performed to determine the number of clusters
 * is done in the following steps:<br>
 * 1. the number of clusters is set to 1<br>
 * 2. the training set is split randomly into 10 folds.<br>
 * 3. EM is performed 10 times using the 10 folds the usual CV way.<br>
 * 4. the loglikelihood is averaged over all 10 results.<br>
 * 5. if loglikelihood has increased the number of clusters is increased by 1
 * and the program continues at step 2. <br>
 *<br>
 * The number of folds is fixed to 10, as long as the number of instances in
 * the training set is not smaller 10. If this is the case the number of folds
 * is set equal to the number of instances.<p>
 *
 * Valid options are:<p>
 *
 * -V <br>
 * Verbose. <p>
 *
 * -N <number of clusters> <br>
 * Specify the number of clusters to generate. If omitted,
 * EM will use cross validation to select the number of clusters
 * automatically. <p>
 *
 * -I <max iterations> <br>
 * Terminate after this many iterations if EM has not converged. <p>
 *
 * -S <seed> <br>
 * Specify random number seed. <p>
 *
 * -M <num> <br>
 * Set the minimum allowable standard deviation for normal density calculation.
 * <p>
 *
 * @author Mark Hall (mhall@cs.waikato.ac.nz)
 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
 * @version $Revision$
 */
public class EM
  extends DensityBasedClusterer
  implements NumberOfClustersRequestable,
	     OptionHandler, WeightedInstancesHandler {

  /**
	 * 
	 */
	private static final long serialVersionUID = -8531884870645586142L;

/** hold the discrete estimators for each cluster */
  private Estimator m_model[][];

  /** hold the normal estimators for each cluster */
  private double m_modelNormal[][][];

  /** default minimum standard deviation */
  private double m_minStdDev = 1e-6;

  /** hold the weights of each instance for each cluster */
  private double m_weights[][];

  /** the prior probabilities for clusters */
  private double m_priors[];

  /** the loglikelihood of the data */
  private double m_loglikely;

  /** training instances */
  private Instances m_theInstances = null;

  /** number of clusters selected by the user or cross validation */
  private int m_num_clusters;

  /** the initial number of clusters requested by the user--- -1 if
      xval is to be used to find the number of clusters */
  private int m_initialNumClusters;

  /** number of attributes */
  private int m_num_attribs;

  /** number of training instances */
  private int m_num_instances;

  /** maximum iterations to perform */
  private int m_max_iterations;

  /** attribute min values */
  private double [] m_minValues;

  /** attribute max values */
  private double [] m_maxValues;

  /** random numbers and seed */
  private Random m_rr;
  private int m_rseed;

  /** Verbose? */
  private boolean m_verbose;

  /**
   * Returns a string describing this clusterer
   * @return a description of the evaluator suitable for
   * displaying in the explorer/experimenter gui
   */
  public String globalInfo() {
    return "Cluster data using expectation maximization";
  }


  /**
   * Returns an enumeration describing the available options.. <p>
   *
   * Valid options are:<p>
   *
   * -V <br>
   * Verbose. <p>
   *
   * -N <number of clusters> <br>
   * Specify the number of clusters to generate. If omitted,
   * EM will use cross validation to select the number of clusters
   * automatically. <p>
   *
   * -I <max iterations> <br>
   * Terminate after this many iterations if EM has not converged. <p>
   *
   * -S <seed> <br>
   * Specify random number seed. <p>
   *
   * -M <num> <br>
   *  Set the minimum allowable standard deviation for normal density 
   * calculation. <p>
   *
   * @return an enumeration of all the available options.
   *
   **/
  public Enumeration listOptions () {
    Vector newVector = new Vector(6);
    newVector.addElement(new Option("\tnumber of clusters. If omitted or" 
				    + "\n\t-1 specified, then cross " 
				    + "validation is used to\n\tselect the " 
				    + "number of clusters.", "N", 1
				    , "-N <num>"));
    newVector.addElement(new Option("\tmax iterations.\n(default 100)", "I"
				    , 1, "-I <num>"));
    newVector.addElement(new Option("\trandom number seed.\n(default 1)"
				    , "S", 1, "-S <num>"));
    newVector.addElement(new Option("\tverbose.", "V", 0, "-V"));
    newVector.addElement(new Option("\tminimum allowable standard deviation "
				    +"for normal density computation "
				    +"\n\t(default 1e-6)"
				    ,"M",1,"-M <num>"));
    return  newVector.elements();
  }


  /**
   * Parses a given list of options.
   * @param options the list of options as an array of strings
   * @exception Exception if an option is not supported
   *
   **/
  public void setOptions (String[] options)
    throws Exception {
    resetOptions();
    setDebug(Utils.getFlag('V', options));
    String optionString = Utils.getOption('I', options);

    if (optionString.length() != 0) {
      setMaxIterations(Integer.parseInt(optionString));
    }

    optionString = Utils.getOption('N', options);

    if (optionString.length() != 0) {
      setNumClusters(Integer.parseInt(optionString));
    }

    optionString = Utils.getOption('S', options);

    if (optionString.length() != 0) {
      setSeed(Integer.parseInt(optionString));
    }

    optionString = Utils.getOption('M', options);
    if (optionString.length() != 0) {
      setMinStdDev((new Double(optionString)).doubleValue());
    }
  }

  /**
   * Returns the tip text for this property
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
  public String minStdDevTipText() {
    return "set minimum allowable standard deviation";
  }

  /**
   * Set the minimum value for standard deviation when calculating
   * normal density. Reducing this value can help prevent arithmetic
   * overflow resulting from multiplying large densities (arising from small
   * standard deviations) when there are many singleton or near singleton
   * values.
   * @param m minimum value for standard deviation
   */
  public void setMinStdDev(double m) {
    m_minStdDev = m;
  }

  /**
   * Get the minimum allowable standard deviation.
   * @return the minumum allowable standard deviation
   */
  public double getMinStdDev() {
    return m_minStdDev;
  }

  /**
   * Returns the tip text for this property
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
  public String seedTipText() {
    return "random number seed";
  }


  /**
   * Set the random number seed
   *
   * @param s the seed
   */
  public void setSeed (int s) {
    m_rseed = s;
  }


  /**
   * Get the random number seed
   *
   * @return the seed
   */
  public int getSeed () {
    return  m_rseed;
  }

  /**
   * Returns the tip text for this property
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
  public String numClustersTipText() {
    return "set number of clusters. -1 to select number of clusters "
      +"automatically by cross validation.";
  }

  /**
   * Set the number of clusters (-1 to select by CV).
   *
   * @param n the number of clusters
   * @exception Exception if n is 0
   */
  public void setNumClusters (int n)
    throws Exception {
    
    if (n == 0) {
      throw  new Exception("Number of clusters must be > 0. (or -1 to " 
			   + "select by cross validation).");
    }

    if (n < 0) {
      m_num_clusters = -1;
      m_initialNumClusters = -1;
    }
    else {
      m_num_clusters = n;
      m_initialNumClusters = n;
    }
  }


  /**
   * Get the number of clusters
   *
   * @return the number of clusters.
   */
  public int getNumClusters () {
    return  m_initialNumClusters;
  }

  /**
   * Returns the tip text for this property
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
  public String maxIterationsTipText() {
    return "maximum number of iterations";
  }

  /**
   * Set the maximum number of iterations to perform
   *
   * @param i the number of iterations
   * @exception Exception if i is less than 1
   */
  public void setMaxIterations (int i)
    throws Exception {
    if (i < 1) {
      throw  new Exception("Maximum number of iterations must be > 0!");
    }

    m_max_iterations = i;
  }


  /**
   * Get the maximum number of iterations
   *
   * @return the number of iterations
   */
  public int getMaxIterations () {
    return  m_max_iterations;
  }


  /**
   * Set debug mode - verbose output
   *
   * @param v true for verbose output
   */
  public void setDebug (boolean v) {
    m_verbose = v;
  }


  /**
   * Get debug mode
   *
   * @return true if debug mode is set
   */
  public boolean getDebug () {
    return  m_verbose;
  }


  /**
   * Gets the current settings of EM.
   *
   * @return an array of strings suitable for passing to setOptions()
   */
  public String[] getOptions () {
    String[] options = new String[9];
    int current = 0;

    if (m_verbose) {
      options[current++] = "-V";
    }

    options[current++] = "-I";
    options[current++] = "" + m_max_iterations;
    options[current++] = "-N";
    options[current++] = "" + getNumClusters();
    options[current++] = "-S";
    options[current++] = "" + m_rseed;
    options[current++] = "-M";
    options[current++] = ""+getMinStdDev();

    while (current < options.length) {
      options[current++] = "";
    }

    return  options;
  }

  /**
   * Initialise estimators and storage.
   *
   * @param inst the instances
   **/
  private void EM_Init (Instances inst)
    throws Exception {
    int i, j, k;

    // run k means 10 times and choose best solution
    SimpleKMeans bestK = null;
    double bestSqE = Double.MAX_VALUE;
    for (i = 0; i < 10; i++) {
      SimpleKMeans sk = new SimpleKMeans();
      sk.setSeed(m_rr.nextInt());
      sk.setNumClusters(m_num_clusters);
      sk.buildClusterer(inst);
      if (sk.getSquaredError() < bestSqE) {
	bestSqE = sk.getSquaredError();
	bestK = sk;
      }
    }
    
    // initialize with best k-means solution
    m_num_clusters = bestK.numberOfClusters();
    m_weights = new double[inst.numInstances()][m_num_clusters];
    m_model = new DiscreteEstimator[m_num_clusters][m_num_attribs];
    m_modelNormal = new double[m_num_clusters][m_num_attribs][3];
    m_priors = new double[m_num_clusters];
    Instances centers = bestK.getClusterCentroids();
    Instances stdD = bestK.getClusterStandardDevs();
    int [][][] nominalCounts = bestK.getClusterNominalCounts();
    int [] clusterSizes = bestK.getClusterSizes();

    for (i = 0; i < m_num_clusters; i++) {
      Instance center = centers.instance(i);
      for (j = 0; j < m_num_attribs; j++) {
	if (inst.attribute(j).isNominal()) {
	  m_model[i][j] = new DiscreteEstimator(m_theInstances.
						attribute(j).numValues()
						, true);
	  for (k = 0; k < inst.attribute(j).numValues(); k++) {
	    m_model[i][j].addValue(k, nominalCounts[i][j][k]);
	  }
	} else {
	  double mean = (center.isMissing(j))
	    ? inst.meanOrMode(j)
	    : center.value(j);
	  m_modelNormal[i][j][0] = mean;
	  double stdv = (stdD.instance(i).isMissing(j))
	    ? ((m_maxValues[j] - m_minValues[j]) / (2 * m_num_clusters))
	    : stdD.instance(i).value(j);
	  if (stdv < m_minStdDev) {
	    stdv = inst.attributeStats(j).numericStats.stdDev;
	    if (stdv < m_minStdDev) {
	      stdv = m_minStdDev;
	    }
	  }
	  m_modelNormal[i][j][1] = stdv;
	  m_modelNormal[i][j][2] = 1.0;
	}
      } 
    }    
    
    
    for (j = 0; j < m_num_clusters; j++) {
      //      m_priors[j] += 1.0;
      m_priors[j] = clusterSizes[j];
    }
    Utils.normalize(m_priors);
  }


  /**
   * calculate prior probabilites for the clusters
   *
   * @param inst the instances
   * @exception Exception if priors can't be calculated
   **/
  private void estimate_priors (Instances inst)
    throws Exception {

    for (int i = 0; i < m_num_clusters; i++) {
      m_priors[i] = 0.0;
    }

    for (int i = 0; i < inst.numInstances(); i++) {
      for (int j = 0; j < m_num_clusters; j++) {
        m_priors[j] += inst.instance(i).weight() * m_weights[i][j];
      }
    }

    Utils.normalize(m_priors);
  }


  /** Constant for normal distribution. */
  private static double m_normConst = Math.log(Math.sqrt(2*Math.PI));

  /**
   * Density function of normal distribution.
   * @param x input value
   * @param mean mean of distribution
   * @param stdDev standard deviation of distribution
   */
  private double logNormalDens (double x, double mean, double stdDev) {

    double diff = x - mean;
    //    System.err.println("x: "+x+" mean: "+mean+" diff: "+diff+" stdv: "+stdDev);
    //    System.err.println("diff*diff/(2*stdv*stdv): "+ (diff * diff / (2 * stdDev * stdDev)));
    
    return - (diff * diff / (2 * stdDev * stdDev))  - m_normConst - Math.log(stdDev);
  }

  /**
   * New probability estimators for an iteration
   *
   * @param num_cl the numbe of clusters
   */
  private void new_estimators () {
    for (int i = 0; i < m_num_clusters; i++) {
      for (int j = 0; j < m_num_attribs; j++) {
        if (m_theInstances.attribute(j).isNominal()) {
          m_model[i][j] = new DiscreteEstimator(m_theInstances.
						attribute(j).numValues()
						, true);

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -