📄 em.java
字号:
/* * EM.java * Copyright (C) 1999 Mark Hall * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */package tclass.clusteralg; import java.io.*;import java.util.*;import weka.core.*;import weka.estimators.*;import tclass.util.FastMath; import weka.clusterers.*; /** * Simple EM (estimation maximisation) class. <p> * * EM assigns a probability distribution to each instance which * indicates the probability of it belonging to each of the clusters. * EM can decide how many clusters to create by cross validation, or you * may specify apriori how many clusters to generate. <p> * * Valid options are:<p> * * -V <br> * Verbose. <p> * * -N <number of clusters> <br * Specify the number of clusters to generate. If omitted, * EM will use cross validation to select the number of clusters * automatically. <p> * * -I <max iterations> <br> * Terminate after this many iterations if EM has not converged. <p> * * -S <seed> <br> * Specify random number seed. <p> * * @author Mark Hall (mhall@cs.waikato.ac.nz) * @version $Revision: 1.1.1.1 $ */public class EM extends DistributionClusterer implements OptionHandler{ /** hold the discrete estimators for each cluster */ private Estimator m_model[][]; /** hold the normal estimators for each cluster */ private double m_modelNormal[][][]; /** hold the weights of each instance for each cluster */ private double m_weights[][]; /** hold default standard deviations for numeric attributes */ private double m_defSds[]; /** the prior probabilities for clusters */ private double m_priors[]; /** the loglikelihood of the data */ private double m_loglikely; /** the number of cross-validations to do */ private int num_cvs; /** training instances */ private Instances m_theInstances = null; /** number of clusters selected by the user or cross validation */ private int m_num_clusters; /** the initial number of clusters requested by the user--- -1 if xval is to be used to find the number of clusters */ private int m_initialNumClusters; /** number of attributes */ private int m_num_attribs; /** number of training instances */ private int m_num_instances; /** maximum iterations to perform */ private int m_max_iterations; /** random numbers and seed */ private Random m_rr; private int m_rseed; /** Constant for normal distribution. */ private static double m_normConst = Math.sqrt(2*Math.PI); /** Verbose? */ private boolean m_verbose; /** * Returns an enumeration describing the available options. <p> * * Valid options are:<p> * * -V <br> * Verbose. <p> * * -N <number of clusters> <br * Specify the number of clusters to generate. If omitted, * EM will use cross validation to select the number of clusters * automatically. <p> * * -I <max iterations> <br> * Terminate after this many iterations if EM has not converged. <p> * * -S <seed> <br> * Specify random number seed. <p> * * @return an enumeration of all the available options * **/ public Enumeration listOptions () { Vector newVector = new Vector(5); newVector.addElement(new Option("\tnumber of clusters. If omitted or" + "\n\t-1 specified, then cross " + "validation is used to\n\tselect the " + "number of clusters.", "N", 1 , "-N <num>")); newVector.addElement(new Option("\tmax iterations.\n(default 100)", "I" , 1, "-I <num>")); newVector.addElement(new Option("\trandom number seed.\n(default 1)" , "S", 1, "-S <num>")); newVector.addElement(new Option("\tverbose.", "V", 0, "-V")); return newVector.elements(); } /** * Parses a given list of options. * @param options the list of options as an array of strings * @exception Exception if an option is not supported * **/ public void setOptions (String[] options) throws Exception { resetOptions(); setDebug(Utils.getFlag('V', options)); String optionString = Utils.getOption('I', options); if (optionString.length() != 0) { setMaxIterations(Integer.parseInt(optionString)); } optionString = Utils.getOption('N', options); if (optionString.length() != 0) { setNumClusters(Integer.parseInt(optionString)); } optionString = Utils.getOption('S', options); if (optionString.length() != 0) { setSeed(Integer.parseInt(optionString)); } } public void setNumCVs(int num_cvs){ this.num_cvs = num_cvs; } public int getNumCVs(){ return num_cvs; } /** * Set the random number seed * * @param s the seed */ public void setSeed (int s) { m_rseed = s; } /** * Get the random number seed * * @return the seed */ public int getSeed () { return m_rseed; } /** * Set the number of clusters (-1 to select by CV). * * @param n the number of clusters * @exception Exception if n is 0 */ public void setNumClusters (int n) throws Exception { if (n == 0) { throw new Exception("Number of clusters must be > 0. (or -1 to " + "select by cross validation)."); } if (n < 0) { m_num_clusters = -1; m_initialNumClusters = -1; } else { m_num_clusters = n; m_initialNumClusters = n; } } /** * Get the number of clusters * * @return the number of clusters. */ public int getNumClusters () { return m_initialNumClusters; } /** * Set the maximum number of iterations to perform * * @param i the number of iterations * @exception Exception if i is less than 1 */ public void setMaxIterations (int i) throws Exception { if (i < 1) { throw new Exception("Maximum number of iterations must be > 0!"); } m_max_iterations = i; } /** * Get the maximum number of iterations * * @return the number of iterations */ public int getMaxIterations () { return m_max_iterations; } /** * Set debug mode - verbose output * * @param v true for verbose output */ public void setDebug (boolean v) { m_verbose = v; } /** * Get debug mode * * @return true if debug mode is set */ public boolean getDebug () { return m_verbose; } /** * Gets the current settings of EM. * * @return an array of strings suitable for passing to setOptions() */ public String[] getOptions () { String[] options = new String[7]; int current = 0; if (m_verbose) { options[current++] = "-V"; } options[current++] = "-I"; options[current++] = "" + m_max_iterations; options[current++] = "-N"; options[current++] = "" + getNumClusters(); options[current++] = "-S"; options[current++] = "" + m_rseed; while (current < options.length) { options[current++] = ""; } return options; } /** * Sets default standard devs for numeric attributes based on the * differences between their sorted values. * @param inst the instances **/ private void setDefaultStdDevs (Instances inst) { int i; Instances copyI = new Instances(inst); inst = copyI; m_defSds = new double[m_num_attribs]; for (i = 0; i < m_num_attribs; i++) { m_defSds[i] = 0.01; } for (i = 0; i < m_num_attribs; i++) { if (inst.attribute(i).isNumeric()) { inst.sort(i); if ((inst.numInstances() > 0) && !inst.instance(0).isMissing(i)) { double lastVal = inst.instance(0).value(i);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -