⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 em.java

📁 一个数据挖掘系统的源码
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
            if (inst.attribute(j).isNominal()) {
              m_model[i][j].addValue(inst.instance(l).value(j),
				     m_weights[l][i]);
            }
            else {
              m_modelNormal[i][j][0] += (inst.instance(l).value(j) *
					 m_weights[l][i]);
              m_modelNormal[i][j][2] += m_weights[l][i];
              m_modelNormal[i][j][1] += (inst.instance(l).value(j) *
					 inst.instance(l).value(j)*m_weights[l][i]);
            }
          }
        }
      }
    }

       // calcualte mean and std deviation for numeric attributes
    for (j = 0; j < m_num_attribs; j++) {
      if (!inst.attribute(j).isNominal()) {
        for (i = 0; i < num_cl; i++) {
          if (m_modelNormal[i][j][2] < 0) {
            m_modelNormal[i][j][1] = 0;
          } else {

	  // variance
	    m_modelNormal[i][j][1] = (m_modelNormal[i][j][1] -
				      (m_modelNormal[i][j][0] *
				       m_modelNormal[i][j][0] /
				       m_modelNormal[i][j][2])) /
	      m_modelNormal[i][j][2];

	    // std dev
	    m_modelNormal[i][j][1] = Math.sqrt(m_modelNormal[i][j][1]);

	    if (m_modelNormal[i][j][1] <= m_minStdDev
		|| Double.isNaN(m_modelNormal[i][j][1])) {
	      m_modelNormal[i][j][1] =
		m_minStdDev;
	    }

	    // mean
	    if (m_modelNormal[i][j][2] > 0.0) {
	      m_modelNormal[i][j][0] /= m_modelNormal[i][j][2];
	    }
	  }
        }
      }
    }
  }


  /**
   * The E step of the EM algorithm. Estimate cluster membership
   * probabilities.
   *
   * @param inst the training instances
   * @param num_cl the number of clusters
   * @return the average log likelihood
   */
  private double E (Instances inst, int num_cl)
    throws Exception
  {
    int i, j, l;
    double prob;
    double loglk = 0.0;

    for (l = 0; l < inst.numInstances(); l++) {
      for (i = 0; i < num_cl; i++) {
	m_weights[l][i] = m_priors[i];
      }
      for (j = 0; j < m_num_attribs; j++) {
	double max = 0;
	for (i = 0; i < num_cl; i++) {

          if (!inst.instance(l).isMissing(j)) {
            if (inst.attribute(j).isNominal()) {
              m_weights[l][i] *=
		m_model[i][j].getProbability(inst.instance(l).value(j));

            }
            else {
              // numeric attribute
              m_weights[l][i] *= normalDens(inst.instance(l).value(j),
					    m_modelNormal[i][j][0],
					    m_modelNormal[i][j][1]);
	      if (Double.isInfinite(m_weights[l][i])) {
		throw new Exception("Joint density has overflowed. Try "
				    +"increasing the minimum allowable "
				    +"standard deviation for normal "
				    +"density calculation.");
	      }
            }
	    if (m_weights[l][i] > max) {
	      max = m_weights[l][i];
	    }
          }
        }
	if (max > 0 && max < 1e-75) { // check for underflow
	  for (int zz = 0; zz < num_cl; zz++) {
	    // rescale
	    m_weights[l][zz] *= 1e75;
	  }
	}
      }

      double temp1 = 0;

      for (i = 0; i < num_cl; i++) {
        temp1 += m_weights[l][i];
      }

      if (temp1 > 0) {
        loglk += Math.log(temp1);
      }

      // normalise the weights for this instance
      try {
	Utils.normalize(m_weights[l]);
      } catch (Exception e) {
	throw new Exception("An instance has zero cluster memberships. Try "
			    +"increasing the minimum allowable "
			    +"standard deviation for normal "
			    +"density calculation.");
      }
    }

    // reestimate priors
    estimate_priors(inst, num_cl);
    return  loglk/inst.numInstances();
  }


  /**
   * Constructor.
   *
   **/
  public EM () {
    resetOptions();
  }


  /**
   * Reset to default options
   */
  protected void resetOptions () {
    m_minStdDev = 1e-6;
    m_max_iterations = 100;
    m_rseed = 100;
    m_num_clusters = -1;
    m_initialNumClusters = -1;
    m_verbose = false;
  }


  /**
   * Outputs the generated clusters into a string.
   */
  public String toString () {
    StringBuffer text = new StringBuffer();
    text.append("\nEM\n==\n");
    if (m_initialNumClusters == -1) {
      text.append("\nNumber of clusters selected by cross validation: "
		  +m_num_clusters+"\n");
    } else {
      text.append("\nNumber of clusters: " + m_num_clusters + "\n");
    }

    for (int j = 0; j < m_num_clusters; j++) {
      text.append("\nCluster: " + j + " Prior probability: "
		  + Utils.doubleToString(m_priors[j], 4) + "\n\n");

      for (int i = 0; i < m_num_attribs; i++) {
        text.append("Attribute: " + m_theInstances.attribute(i).name() + "\n");

        if (m_theInstances.attribute(i).isNominal()) {
          if (m_model[j][i] != null) {
            text.append(m_model[j][i].toString());
          }
        }
        else {
          text.append("Normal Distribution. Mean = "
		      + Utils.doubleToString(m_modelNormal[j][i][0], 4)
		      + " StdDev = "
		      + Utils.doubleToString(m_modelNormal[j][i][1], 4)
		      + "\n");
        }
      }
    }

    return  text.toString();
  }


  /**
   * verbose output for debugging
   * @param inst the training instances
   */
  private void EM_Report (Instances inst) {
    int i, j, l, m;
    System.out.println("======================================");

    for (j = 0; j < m_num_clusters; j++) {
      for (i = 0; i < m_num_attribs; i++) {
	System.out.println("Clust: " + j + " att: " + i + "\n");

	if (m_theInstances.attribute(i).isNominal()) {
	  if (m_model[j][i] != null) {
	    System.out.println(m_model[j][i].toString());
	  }
	}
	else {
	  System.out.println("Normal Distribution. Mean = "
			     + Utils.doubleToString(m_modelNormal[j][i][0]
						    , 8, 4)
			     + " StandardDev = "
			     + Utils.doubleToString(m_modelNormal[j][i][1]
						    , 8, 4)
			     + " WeightSum = "
			     + Utils.doubleToString(m_modelNormal[j][i][2]
						    , 8, 4));
	}
      }
    }

    for (l = 0; l < inst.numInstances(); l++) {
      m = Utils.maxIndex(m_weights[l]);
      System.out.print("Inst " + Utils.doubleToString((double)l, 5, 0)
		       + " Class " + m + "\t");
      for (j = 0; j < m_num_clusters; j++) {
	System.out.print(Utils.doubleToString(m_weights[l][j], 7, 5) + "  ");
      }
      System.out.println();
    }
  }


  /**
   * estimate the number of clusters by cross validation on the training
   * data.
   *
   * @return the number of clusters selected
   */
  private int CVClusters ()
    throws Exception
  {
    double CVLogLikely = -Double.MAX_VALUE;
    double templl, tll;
    boolean CVdecreased = true;
    int num_cl = 1;
    int i;
    Random cvr;
    Instances trainCopy;
    int numFolds = (m_theInstances.numInstances() < 10)
      ? m_theInstances.numInstances()
      : 10;

    while (CVdecreased) {
      CVdecreased = false;
      cvr = new Random(m_rseed);
      trainCopy = new Instances(m_theInstances);
      trainCopy.randomize(cvr);
      // theInstances.stratify(10);
      templl = 0.0;

      for (i = 0; i < numFolds; i++) {
	Instances cvTrain = trainCopy.trainCV(numFolds, i);
	Instances cvTest = trainCopy.testCV(numFolds, i);
	EM_Init(cvTrain, num_cl);
	iterate(cvTrain, num_cl, false);
	tll = E(cvTest, num_cl);

	if (m_verbose) {
	  System.out.println("# clust: " + num_cl + " Fold: " + i
			     + " Loglikely: " + tll);
	}

	templl += tll;
      }

      templl /= (double)numFolds;

      if (m_verbose) {
	System.out.println("==================================="
			   + "==============\n# clust: "
			   + num_cl
			   + " Mean Loglikely: "
			   + templl
			   + "\n================================"
			   + "=================");
      }

      if (templl > CVLogLikely) {
	CVLogLikely = templl;
	CVdecreased = true;
	num_cl++;
      }
    }

    if (m_verbose) {
      System.out.println("Number of clusters: " + (num_cl - 1));
    }

    return  num_cl - 1;
  }


  /**
   * Returns the number of clusters.
   *
   * @return the number of clusters generated for a training dataset.
   * @exception Exception if number of clusters could not be returned
   * successfully
   */
  public int numberOfClusters ()
    throws Exception
  {
    if (m_num_clusters == -1) {
      throw  new Exception("Haven't generated any clusters!");
    }

    return  m_num_clusters;
  }


  /**
   * Generates a clusterer. Has to initialize all fields of the clusterer
   * that are not being set via options.
   *
   * @param data set of instances serving as training data
   * @exception Exception if the clusterer has not been
   * generated successfully
   */
  public void buildClusterer (Instances data)
    throws Exception {
    if (data.checkForStringAttributes()) {
      throw  new Exception("Can't handle string attributes!");
    }

    m_theInstances = data;
    doEM();

    // save memory
    m_theInstances = new Instances(m_theInstances,0);
  }

  /**
   * Computes the density for a given instance.
   *
   * @param inst the instance to compute the density for
   * @return the density.
   * @exception Exception if the density could not be computed
   * successfully
   */
  public double densityForInstance(Instance inst) throws Exception {
    return Utils.sum(weightsForInstance(inst));
  }

  /**
   * Predicts the cluster memberships for a given instance.
   *
   * @param data set of test instances
   * @param instance the instance to be assigned a cluster.
   * @return an array containing the estimated membership
   * probabilities of the test instance in each cluster (this
   * should sum to at most 1)
   * @exception Exception if distribution could not be
   * computed successfully
   */
  public double[] distributionForInstance (Instance inst)
    throws Exception {
    double [] distrib = weightsForInstance(inst);
    Utils.normalize(distrib);
    return distrib;
  }

  /**
   * Returns the weights (indicating cluster membership) for a given instance
   *
   * @param inst the instance to be assigned a cluster
   * @return an array of weights
   * @exception Exception if weights could not be computed
   */
  protected double[] weightsForInstance(Instance inst)
    throws Exception {

    int i, j;
    double prob;
    double[] wghts = new double[m_num_clusters];

    for (i = 0; i < m_num_clusters; i++) {
      prob = 1.0;

      for (j = 0; j < m_num_attribs; j++) {
	if (!inst.isMissing(j)) {
	  if (inst.attribute(j).isNominal()) {
	    prob *= m_model[i][j].getProbability(inst.value(j));
	  }
	  else { // numeric attribute
	    prob *= normalDens(inst.value(j),
			       m_modelNormal[i][j][0],
			       m_modelNormal[i][j][1]);
	  }
	}
      }

      wghts[i] = (prob*m_priors[i]);
    }

    return  wghts;
  }


  /**
   * Perform the EM algorithm
   */
  private void doEM ()
    throws Exception
  {
    if (m_verbose) {
      System.out.println("Seed: " + m_rseed);
    }

    m_rr = new Random(m_rseed);
    m_num_instances = m_theInstances.numInstances();
    m_num_attribs = m_theInstances.numAttributes();

    if (m_verbose) {
      System.out.println("Number of instances: "
			 + m_num_instances
			 + "\nNumber of atts: "
			 + m_num_attribs
			 + "\n");
    }

    // setDefaultStdDevs(theInstances);
    // cross validate to determine number of clusters?
    if (m_initialNumClusters == -1) {
      if (m_theInstances.numInstances() > 9) {
	m_num_clusters = CVClusters();
      } else {
	m_num_clusters = 1;
      }
    }

    // fit full training set
    EM_Init(m_theInstances, m_num_clusters);
    m_loglikely = iterate(m_theInstances, m_num_clusters, m_verbose);
  }


  /**
   * iterates the M and E steps until the log likelihood of the data
   * converges.
   *
   * @param inst the training instances.
   * @param num_cl the number of clusters.
   * @param report be verbose.
   * @return the log likelihood of the data
   */
  private double iterate (Instances inst, int num_cl, boolean report)
    throws Exception
  {
    int i;
    double llkold = 0.0;
    double llk = 0.0;

    if (report) {
      EM_Report(inst);
    }

    for (i = 0; i < m_max_iterations; i++) {
      M(inst, num_cl);
      llkold = llk;
      llk = E(inst, num_cl);

      if (report) {
	System.out.println("Loglikely: " + llk);
      }

      if (i > 0) {
	if ((llk - llkold) < 1e-6) {
	  break;
	}
      }
    }

    if (report) {
      EM_Report(inst);
    }

    return  llk;
  }


  // ============
  // Test method.
  // ============
  /**
   * Main method for testing this class.
   *
   * @param argv should contain the following arguments: <p>
   * -t training file [-T test file] [-N number of clusters] [-S random seed]
   */
  public static void main (String[] argv) {
    try {
      System.out.println(ClusterEvaluation.
			 evaluateClusterer(new EM(), argv));
    }
    catch (Exception e) {
      log.error(e.getMessage());
      log.error(e.getStackTrace().toString());
    }
  }
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -