⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 em.java

📁 数据挖掘中聚类的算法
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
    }        for (l = 0; l < inst.numInstances(); l++) {      m = Utils.maxIndex(m_weights[l]);      System.out.print("Inst " + Utils.doubleToString((double)l, 5, 0) 		       + " Class " + m + "\t");      for (j = 0; j < m_num_clusters; j++) {	System.out.print(Utils.doubleToString(m_weights[l][j], 7, 5) + "  ");      }      System.out.println();    }  }  /**   * estimate the number of clusters by cross validation on the training   * data.   *   * @throws Exception if something goes wrong   */  private void CVClusters ()    throws Exception {    double CVLogLikely = -Double.MAX_VALUE;    double templl, tll;    boolean CVincreased = true;    m_num_clusters = 1;    int num_clusters = m_num_clusters;    int i;    Random cvr;    Instances trainCopy;    int numFolds = (m_theInstances.numInstances() < 10)       ? m_theInstances.numInstances()       : 10;    boolean ok = true;    int seed = getSeed();    int restartCount = 0;    CLUSTER_SEARCH: while (CVincreased) {      // theInstances.stratify(10);              CVincreased = false;      cvr = new Random(getSeed());      trainCopy = new Instances(m_theInstances);      trainCopy.randomize(cvr);      templl = 0.0;      for (i = 0; i < numFolds; i++) {	Instances cvTrain = trainCopy.trainCV(numFolds, i, cvr);	if (num_clusters > cvTrain.numInstances()) {	  break CLUSTER_SEARCH;	}	Instances cvTest = trainCopy.testCV(numFolds, i);	m_rr = new Random(seed);        for (int z=0; z<10; z++) m_rr.nextDouble();	m_num_clusters = num_clusters;	EM_Init(cvTrain);	try {	  iterate(cvTrain, false);	} catch (Exception ex) {	  // catch any problems - i.e. empty clusters occuring	  ex.printStackTrace();          //          System.err.println("Restarting after CV training failure ("+num_clusters+" clusters");          seed++;          restartCount++;          ok = false;          if (restartCount > 5) {            break CLUSTER_SEARCH;          }	  break;	}        try {          tll = E(cvTest, false);        } catch (Exception ex) {          // catch any problems - i.e. empty clusters occuring          //          ex.printStackTrace();          ex.printStackTrace();          //          System.err.println("Restarting after CV testing failure ("+num_clusters+" clusters");          //          throw new Exception(ex);           seed++;          restartCount++;          ok = false;          if (restartCount > 5) {            break CLUSTER_SEARCH;          }          break;        }	if (m_verbose) {	  System.out.println("# clust: " + num_clusters + " Fold: " + i 			     + " Loglikely: " + tll);	}	templl += tll;      }      if (ok) {        restartCount = 0;        seed = getSeed();        templl /= (double)numFolds;                if (m_verbose) {          System.out.println("==================================="                              + "==============\n# clust: "                              + num_clusters                              + " Mean Loglikely: "                              + templl                              + "\n================================"                              + "=================");        }                if (templl > CVLogLikely) {          CVLogLikely = templl;          CVincreased = true;          num_clusters++;        }      }    }    if (m_verbose) {      System.out.println("Number of clusters: " + (num_clusters - 1));    }    m_num_clusters = num_clusters - 1;  }  /**   * Returns the number of clusters.   *   * @return the number of clusters generated for a training dataset.   * @throws Exception if number of clusters could not be returned   * successfully   */  public int numberOfClusters ()    throws Exception {    if (m_num_clusters == -1) {      throw  new Exception("Haven't generated any clusters!");    }    return  m_num_clusters;  } /**  * Updates the minimum and maximum values for all the attributes  * based on a new instance.  *  * @param instance the new instance  */  private void updateMinMax(Instance instance) {        for (int j = 0; j < m_theInstances.numAttributes(); j++) {      if (!instance.isMissing(j)) {	if (Double.isNaN(m_minValues[j])) {	  m_minValues[j] = instance.value(j);	  m_maxValues[j] = instance.value(j);	} else {	  if (instance.value(j) < m_minValues[j]) {	    m_minValues[j] = instance.value(j);	  } else {	    if (instance.value(j) > m_maxValues[j]) {	      m_maxValues[j] = instance.value(j);	    }	  }	}      }    }  }  /**   * Returns default capabilities of the clusterer (i.e., the ones of    * SimpleKMeans).   *   * @return      the capabilities of this clusterer   */  public Capabilities getCapabilities() {    Capabilities result = new SimpleKMeans().getCapabilities();    result.setOwner(this);    return result;  }    /**   * Generates a clusterer. Has to initialize all fields of the clusterer   * that are not being set via options.   *   * @param data set of instances serving as training data    * @throws Exception if the clusterer has not been    * generated successfully   */  public void buildClusterer (Instances data)    throws Exception {        // can clusterer handle the data?    getCapabilities().testWithFail(data);    m_replaceMissing = new ReplaceMissingValues();    Instances instances = new Instances(data);    instances.setClassIndex(-1);    m_replaceMissing.setInputFormat(instances);    data = weka.filters.Filter.useFilter(instances, m_replaceMissing);    instances = null;        m_theInstances = data;    // calculate min and max values for attributes    m_minValues = new double [m_theInstances.numAttributes()];    m_maxValues = new double [m_theInstances.numAttributes()];    for (int i = 0; i < m_theInstances.numAttributes(); i++) {      m_minValues[i] = m_maxValues[i] = Double.NaN;    }    for (int i = 0; i < m_theInstances.numInstances(); i++) {      updateMinMax(m_theInstances.instance(i));    }    doEM();        // save memory    m_theInstances = new Instances(m_theInstances,0);  }  /**   * Returns the cluster priors.   *    * @return the cluster priors   */  public double[] clusterPriors() {    double[] n = new double[m_priors.length];      System.arraycopy(m_priors, 0, n, 0, n.length);    return n;  }  /**   * Computes the log of the conditional density (per cluster) for a given instance.   *    * @param inst the instance to compute the density for   * @return an array containing the estimated densities   * @throws Exception if the density could not be computed   * successfully   */  public double[] logDensityPerClusterForInstance(Instance inst) throws Exception {    int i, j;    double logprob;    double[] wghts = new double[m_num_clusters];        m_replaceMissing.input(inst);    inst = m_replaceMissing.output();    for (i = 0; i < m_num_clusters; i++) {      //      System.err.println("Cluster : "+i);      logprob = 0.0;      for (j = 0; j < m_num_attribs; j++) {	if (!inst.isMissing(j)) {	  if (inst.attribute(j).isNominal()) {	    logprob += Math.log(m_model[i][j].getProbability(inst.value(j)));	  }	  else { // numeric attribute	    logprob += logNormalDens(inst.value(j), 				     m_modelNormal[i][j][0], 				     m_modelNormal[i][j][1]);	    /*	    System.err.println(logNormalDens(inst.value(j), 				     m_modelNormal[i][j][0], 				     m_modelNormal[i][j][1]) + " "); */	  }	}      }      //      System.err.println("");      wghts[i] = logprob;    }    return  wghts;  }  /**   * Perform the EM algorithm   *    * @throws Exception if something goes wrong   */  private void doEM ()    throws Exception {        if (m_verbose) {      System.out.println("Seed: " + getSeed());    }    m_rr = new Random(getSeed());    // throw away numbers to avoid problem of similar initial numbers    // from a similar seed    for (int i=0; i<10; i++) m_rr.nextDouble();    m_num_instances = m_theInstances.numInstances();    m_num_attribs = m_theInstances.numAttributes();    if (m_verbose) {      System.out.println("Number of instances: " 			 + m_num_instances 			 + "\nNumber of atts: " 			 + m_num_attribs 			 + "\n");    }    // setDefaultStdDevs(theInstances);    // cross validate to determine number of clusters?    if (m_initialNumClusters == -1) {      if (m_theInstances.numInstances() > 9) {	CVClusters();	m_rr = new Random(getSeed());	for (int i=0; i<10; i++) m_rr.nextDouble();      } else {	m_num_clusters = 1;      }    }    // fit full training set    EM_Init(m_theInstances);    m_loglikely = iterate(m_theInstances, m_verbose);  }  /**   * iterates the E and M steps until the log likelihood of the data   * converges.   *   * @param inst the training instances.   * @param report be verbose.   * @return the log likelihood of the data   * @throws Exception if something goes wrong   */  private double iterate (Instances inst, boolean report)    throws Exception {        int i;    double llkold = 0.0;    double llk = 0.0;    if (report) {      EM_Report(inst);    }    boolean ok = false;    int seed = getSeed();    int restartCount = 0;    while (!ok) {      try {        for (i = 0; i < m_max_iterations; i++) {          llkold = llk;          llk = E(inst, true);                    if (report) {            System.out.println("Loglikely: " + llk);          }                    if (i > 0) {            if ((llk - llkold) < 1e-6) {              break;            }          }          M(inst);        }        ok = true;      } catch (Exception ex) {        //        System.err.println("Restarting after training failure");        ex.printStackTrace();        seed++;        restartCount++;        m_rr = new Random(seed);        for (int z = 0; z < 10; z++) {          m_rr.nextDouble(); m_rr.nextInt();        }        if (restartCount > 5) {          //          System.err.println("Reducing the number of clusters");          m_num_clusters--;          restartCount = 0;        }        EM_Init(m_theInstances);      }    }          if (report) {      EM_Report(inst);    }    return  llk;  }  // ============  // Test method.  // ============  /**   * Main method for testing this class.   *   * @param argv should contain the following arguments: <p>   * -t training file [-T test file] [-N number of clusters] [-S random seed]   */  public static void main (String[] argv) {    runClusterer(new EM(), argv);  }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -