⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 em.java

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
	  // variance	    m_modelNormal[i][j][1] = (m_modelNormal[i][j][1] - 				      (m_modelNormal[i][j][0] * 				       m_modelNormal[i][j][0] / 				       m_modelNormal[i][j][2])) / 	      m_modelNormal[i][j][2];	    // std dev      	    m_modelNormal[i][j][1] = Math.sqrt(m_modelNormal[i][j][1]); 	    if (m_modelNormal[i][j][1] <= m_minStdDev 		|| Double.isNaN(m_modelNormal[i][j][1])) {	      m_modelNormal[i][j][1] = 		m_minStdDev;	    }	    	    // mean	    if (m_modelNormal[i][j][2] > 0.0) {	      m_modelNormal[i][j][0] /= m_modelNormal[i][j][2];	    }	  }                }      }    }  }  /**   * The E step of the EM algorithm. Estimate cluster membership    * probabilities.   *   * @param inst the training instances   * @param num_cl the number of clusters   * @return the average log likelihood   */  private double E (Instances inst, int num_cl)    throws Exception {    int i, j, l;    double prob;    double loglk = 0.0;    for (l = 0; l < inst.numInstances(); l++) {      for (i = 0; i < num_cl; i++) {	m_weights[l][i] = m_priors[i];      }      for (j = 0; j < m_num_attribs; j++) {	double max = 0;	for (i = 0; i < num_cl; i++) {	            if (!inst.instance(l).isMissing(j)) {            if (inst.attribute(j).isNominal()) {              m_weights[l][i] *= 		m_model[i][j].getProbability(inst.instance(l).value(j));	                  }            else {              // numeric attribute              m_weights[l][i] *= normalDens(inst.instance(l).value(j), 					    m_modelNormal[i][j][0], 					    m_modelNormal[i][j][1]);	      if (Double.isInfinite(m_weights[l][i])) {		throw new Exception("Joint density has overflowed. Try "				    +"increasing the minimum allowable "				    +"standard deviation for normal "				    +"density calculation.");	      }            }	    if (m_weights[l][i] > max) {	      max = m_weights[l][i];	    }          }        }	if (max > 0 && max < 1e-75) { // check for underflow	  for (int zz = 0; zz < num_cl; zz++) {	    // rescale	    m_weights[l][zz] *= 1e75;	  }	}      }            double temp1 = 0;            for (i = 0; i < num_cl; i++) {        temp1 += m_weights[l][i];      }            if (temp1 > 0) {        loglk += Math.log(temp1);      }            // normalise the weights for this instance      try {	Utils.normalize(m_weights[l]);      } catch (Exception e) {	throw new Exception("An instance has zero cluster memberships. Try "			    +"increasing the minimum allowable "			    +"standard deviation for normal "			    +"density calculation.");      }    }        // reestimate priors    estimate_priors(inst, num_cl);    return  loglk/inst.numInstances();  }      /**   * Constructor.   *   **/  public EM () {    resetOptions();  }  /**   * Reset to default options   */  protected void resetOptions () {    m_minStdDev = 1e-6;    m_max_iterations = 100;    m_rseed = 100;    m_num_clusters = -1;    m_initialNumClusters = -1;    m_verbose = false;  }  /**   * Return the normal distributions for the cluster models   *   * @return a <code>double[][][]</code> value   */  public double [][][] getClusterModelsNumericAtts() {    return m_modelNormal;  }  /**   * Return the priors for the clusters   *   * @return a <code>double[]</code> value   */  public double [] getClusterPriors() {    return m_priors;  }  /**   * Outputs the generated clusters into a string.   */  public String toString () {    StringBuffer text = new StringBuffer();    text.append("\nEM\n==\n");    if (m_initialNumClusters == -1) {      text.append("\nNumber of clusters selected by cross validation: "		  +m_num_clusters+"\n");    } else {      text.append("\nNumber of clusters: " + m_num_clusters + "\n");    }    for (int j = 0; j < m_num_clusters; j++) {      text.append("\nCluster: " + j + " Prior probability: " 		  + Utils.doubleToString(m_priors[j], 4) + "\n\n");      for (int i = 0; i < m_num_attribs; i++) {        text.append("Attribute: " + m_theInstances.attribute(i).name() + "\n");        if (m_theInstances.attribute(i).isNominal()) {          if (m_model[j][i] != null) {            text.append(m_model[j][i].toString());          }        }        else {          text.append("Normal Distribution. Mean = " 		      + Utils.doubleToString(m_modelNormal[j][i][0], 4) 		      + " StdDev = " 		      + Utils.doubleToString(m_modelNormal[j][i][1], 4) 		      + "\n");        }      }    }    return  text.toString();  }  /**   * verbose output for debugging   * @param inst the training instances   */  private void EM_Report (Instances inst) {    int i, j, l, m;    System.out.println("======================================");    for (j = 0; j < m_num_clusters; j++) {      for (i = 0; i < m_num_attribs; i++) {	System.out.println("Clust: " + j + " att: " + i + "\n");	if (m_theInstances.attribute(i).isNominal()) {	  if (m_model[j][i] != null) {	    System.out.println(m_model[j][i].toString());	  }	}	else {	  System.out.println("Normal Distribution. Mean = " 			     + Utils.doubleToString(m_modelNormal[j][i][0]						    , 8, 4) 			     + " StandardDev = " 			     + Utils.doubleToString(m_modelNormal[j][i][1]						    , 8, 4) 			     + " WeightSum = " 			     + Utils.doubleToString(m_modelNormal[j][i][2]						    , 8, 4));	}      }    }        for (l = 0; l < inst.numInstances(); l++) {      m = Utils.maxIndex(m_weights[l]);      System.out.print("Inst " + Utils.doubleToString((double)l, 5, 0) 		       + " Class " + m + "\t");      for (j = 0; j < m_num_clusters; j++) {	System.out.print(Utils.doubleToString(m_weights[l][j], 7, 5) + "  ");      }      System.out.println();    }  }  /**   * estimate the number of clusters by cross validation on the training   * data.   *   * @return the number of clusters selected   */  private int CVClusters ()    throws Exception {    double CVLogLikely = -Double.MAX_VALUE;    double templl, tll;    boolean CVincreased = true;    int num_cl = 1;    int i;    Random cvr;    Instances trainCopy;    int numFolds = (m_theInstances.numInstances() < 10)       ? m_theInstances.numInstances()       : 10;    while (CVincreased) {      CVincreased = false;      cvr = new Random(m_rseed);      trainCopy = new Instances(m_theInstances);      trainCopy.randomize(cvr);      // theInstances.stratify(10);      templl = 0.0;      for (i = 0; i < numFolds; i++) {	Instances cvTrain = trainCopy.trainCV(numFolds, i);	Instances cvTest = trainCopy.testCV(numFolds, i);	EM_Init(cvTrain, num_cl);	iterate(cvTrain, num_cl, false);	tll = E(cvTest, num_cl);	if (m_verbose) {	  System.out.println("# clust: " + num_cl + " Fold: " + i 			     + " Loglikely: " + tll);	}	templl += tll;      }      templl /= (double)numFolds;      if (m_verbose) {	System.out.println("===================================" 			   + "==============\n# clust: " 			   + num_cl 			   + " Mean Loglikely: " 			   + templl 			   + "\n================================" 			   + "=================");      }      if (templl > CVLogLikely) {	CVLogLikely = templl;	CVincreased = true;	num_cl++;      }    }    if (m_verbose) {      System.out.println("Number of clusters: " + (num_cl - 1));    }    return  num_cl - 1;  }  /**   * Returns the number of clusters.   *   * @return the number of clusters generated for a training dataset.   * @exception Exception if number of clusters could not be returned   * successfully   */  public int numberOfClusters ()    throws Exception {    if (m_num_clusters == -1) {      throw  new Exception("Haven't generated any clusters!");    }    return  m_num_clusters;  } /**  * Updates the minimum and maximum values for all the attributes  * based on a new instance.  *  * @param instance the new instance  */  private void updateMinMax(Instance instance) {        for (int j = 0; j < m_theInstances.numAttributes(); j++) {      if (!instance.isMissing(j)) {	if (Double.isNaN(m_minValues[j])) {	  m_minValues[j] = instance.value(j);	  m_maxValues[j] = instance.value(j);	} else {	  if (instance.value(j) < m_minValues[j]) {	    m_minValues[j] = instance.value(j);	  } else {	    if (instance.value(j) > m_maxValues[j]) {	      m_maxValues[j] = instance.value(j);	    }	  }	}      }    }  }    /**   * Generates a clusterer. Has to initialize all fields of the clusterer   * that are not being set via options.   *   * @param data set of instances serving as training data    * @exception Exception if the clusterer has not been    * generated successfully   */  public void buildClusterer (Instances data)    throws Exception {    if (data.checkForStringAttributes()) {      throw  new Exception("Can't handle string attributes!");    }        m_theInstances = data;        // calculate min and max values for attributes    m_minValues = new double [m_theInstances.numAttributes()];    m_maxValues = new double [m_theInstances.numAttributes()];    for (int i = 0; i < m_theInstances.numAttributes(); i++) {      m_minValues[i] = m_maxValues[i] = Double.NaN;    }    for (int i = 0; i < m_theInstances.numInstances(); i++) {      updateMinMax(m_theInstances.instance(i));    }    doEM();        // save memory    m_theInstances = new Instances(m_theInstances,0);  }  /**   * Computes the density for a given instance.   *    * @param inst the instance to compute the density for   * @return the density.   * @exception Exception if the density could not be computed   * successfully   */  public double densityForInstance(Instance inst) throws Exception {    return Utils.sum(weightsForInstance(inst));  }  /**   * Predicts the cluster memberships for a given instance.   *   * @param data set of test instances   * @param instance the instance to be assigned a cluster.   * @return an array containing the estimated membership    * probabilities of the test instance in each cluster (this    * should sum to at most 1)   * @exception Exception if distribution could not be    * computed successfully   */  public double[] distributionForInstance (Instance inst)    throws Exception {    double [] distrib = weightsForInstance(inst);    Utils.normalize(distrib);    return distrib;  }  /**   * Returns the weights (indicating cluster membership) for a given instance   *    * @param inst the instance to be assigned a cluster   * @return an array of weights   * @exception Exception if weights could not be computed   */  protected double[] weightsForInstance(Instance inst)    throws Exception {    int i, j;    double prob;    double[] wghts = new double[m_num_clusters];    for (i = 0; i < m_num_clusters; i++) {      prob = 1.0;      for (j = 0; j < m_num_attribs; j++) {	if (!inst.isMissing(j)) {	  if (inst.attribute(j).isNominal()) {	    prob *= m_model[i][j].getProbability(inst.value(j));	  }	  else { // numeric attribute	    prob *= normalDens(inst.value(j), 			       m_modelNormal[i][j][0], 			       m_modelNormal[i][j][1]);	  }	}      }      wghts[i] = (prob*m_priors[i]);    }    return  wghts;  }  /**   * Perform the EM algorithm   */  private void doEM ()    throws Exception {    if (m_verbose) {      System.out.println("Seed: " + m_rseed);    }    m_rr = new Random(m_rseed);    // throw away numbers to avoid problem of similar initial numbers    // from a similar seed    for (int i=0; i<10; i++) m_rr.nextDouble();    m_num_instances = m_theInstances.numInstances();    m_num_attribs = m_theInstances.numAttributes();    if (m_verbose) {      System.out.println("Number of instances: " 			 + m_num_instances 			 + "\nNumber of atts: " 			 + m_num_attribs 			 + "\n");    }    // setDefaultStdDevs(theInstances);    // cross validate to determine number of clusters?    if (m_initialNumClusters == -1) {      if (m_theInstances.numInstances() > 9) {	m_num_clusters = CVClusters();      } else {	m_num_clusters = 1;      }    }    // fit full training set    EM_Init(m_theInstances, m_num_clusters);    m_loglikely = iterate(m_theInstances, m_num_clusters, m_verbose);  }  /**   * iterates the E and M steps until the log likelihood of the data   * converges.   *   * @param inst the training instances.   * @param num_cl the number of clusters.   * @param report be verbose.   * @return the log likelihood of the data   */  private double iterate (Instances inst, int num_cl, boolean report)    throws Exception {    int i;    double llkold = 0.0;    double llk = 0.0;    if (report) {      EM_Report(inst);    }    for (i = 0; i < m_max_iterations; i++) {      llkold = llk;      llk = E(inst, num_cl);      if (report) {	System.out.println("Loglikely: " + llk);      }      if (i > 0) {	if ((llk - llkold) < 1e-6) {	  break;	}      }      M(inst, num_cl);    }    if (report) {      EM_Report(inst);    }    return  llk;  }  // ============  // Test method.  // ============  /**   * Main method for testing this class.   *   * @param argv should contain the following arguments: <p>   * -t training file [-T test file] [-N number of clusters] [-S random seed]   */  public static void main (String[] argv) {    try {      System.out.println(ClusterEvaluation.			 evaluateClusterer(new EM(), argv));    }    catch (Exception e) {      System.out.println(e.getMessage());      e.printStackTrace();    }  }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -