⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 em.java

📁 MacroWeka扩展了著名数据挖掘工具weka
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
    }
    
    for (l = 0; l < inst.numInstances(); l++) {
      m = Utils.maxIndex(m_weights[l]);
      System.out.print("Inst " + Utils.doubleToString((double)l, 5, 0) 
		       + " Class " + m + "\t");
      for (j = 0; j < m_num_clusters; j++) {
	System.out.print(Utils.doubleToString(m_weights[l][j], 7, 5) + "  ");
      }
      System.out.println();
    }
  }


  /**
   * estimate the number of clusters by cross validation on the training
   * data.
   *
   */
  private void CVClusters ()
    throws Exception {
    double CVLogLikely = -Double.MAX_VALUE;
    double templl, tll;
    boolean CVincreased = true;
    m_num_clusters = 1;
    int num_clusters = m_num_clusters;
    int i;
    Random cvr;
    Instances trainCopy;
    int numFolds = (m_theInstances.numInstances() < 10) 
      ? m_theInstances.numInstances() 
      : 10;

    boolean ok = true;
    int seed = m_rseed;
    int restartCount = 0;
    CLUSTER_SEARCH: while (CVincreased) {
      // theInstances.stratify(10);
        
      CVincreased = false;
      cvr = new Random(m_rseed);
      trainCopy = new Instances(m_theInstances);
      trainCopy.randomize(cvr);
      templl = 0.0;
      for (i = 0; i < numFolds; i++) {
	Instances cvTrain = trainCopy.trainCV(numFolds, i, cvr);
	if (num_clusters > cvTrain.numInstances()) {
	  break CLUSTER_SEARCH;
	}
	Instances cvTest = trainCopy.testCV(numFolds, i);
	m_rr = new Random(seed);
        for (int z=0; z<10; z++) m_rr.nextDouble();
	m_num_clusters = num_clusters;
	EM_Init(cvTrain);
	try {
	  iterate(cvTrain, false);
	} catch (Exception ex) {
	  // catch any problems - i.e. empty clusters occuring
	  ex.printStackTrace();
          //          System.err.println("Restarting after CV training failure ("+num_clusters+" clusters");
          seed++;
          restartCount++;
          ok = false;
          if (restartCount > 5) {
            break CLUSTER_SEARCH;
          }
	  break;
	}
        try {
          tll = E(cvTest, false);
        } catch (Exception ex) {
          // catch any problems - i.e. empty clusters occuring
          //          ex.printStackTrace();
          ex.printStackTrace();
          //          System.err.println("Restarting after CV testing failure ("+num_clusters+" clusters");
          //          throw new Exception(ex); 
          seed++;
          restartCount++;
          ok = false;
          if (restartCount > 5) {
            break CLUSTER_SEARCH;
          }
          break;
        }

	if (m_verbose) {
	  System.out.println("# clust: " + num_clusters + " Fold: " + i 
			     + " Loglikely: " + tll);
	}
	templl += tll;
      }

      if (ok) {
        restartCount = 0;
        seed = m_rseed;
        templl /= (double)numFolds;
        
        if (m_verbose) {
          System.out.println("===================================" 
                             + "==============\n# clust: " 
                             + num_clusters 
                             + " Mean Loglikely: " 
                             + templl 
                             + "\n================================" 
                             + "=================");
        }
        
        if (templl > CVLogLikely) {
          CVLogLikely = templl;
          CVincreased = true;
          num_clusters++;
        }
      }
    }

    if (m_verbose) {
      System.out.println("Number of clusters: " + (num_clusters - 1));
    }

    m_num_clusters = num_clusters - 1;
  }


  /**
   * Returns the number of clusters.
   *
   * @return the number of clusters generated for a training dataset.
   * @exception Exception if number of clusters could not be returned
   * successfully
   */
  public int numberOfClusters ()
    throws Exception {
    if (m_num_clusters == -1) {
      throw  new Exception("Haven't generated any clusters!");
    }

    return  m_num_clusters;
  }

 /**
  * Updates the minimum and maximum values for all the attributes
  * based on a new instance.
  *
  * @param instance the new instance
  */
  private void updateMinMax(Instance instance) {
    
    for (int j = 0; j < m_theInstances.numAttributes(); j++) {
      if (!instance.isMissing(j)) {
	if (Double.isNaN(m_minValues[j])) {
	  m_minValues[j] = instance.value(j);
	  m_maxValues[j] = instance.value(j);
	} else {
	  if (instance.value(j) < m_minValues[j]) {
	    m_minValues[j] = instance.value(j);
	  } else {
	    if (instance.value(j) > m_maxValues[j]) {
	      m_maxValues[j] = instance.value(j);
	    }
	  }
	}
      }
    }
  }
  
  /**
   * Generates a clusterer. Has to initialize all fields of the clusterer
   * that are not being set via options.
   *
   * @param data set of instances serving as training data 
   * @exception Exception if the clusterer has not been 
   * generated successfully
   */
  public void buildClusterer (Instances data)
    throws Exception {
    if (data.checkForStringAttributes()) {
      throw  new Exception("Can't handle string attributes!");
    }
    
    m_replaceMissing = new ReplaceMissingValues();
    Instances instances = new Instances(data);
    instances.setClassIndex(-1);
    m_replaceMissing.setInputFormat(instances);
    data = weka.filters.Filter.useFilter(instances, m_replaceMissing);
    instances = null;
    
    m_theInstances = data;

    // calculate min and max values for attributes
    m_minValues = new double [m_theInstances.numAttributes()];
    m_maxValues = new double [m_theInstances.numAttributes()];
    for (int i = 0; i < m_theInstances.numAttributes(); i++) {
      m_minValues[i] = m_maxValues[i] = Double.NaN;
    }
    for (int i = 0; i < m_theInstances.numInstances(); i++) {
      updateMinMax(m_theInstances.instance(i));
    }

    doEM();
    
    // save memory
    m_theInstances = new Instances(m_theInstances,0);
  }

  /**
   * Returns the cluster priors.
   */
  public double[] clusterPriors() {

    double[] n = new double[m_priors.length];
  
    System.arraycopy(m_priors, 0, n, 0, n.length);
    return n;
  }

  /**
   * Computes the log of the conditional density (per cluster) for a given instance.
   * 
   * @param instance the instance to compute the density for
   * @return the density.
   * @return an array containing the estimated densities
   * @exception Exception if the density could not be computed
   * successfully
   */
  public double[] logDensityPerClusterForInstance(Instance inst) throws Exception {

    int i, j;
    double logprob;
    double[] wghts = new double[m_num_clusters];
    
    m_replaceMissing.input(inst);
    inst = m_replaceMissing.output();

    for (i = 0; i < m_num_clusters; i++) {
      //      System.err.println("Cluster : "+i);
      logprob = 0.0;

      for (j = 0; j < m_num_attribs; j++) {
	if (!inst.isMissing(j)) {
	  if (inst.attribute(j).isNominal()) {
	    logprob += Math.log(m_model[i][j].getProbability(inst.value(j)));
	  }
	  else { // numeric attribute
	    logprob += logNormalDens(inst.value(j), 
				     m_modelNormal[i][j][0], 
				     m_modelNormal[i][j][1]);
	    /*	    System.err.println(logNormalDens(inst.value(j), 
				     m_modelNormal[i][j][0], 
				     m_modelNormal[i][j][1]) + " "); */
	  }
	}
      }
      //      System.err.println("");

      wghts[i] = logprob;
    }
    return  wghts;
  }


  /**
   * Perform the EM algorithm
   */
  private void doEM ()
    throws Exception {
    if (m_verbose) {
      System.out.println("Seed: " + m_rseed);
    }

    m_rr = new Random(m_rseed);

    // throw away numbers to avoid problem of similar initial numbers
    // from a similar seed
    for (int i=0; i<10; i++) m_rr.nextDouble();

    m_num_instances = m_theInstances.numInstances();
    m_num_attribs = m_theInstances.numAttributes();

    if (m_verbose) {
      System.out.println("Number of instances: " 
			 + m_num_instances 
			 + "\nNumber of atts: " 
			 + m_num_attribs 
			 + "\n");
    }

    // setDefaultStdDevs(theInstances);
    // cross validate to determine number of clusters?
    if (m_initialNumClusters == -1) {
      if (m_theInstances.numInstances() > 9) {
	CVClusters();
	m_rr = new Random(m_rseed);
	for (int i=0; i<10; i++) m_rr.nextDouble();
      } else {
	m_num_clusters = 1;
      }
    }

    // fit full training set
    EM_Init(m_theInstances);
    m_loglikely = iterate(m_theInstances, m_verbose);
  }


  /**
   * iterates the E and M steps until the log likelihood of the data
   * converges.
   *
   * @param inst the training instances.
   * @param num_cl the number of clusters.
   * @param report be verbose.
   * @return the log likelihood of the data
   */
  private double iterate (Instances inst, boolean report)
    throws Exception {
    int i;
    double llkold = 0.0;
    double llk = 0.0;

    if (report) {
      EM_Report(inst);
    }

    boolean ok = false;
    int seed = m_rseed;
    int restartCount = 0;
    while (!ok) {
      try {
        for (i = 0; i < m_max_iterations; i++) {
          llkold = llk;
          llk = E(inst, true);
          
          if (report) {
            System.out.println("Loglikely: " + llk);
          }
          
          if (i > 0) {
            if ((llk - llkold) < 1e-6) {
              break;
            }
          }
          M(inst);
        }
        ok = true;
      } catch (Exception ex) {
        //        System.err.println("Restarting after training failure");
        ex.printStackTrace();
        seed++;
        restartCount++;
        m_rr = new Random(seed);
        for (int z = 0; z < 10; z++) {
          m_rr.nextDouble(); m_rr.nextInt();
        }
        if (restartCount > 5) {
          //          System.err.println("Reducing the number of clusters");
          m_num_clusters--;
          restartCount = 0;
        }
        EM_Init(m_theInstances);
      }
    }
      
    if (report) {
      EM_Report(inst);
    }

    return  llk;
  }


  // ============
  // Test method.
  // ============
  /**
   * Main method for testing this class.
   *
   * @param argv should contain the following arguments: <p>
   * -t training file [-T test file] [-N number of clusters] [-S random seed]
   */
  public static void main (String[] argv) {
    try {
      System.out.println(ClusterEvaluation.
			 evaluateClusterer(new EM(), argv));
    }
    catch (Exception e) {
      System.out.println(e.getMessage());
      e.printStackTrace();
    }
  }
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -