⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 em.java

📁 一个数据挖掘软件ALPHAMINERR的整个过程的JAVA版源代码
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
        }
        else {
          m_modelNormal[i][j][0] = m_modelNormal[i][j][1] = 
	    m_modelNormal[i][j][2] = 0.0;
        }
      }
    }
  }


  /**
   * The M step of the EM algorithm.
   * @param inst the training instances
   */
  private void M (Instances inst)
    throws Exception {

    int i, j, l;

    new_estimators();

    for (i = 0; i < m_num_clusters; i++) {
      for (j = 0; j < m_num_attribs; j++) {
        for (l = 0; l < inst.numInstances(); l++) {
	  Instance in = inst.instance(l);
          if (!in.isMissing(j)) {
            if (inst.attribute(j).isNominal()) {
              m_model[i][j].addValue(in.value(j), 
				     in.weight() * m_weights[l][i]);
            }
            else {
              m_modelNormal[i][j][0] += (in.value(j) * in.weight() *
					 m_weights[l][i]);
              m_modelNormal[i][j][2] += in.weight() * m_weights[l][i];
              m_modelNormal[i][j][1] += (in.value(j) * 
					 in.value(j) * in.weight() * m_weights[l][i]);
            }
          }
        }
      }
    }
    
    // calcualte mean and std deviation for numeric attributes
    for (j = 0; j < m_num_attribs; j++) {
      if (!inst.attribute(j).isNominal()) {
        for (i = 0; i < m_num_clusters; i++) {
          if (m_modelNormal[i][j][2] <= 0) {
            m_modelNormal[i][j][1] = Double.MAX_VALUE;
	    //	    m_modelNormal[i][j][0] = 0;
	    m_modelNormal[i][j][0] = m_minStdDev;
          } else {
	    if (m_modelNormal[i][j][2] > 0) {
	      
	      // variance
	      m_modelNormal[i][j][1] = (m_modelNormal[i][j][1] - 
					(m_modelNormal[i][j][0] * 
					 m_modelNormal[i][j][0] / 
					 m_modelNormal[i][j][2])) / 
		(m_modelNormal[i][j][2]);

	      if (m_modelNormal[i][j][1] < 0) {
		m_modelNormal[i][j][1] = 0;
	      }
	      
	      // std dev      
	      m_modelNormal[i][j][1] = Math.sqrt(m_modelNormal[i][j][1]); 

	      if ((m_modelNormal[i][j][1] <= m_minStdDev)) {
		m_modelNormal[i][j][1] = inst.attributeStats(j).numericStats.stdDev;
		if ((m_modelNormal[i][j][1] <= m_minStdDev)) {
		  m_modelNormal[i][j][1] = m_minStdDev;
		}
	      }
	    } else {
	      m_modelNormal[i][j][1] = m_minStdDev;
	    }
	    
	    // mean
	    m_modelNormal[i][j][0] /= m_modelNormal[i][j][2];
	  }
        }
      }
    }
  }

  /**
   * The E step of the EM algorithm. Estimate cluster membership 
   * probabilities.
   *
   * @param inst the training instances
   * @return the average log likelihood
   */
  private double E (Instances inst, boolean change_weights)
    throws Exception {

    double loglk = 0.0, sOW = 0.0;

    for (int l = 0; l < inst.numInstances(); l++) {

      Instance in = inst.instance(l);

      loglk += in.weight() * logDensityForInstance(in);
      sOW += in.weight();

      if (change_weights) {
	m_weights[l] = distributionForInstance(in);
      }
    }
    
    // reestimate priors
    if (change_weights) {
      estimate_priors(inst);
    }
    return  loglk / sOW;
  }
  
  
  /**
   * Constructor.
   *
   **/
  public EM () {
    resetOptions();
  }


  /**
   * Reset to default options
   */
  protected void resetOptions () {
    m_minStdDev = 1e-6;
    m_max_iterations = 100;
    m_rseed = 100;
    m_num_clusters = -1;
    m_initialNumClusters = -1;
    m_verbose = false;
  }

  /**
   * Return the normal distributions for the cluster models
   *
   * @return a <code>double[][][]</code> value
   */
  public double [][][] getClusterModelsNumericAtts() {
    return m_modelNormal;
  }

  /**
   * Return the priors for the clusters
   *
   * @return a <code>double[]</code> value
   */
  public double [] getClusterPriors() {
    return m_priors;
  }

  /**
   * Outputs the generated clusters into a string.
   */
  public String toString () {
    if (m_priors == null) {
      return "No clusterer built yet!";
    }
    StringBuffer text = new StringBuffer();
    text.append("\nEM\n==\n");
    if (m_initialNumClusters == -1) {
      text.append("\nNumber of clusters selected by cross validation: "
		  +m_num_clusters+"\n");
    } else {
      text.append("\nNumber of clusters: " + m_num_clusters + "\n");
    }

    for (int j = 0; j < m_num_clusters; j++) {
      text.append("\nCluster: " + j + " Prior probability: " 
		  + Utils.doubleToString(m_priors[j], 4) + "\n\n");

      for (int i = 0; i < m_num_attribs; i++) {
        text.append("Attribute: " + m_theInstances.attribute(i).name() + "\n");

        if (m_theInstances.attribute(i).isNominal()) {
          if (m_model[j][i] != null) {
            text.append(m_model[j][i].toString());
          }
        }
        else {
          text.append("Normal Distribution. Mean = " 
		      + Utils.doubleToString(m_modelNormal[j][i][0], 4) 
		      + " StdDev = " 
		      + Utils.doubleToString(m_modelNormal[j][i][1], 4) 
		      + "\n");
        }
      }
    }

    return  text.toString();
  }


  /**
   * verbose output for debugging
   * @param inst the training instances
   */
  private void EM_Report (Instances inst) {
    int i, j, l, m;
    System.out.println("======================================");

    for (j = 0; j < m_num_clusters; j++) {
      for (i = 0; i < m_num_attribs; i++) {
	System.out.println("Clust: " + j + " att: " + i + "\n");

	if (m_theInstances.attribute(i).isNominal()) {
	  if (m_model[j][i] != null) {
	    System.out.println(m_model[j][i].toString());
	  }
	}
	else {
	  System.out.println("Normal Distribution. Mean = " 
			     + Utils.doubleToString(m_modelNormal[j][i][0]
						    , 8, 4) 
			     + " StandardDev = " 
			     + Utils.doubleToString(m_modelNormal[j][i][1]
						    , 8, 4) 
			     + " WeightSum = " 
			     + Utils.doubleToString(m_modelNormal[j][i][2]
						    , 8, 4));
	}
      }
    }
    
    for (l = 0; l < inst.numInstances(); l++) {
      m = Utils.maxIndex(m_weights[l]);
      System.out.print("Inst " + Utils.doubleToString((double)l, 5, 0) 
		       + " Class " + m + "\t");
      for (j = 0; j < m_num_clusters; j++) {
	System.out.print(Utils.doubleToString(m_weights[l][j], 7, 5) + "  ");
      }
      System.out.println();
    }
  }


  /**
   * estimate the number of clusters by cross validation on the training
   * data.
   *
   */
  private void CVClusters ()
    throws Exception {
    double CVLogLikely = -Double.MAX_VALUE;
    double templl, tll;
    boolean CVincreased = true;
    m_num_clusters = 1;
    int i;
    Random cvr;
    Instances trainCopy;
    int numFolds = (m_theInstances.numInstances() < 10) 
      ? m_theInstances.numInstances() 
      : 10;

    while (CVincreased) {
      CVincreased = false;
      cvr = new Random(m_rseed);
      trainCopy = new Instances(m_theInstances);
      trainCopy.randomize(cvr);
      // theInstances.stratify(10);
      templl = 0.0;

      for (i = 0; i < numFolds; i++) {
	Instances cvTrain = trainCopy.trainCV(numFolds, i, cvr);
	Instances cvTest = trainCopy.testCV(numFolds, i);
	m_rr = new Random(m_rseed);
	EM_Init(cvTrain);
	iterate(cvTrain, false);
	tll = E(cvTest, false);

	if (m_verbose) {
	  System.out.println("# clust: " + m_num_clusters + " Fold: " + i 
			     + " Loglikely: " + tll);
	}
	templl += tll;
      }

      templl /= (double)numFolds;

      if (m_verbose) {
	System.out.println("===================================" 
			   + "==============\n# clust: " 
			   + m_num_clusters 
			   + " Mean Loglikely: " 
			   + templl 
			   + "\n================================" 
			   + "=================");
      }

      if (templl > CVLogLikely) {
	CVLogLikely = templl;
	CVincreased = true;
	m_num_clusters++;
      }
    }

    if (m_verbose) {
      System.out.println("Number of clusters: " + (m_num_clusters - 1));
    }

    m_num_clusters--;
  }


  /**
   * Returns the number of clusters.
   *
   * @return the number of clusters generated for a training dataset.
   * @exception Exception if number of clusters could not be returned
   * successfully
   */
  public int numberOfClusters ()
    throws Exception {
    if (m_num_clusters == -1) {
      throw  new Exception("Haven't generated any clusters!");
    }

    return  m_num_clusters;
  }

 /**
  * Updates the minimum and maximum values for all the attributes
  * based on a new instance.
  *
  * @param instance the new instance
  */
  private void updateMinMax(Instance instance) {
    
    for (int j = 0; j < m_theInstances.numAttributes(); j++) {
      if (!instance.isMissing(j)) {
	if (Double.isNaN(m_minValues[j])) {
	  m_minValues[j] = instance.value(j);
	  m_maxValues[j] = instance.value(j);
	} else {
	  if (instance.value(j) < m_minValues[j]) {
	    m_minValues[j] = instance.value(j);
	  } else {
	    if (instance.value(j) > m_maxValues[j]) {
	      m_maxValues[j] = instance.value(j);
	    }
	  }
	}
      }
    }
  }
  
  /**
   * Generates a clusterer. Has to initialize all fields of the clusterer
   * that are not being set via options.
   *
   * @param data set of instances serving as training data 
   * @exception Exception if the clusterer has not been 
   * generated successfully
   */
  public void buildClusterer (Instances data)
    throws Exception {
    if (data.checkForStringAttributes()) {
      throw  new Exception("Can't handle string attributes!");
    }
    
    m_theInstances = data;
    
    // calculate min and max values for attributes
    m_minValues = new double [m_theInstances.numAttributes()];
    m_maxValues = new double [m_theInstances.numAttributes()];
    for (int i = 0; i < m_theInstances.numAttributes(); i++) {
      m_minValues[i] = m_maxValues[i] = Double.NaN;
    }
    for (int i = 0; i < m_theInstances.numInstances(); i++) {
      updateMinMax(m_theInstances.instance(i));
    }

    doEM();
    
    // save memory
    m_theInstances = new Instances(m_theInstances,0);
  }

  /**
   * Returns the cluster priors.
   */
  public double[] clusterPriors() {

    double[] n = new double[m_priors.length];
  
    System.arraycopy(m_priors, 0, n, 0, n.length);
    return n;
  }

  /**
   * Computes the log of the conditional density (per cluster) for a given instance.
   * 
   * @param instance the instance to compute the density for
   * @return the density.
   * @return an array containing the estimated densities
   * @exception Exception if the density could not be computed
   * successfully
   */
  public double[] logDensityPerClusterForInstance(Instance inst) throws Exception {

    int i, j;
    double logprob;
    double[] wghts = new double[m_num_clusters];

    for (i = 0; i < m_num_clusters; i++) {
      //      System.err.println("Cluster : "+i);
      logprob = 0.0;

      for (j = 0; j < m_num_attribs; j++) {
	if (!inst.isMissing(j)) {
	  if (inst.attribute(j).isNominal()) {
	    logprob += Math.log(m_model[i][j].getProbability(inst.value(j)));
	  }
	  else { // numeric attribute
	    logprob += logNormalDens(inst.value(j), 
				     m_modelNormal[i][j][0], 
				     m_modelNormal[i][j][1]);
	    /*	    System.err.println(logNormalDens(inst.value(j), 
				     m_modelNormal[i][j][0], 
				     m_modelNormal[i][j][1]) + " "); */
	  }
	}
      }
      //      System.err.println("");

      wghts[i] = logprob;
    }
    return  wghts;
  }


  /**
   * Perform the EM algorithm
   */
  private void doEM ()
    throws Exception {
    if (m_verbose) {
      System.out.println("Seed: " + m_rseed);
    }

    m_rr = new Random(m_rseed);

    // throw away numbers to avoid problem of similar initial numbers
    // from a similar seed
    for (int i=0; i<10; i++) m_rr.nextDouble();

    m_num_instances = m_theInstances.numInstances();
    m_num_attribs = m_theInstances.numAttributes();

    if (m_verbose) {
      System.out.println("Number of instances: " 
			 + m_num_instances 
			 + "\nNumber of atts: " 
			 + m_num_attribs 
			 + "\n");
    }

    // setDefaultStdDevs(theInstances);
    // cross validate to determine number of clusters?
    if (m_initialNumClusters == -1) {
      if (m_theInstances.numInstances() > 9) {
	CVClusters();
	m_rr = new Random(m_rseed);
	for (int i=0; i<10; i++) m_rr.nextDouble();
      } else {
	m_num_clusters = 1;
      }
    }

    // fit full training set
    EM_Init(m_theInstances);
    m_loglikely = iterate(m_theInstances, m_verbose);
  }


  /**
   * iterates the E and M steps until the log likelihood of the data
   * converges.
   *
   * @param inst the training instances.
   * @param num_cl the number of clusters.
   * @param report be verbose.
   * @return the log likelihood of the data
   */
  private double iterate (Instances inst, boolean report)
    throws Exception {
    int i;
    double llkold = 0.0;
    double llk = 0.0;

    if (report) {
      EM_Report(inst);
    }

    for (i = 0; i < m_max_iterations; i++) {
      llkold = llk;
      llk = E(inst, true);

      if (report) {
	System.out.println("Loglikely: " + llk);
      }

      if (i > 0) {
	if ((llk - llkold) < 1e-6) {
	  break;
	}
      }
      M(inst);
    }

    if (report) {
      EM_Report(inst);
    }

    return  llk;
  }


  // ============
  // Test method.
  // ============
  /**
   * Main method for testing this class.
   *
   * @param argv should contain the following arguments: <p>
   * -t training file [-T test file] [-N number of clusters] [-S random seed]
   */
  public static void main (String[] argv) {
    try {
      System.out.println(ClusterEvaluation.
			 evaluateClusterer(new EM(), argv));
    }
    catch (Exception e) {
      System.out.println(e.getMessage());
      e.printStackTrace();
    }
  }
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -