📄 em.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * EM.java * Copyright (C) 1999 Mark Hall * */package weka.clusterers;import java.io.*;import java.util.*;import weka.core.*;import weka.estimators.*;/** * Simple EM (expectation maximisation) class. <p> * * EM assigns a probability distribution to each instance which * indicates the probability of it belonging to each of the clusters. * EM can decide how many clusters to create by cross validation, or you * may specify apriori how many clusters to generate. <p> * <br> * The cross validation performed to determine the number of clusters * is done in the following steps:<br> * 1. the number of clusters is set to 1<br> * 2. the training set is split randomly into 10 folds.<br> * 3. EM is performed 10 times using the 10 folds the usual CV way.<br> * 4. the loglikelihood is averaged over all 10 results.<br> * 5. if loglikelihood has increased the number of clusters is increased by 1 * and the program continues at step 2. <br> *<br> * The number of folds is fixed to 10, as long as the number of instances in * the training set is not smaller 10. If this is the case the number of folds * is set equal to the number of instances.<p> * * Valid options are:<p> * * -V <br> * Verbose. <p> * * -N <number of clusters> <br> * Specify the number of clusters to generate. If omitted, * EM will use cross validation to select the number of clusters * automatically. <p> * * -I <max iterations> <br> * Terminate after this many iterations if EM has not converged. <p> * * -S <seed> <br> * Specify random number seed. <p> * * -M <num> <br> * Set the minimum allowable standard deviation for normal density calculation. * <p> * * @author Mark Hall (mhall@cs.waikato.ac.nz) * @version $Revision: 1.1.1.1 $ */public class EM extends DistributionClusterer implements OptionHandler{ /** hold the discrete estimators for each cluster */ private Estimator m_model[][]; /** hold the normal estimators for each cluster */ private double m_modelNormal[][][]; /** default minimum standard deviation */ private double m_minStdDev = 1e-6; /** hold the weights of each instance for each cluster */ private double m_weights[][]; /** the prior probabilities for clusters */ private double m_priors[]; /** the loglikelihood of the data */ private double m_loglikely; /** training instances */ private Instances m_theInstances = null; /** number of clusters selected by the user or cross validation */ private int m_num_clusters; /** the initial number of clusters requested by the user--- -1 if xval is to be used to find the number of clusters */ private int m_initialNumClusters; /** number of attributes */ private int m_num_attribs; /** number of training instances */ private int m_num_instances; /** maximum iterations to perform */ private int m_max_iterations; /** attribute min values */ private double [] m_minValues; /** attribute max values */ private double [] m_maxValues; /** random numbers and seed */ private Random m_rr; private int m_rseed; /** Constant for normal distribution. */ private static double m_normConst = Math.sqrt(2*Math.PI); /** Verbose? */ private boolean m_verbose; /** * Returns a string describing this clusterer * @return a description of the evaluator suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "Cluster data using expectation maximization"; } /** * Returns an enumeration describing the available options.. <p> * * Valid options are:<p> * * -V <br> * Verbose. <p> * * -N <number of clusters> <br> * Specify the number of clusters to generate. If omitted, * EM will use cross validation to select the number of clusters * automatically. <p> * * -I <max iterations> <br> * Terminate after this many iterations if EM has not converged. <p> * * -S <seed> <br> * Specify random number seed. <p> * * -M <num> <br> * Set the minimum allowable standard deviation for normal density * calculation. <p> * * @return an enumeration of all the available options. * **/ public Enumeration listOptions () { Vector newVector = new Vector(6); newVector.addElement(new Option("\tnumber of clusters. If omitted or" + "\n\t-1 specified, then cross " + "validation is used to\n\tselect the " + "number of clusters.", "N", 1 , "-N <num>")); newVector.addElement(new Option("\tmax iterations.\n(default 100)", "I" , 1, "-I <num>")); newVector.addElement(new Option("\trandom number seed.\n(default 1)" , "S", 1, "-S <num>")); newVector.addElement(new Option("\tverbose.", "V", 0, "-V")); newVector.addElement(new Option("\tminimum allowable standard deviation " +"for normal density computation " +"\n\t(default 1e-6)" ,"M",1,"-M <num>")); return newVector.elements(); } /** * Parses a given list of options. * @param options the list of options as an array of strings * @exception Exception if an option is not supported * **/ public void setOptions (String[] options) throws Exception { resetOptions(); setDebug(Utils.getFlag('V', options)); String optionString = Utils.getOption('I', options); if (optionString.length() != 0) { setMaxIterations(Integer.parseInt(optionString)); } optionString = Utils.getOption('N', options); if (optionString.length() != 0) { setNumClusters(Integer.parseInt(optionString)); } optionString = Utils.getOption('S', options); if (optionString.length() != 0) { setSeed(Integer.parseInt(optionString)); } optionString = Utils.getOption('M', options); if (optionString.length() != 0) { setMinStdDev((new Double(optionString)).doubleValue()); } } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String minStdDevTipText() { return "set minimum allowable standard deviation"; } /** * Set the minimum value for standard deviation when calculating * normal density. Reducing this value can help prevent arithmetic * overflow resulting from multiplying large densities (arising from small * standard deviations) when there are many singleton or near singleton * values. * @param m minimum value for standard deviation */ public void setMinStdDev(double m) { m_minStdDev = m; } /** * Get the minimum allowable standard deviation. * @return the minumum allowable standard deviation */ public double getMinStdDev() { return m_minStdDev; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String seedTipText() { return "random number seed"; } /** * Set the random number seed * * @param s the seed */ public void setSeed (int s) { m_rseed = s; } /** * Get the random number seed * * @return the seed */ public int getSeed () { return m_rseed; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String numClustersTipText() { return "set number of clusters. -1 to select number of clusters " +"automatically by cross validation."; } /** * Set the number of clusters (-1 to select by CV). * * @param n the number of clusters * @exception Exception if n is 0 */ public void setNumClusters (int n) throws Exception { if (n == 0) { throw new Exception("Number of clusters must be > 0. (or -1 to " + "select by cross validation)."); } if (n < 0) { m_num_clusters = -1; m_initialNumClusters = -1; } else { m_num_clusters = n; m_initialNumClusters = n; } } /** * Get the number of clusters * * @return the number of clusters. */ public int getNumClusters () { return m_initialNumClusters; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String maxIterationsTipText() { return "maximum number of iterations"; } /** * Set the maximum number of iterations to perform * * @param i the number of iterations * @exception Exception if i is less than 1 */ public void setMaxIterations (int i) throws Exception { if (i < 1) { throw new Exception("Maximum number of iterations must be > 0!"); } m_max_iterations = i; } /** * Get the maximum number of iterations * * @return the number of iterations */ public int getMaxIterations () { return m_max_iterations; } /** * Set debug mode - verbose output * * @param v true for verbose output */ public void setDebug (boolean v) { m_verbose = v; } /** * Get debug mode * * @return true if debug mode is set */ public boolean getDebug () { return m_verbose; } /** * Gets the current settings of EM. * * @return an array of strings suitable for passing to setOptions() */ public String[] getOptions () { String[] options = new String[9]; int current = 0; if (m_verbose) { options[current++] = "-V"; } options[current++] = "-I"; options[current++] = "" + m_max_iterations; options[current++] = "-N"; options[current++] = "" + getNumClusters(); options[current++] = "-S"; options[current++] = "" + m_rseed; options[current++] = "-M"; options[current++] = ""+getMinStdDev(); while (current < options.length) { options[current++] = ""; } return options; } /** * Initialise estimators and storage. * * @param inst the instances * @param num_cl the number of clusters **/ private void EM_Init (Instances inst, int num_cl) throws Exception { int i, j, k; m_weights = new double[inst.numInstances()][num_cl]; m_model = new DiscreteEstimator[num_cl][m_num_attribs]; m_modelNormal = new double[num_cl][m_num_attribs][3]; m_priors = new double[num_cl]; for (i = 0; i < num_cl; i++) { for (j = 0; j < m_num_attribs; j++) { if (inst.attribute(j).isNominal()) { m_model[i][j] = new DiscreteEstimator(m_theInstances. attribute(j).numValues() , true); for (k=0; k<m_theInstances.attribute(j).numValues(); k++) { m_model[i][j].addValue(k, 10*m_rr.nextDouble()); } } else { double delta_init = m_maxValues[j]-m_minValues[j]; m_modelNormal[i][j][0] = m_minValues[j]+delta_init*m_rr.nextDouble(); m_modelNormal[i][j][1] = delta_init/(2*num_cl); m_modelNormal[i][j][2] = 1.0; } } } // initially equal priors for (j = 0; j < num_cl; j++) { m_priors[j] += 1.0; } Utils.normalize(m_priors); } /** * calculate prior probabilites for the clusters * * @param inst the instances * @param num_cl the number of clusters * @exception Exception if priors can't be calculated **/ private void estimate_priors (Instances inst, int num_cl) throws Exception { for (int i = 0; i < num_cl; i++) { m_priors[i] = 0.0; } for (int i = 0; i < inst.numInstances(); i++) { for (int j = 0; j < num_cl; j++) { m_priors[j] += m_weights[i][j]; } } Utils.normalize(m_priors); } /** * Density function of normal distribution. * @param x input value * @param mean mean of distribution * @param stdDev standard deviation of distribution */ private double normalDens (double x, double mean, double stdDev) { double diff = x - mean; return (1/(m_normConst*stdDev))*Math.exp(-(diff*diff/(2*stdDev*stdDev))); } /** * New probability estimators for an iteration * * @param num_cl the numbe of clusters */ private void new_estimators (int num_cl) { for (int i = 0; i < num_cl; i++) { for (int j = 0; j < m_num_attribs; j++) { if (m_theInstances.attribute(j).isNominal()) { m_model[i][j] = new DiscreteEstimator(m_theInstances. attribute(j).numValues() , true); } else { m_modelNormal[i][j][0] = m_modelNormal[i][j][1] = m_modelNormal[i][j][2] = 0.0; } } } } /** * The M step of the EM algorithm. * @param inst the training instances * @param num_cl the number of clusters */ private void M (Instances inst, int num_cl) throws Exception { int i, j, l; new_estimators(num_cl); for (i = 0; i < num_cl; i++) { for (j = 0; j < m_num_attribs; j++) { for (l = 0; l < inst.numInstances(); l++) { if (!inst.instance(l).isMissing(j)) { if (inst.attribute(j).isNominal()) { m_model[i][j].addValue(inst.instance(l).value(j), m_weights[l][i]); } else { m_modelNormal[i][j][0] += (inst.instance(l).value(j) * m_weights[l][i]); m_modelNormal[i][j][2] += m_weights[l][i]; m_modelNormal[i][j][1] += (inst.instance(l).value(j) * inst.instance(l).value(j)*m_weights[l][i]); } } } } } // calcualte mean and std deviation for numeric attributes for (j = 0; j < m_num_attribs; j++) { if (!inst.attribute(j).isNominal()) { for (i = 0; i < num_cl; i++) { if (m_modelNormal[i][j][2] < 0) { m_modelNormal[i][j][1] = 0; } else {
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -