⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 em.java

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
/* *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* *    EM.java *    Copyright (C) 1999 Mark Hall * */package  weka.clusterers;import  java.io.*;import  java.util.*;import  weka.core.*;import  weka.estimators.*;/** * Simple EM (expectation maximisation) class. <p> *  * EM assigns a probability distribution to each instance which * indicates the probability of it belonging to each of the clusters. * EM can decide how many clusters to create by cross validation, or you * may specify apriori how many clusters to generate. <p> * <br> * The cross validation performed to determine the number of clusters * is done in the following steps:<br> * 1. the number of clusters is set to 1<br> * 2. the training set is split randomly into 10 folds.<br> * 3. EM is performed 10 times using the 10 folds the usual CV way.<br> * 4. the loglikelihood is averaged over all 10 results.<br> * 5. if loglikelihood has increased the number of clusters is increased by 1 * and the program continues at step 2. <br> *<br> * The number of folds is fixed to 10, as long as the number of instances in * the training set is not smaller 10. If this is the case the number of folds * is set equal to the number of instances.<p> * * Valid options are:<p> * * -V <br> * Verbose. <p> * * -N <number of clusters> <br> * Specify the number of clusters to generate. If omitted, * EM will use cross validation to select the number of clusters * automatically. <p> * * -I <max iterations> <br> * Terminate after this many iterations if EM has not converged. <p> * * -S <seed> <br> * Specify random number seed. <p> * * -M <num> <br> * Set the minimum allowable standard deviation for normal density calculation. * <p> * * @author Mark Hall (mhall@cs.waikato.ac.nz) * @version $Revision: 1.1.1.1 $ */public class EM  extends DistributionClusterer  implements OptionHandler{  /** hold the discrete estimators for each cluster */  private Estimator m_model[][];  /** hold the normal estimators for each cluster */  private double m_modelNormal[][][];  /** default minimum standard deviation */  private double m_minStdDev = 1e-6;  /** hold the weights of each instance for each cluster */  private double m_weights[][];  /** the prior probabilities for clusters */  private double m_priors[];  /** the loglikelihood of the data */  private double m_loglikely;  /** training instances */  private Instances m_theInstances = null;  /** number of clusters selected by the user or cross validation */  private int m_num_clusters;  /** the initial number of clusters requested by the user--- -1 if      xval is to be used to find the number of clusters */  private int m_initialNumClusters;  /** number of attributes */  private int m_num_attribs;  /** number of training instances */  private int m_num_instances;  /** maximum iterations to perform */  private int m_max_iterations;  /** attribute min values */  private double [] m_minValues;  /** attribute max values */  private double [] m_maxValues;  /** random numbers and seed */  private Random m_rr;  private int m_rseed;  /** Constant for normal distribution. */  private static double m_normConst = Math.sqrt(2*Math.PI);  /** Verbose? */  private boolean m_verbose;  /**   * Returns a string describing this clusterer   * @return a description of the evaluator suitable for   * displaying in the explorer/experimenter gui   */  public String globalInfo() {    return "Cluster data using expectation maximization";  }  /**   * Returns an enumeration describing the available options.. <p>   *   * Valid options are:<p>   *   * -V <br>   * Verbose. <p>   *   * -N <number of clusters> <br>   * Specify the number of clusters to generate. If omitted,   * EM will use cross validation to select the number of clusters   * automatically. <p>   *   * -I <max iterations> <br>   * Terminate after this many iterations if EM has not converged. <p>   *   * -S <seed> <br>   * Specify random number seed. <p>   *   * -M <num> <br>   *  Set the minimum allowable standard deviation for normal density    * calculation. <p>   *   * @return an enumeration of all the available options.   *   **/  public Enumeration listOptions () {    Vector newVector = new Vector(6);    newVector.addElement(new Option("\tnumber of clusters. If omitted or" 				    + "\n\t-1 specified, then cross " 				    + "validation is used to\n\tselect the " 				    + "number of clusters.", "N", 1				    , "-N <num>"));    newVector.addElement(new Option("\tmax iterations.\n(default 100)", "I"				    , 1, "-I <num>"));    newVector.addElement(new Option("\trandom number seed.\n(default 1)"				    , "S", 1, "-S <num>"));    newVector.addElement(new Option("\tverbose.", "V", 0, "-V"));    newVector.addElement(new Option("\tminimum allowable standard deviation "				    +"for normal density computation "				    +"\n\t(default 1e-6)"				    ,"M",1,"-M <num>"));    return  newVector.elements();  }  /**   * Parses a given list of options.   * @param options the list of options as an array of strings   * @exception Exception if an option is not supported   *   **/  public void setOptions (String[] options)    throws Exception {    resetOptions();    setDebug(Utils.getFlag('V', options));    String optionString = Utils.getOption('I', options);    if (optionString.length() != 0) {      setMaxIterations(Integer.parseInt(optionString));    }    optionString = Utils.getOption('N', options);    if (optionString.length() != 0) {      setNumClusters(Integer.parseInt(optionString));    }    optionString = Utils.getOption('S', options);    if (optionString.length() != 0) {      setSeed(Integer.parseInt(optionString));    }    optionString = Utils.getOption('M', options);    if (optionString.length() != 0) {      setMinStdDev((new Double(optionString)).doubleValue());    }  }  /**   * Returns the tip text for this property   * @return tip text for this property suitable for   * displaying in the explorer/experimenter gui   */  public String minStdDevTipText() {    return "set minimum allowable standard deviation";  }  /**   * Set the minimum value for standard deviation when calculating   * normal density. Reducing this value can help prevent arithmetic   * overflow resulting from multiplying large densities (arising from small   * standard deviations) when there are many singleton or near singleton   * values.   * @param m minimum value for standard deviation   */  public void setMinStdDev(double m) {    m_minStdDev = m;  }  /**   * Get the minimum allowable standard deviation.   * @return the minumum allowable standard deviation   */  public double getMinStdDev() {    return m_minStdDev;  }  /**   * Returns the tip text for this property   * @return tip text for this property suitable for   * displaying in the explorer/experimenter gui   */  public String seedTipText() {    return "random number seed";  }  /**   * Set the random number seed   *   * @param s the seed   */  public void setSeed (int s) {    m_rseed = s;  }  /**   * Get the random number seed   *   * @return the seed   */  public int getSeed () {    return  m_rseed;  }  /**   * Returns the tip text for this property   * @return tip text for this property suitable for   * displaying in the explorer/experimenter gui   */  public String numClustersTipText() {    return "set number of clusters. -1 to select number of clusters "      +"automatically by cross validation.";  }  /**   * Set the number of clusters (-1 to select by CV).   *   * @param n the number of clusters   * @exception Exception if n is 0   */  public void setNumClusters (int n)    throws Exception {        if (n == 0) {      throw  new Exception("Number of clusters must be > 0. (or -1 to " 			   + "select by cross validation).");    }    if (n < 0) {      m_num_clusters = -1;      m_initialNumClusters = -1;    }    else {      m_num_clusters = n;      m_initialNumClusters = n;    }  }  /**   * Get the number of clusters   *   * @return the number of clusters.   */  public int getNumClusters () {    return  m_initialNumClusters;  }  /**   * Returns the tip text for this property   * @return tip text for this property suitable for   * displaying in the explorer/experimenter gui   */  public String maxIterationsTipText() {    return "maximum number of iterations";  }  /**   * Set the maximum number of iterations to perform   *   * @param i the number of iterations   * @exception Exception if i is less than 1   */  public void setMaxIterations (int i)    throws Exception {    if (i < 1) {      throw  new Exception("Maximum number of iterations must be > 0!");    }    m_max_iterations = i;  }  /**   * Get the maximum number of iterations   *   * @return the number of iterations   */  public int getMaxIterations () {    return  m_max_iterations;  }  /**   * Set debug mode - verbose output   *   * @param v true for verbose output   */  public void setDebug (boolean v) {    m_verbose = v;  }  /**   * Get debug mode   *   * @return true if debug mode is set   */  public boolean getDebug () {    return  m_verbose;  }  /**   * Gets the current settings of EM.   *   * @return an array of strings suitable for passing to setOptions()   */  public String[] getOptions () {    String[] options = new String[9];    int current = 0;    if (m_verbose) {      options[current++] = "-V";    }    options[current++] = "-I";    options[current++] = "" + m_max_iterations;    options[current++] = "-N";    options[current++] = "" + getNumClusters();    options[current++] = "-S";    options[current++] = "" + m_rseed;    options[current++] = "-M";    options[current++] = ""+getMinStdDev();    while (current < options.length) {      options[current++] = "";    }    return  options;  }  /**   * Initialise estimators and storage.   *   * @param inst the instances   * @param num_cl the number of clusters   **/  private void EM_Init (Instances inst, int num_cl)    throws Exception {    int i, j, k;    m_weights = new double[inst.numInstances()][num_cl];    m_model = new DiscreteEstimator[num_cl][m_num_attribs];    m_modelNormal = new double[num_cl][m_num_attribs][3];    m_priors = new double[num_cl];    for (i = 0; i < num_cl; i++) {      for (j = 0; j < m_num_attribs; j++) {	if (inst.attribute(j).isNominal()) {	  m_model[i][j] = new DiscreteEstimator(m_theInstances.						attribute(j).numValues()						, true);	  for (k=0; k<m_theInstances.attribute(j).numValues(); k++) {	    m_model[i][j].addValue(k, 10*m_rr.nextDouble());	  }	}	else {	  double delta_init = m_maxValues[j]-m_minValues[j];	  m_modelNormal[i][j][0] = m_minValues[j]+delta_init*m_rr.nextDouble();	  m_modelNormal[i][j][1] = delta_init/(2*num_cl);	  m_modelNormal[i][j][2] = 1.0;	}      }    }        // initially equal priors    for (j = 0; j < num_cl; j++) {      m_priors[j] += 1.0;    }    Utils.normalize(m_priors);  }  /**   * calculate prior probabilites for the clusters   *   * @param inst the instances   * @param num_cl the number of clusters   * @exception Exception if priors can't be calculated   **/  private void estimate_priors (Instances inst, int num_cl)    throws Exception {    for (int i = 0; i < num_cl; i++) {      m_priors[i] = 0.0;    }    for (int i = 0; i < inst.numInstances(); i++) {      for (int j = 0; j < num_cl; j++) {        m_priors[j] += m_weights[i][j];      }    }    Utils.normalize(m_priors);  }  /**   * Density function of normal distribution.   * @param x input value   * @param mean mean of distribution   * @param stdDev standard deviation of distribution   */  private double normalDens (double x, double mean, double stdDev) {    double diff = x - mean;       return  (1/(m_normConst*stdDev))*Math.exp(-(diff*diff/(2*stdDev*stdDev)));  }  /**   * New probability estimators for an iteration   *   * @param num_cl the numbe of clusters   */  private void new_estimators (int num_cl) {    for (int i = 0; i < num_cl; i++) {      for (int j = 0; j < m_num_attribs; j++) {        if (m_theInstances.attribute(j).isNominal()) {          m_model[i][j] = new DiscreteEstimator(m_theInstances.						attribute(j).numValues()						, true);        }        else {          m_modelNormal[i][j][0] = m_modelNormal[i][j][1] = 	    m_modelNormal[i][j][2] = 0.0;        }      }    }  }  /**   * The M step of the EM algorithm.   * @param inst the training instances   * @param num_cl the number of clusters   */  private void M (Instances inst, int num_cl)    throws Exception {    int i, j, l;    new_estimators(num_cl);    for (i = 0; i < num_cl; i++) {      for (j = 0; j < m_num_attribs; j++) {        for (l = 0; l < inst.numInstances(); l++) {          if (!inst.instance(l).isMissing(j)) {            if (inst.attribute(j).isNominal()) {              m_model[i][j].addValue(inst.instance(l).value(j), 				     m_weights[l][i]);            }            else {              m_modelNormal[i][j][0] += (inst.instance(l).value(j) * 					 m_weights[l][i]);              m_modelNormal[i][j][2] += m_weights[l][i];              m_modelNormal[i][j][1] += (inst.instance(l).value(j) * 					 inst.instance(l).value(j)*m_weights[l][i]);            }          }        }      }    }        // calcualte mean and std deviation for numeric attributes    for (j = 0; j < m_num_attribs; j++) {      if (!inst.attribute(j).isNominal()) {        for (i = 0; i < num_cl; i++) {          if (m_modelNormal[i][j][2] < 0) {            m_modelNormal[i][j][1] = 0;          } else {

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -