📄 em.java
字号:
/** * Gets the current settings of EM. * * @return an array of strings suitable for passing to setOptions() */ public String[] getOptions () { int i; Vector result; String[] options; result = new Vector(); result.add("-I"); result.add("" + m_max_iterations); result.add("-N"); result.add("" + getNumClusters()); result.add("-M"); result.add("" + getMinStdDev()); options = super.getOptions(); for (i = 0; i < options.length; i++) result.add(options[i]); return (String[]) result.toArray(new String[result.size()]); } /** * Initialise estimators and storage. * * @param inst the instances * @throws Exception if initialization fails **/ private void EM_Init (Instances inst) throws Exception { int i, j, k; // run k means 10 times and choose best solution SimpleKMeans bestK = null; double bestSqE = Double.MAX_VALUE; for (i = 0; i < 10; i++) { SimpleKMeans sk = new SimpleKMeans(); sk.setSeed(m_rr.nextInt()); sk.setNumClusters(m_num_clusters); sk.buildClusterer(inst); if (sk.getSquaredError() < bestSqE) { bestSqE = sk.getSquaredError(); bestK = sk; } } // initialize with best k-means solution m_num_clusters = bestK.numberOfClusters(); m_weights = new double[inst.numInstances()][m_num_clusters]; m_model = new DiscreteEstimator[m_num_clusters][m_num_attribs]; m_modelNormal = new double[m_num_clusters][m_num_attribs][3]; m_priors = new double[m_num_clusters]; Instances centers = bestK.getClusterCentroids(); Instances stdD = bestK.getClusterStandardDevs(); int [][][] nominalCounts = bestK.getClusterNominalCounts(); int [] clusterSizes = bestK.getClusterSizes(); for (i = 0; i < m_num_clusters; i++) { Instance center = centers.instance(i); for (j = 0; j < m_num_attribs; j++) { if (inst.attribute(j).isNominal()) { m_model[i][j] = new DiscreteEstimator(m_theInstances. attribute(j).numValues() , true); for (k = 0; k < inst.attribute(j).numValues(); k++) { m_model[i][j].addValue(k, nominalCounts[i][j][k]); } } else { double minStdD = (m_minStdDevPerAtt != null) ? m_minStdDevPerAtt[j] : m_minStdDev; double mean = (center.isMissing(j)) ? inst.meanOrMode(j) : center.value(j); m_modelNormal[i][j][0] = mean; double stdv = (stdD.instance(i).isMissing(j)) ? ((m_maxValues[j] - m_minValues[j]) / (2 * m_num_clusters)) : stdD.instance(i).value(j); if (stdv < minStdD) { stdv = inst.attributeStats(j).numericStats.stdDev; if (Double.isInfinite(stdv)) { stdv = minStdD; } if (stdv < minStdD) { stdv = minStdD; } } if (stdv <= 0) { stdv = m_minStdDev; } m_modelNormal[i][j][1] = stdv; m_modelNormal[i][j][2] = 1.0; } } } for (j = 0; j < m_num_clusters; j++) { // m_priors[j] += 1.0; m_priors[j] = clusterSizes[j]; } Utils.normalize(m_priors); } /** * calculate prior probabilites for the clusters * * @param inst the instances * @throws Exception if priors can't be calculated **/ private void estimate_priors (Instances inst) throws Exception { for (int i = 0; i < m_num_clusters; i++) { m_priors[i] = 0.0; } for (int i = 0; i < inst.numInstances(); i++) { for (int j = 0; j < m_num_clusters; j++) { m_priors[j] += inst.instance(i).weight() * m_weights[i][j]; } } Utils.normalize(m_priors); } /** Constant for normal distribution. */ private static double m_normConst = Math.log(Math.sqrt(2*Math.PI)); /** * Density function of normal distribution. * @param x input value * @param mean mean of distribution * @param stdDev standard deviation of distribution * @return the density */ private double logNormalDens (double x, double mean, double stdDev) { double diff = x - mean; // System.err.println("x: "+x+" mean: "+mean+" diff: "+diff+" stdv: "+stdDev); // System.err.println("diff*diff/(2*stdv*stdv): "+ (diff * diff / (2 * stdDev * stdDev))); return - (diff * diff / (2 * stdDev * stdDev)) - m_normConst - Math.log(stdDev); } /** * New probability estimators for an iteration */ private void new_estimators () { for (int i = 0; i < m_num_clusters; i++) { for (int j = 0; j < m_num_attribs; j++) { if (m_theInstances.attribute(j).isNominal()) { m_model[i][j] = new DiscreteEstimator(m_theInstances. attribute(j).numValues() , true); } else { m_modelNormal[i][j][0] = m_modelNormal[i][j][1] = m_modelNormal[i][j][2] = 0.0; } } } } /** * The M step of the EM algorithm. * @param inst the training instances * @throws Exception if something goes wrong */ private void M (Instances inst) throws Exception { int i, j, l; new_estimators(); for (i = 0; i < m_num_clusters; i++) { for (j = 0; j < m_num_attribs; j++) { for (l = 0; l < inst.numInstances(); l++) { Instance in = inst.instance(l); if (!in.isMissing(j)) { if (inst.attribute(j).isNominal()) { m_model[i][j].addValue(in.value(j), in.weight() * m_weights[l][i]); } else { m_modelNormal[i][j][0] += (in.value(j) * in.weight() * m_weights[l][i]); m_modelNormal[i][j][2] += in.weight() * m_weights[l][i]; m_modelNormal[i][j][1] += (in.value(j) * in.value(j) * in.weight() * m_weights[l][i]); } } } } } // calcualte mean and std deviation for numeric attributes for (j = 0; j < m_num_attribs; j++) { if (!inst.attribute(j).isNominal()) { for (i = 0; i < m_num_clusters; i++) { if (m_modelNormal[i][j][2] <= 0) { m_modelNormal[i][j][1] = Double.MAX_VALUE; // m_modelNormal[i][j][0] = 0; m_modelNormal[i][j][0] = m_minStdDev; } else { // variance m_modelNormal[i][j][1] = (m_modelNormal[i][j][1] - (m_modelNormal[i][j][0] * m_modelNormal[i][j][0] / m_modelNormal[i][j][2])) / (m_modelNormal[i][j][2]); if (m_modelNormal[i][j][1] < 0) { m_modelNormal[i][j][1] = 0; } // std dev double minStdD = (m_minStdDevPerAtt != null) ? m_minStdDevPerAtt[j] : m_minStdDev; m_modelNormal[i][j][1] = Math.sqrt(m_modelNormal[i][j][1]); if ((m_modelNormal[i][j][1] <= minStdD)) { m_modelNormal[i][j][1] = inst.attributeStats(j).numericStats.stdDev; if ((m_modelNormal[i][j][1] <= minStdD)) { m_modelNormal[i][j][1] = minStdD; } } if ((m_modelNormal[i][j][1] <= 0)) { m_modelNormal[i][j][1] = m_minStdDev; } if (Double.isInfinite(m_modelNormal[i][j][1])) { m_modelNormal[i][j][1] = m_minStdDev; } // mean m_modelNormal[i][j][0] /= m_modelNormal[i][j][2]; } } } } } /** * The E step of the EM algorithm. Estimate cluster membership * probabilities. * * @param inst the training instances * @param change_weights whether to change the weights * @return the average log likelihood * @throws Exception if computation fails */ private double E (Instances inst, boolean change_weights) throws Exception { double loglk = 0.0, sOW = 0.0; for (int l = 0; l < inst.numInstances(); l++) { Instance in = inst.instance(l); loglk += in.weight() * logDensityForInstance(in); sOW += in.weight(); if (change_weights) { m_weights[l] = distributionForInstance(in); } } // reestimate priors if (change_weights) { estimate_priors(inst); } return loglk / sOW; } /** * Constructor. * **/ public EM () { super(); m_SeedDefault = 100; resetOptions(); } /** * Reset to default options */ protected void resetOptions () { m_minStdDev = 1e-6; m_max_iterations = 100; m_Seed = m_SeedDefault; m_num_clusters = -1; m_initialNumClusters = -1; m_verbose = false; } /** * Return the normal distributions for the cluster models * * @return a <code>double[][][]</code> value */ public double [][][] getClusterModelsNumericAtts() { return m_modelNormal; } /** * Return the priors for the clusters * * @return a <code>double[]</code> value */ public double [] getClusterPriors() { return m_priors; } /** * Outputs the generated clusters into a string. * * @return the clusterer in string representation */ public String toString () { if (m_priors == null) { return "No clusterer built yet!"; } StringBuffer text = new StringBuffer(); text.append("\nEM\n==\n"); if (m_initialNumClusters == -1) { text.append("\nNumber of clusters selected by cross validation: " +m_num_clusters+"\n"); } else { text.append("\nNumber of clusters: " + m_num_clusters + "\n"); } for (int j = 0; j < m_num_clusters; j++) { text.append("\nCluster: " + j + " Prior probability: " + Utils.doubleToString(m_priors[j], 4) + "\n\n"); for (int i = 0; i < m_num_attribs; i++) { text.append("Attribute: " + m_theInstances.attribute(i).name() + "\n"); if (m_theInstances.attribute(i).isNominal()) { if (m_model[j][i] != null) { text.append(m_model[j][i].toString()); } } else { text.append("Normal Distribution. Mean = " + Utils.doubleToString(m_modelNormal[j][i][0], 4) + " StdDev = " + Utils.doubleToString(m_modelNormal[j][i][1], 4) + "\n"); } } } return text.toString(); } /** * verbose output for debugging * @param inst the training instances */ private void EM_Report (Instances inst) { int i, j, l, m; System.out.println("======================================"); for (j = 0; j < m_num_clusters; j++) { for (i = 0; i < m_num_attribs; i++) { System.out.println("Clust: " + j + " att: " + i + "\n"); if (m_theInstances.attribute(i).isNominal()) { if (m_model[j][i] != null) { System.out.println(m_model[j][i].toString()); } } else { System.out.println("Normal Distribution. Mean = " + Utils.doubleToString(m_modelNormal[j][i][0] , 8, 4) + " StandardDev = " + Utils.doubleToString(m_modelNormal[j][i][1] , 8, 4) + " WeightSum = " + Utils.doubleToString(m_modelNormal[j][i][2] , 8, 4)); } }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -