📄 em.java
字号:
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* EM.java
* Copyright (C) 1999 Mark Hall
*
*/
package weka.clusterers;
import java.io.*;
import java.util.*;
import weka.core.*;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;
import weka.estimators.*;
/**
* Simple EM (expectation maximisation) class. <p>
*
* EM assigns a probability distribution to each instance which
* indicates the probability of it belonging to each of the clusters.
* EM can decide how many clusters to create by cross validation, or you
* may specify apriori how many clusters to generate. <p>
* <br>
* The cross validation performed to determine the number of clusters
* is done in the following steps:<br>
* 1. the number of clusters is set to 1<br>
* 2. the training set is split randomly into 10 folds.<br>
* 3. EM is performed 10 times using the 10 folds the usual CV way.<br>
* 4. the loglikelihood is averaged over all 10 results.<br>
* 5. if loglikelihood has increased the number of clusters is increased by 1
* and the program continues at step 2. <br>
*<br>
* The number of folds is fixed to 10, as long as the number of instances in
* the training set is not smaller 10. If this is the case the number of folds
* is set equal to the number of instances.<p>
*
* Valid options are:<p>
*
* -V <br>
* Verbose. <p>
*
* -N <number of clusters> <br>
* Specify the number of clusters to generate. If omitted,
* EM will use cross validation to select the number of clusters
* automatically. <p>
*
* -I <max iterations> <br>
* Terminate after this many iterations if EM has not converged. <p>
*
* -S <seed> <br>
* Specify random number seed. <p>
*
* -M <num> <br>
* Set the minimum allowable standard deviation for normal density calculation.
* <p>
*
* @author Mark Hall (mhall@cs.waikato.ac.nz)
* @author Eibe Frank (eibe@cs.waikato.ac.nz)
* @version $Revision: 1.1 $
*/
public class EM
extends DensityBasedClusterer
implements NumberOfClustersRequestable,
OptionHandler, WeightedInstancesHandler {
/** hold the discrete estimators for each cluster */
private Estimator m_model[][];
/** hold the normal estimators for each cluster */
private double m_modelNormal[][][];
/** default minimum standard deviation */
private double m_minStdDev = 1e-6;
private double [] m_minStdDevPerAtt;
/** hold the weights of each instance for each cluster */
private double m_weights[][];
/** the prior probabilities for clusters */
private double m_priors[];
/** the loglikelihood of the data */
private double m_loglikely;
/** training instances */
private Instances m_theInstances = null;
/** number of clusters selected by the user or cross validation */
private int m_num_clusters;
/** the initial number of clusters requested by the user--- -1 if
xval is to be used to find the number of clusters */
private int m_initialNumClusters;
/** number of attributes */
private int m_num_attribs;
/** number of training instances */
private int m_num_instances;
/** maximum iterations to perform */
private int m_max_iterations;
/** attribute min values */
private double [] m_minValues;
/** attribute max values */
private double [] m_maxValues;
/** random numbers and seed */
private Random m_rr;
private int m_rseed;
/** Verbose? */
private boolean m_verbose;
/** globally replace missing values */
private ReplaceMissingValues m_replaceMissing;
/**
* Returns a string describing this clusterer
* @return a description of the evaluator suitable for
* displaying in the explorer/experimenter gui
*/
public String globalInfo() {
return "Cluster data using expectation maximization";
}
/**
* Returns an enumeration describing the available options.. <p>
*
* Valid options are:<p>
*
* -V <br>
* Verbose. <p>
*
* -N <number of clusters> <br>
* Specify the number of clusters to generate. If omitted,
* EM will use cross validation to select the number of clusters
* automatically. <p>
*
* -I <max iterations> <br>
* Terminate after this many iterations if EM has not converged. <p>
*
* -S <seed> <br>
* Specify random number seed. <p>
*
* -M <num> <br>
* Set the minimum allowable standard deviation for normal density
* calculation. <p>
*
* @return an enumeration of all the available options.
*
**/
public Enumeration listOptions () {
Vector newVector = new Vector(6);
newVector.addElement(new Option("\tnumber of clusters. If omitted or"
+ "\n\t-1 specified, then cross "
+ "validation is used to\n\tselect the "
+ "number of clusters.", "N", 1
, "-N <num>"));
newVector.addElement(new Option("\tmax iterations.\n(default 100)", "I"
, 1, "-I <num>"));
newVector.addElement(new Option("\trandom number seed.\n(default 1)"
, "S", 1, "-S <num>"));
newVector.addElement(new Option("\tverbose.", "V", 0, "-V"));
newVector.addElement(new Option("\tminimum allowable standard deviation "
+"for normal density computation "
+"\n\t(default 1e-6)"
,"M",1,"-M <num>"));
return newVector.elements();
}
/**
* Parses a given list of options.
* @param options the list of options as an array of strings
* @exception Exception if an option is not supported
*
**/
public void setOptions (String[] options)
throws Exception {
resetOptions();
setDebug(Utils.getFlag('V', options));
String optionString = Utils.getOption('I', options);
if (optionString.length() != 0) {
setMaxIterations(Integer.parseInt(optionString));
}
optionString = Utils.getOption('N', options);
if (optionString.length() != 0) {
setNumClusters(Integer.parseInt(optionString));
}
optionString = Utils.getOption('S', options);
if (optionString.length() != 0) {
setSeed(Integer.parseInt(optionString));
}
optionString = Utils.getOption('M', options);
if (optionString.length() != 0) {
setMinStdDev((new Double(optionString)).doubleValue());
}
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String minStdDevTipText() {
return "set minimum allowable standard deviation";
}
/**
* Set the minimum value for standard deviation when calculating
* normal density. Reducing this value can help prevent arithmetic
* overflow resulting from multiplying large densities (arising from small
* standard deviations) when there are many singleton or near singleton
* values.
* @param m minimum value for standard deviation
*/
public void setMinStdDev(double m) {
m_minStdDev = m;
}
public void setMinStdDevPerAtt(double [] m) {
m_minStdDevPerAtt = m;
}
/**
* Get the minimum allowable standard deviation.
* @return the minumum allowable standard deviation
*/
public double getMinStdDev() {
return m_minStdDev;
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String seedTipText() {
return "random number seed";
}
/**
* Set the random number seed
*
* @param s the seed
*/
public void setSeed (int s) {
m_rseed = s;
}
/**
* Get the random number seed
*
* @return the seed
*/
public int getSeed () {
return m_rseed;
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String numClustersTipText() {
return "set number of clusters. -1 to select number of clusters "
+"automatically by cross validation.";
}
/**
* Set the number of clusters (-1 to select by CV).
*
* @param n the number of clusters
* @exception Exception if n is 0
*/
public void setNumClusters (int n)
throws Exception {
if (n == 0) {
throw new Exception("Number of clusters must be > 0. (or -1 to "
+ "select by cross validation).");
}
if (n < 0) {
m_num_clusters = -1;
m_initialNumClusters = -1;
}
else {
m_num_clusters = n;
m_initialNumClusters = n;
}
}
/**
* Get the number of clusters
*
* @return the number of clusters.
*/
public int getNumClusters () {
return m_initialNumClusters;
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String maxIterationsTipText() {
return "maximum number of iterations";
}
/**
* Set the maximum number of iterations to perform
*
* @param i the number of iterations
* @exception Exception if i is less than 1
*/
public void setMaxIterations (int i)
throws Exception {
if (i < 1) {
throw new Exception("Maximum number of iterations must be > 0!");
}
m_max_iterations = i;
}
/**
* Get the maximum number of iterations
*
* @return the number of iterations
*/
public int getMaxIterations () {
return m_max_iterations;
}
/**
* Set debug mode - verbose output
*
* @param v true for verbose output
*/
public void setDebug (boolean v) {
m_verbose = v;
}
/**
* Get debug mode
*
* @return true if debug mode is set
*/
public boolean getDebug () {
return m_verbose;
}
/**
* Gets the current settings of EM.
*
* @return an array of strings suitable for passing to setOptions()
*/
public String[] getOptions () {
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -