📄 em.java
字号:
// variance m_modelNormal[i][j][1] = (m_modelNormal[i][j][1] - (m_modelNormal[i][j][0] * m_modelNormal[i][j][0] / m_modelNormal[i][j][2])) / m_modelNormal[i][j][2]; // std dev m_modelNormal[i][j][1] = Math.sqrt(m_modelNormal[i][j][1]); if (m_modelNormal[i][j][1] <= m_minStdDev || Double.isNaN(m_modelNormal[i][j][1])) { m_modelNormal[i][j][1] = m_minStdDev; } // mean if (m_modelNormal[i][j][2] > 0.0) { m_modelNormal[i][j][0] /= m_modelNormal[i][j][2]; } } } } } } /** * The E step of the EM algorithm. Estimate cluster membership * probabilities. * * @param inst the training instances * @param num_cl the number of clusters * @return the average log likelihood */ private double E (Instances inst, int num_cl) throws Exception { int i, j, l; double prob; double loglk = 0.0; for (l = 0; l < inst.numInstances(); l++) { for (i = 0; i < num_cl; i++) { m_weights[l][i] = m_priors[i]; } for (j = 0; j < m_num_attribs; j++) { double max = 0; for (i = 0; i < num_cl; i++) { if (!inst.instance(l).isMissing(j)) { if (inst.attribute(j).isNominal()) { m_weights[l][i] *= m_model[i][j].getProbability(inst.instance(l).value(j)); } else { // numeric attribute m_weights[l][i] *= normalDens(inst.instance(l).value(j), m_modelNormal[i][j][0], m_modelNormal[i][j][1]); if (Double.isInfinite(m_weights[l][i])) { throw new Exception("Joint density has overflowed. Try " +"increasing the minimum allowable " +"standard deviation for normal " +"density calculation."); } } if (m_weights[l][i] > max) { max = m_weights[l][i]; } } } if (max > 0 && max < 1e-75) { // check for underflow for (int zz = 0; zz < num_cl; zz++) { // rescale m_weights[l][zz] *= 1e75; } } } double temp1 = 0; for (i = 0; i < num_cl; i++) { temp1 += m_weights[l][i]; } if (temp1 > 0) { loglk += Math.log(temp1); } // normalise the weights for this instance try { Utils.normalize(m_weights[l]); } catch (Exception e) { throw new Exception("An instance has zero cluster memberships. Try " +"increasing the minimum allowable " +"standard deviation for normal " +"density calculation."); } } // reestimate priors estimate_priors(inst, num_cl); return loglk/inst.numInstances(); } /** * Constructor. * **/ public EM () { resetOptions(); } /** * Reset to default options */ protected void resetOptions () { m_minStdDev = 1e-6; m_max_iterations = 100; m_rseed = 100; m_num_clusters = -1; m_initialNumClusters = -1; m_verbose = false; } /** * Return the normal distributions for the cluster models * * @return a <code>double[][][]</code> value */ public double [][][] getClusterModelsNumericAtts() { return m_modelNormal; } /** * Return the priors for the clusters * * @return a <code>double[]</code> value */ public double [] getClusterPriors() { return m_priors; } /** * Outputs the generated clusters into a string. */ public String toString () { StringBuffer text = new StringBuffer(); text.append("\nEM\n==\n"); if (m_initialNumClusters == -1) { text.append("\nNumber of clusters selected by cross validation: " +m_num_clusters+"\n"); } else { text.append("\nNumber of clusters: " + m_num_clusters + "\n"); } for (int j = 0; j < m_num_clusters; j++) { text.append("\nCluster: " + j + " Prior probability: " + Utils.doubleToString(m_priors[j], 4) + "\n\n"); for (int i = 0; i < m_num_attribs; i++) { text.append("Attribute: " + m_theInstances.attribute(i).name() + "\n"); if (m_theInstances.attribute(i).isNominal()) { if (m_model[j][i] != null) { text.append(m_model[j][i].toString()); } } else { text.append("Normal Distribution. Mean = " + Utils.doubleToString(m_modelNormal[j][i][0], 4) + " StdDev = " + Utils.doubleToString(m_modelNormal[j][i][1], 4) + "\n"); } } } return text.toString(); } /** * verbose output for debugging * @param inst the training instances */ private void EM_Report (Instances inst) { int i, j, l, m; System.out.println("======================================"); for (j = 0; j < m_num_clusters; j++) { for (i = 0; i < m_num_attribs; i++) { System.out.println("Clust: " + j + " att: " + i + "\n"); if (m_theInstances.attribute(i).isNominal()) { if (m_model[j][i] != null) { System.out.println(m_model[j][i].toString()); } } else { System.out.println("Normal Distribution. Mean = " + Utils.doubleToString(m_modelNormal[j][i][0] , 8, 4) + " StandardDev = " + Utils.doubleToString(m_modelNormal[j][i][1] , 8, 4) + " WeightSum = " + Utils.doubleToString(m_modelNormal[j][i][2] , 8, 4)); } } } for (l = 0; l < inst.numInstances(); l++) { m = Utils.maxIndex(m_weights[l]); System.out.print("Inst " + Utils.doubleToString((double)l, 5, 0) + " Class " + m + "\t"); for (j = 0; j < m_num_clusters; j++) { System.out.print(Utils.doubleToString(m_weights[l][j], 7, 5) + " "); } System.out.println(); } } /** * estimate the number of clusters by cross validation on the training * data. * * @return the number of clusters selected */ private int CVClusters () throws Exception { double CVLogLikely = -Double.MAX_VALUE; double templl, tll; boolean CVincreased = true; int num_cl = 1; int i; Random cvr; Instances trainCopy; int numFolds = (m_theInstances.numInstances() < 10) ? m_theInstances.numInstances() : 10; while (CVincreased) { CVincreased = false; cvr = new Random(m_rseed); trainCopy = new Instances(m_theInstances); trainCopy.randomize(cvr); // theInstances.stratify(10); templl = 0.0; for (i = 0; i < numFolds; i++) { Instances cvTrain = trainCopy.trainCV(numFolds, i); Instances cvTest = trainCopy.testCV(numFolds, i); EM_Init(cvTrain, num_cl); iterate(cvTrain, num_cl, false); tll = E(cvTest, num_cl); if (m_verbose) { System.out.println("# clust: " + num_cl + " Fold: " + i + " Loglikely: " + tll); } templl += tll; } templl /= (double)numFolds; if (m_verbose) { System.out.println("===================================" + "==============\n# clust: " + num_cl + " Mean Loglikely: " + templl + "\n================================" + "================="); } if (templl > CVLogLikely) { CVLogLikely = templl; CVincreased = true; num_cl++; } } if (m_verbose) { System.out.println("Number of clusters: " + (num_cl - 1)); } return num_cl - 1; } /** * Returns the number of clusters. * * @return the number of clusters generated for a training dataset. * @exception Exception if number of clusters could not be returned * successfully */ public int numberOfClusters () throws Exception { if (m_num_clusters == -1) { throw new Exception("Haven't generated any clusters!"); } return m_num_clusters; } /** * Updates the minimum and maximum values for all the attributes * based on a new instance. * * @param instance the new instance */ private void updateMinMax(Instance instance) { for (int j = 0; j < m_theInstances.numAttributes(); j++) { if (!instance.isMissing(j)) { if (Double.isNaN(m_minValues[j])) { m_minValues[j] = instance.value(j); m_maxValues[j] = instance.value(j); } else { if (instance.value(j) < m_minValues[j]) { m_minValues[j] = instance.value(j); } else { if (instance.value(j) > m_maxValues[j]) { m_maxValues[j] = instance.value(j); } } } } } } /** * Generates a clusterer. Has to initialize all fields of the clusterer * that are not being set via options. * * @param data set of instances serving as training data * @exception Exception if the clusterer has not been * generated successfully */ public void buildClusterer (Instances data) throws Exception { if (data.checkForStringAttributes()) { throw new Exception("Can't handle string attributes!"); } m_theInstances = data; // calculate min and max values for attributes m_minValues = new double [m_theInstances.numAttributes()]; m_maxValues = new double [m_theInstances.numAttributes()]; for (int i = 0; i < m_theInstances.numAttributes(); i++) { m_minValues[i] = m_maxValues[i] = Double.NaN; } for (int i = 0; i < m_theInstances.numInstances(); i++) { updateMinMax(m_theInstances.instance(i)); } doEM(); // save memory m_theInstances = new Instances(m_theInstances,0); } /** * Computes the density for a given instance. * * @param inst the instance to compute the density for * @return the density. * @exception Exception if the density could not be computed * successfully */ public double densityForInstance(Instance inst) throws Exception { return Utils.sum(weightsForInstance(inst)); } /** * Predicts the cluster memberships for a given instance. * * @param data set of test instances * @param instance the instance to be assigned a cluster. * @return an array containing the estimated membership * probabilities of the test instance in each cluster (this * should sum to at most 1) * @exception Exception if distribution could not be * computed successfully */ public double[] distributionForInstance (Instance inst) throws Exception { double [] distrib = weightsForInstance(inst); Utils.normalize(distrib); return distrib; } /** * Returns the weights (indicating cluster membership) for a given instance * * @param inst the instance to be assigned a cluster * @return an array of weights * @exception Exception if weights could not be computed */ protected double[] weightsForInstance(Instance inst) throws Exception { int i, j; double prob; double[] wghts = new double[m_num_clusters]; for (i = 0; i < m_num_clusters; i++) { prob = 1.0; for (j = 0; j < m_num_attribs; j++) { if (!inst.isMissing(j)) { if (inst.attribute(j).isNominal()) { prob *= m_model[i][j].getProbability(inst.value(j)); } else { // numeric attribute prob *= normalDens(inst.value(j), m_modelNormal[i][j][0], m_modelNormal[i][j][1]); } } } wghts[i] = (prob*m_priors[i]); } return wghts; } /** * Perform the EM algorithm */ private void doEM () throws Exception { if (m_verbose) { System.out.println("Seed: " + m_rseed); } m_rr = new Random(m_rseed); // throw away numbers to avoid problem of similar initial numbers // from a similar seed for (int i=0; i<10; i++) m_rr.nextDouble(); m_num_instances = m_theInstances.numInstances(); m_num_attribs = m_theInstances.numAttributes(); if (m_verbose) { System.out.println("Number of instances: " + m_num_instances + "\nNumber of atts: " + m_num_attribs + "\n"); } // setDefaultStdDevs(theInstances); // cross validate to determine number of clusters? if (m_initialNumClusters == -1) { if (m_theInstances.numInstances() > 9) { m_num_clusters = CVClusters(); } else { m_num_clusters = 1; } } // fit full training set EM_Init(m_theInstances, m_num_clusters); m_loglikely = iterate(m_theInstances, m_num_clusters, m_verbose); } /** * iterates the E and M steps until the log likelihood of the data * converges. * * @param inst the training instances. * @param num_cl the number of clusters. * @param report be verbose. * @return the log likelihood of the data */ private double iterate (Instances inst, int num_cl, boolean report) throws Exception { int i; double llkold = 0.0; double llk = 0.0; if (report) { EM_Report(inst); } for (i = 0; i < m_max_iterations; i++) { llkold = llk; llk = E(inst, num_cl); if (report) { System.out.println("Loglikely: " + llk); } if (i > 0) { if ((llk - llkold) < 1e-6) { break; } } M(inst, num_cl); } if (report) { EM_Report(inst); } return llk; } // ============ // Test method. // ============ /** * Main method for testing this class. * * @param argv should contain the following arguments: <p> * -t training file [-T test file] [-N number of clusters] [-S random seed] */ public static void main (String[] argv) { try { System.out.println(ClusterEvaluation. evaluateClusterer(new EM(), argv)); } catch (Exception e) { System.out.println(e.getMessage()); e.printStackTrace(); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -