⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 em.java

📁 MacroWeka扩展了著名数据挖掘工具weka
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
    String[] options = new String[9];
    int current = 0;

    if (m_verbose) {
      options[current++] = "-V";
    }

    options[current++] = "-I";
    options[current++] = "" + m_max_iterations;
    options[current++] = "-N";
    options[current++] = "" + getNumClusters();
    options[current++] = "-S";
    options[current++] = "" + m_rseed;
    options[current++] = "-M";
    options[current++] = ""+getMinStdDev();

    while (current < options.length) {
      options[current++] = "";
    }

    return  options;
  }

  /**
   * Initialise estimators and storage.
   *
   * @param inst the instances
   **/
  private void EM_Init (Instances inst)
    throws Exception {
    int i, j, k;

    // run k means 10 times and choose best solution
    SimpleKMeans bestK = null;
    double bestSqE = Double.MAX_VALUE;
    for (i = 0; i < 10; i++) {
      SimpleKMeans sk = new SimpleKMeans();
      sk.setSeed(m_rr.nextInt());
      sk.setNumClusters(m_num_clusters);
      sk.buildClusterer(inst);
      if (sk.getSquaredError() < bestSqE) {
	bestSqE = sk.getSquaredError();
	bestK = sk;
      }
    }
    
    // initialize with best k-means solution
    m_num_clusters = bestK.numberOfClusters();
    m_weights = new double[inst.numInstances()][m_num_clusters];
    m_model = new DiscreteEstimator[m_num_clusters][m_num_attribs];
    m_modelNormal = new double[m_num_clusters][m_num_attribs][3];
    m_priors = new double[m_num_clusters];
    Instances centers = bestK.getClusterCentroids();
    Instances stdD = bestK.getClusterStandardDevs();
    int [][][] nominalCounts = bestK.getClusterNominalCounts();
    int [] clusterSizes = bestK.getClusterSizes();

    for (i = 0; i < m_num_clusters; i++) {
      Instance center = centers.instance(i);
      for (j = 0; j < m_num_attribs; j++) {
	if (inst.attribute(j).isNominal()) {
	  m_model[i][j] = new DiscreteEstimator(m_theInstances.
						attribute(j).numValues()
						, true);
	  for (k = 0; k < inst.attribute(j).numValues(); k++) {
	    m_model[i][j].addValue(k, nominalCounts[i][j][k]);
	  }
	} else {
	  double minStdD = (m_minStdDevPerAtt != null)
	    ? m_minStdDevPerAtt[j]
	    : m_minStdDev;
	  double mean = (center.isMissing(j))
	    ? inst.meanOrMode(j)
	    : center.value(j);
	  m_modelNormal[i][j][0] = mean;
	  double stdv = (stdD.instance(i).isMissing(j))
	    ? ((m_maxValues[j] - m_minValues[j]) / (2 * m_num_clusters))
	    : stdD.instance(i).value(j);
	  if (stdv < minStdD) {
	    stdv = inst.attributeStats(j).numericStats.stdDev;
            if (Double.isInfinite(stdv)) {
              stdv = minStdD;
            }
	    if (stdv < minStdD) {
	      stdv = minStdD;
	    }
	  }
	  if (stdv <= 0) {
	    stdv = m_minStdDev;
	  }

	  m_modelNormal[i][j][1] = stdv;
	  m_modelNormal[i][j][2] = 1.0;
	}
      } 
    }    
    
    
    for (j = 0; j < m_num_clusters; j++) {
      //      m_priors[j] += 1.0;
      m_priors[j] = clusterSizes[j];
    }
    Utils.normalize(m_priors);
  }


  /**
   * calculate prior probabilites for the clusters
   *
   * @param inst the instances
   * @exception Exception if priors can't be calculated
   **/
  private void estimate_priors (Instances inst)
    throws Exception {

    for (int i = 0; i < m_num_clusters; i++) {
      m_priors[i] = 0.0;
    }

    for (int i = 0; i < inst.numInstances(); i++) {
      for (int j = 0; j < m_num_clusters; j++) {
        m_priors[j] += inst.instance(i).weight() * m_weights[i][j];
      }
    }

    Utils.normalize(m_priors);
  }


  /** Constant for normal distribution. */
  private static double m_normConst = Math.log(Math.sqrt(2*Math.PI));

  /**
   * Density function of normal distribution.
   * @param x input value
   * @param mean mean of distribution
   * @param stdDev standard deviation of distribution
   */
  private double logNormalDens (double x, double mean, double stdDev) {

    double diff = x - mean;
    //    System.err.println("x: "+x+" mean: "+mean+" diff: "+diff+" stdv: "+stdDev);
    //    System.err.println("diff*diff/(2*stdv*stdv): "+ (diff * diff / (2 * stdDev * stdDev)));
    
    return - (diff * diff / (2 * stdDev * stdDev))  - m_normConst - Math.log(stdDev);
  }

  /**
   * New probability estimators for an iteration
   *
   * @param num_cl the numbe of clusters
   */
  private void new_estimators () {
    for (int i = 0; i < m_num_clusters; i++) {
      for (int j = 0; j < m_num_attribs; j++) {
        if (m_theInstances.attribute(j).isNominal()) {
          m_model[i][j] = new DiscreteEstimator(m_theInstances.
						attribute(j).numValues()
						, true);
        }
        else {
          m_modelNormal[i][j][0] = m_modelNormal[i][j][1] = 
	    m_modelNormal[i][j][2] = 0.0;
        }
      }
    }
  }


  /**
   * The M step of the EM algorithm.
   * @param inst the training instances
   */
  private void M (Instances inst)
    throws Exception {

    int i, j, l;

    new_estimators();

    for (i = 0; i < m_num_clusters; i++) {
      for (j = 0; j < m_num_attribs; j++) {
        for (l = 0; l < inst.numInstances(); l++) {
	  Instance in = inst.instance(l);
          if (!in.isMissing(j)) {
            if (inst.attribute(j).isNominal()) {
              m_model[i][j].addValue(in.value(j), 
				     in.weight() * m_weights[l][i]);
            }
            else {
              m_modelNormal[i][j][0] += (in.value(j) * in.weight() *
					 m_weights[l][i]);
              m_modelNormal[i][j][2] += in.weight() * m_weights[l][i];
              m_modelNormal[i][j][1] += (in.value(j) * 
					 in.value(j) * in.weight() * m_weights[l][i]);
            }
          }
        }
      }
    }
    
    // calcualte mean and std deviation for numeric attributes
    for (j = 0; j < m_num_attribs; j++) {
      if (!inst.attribute(j).isNominal()) {
        for (i = 0; i < m_num_clusters; i++) {
          if (m_modelNormal[i][j][2] <= 0) {
            m_modelNormal[i][j][1] = Double.MAX_VALUE;
	    //	    m_modelNormal[i][j][0] = 0;
	    m_modelNormal[i][j][0] = m_minStdDev;
          } else {
	      
	    // variance
	    m_modelNormal[i][j][1] = (m_modelNormal[i][j][1] - 
				      (m_modelNormal[i][j][0] * 
				       m_modelNormal[i][j][0] / 
				       m_modelNormal[i][j][2])) / 
	      (m_modelNormal[i][j][2]);
	    
	    if (m_modelNormal[i][j][1] < 0) {
	      m_modelNormal[i][j][1] = 0;
	    }
	    
	    // std dev      
	    double minStdD = (m_minStdDevPerAtt != null)
	    ? m_minStdDevPerAtt[j]
	    : m_minStdDev;

	    m_modelNormal[i][j][1] = Math.sqrt(m_modelNormal[i][j][1]);              

	    if ((m_modelNormal[i][j][1] <= minStdD)) {
	      m_modelNormal[i][j][1] = inst.attributeStats(j).numericStats.stdDev;
	      if ((m_modelNormal[i][j][1] <= minStdD)) {
		m_modelNormal[i][j][1] = minStdD;
	      }
	    }
	    if ((m_modelNormal[i][j][1] <= 0)) {
	      m_modelNormal[i][j][1] = m_minStdDev;
	    }
            if (Double.isInfinite(m_modelNormal[i][j][1])) {
              m_modelNormal[i][j][1] = m_minStdDev;
            }
	    
	    // mean
	    m_modelNormal[i][j][0] /= m_modelNormal[i][j][2];
	  }
        }
      }
    }
  }

  /**
   * The E step of the EM algorithm. Estimate cluster membership 
   * probabilities.
   *
   * @param inst the training instances
   * @return the average log likelihood
   */
  private double E (Instances inst, boolean change_weights)
    throws Exception {

    double loglk = 0.0, sOW = 0.0;

    for (int l = 0; l < inst.numInstances(); l++) {

      Instance in = inst.instance(l);

      loglk += in.weight() * logDensityForInstance(in);
      sOW += in.weight();

      if (change_weights) {
	m_weights[l] = distributionForInstance(in);
      }
    }
    
    // reestimate priors
    if (change_weights) {
      estimate_priors(inst);
    }
    return  loglk / sOW;
  }
  
  
  /**
   * Constructor.
   *
   **/
  public EM () {
    resetOptions();
  }


  /**
   * Reset to default options
   */
  protected void resetOptions () {
    m_minStdDev = 1e-6;
    m_max_iterations = 100;
    m_rseed = 100;
    m_num_clusters = -1;
    m_initialNumClusters = -1;
    m_verbose = false;
  }

  /**
   * Return the normal distributions for the cluster models
   *
   * @return a <code>double[][][]</code> value
   */
  public double [][][] getClusterModelsNumericAtts() {
    return m_modelNormal;
  }

  /**
   * Return the priors for the clusters
   *
   * @return a <code>double[]</code> value
   */
  public double [] getClusterPriors() {
    return m_priors;
  }

  /**
   * Outputs the generated clusters into a string.
   */
  public String toString () {
    if (m_priors == null) {
      return "No clusterer built yet!";
    }
    StringBuffer text = new StringBuffer();
    text.append("\nEM\n==\n");
    if (m_initialNumClusters == -1) {
      text.append("\nNumber of clusters selected by cross validation: "
		  +m_num_clusters+"\n");
    } else {
      text.append("\nNumber of clusters: " + m_num_clusters + "\n");
    }

    for (int j = 0; j < m_num_clusters; j++) {
      text.append("\nCluster: " + j + " Prior probability: " 
		  + Utils.doubleToString(m_priors[j], 4) + "\n\n");

      for (int i = 0; i < m_num_attribs; i++) {
        text.append("Attribute: " + m_theInstances.attribute(i).name() + "\n");

        if (m_theInstances.attribute(i).isNominal()) {
          if (m_model[j][i] != null) {
            text.append(m_model[j][i].toString());
          }
        }
        else {
          text.append("Normal Distribution. Mean = " 
		      + Utils.doubleToString(m_modelNormal[j][i][0], 4) 
		      + " StdDev = " 
		      + Utils.doubleToString(m_modelNormal[j][i][1], 4) 
		      + "\n");
        }
      }
    }

    return  text.toString();
  }


  /**
   * verbose output for debugging
   * @param inst the training instances
   */
  private void EM_Report (Instances inst) {
    int i, j, l, m;
    System.out.println("======================================");

    for (j = 0; j < m_num_clusters; j++) {
      for (i = 0; i < m_num_attribs; i++) {
	System.out.println("Clust: " + j + " att: " + i + "\n");

	if (m_theInstances.attribute(i).isNominal()) {
	  if (m_model[j][i] != null) {
	    System.out.println(m_model[j][i].toString());
	  }
	}
	else {
	  System.out.println("Normal Distribution. Mean = " 
			     + Utils.doubleToString(m_modelNormal[j][i][0]
						    , 8, 4) 
			     + " StandardDev = " 
			     + Utils.doubleToString(m_modelNormal[j][i][1]
						    , 8, 4) 
			     + " WeightSum = " 
			     + Utils.doubleToString(m_modelNormal[j][i][2]
						    , 8, 4));
	}
      }

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -