📄 em.java
字号:
} for (l = 0; l < inst.numInstances(); l++) { m = Utils.maxIndex(m_weights[l]); System.out.print("Inst " + Utils.doubleToString((double)l, 5, 0) + " Class " + m + "\t"); for (j = 0; j < m_num_clusters; j++) { System.out.print(Utils.doubleToString(m_weights[l][j], 7, 5) + " "); } System.out.println(); } } /** * estimate the number of clusters by cross validation on the training * data. * * @throws Exception if something goes wrong */ private void CVClusters () throws Exception { double CVLogLikely = -Double.MAX_VALUE; double templl, tll; boolean CVincreased = true; m_num_clusters = 1; int num_clusters = m_num_clusters; int i; Random cvr; Instances trainCopy; int numFolds = (m_theInstances.numInstances() < 10) ? m_theInstances.numInstances() : 10; boolean ok = true; int seed = getSeed(); int restartCount = 0; CLUSTER_SEARCH: while (CVincreased) { // theInstances.stratify(10); CVincreased = false; cvr = new Random(getSeed()); trainCopy = new Instances(m_theInstances); trainCopy.randomize(cvr); templl = 0.0; for (i = 0; i < numFolds; i++) { Instances cvTrain = trainCopy.trainCV(numFolds, i, cvr); if (num_clusters > cvTrain.numInstances()) { break CLUSTER_SEARCH; } Instances cvTest = trainCopy.testCV(numFolds, i); m_rr = new Random(seed); for (int z=0; z<10; z++) m_rr.nextDouble(); m_num_clusters = num_clusters; EM_Init(cvTrain); try { iterate(cvTrain, false); } catch (Exception ex) { // catch any problems - i.e. empty clusters occuring ex.printStackTrace(); // System.err.println("Restarting after CV training failure ("+num_clusters+" clusters"); seed++; restartCount++; ok = false; if (restartCount > 5) { break CLUSTER_SEARCH; } break; } try { tll = E(cvTest, false); } catch (Exception ex) { // catch any problems - i.e. empty clusters occuring // ex.printStackTrace(); ex.printStackTrace(); // System.err.println("Restarting after CV testing failure ("+num_clusters+" clusters"); // throw new Exception(ex); seed++; restartCount++; ok = false; if (restartCount > 5) { break CLUSTER_SEARCH; } break; } if (m_verbose) { System.out.println("# clust: " + num_clusters + " Fold: " + i + " Loglikely: " + tll); } templl += tll; } if (ok) { restartCount = 0; seed = getSeed(); templl /= (double)numFolds; if (m_verbose) { System.out.println("===================================" + "==============\n# clust: " + num_clusters + " Mean Loglikely: " + templl + "\n================================" + "================="); } if (templl > CVLogLikely) { CVLogLikely = templl; CVincreased = true; num_clusters++; } } } if (m_verbose) { System.out.println("Number of clusters: " + (num_clusters - 1)); } m_num_clusters = num_clusters - 1; } /** * Returns the number of clusters. * * @return the number of clusters generated for a training dataset. * @throws Exception if number of clusters could not be returned * successfully */ public int numberOfClusters () throws Exception { if (m_num_clusters == -1) { throw new Exception("Haven't generated any clusters!"); } return m_num_clusters; } /** * Updates the minimum and maximum values for all the attributes * based on a new instance. * * @param instance the new instance */ private void updateMinMax(Instance instance) { for (int j = 0; j < m_theInstances.numAttributes(); j++) { if (!instance.isMissing(j)) { if (Double.isNaN(m_minValues[j])) { m_minValues[j] = instance.value(j); m_maxValues[j] = instance.value(j); } else { if (instance.value(j) < m_minValues[j]) { m_minValues[j] = instance.value(j); } else { if (instance.value(j) > m_maxValues[j]) { m_maxValues[j] = instance.value(j); } } } } } } /** * Returns default capabilities of the clusterer (i.e., the ones of * SimpleKMeans). * * @return the capabilities of this clusterer */ public Capabilities getCapabilities() { Capabilities result = new SimpleKMeans().getCapabilities(); result.setOwner(this); return result; } /** * Generates a clusterer. Has to initialize all fields of the clusterer * that are not being set via options. * * @param data set of instances serving as training data * @throws Exception if the clusterer has not been * generated successfully */ public void buildClusterer (Instances data) throws Exception { // can clusterer handle the data? getCapabilities().testWithFail(data); m_replaceMissing = new ReplaceMissingValues(); Instances instances = new Instances(data); instances.setClassIndex(-1); m_replaceMissing.setInputFormat(instances); data = weka.filters.Filter.useFilter(instances, m_replaceMissing); instances = null; m_theInstances = data; // calculate min and max values for attributes m_minValues = new double [m_theInstances.numAttributes()]; m_maxValues = new double [m_theInstances.numAttributes()]; for (int i = 0; i < m_theInstances.numAttributes(); i++) { m_minValues[i] = m_maxValues[i] = Double.NaN; } for (int i = 0; i < m_theInstances.numInstances(); i++) { updateMinMax(m_theInstances.instance(i)); } doEM(); // save memory m_theInstances = new Instances(m_theInstances,0); } /** * Returns the cluster priors. * * @return the cluster priors */ public double[] clusterPriors() { double[] n = new double[m_priors.length]; System.arraycopy(m_priors, 0, n, 0, n.length); return n; } /** * Computes the log of the conditional density (per cluster) for a given instance. * * @param inst the instance to compute the density for * @return an array containing the estimated densities * @throws Exception if the density could not be computed * successfully */ public double[] logDensityPerClusterForInstance(Instance inst) throws Exception { int i, j; double logprob; double[] wghts = new double[m_num_clusters]; m_replaceMissing.input(inst); inst = m_replaceMissing.output(); for (i = 0; i < m_num_clusters; i++) { // System.err.println("Cluster : "+i); logprob = 0.0; for (j = 0; j < m_num_attribs; j++) { if (!inst.isMissing(j)) { if (inst.attribute(j).isNominal()) { logprob += Math.log(m_model[i][j].getProbability(inst.value(j))); } else { // numeric attribute logprob += logNormalDens(inst.value(j), m_modelNormal[i][j][0], m_modelNormal[i][j][1]); /* System.err.println(logNormalDens(inst.value(j), m_modelNormal[i][j][0], m_modelNormal[i][j][1]) + " "); */ } } } // System.err.println(""); wghts[i] = logprob; } return wghts; } /** * Perform the EM algorithm * * @throws Exception if something goes wrong */ private void doEM () throws Exception { if (m_verbose) { System.out.println("Seed: " + getSeed()); } m_rr = new Random(getSeed()); // throw away numbers to avoid problem of similar initial numbers // from a similar seed for (int i=0; i<10; i++) m_rr.nextDouble(); m_num_instances = m_theInstances.numInstances(); m_num_attribs = m_theInstances.numAttributes(); if (m_verbose) { System.out.println("Number of instances: " + m_num_instances + "\nNumber of atts: " + m_num_attribs + "\n"); } // setDefaultStdDevs(theInstances); // cross validate to determine number of clusters? if (m_initialNumClusters == -1) { if (m_theInstances.numInstances() > 9) { CVClusters(); m_rr = new Random(getSeed()); for (int i=0; i<10; i++) m_rr.nextDouble(); } else { m_num_clusters = 1; } } // fit full training set EM_Init(m_theInstances); m_loglikely = iterate(m_theInstances, m_verbose); } /** * iterates the E and M steps until the log likelihood of the data * converges. * * @param inst the training instances. * @param report be verbose. * @return the log likelihood of the data * @throws Exception if something goes wrong */ private double iterate (Instances inst, boolean report) throws Exception { int i; double llkold = 0.0; double llk = 0.0; if (report) { EM_Report(inst); } boolean ok = false; int seed = getSeed(); int restartCount = 0; while (!ok) { try { for (i = 0; i < m_max_iterations; i++) { llkold = llk; llk = E(inst, true); if (report) { System.out.println("Loglikely: " + llk); } if (i > 0) { if ((llk - llkold) < 1e-6) { break; } } M(inst); } ok = true; } catch (Exception ex) { // System.err.println("Restarting after training failure"); ex.printStackTrace(); seed++; restartCount++; m_rr = new Random(seed); for (int z = 0; z < 10; z++) { m_rr.nextDouble(); m_rr.nextInt(); } if (restartCount > 5) { // System.err.println("Reducing the number of clusters"); m_num_clusters--; restartCount = 0; } EM_Init(m_theInstances); } } if (report) { EM_Report(inst); } return llk; } // ============ // Test method. // ============ /** * Main method for testing this class. * * @param argv should contain the following arguments: <p> * -t training file [-T test file] [-N number of clusters] [-S random seed] */ public static void main (String[] argv) { runClusterer(new EM(), argv); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -