📄 em.java
字号:
double currentVal, deltaSum = 0; int distinct = 0; for (int j = 1; j < inst.numInstances(); j++) { Instance currentInst = inst.instance(j); if (currentInst.isMissing(i)) { break; } currentVal = currentInst.value(i); if (currentVal != lastVal) { deltaSum += currentVal - lastVal; lastVal = currentVal; distinct++; } } if (distinct > 0) { m_defSds[i] = deltaSum/distinct; } } } } } /** * Initialised estimators and storage. * * @param inst the instances * @param num_cl the number of clusters **/ private void EM_Init (Instances inst, int num_cl) throws Exception { m_weights = new double[inst.numInstances()][num_cl]; int z; m_model = new Estimator[num_cl][m_num_attribs]; m_modelNormal = new double[num_cl][m_num_attribs][3]; m_priors = new double[num_cl]; for (int i = 0; i < inst.numInstances(); i++) { for (int j = 0; j < num_cl; j++) { m_weights[i][j] = m_rr.nextDouble(); } Utils.normalize(m_weights[i]); } // initial priors estimate_priors(inst, num_cl); } /** * calculate prior probabilites for the clusters * * @param inst the instances * @param num_cl the number of clusters * @exception Exception if priors can't be calculated **/ private void estimate_priors (Instances inst, int num_cl) throws Exception { for (int i = 0; i < num_cl; i++) { m_priors[i] = 0.0; } for (int i = 0; i < inst.numInstances(); i++) { for (int j = 0; j < num_cl; j++) { m_priors[j] += m_weights[i][j]; } } Utils.normalize(m_priors); } /** * Density function of normal distribution. * @param x input value * @param mean mean of distribution * @param stdDev standard deviation of distribution */ private double normalDens (double x, double mean, double stdDev) { double diff = x - mean; return (1/(m_normConst*stdDev))*Math.exp(-(diff*diff/(2*stdDev*stdDev))); } /** * Added by Waleed Kadous. This is a more accurate representation. */ private double normalIns(double x, double mean, double stdDev) { float diff = (float) Math.abs((x - mean)/stdDev); double retval = 2*FastMath.normalCDF(-diff); // System.out.println("Returning " + retval + " for x = " + x + " mean = " + mean + " sd = " + stdDev); return retval; } /** * New probability estimators for an iteration * * @param num_cl the numbe of clusters */ private void new_estimators (int num_cl) { for (int i = 0; i < num_cl; i++) { for (int j = 0; j < m_num_attribs; j++) { if (m_theInstances.attribute(j).isNominal()) { m_model[i][j] = new DiscreteEstimator(m_theInstances. attribute(j).numValues() , true); } else { m_modelNormal[i][j][0] = m_modelNormal[i][j][1] = m_modelNormal[i][j][2] = 0.0; } } } } /** * The M step of the EM algorithm. * @param inst the training instances * @param num_cl the number of clusters */ private void M (Instances inst, int num_cl) throws Exception { int i, j, l; new_estimators(num_cl); for (i = 0; i < num_cl; i++) { for (j = 0; j < m_num_attribs; j++) { for (l = 0; l < inst.numInstances(); l++) { if (!inst.instance(l).isMissing(j)) { if (inst.attribute(j).isNominal()) { m_model[i][j].addValue(inst.instance(l).value(j), m_weights[l][i]); } else { m_modelNormal[i][j][0] += (inst.instance(l).value(j) * m_weights[l][i]); m_modelNormal[i][j][2] += m_weights[l][i]; m_modelNormal[i][j][1] += (inst.instance(l).value(j) * inst.instance(l).value(j)*m_weights[l][i]); } } } } } // calcualte mean and std deviation for numeric attributes for (j = 0; j < m_num_attribs; j++) { if (!inst.attribute(j).isNominal()) { for (i = 0; i < num_cl; i++) { if (Utils.smOrEq(m_modelNormal[i][j][2], 1)) { m_modelNormal[i][j][1] = /* m_defSds[j] / (2 * 3);*/ 1e-6; } else { // variance m_modelNormal[i][j][1] = (m_modelNormal[i][j][1] - (m_modelNormal[i][j][0] * m_modelNormal[i][j][0] / m_modelNormal[i][j][2])) / (m_modelNormal[i][j][2] - 1); } // std dev if (m_modelNormal[i][j][1] <= 0.0) { m_modelNormal[i][j][1] = /* m_defSds[j] / (2 * 3);*/ 1e-6; } m_modelNormal[i][j][1] = Math.sqrt(m_modelNormal[i][j][1]); // mean if (m_modelNormal[i][j][2] > 0.0) { m_modelNormal[i][j][0] /= m_modelNormal[i][j][2]; } } } } } /** * The E step of the EM algorithm. Estimate cluster membership * probabilities. * * @param inst the training instances * @param num_cl the number of clusters * @return the average log likelihood */ private double E (Instances inst, int num_cl) throws Exception { int i, j, l; double prob; double loglk = 0.0; for (l = 0; l < inst.numInstances(); l++) { for (i = 0; i < num_cl; i++) { prob = 1.0; for (j = 0; j < m_num_attribs; j++) { if (!inst.instance(l).isMissing(j)) { if (inst.attribute(j).isNominal()) { prob *= m_model[i][j].getProbability(inst.instance(l).value(j)); } else { // numeric attribute prob *= normalDens(inst.instance(l).value(j), m_modelNormal[i][j][0], m_modelNormal[i][j][1]); } } } m_weights[l][i] = (prob*m_priors[i]); } double temp1 = 0; for (i = 0; i < num_cl; i++) { temp1 += m_weights[l][i]; } if (temp1 > 0) { loglk += Math.log(temp1); } // normalise the weights for this instance Utils.normalize(m_weights[l]); } // reestimate priors estimate_priors(inst, num_cl); return loglk/inst.numInstances(); } /** * Constructor. * **/ public EM () { resetOptions(); } /** * Reset to default options */ protected void resetOptions () { m_max_iterations = 100; m_rseed = 100; m_num_clusters = -1; m_initialNumClusters = -1; m_verbose = false; } /** * Outputs the generated clusters into a string. */ public String toString () { StringBuffer text = new StringBuffer(); text.append("\nEM\n==\n"); if (m_initialNumClusters == -1) { text.append("\nNumber of clusters selected by cross validation: " +m_num_clusters+"\n"); } else { text.append("\nNumber of clusters: " + m_num_clusters + "\n"); } for (int j = 0; j < m_num_clusters; j++) { text.append("\nCluster: " + j + " Prior probability: " + Utils.doubleToString(m_priors[j], 6, 4) + "\n\n"); for (int i = 0; i < m_num_attribs; i++) { text.append("Attribute: " + m_theInstances.attribute(i).name() + "\n"); if (m_theInstances.attribute(i).isNominal()) { if (m_model[j][i] != null) { text.append(m_model[j][i].toString()); } } else { text.append("Normal Distribution. Mean = " + Utils.doubleToString(m_modelNormal[j][i][0], 8, 4) + " StdDev = " + Utils.doubleToString(m_modelNormal[j][i][1], 8, 4) + "\n"); } } } return text.toString(); } public String describeCluster(int cluster){ StringBuffer text = new StringBuffer(); text.append("\nCluster: " + cluster + " Prior probability: " + Utils.doubleToString(m_priors[cluster], 6, 4) + "\n"); for (int i = 0; i < m_num_attribs; i++) { text.append("Attribute: " + m_theInstances.attribute(i).name()); if (m_theInstances.attribute(i).isNominal()) { if (m_model[cluster][i] != null) { text.append(m_model[cluster][i].toString()); } } else { text.append(" mean = " + Utils.doubleToString(m_modelNormal[cluster][i][0], 8, 4) + " sd = " + Utils.doubleToString(m_modelNormal[cluster][i][1], 8, 4) + "\n"); } } return text.toString(); } /** * verbose output for debugging * @param inst the training instances */ private void EM_Report (Instances inst) { int i, j, l, m; System.out.println("======================================"); for (j = 0; j < m_num_clusters; j++) { for (i = 0; i < m_num_attribs; i++) {
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -