📄 em.java
字号:
System.out.println("Clust: " + j + " att: " + i + "\n"); if (m_theInstances.attribute(i).isNominal()) { if (m_model[j][i] != null) { System.out.println(m_model[j][i].toString()); } } else { System.out.println("Normal Distribution. Mean = " + Utils.doubleToString(m_modelNormal[j][i][0] , 8, 4) + " StandardDev = " + Utils.doubleToString(m_modelNormal[j][i][1] , 8, 4) + " WeightSum = " + Utils.doubleToString(m_modelNormal[j][i][2] , 8, 4)); } } } for (l = 0; l < inst.numInstances(); l++) { m = Utils.maxIndex(m_weights[l]); System.out.print("Inst " + Utils.doubleToString((double)l, 5, 0) + " Class " + m + "\t"); for (j = 0; j < m_num_clusters; j++) { System.out.print(Utils.doubleToString(m_weights[l][j], 7, 5) + " "); } System.out.println(); } } /** * estimate the number of clusters by cross validation on the training * data. * * @return the number of clusters selected */ private int CVClusters () throws Exception { double CVLogLikely = -Double.MAX_VALUE; double templl, tll; boolean CVdecreased = true; int num_cl = 1; int i; Random cvr; Instances trainCopy; while (CVdecreased) { CVdecreased = false; cvr = new Random(m_rseed); trainCopy = new Instances(m_theInstances); trainCopy.randomize(cvr); // theInstances.stratify(10); templl = 0.0; for (i = 0; i < num_cvs; i++) { Instances cvTrain = trainCopy.trainCV(num_cvs, i); Instances cvTest = trainCopy.testCV(num_cvs, i); EM_Init(cvTrain, num_cl); iterate(cvTrain, num_cl, false); tll = E(cvTest, num_cl); if (m_verbose) { System.out.println("# clust: " + num_cl + " Fold: " + i + " Loglikely: " + tll); } templl += tll; } templl /= num_cvs; if (m_verbose) { System.out.println("===================================" + "==============\n# clust: " + num_cl + " Mean Loglikely: " + templl + "\n================================" + "================="); } if (templl > CVLogLikely) { CVLogLikely = templl; CVdecreased = true; num_cl++; } } if (m_verbose) { System.out.println("Number of clusters: " + (num_cl - 1)); } return num_cl - 1; } /** * Returns the number of clusters. * * @return the number of clusters generated for a training dataset. * @exception Exception if number of clusters could not be returned * successfully */ public int numberOfClusters () throws Exception { if (m_num_clusters == -1) { throw new Exception("Haven't generated any clusters!"); } return m_num_clusters; } /** * Classifies a given instance. * * @param instance the instance to be assigned to a cluster * @return the number of the assigned cluster as an interger * if the class is enumerated, otherwise the predicted value * @exception Exception if instance could not be classified * successfully */ public void buildClusterer (Instances data) throws Exception { if (data.checkForStringAttributes()) { throw new Exception("Can't handle string attributes!"); } m_theInstances = data; doEM(); } /** * Predicts the cluster memberships for a given instance. * * @param data set of test instances * @param instance the instance to be assigned a cluster. * @return an array containing the estimated membership * probabilities of the test instance in each cluster (this * should sum to at most 1) * @exception Exception if distribution could not be * computed successfully */ public double[] distributionForInstance (Instance inst) throws Exception { int i, j; double prob; double[] wghts = new double[m_num_clusters]; for (i = 0; i < m_num_clusters; i++) { prob = 1.0; for (j = 0; j < m_num_attribs; j++) { if (!inst.isMissing(j)) { if (inst.attribute(j).isNominal()) { prob *= m_model[i][j].getProbability(inst.value(j)); } else { // numeric attribute prob *= normalDens(inst.value(j), m_modelNormal[i][j][0], m_modelNormal[i][j][1]); } } } wghts[i] = (prob*m_priors[i]); } return wghts; } public double probInstanceInCluster (Instance inst, int cluster) throws Exception{ double prob = 1; // System.out.println("Instance " + inst.toString()); for (int j = 0; j < m_num_attribs; j++) { if (!inst.isMissing(j)) { if (inst.attribute(j).isNominal()) { prob *= m_model[cluster][j].getProbability(inst.value(j)); } else { // numeric attribute prob *= normalIns(inst.value(j), m_modelNormal[cluster][j][0], m_modelNormal[cluster][j][1]); // System.out.println("Prob is: " + prob); } } } prob *= m_priors[cluster]; return prob; } /** * Perform the EM algorithm */ private void doEM () throws Exception { if (m_verbose) { System.out.println("Seed: " + m_rseed); } m_rr = new Random(m_rseed); m_num_instances = m_theInstances.numInstances(); m_num_attribs = m_theInstances.numAttributes(); if (m_verbose) { System.out.println("Number of instances: " + m_num_instances + "\nNumber of atts: " + m_num_attribs + "\n"); } // setDefaultStdDevs(theInstances); // cross validate to determine number of clusters? if (m_initialNumClusters == -1) { m_num_clusters = CVClusters(); } // fit full training set EM_Init(m_theInstances, m_num_clusters); m_loglikely = iterate(m_theInstances, m_num_clusters, m_verbose); } /** * iterates the M and E steps until the log likelihood of the data * converges. * * @param inst the training instances. * @param num_cl the number of clusters. * @param report be verbose. * @return the log likelihood of the data */ private double iterate (Instances inst, int num_cl, boolean report) throws Exception { int i; double llkold = 0.0; double llk = 0.0; if (report) { EM_Report(inst); } for (i = 0; i < m_max_iterations; i++) { M(inst, num_cl); llkold = llk; llk = E(inst, num_cl); if (report) { System.out.println("Loglikely: " + llk); } if (i > 0) { if ((llk - llkold) < 1e-6) { break; } } } if (report) { EM_Report(inst); } return llk; } public double densityForInstance(Instance inst) throws Exception { return Utils.sum(weightsForInstance(inst)); } protected double[] weightsForInstance(Instance inst) throws Exception { int i, j; double prob; double[] wghts = new double[m_num_clusters]; for (i = 0; i < m_num_clusters; i++) { prob = 1.0; for (j = 0; j < m_num_attribs; j++) { if (!inst.isMissing(j)) { if (inst.attribute(j).isNominal()) { prob *= m_model[i][j].getProbability(inst.value(j)); } else { // numeric attribute prob *= normalDens(inst.value(j), m_modelNormal[i][j][0], m_modelNormal[i][j][1]); } } } wghts[i] = (prob*m_priors[i]); } return wghts; } // ============ // Test method. // ============ /** * Main method for testing this class. * * @param argv should contain the following arguments: <p> * -t training file [-T test file] [-N number of clusters] [-S random seed] */ public static void main (String[] argv) { try { System.out.println(ClusterEvaluation. evaluateClusterer(new EM(), argv)); } catch (Exception e) { System.out.println(e.getMessage()); } } }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -