seededkmeans.java

来自「wekaUT是 university texas austin 开发的基于wek」· Java 代码 · 共 2,032 行 · 第 1/5 页

JAVA
2,032
字号
      String metricName = metricSpec[0];       metricSpec[0] = "";      if (m_Verbose) {	System.out.println("Metric name: " + metricName + "\nMetric parameters: " + concatStringArray(metricSpec));      }      setMetric((LearnableMetric) LearnableMetric.forName(metricName, metricSpec));    }  }  /** A little helper to create a single String from an array of Strings   * @param strings an array of strings   * @returns a single concatenated string, separated by commas   */  public static String concatStringArray(String[] strings) {    String result = new String();    for (int i = 0; i < strings.length; i++) {      result = result + "\"" + strings[i] + "\" ";    }    return result;  }   /**      * return a string describing this clusterer   *   * @return a description of the clusterer as a string   */  public String toString() {    StringBuffer temp = new StringBuffer();    temp.append("\nkMeans\n======\n");    temp.append("\nNumber of iterations: " + m_Iterations+"\n");    temp.append("\nCluster centroids:\n");    for (int i = 0; i < m_NumClusters; i++) {      temp.append("\nCluster "+i+"\n\t");      /*        temp.append(m_ClusterCentroids.instance(i));	for (int j = 0; j < m_ClusterCentroids.numAttributes(); j++) {	if (m_ClusterCentroids.attribute(j).isNominal()) {	temp.append(" "+m_ClusterCentroids.attribute(j).	value((int)m_ClusterCentroids.instance(i).value(j)));	} 	else {	temp.append(" "+m_ClusterCentroids.instance(i).value(j));	}	}      */    }    temp.append("\n");    return temp.toString();  }  /**   * set the verbosity level of the clusterer   * @param verbose messages on(true) or off (false)   */  public void setVerbose (boolean verbose) {    m_Verbose = verbose;  }  /**   * get the verbosity level of the clusterer   * @return messages on(true) or off (false)   */  public boolean getVerbose () {    return m_Verbose;  }       /**   * Train the clusterer using specified parameters   *   * @param instances Instances to be used for training   */  public void trainClusterer (Instances instances) throws Exception {    if (m_metric instanceof LearnableMetric) {      if (((LearnableMetric)m_metric).getTrainable()) {	((LearnableMetric)m_metric).learnMetric(instances);      }      else {	throw new Exception ("Metric is not trainable");      }    }    else {      throw new Exception ("Metric is not trainable");    }  }  /** Normalizes Instance or SparseInstance   *   * @author Sugato Basu   * @param inst Instance to be normalized   */  public void normalize(Instance inst) throws Exception {    if (inst instanceof SparseInstance) {      normalizeSparseInstance(inst);    }    else {      ((LearnableMetric) m_metric).normalizeInstanceWeighted(inst);    }  }  /** Normalizes the values of a normal Instance   *   * @author Sugato Basu   * @param inst Instance to be normalized   */  public void normalizeInstance(Instance inst) throws Exception{    double norm = 0;    double values [] = inst.toDoubleArray();    if (inst instanceof SparseInstance) {      throw new Exception("Use normalizeSparseInstance function");    }    for (int i=0; i<values.length; i++) {      if (i != inst.classIndex()) { // don't normalize the class index 	norm += values[i] * values[i];      }    }    norm = Math.sqrt(norm);    for (int i=0; i<values.length; i++) {      if (i != inst.classIndex()) { // don't normalize the class index 	if (norm == 0) {	  values[i]= 0;	} else {	  values[i] /= norm;	}      }    }    inst.setValueArray(values);  }  /** Normalizes the values of a SparseInstance   *   * @author Sugato Basu   * @param inst SparseInstance to be normalized   */  public void normalizeSparseInstance(Instance inst) throws Exception{    double norm=0;    int length = inst.numValues();    if (!(inst instanceof SparseInstance)) {      throw new Exception("Use normalizeInstance function");    }    for (int i=0; i<length; i++) {      if (inst.index(i) != inst.classIndex()) { // don't normalize the class index	norm += inst.valueSparse(i) * inst.valueSparse(i);      }    }    norm = Math.sqrt(norm);    for (int i=0; i<length; i++) { // don't normalize the class index      if (inst.index(i) != inst.classIndex()) {	inst.setValueSparse(i, inst.valueSparse(i)/norm);      }    }  }    /** Fast version of meanOrMode - streamlined from Instances.meanOrMode for efficiency    *  Does not check for missing attributes, assumes numeric attributes, assumes Sparse instances   */  protected double[] meanOrMode(Instances insts) {    int numAttributes = insts.numAttributes();    double [] value = new double[numAttributes];    double weight = 0;        for (int i=0; i<numAttributes; i++) {      value[i] = 0;    }    for (int j=0; j<insts.numInstances(); j++) {      SparseInstance inst = (SparseInstance) (insts.instance(j));      weight += inst.weight();      for (int i=0; i<inst.numValues(); i++) {	int indexOfIndex = inst.index(i);	value[indexOfIndex]  += inst.weight() * inst.valueSparse(i);      }    }        if (Utils.eq(weight, 0)) {      for (int k=0; k<numAttributes; k++) {	value[k] = 0;      }    }    else {      for (int k=0; k<numAttributes; k++) {	value[k] = value[k] / weight;      }    }    return value;  }  /**   * Main method for testing this class.   *   */  public static void main (String[] args) {    try {          String dataSet = new String("news");      //String dataSet = new String("iris");      if (dataSet.equals("iris")) {	//////// Iris data	String datafile = "/u/ml/software/weka-latest/data/iris.arff";		// set up the data	FileReader reader = new FileReader (datafile);	Instances data = new Instances (reader);		// Make the last attribute be the class 	int theClass = data.numAttributes();	data.setClassIndex(theClass-1); // starts with 0		// Remove the class labels before clustering		Instances clusterData = new Instances(data);	clusterData.deleteClassAttribute();	// #clusters = #classes	int num_clusters = data.numClasses();		// cluster with seeding      	Instances seeds = new Instances(data,0,5);	seeds.add(data.instance(50));	seeds.add(data.instance(51));	seeds.add(data.instance(52));	seeds.add(data.instance(53));	seeds.add(data.instance(54));		seeds.add(data.instance(100));	seeds.add(data.instance(101));	seeds.add(data.instance(102));	seeds.add(data.instance(103));	seeds.add(data.instance(104));	data.delete(104);	data.delete(103);	data.delete(102);	data.delete(101);	data.delete(100);	data.delete(54);	data.delete(53);	data.delete(52);	data.delete(51);	data.delete(50);	data.delete(4);	data.delete(3);	data.delete(2);	data.delete(1);	data.delete(0);		System.out.println("\nClustering the iris data with seeding, using seeded KMeans...\n");      	WeightedEuclidean euclidean = new WeightedEuclidean();	SeededKMeans kmeans = new SeededKMeans (euclidean);	kmeans.resetClusterer();	kmeans.setVerbose(false);	kmeans.setSeedingMethod(new SelectedTag(SEEDING_SEEDED, TAGS_SEEDING));	kmeans.setAlgorithm(new SelectedTag(ALGORITHM_SIMPLE, TAGS_ALGORITHM));	euclidean.setExternal(false);	euclidean.setTrainable(false);	// phase 1 test	kmeans.setSeedable(false);	kmeans.buildClusterer(null, clusterData, theClass, data, 150);	// phase 2 test	//kmeans.setSeedable(true);	//kmeans.buildClusterer(seeds, clusterData, theClass, data, 150);	kmeans.getIndexClusters();	kmeans.printIndexClusters();	//	kmeans.setVerbose(true);	kmeans.bestInstancesForActiveLearning(50);      }      else if (dataSet.equals("news")) {	//////// Text data - 3000 documents	String datafile = "/u/ml/data/CCSfiles/arffFromCCS/cmu-newsgroup-clean-1000_fromCCS.arff";	System.out.println("\nClustering complete newsgroup data with seeding, using constrained KMeans...\n");      	// set up the data	FileReader reader = new FileReader (datafile);	Instances data = new Instances (reader);	System.out.println("Initial data has size: " + data.numInstances());		// Make the last attribute be the class 	int theClass = data.numAttributes();	data.setClassIndex(theClass-1); // starts with 0	int num_clusters = data.numClasses();		// cluster with seeding              Instances seeds = new Instances(data, 0);	/*	seeds.add(data.instance(994));	seeds.add(data.instance(1431));	seeds.add(data.instance(1612));	seeds.add(data.instance(1747));	seeds.add(data.instance(2205));	seeds.add(data.instance(2736));	data.delete(2736);	data.delete(2205);	data.delete(1747);	data.delete(1612);	data.delete(1431);	data.delete(994);	seeds.add(data.instance(1000));	seeds.add(data.instance(1001));	seeds.add(data.instance(1002));	seeds.add(data.instance(1003));	seeds.add(data.instance(1004));	seeds.add(data.instance(2000));	seeds.add(data.instance(2001));	seeds.add(data.instance(2002));	seeds.add(data.instance(2003));	seeds.add(data.instance(2004));	//        System.out.println("Labeled data has size: " + seeds.numInstances() + ", number of attributes: " + data.numAttributes());	data.delete(2004);	data.delete(2003);	data.delete(2002);	data.delete(2001);	data.delete(2000);	data.delete(1004);	data.delete(1003);	data.delete(1002);	data.delete(1001);	data.delete(1000);	data.delete(4);	data.delete(3);	data.delete(2);	data.delete(1);	data.delete(0);	*/	System.out.println("Unlabeled data has size: " + data.numInstances());		// Remove the class labels before clustering		Instances clusterData = new Instances(data);	clusterData.deleteClassAttribute();	WeightedDotP dotp = new WeightedDotP();	dotp.setExternal(false);	dotp.setTrainable(false);	dotp.setLengthNormalized(false);	SeededKMeans kmeans = new SeededKMeans(dotp);	kmeans.setVerbose(false);	kmeans.setSeedingMethod(new SelectedTag(SEEDING_SEEDED, TAGS_SEEDING));	kmeans.setAlgorithm(new SelectedTag(ALGORITHM_SPHERICAL, TAGS_ALGORITHM));	kmeans.setNumClusters(3);	// phase 1 test	kmeans.setSeedable(false);	kmeans.buildClusterer(null, clusterData, theClass, data, data.numInstances());	// phase 2 test	//kmeans.setSeedable(true);	//kmeans.buildClusterer(seeds, clusterData, theClass, data, 3000);	kmeans.getIndexClusters();	kmeans.printIndexClusters();	//	kmeans.setVerbose(true);	//kmeans.bestInstancesForActiveLearning(50);      //  	// cluster with seeding for small newsgroup//  	seeds = new Instances(data, 0, 3);	//  	seeds.add(data.instance(100)); //  	seeds.add(data.instance(101));//  	seeds.add(data.instance(102));//  	seeds.add(data.instance(200));//  	seeds.add(data.instance(201));//  	seeds.add(data.instance(202));//  	seeds.add(data.instance(300));//  	seeds.add(data.instance(301));//  	seeds.add(data.instance(302));//  	seeds.add(data.instance(400));//  	seeds.add(data.instance(401));//  	seeds.add(data.instance(402));//  	seeds.add(data.instance(500));//  	seeds.add(data.instance(501));//  	seeds.add(data.instance(502));//  	seeds.add(data.instance(600));//  	seeds.add(data.instance(601));//  	seeds.add(data.instance(602));//  	seeds.add(data.instance(700));//  	seeds.add(data.instance(701));//  	seeds.add(data.instance(702));//  	seeds.add(data.instance(800));//  	seeds.add(data.instance(801));//  	seeds.add(data.instance(802));//  	seeds.add(data.instance(900));//  	seeds.add(data.instance(901));      //  	seeds.add(data.instance(902));//  	seeds.add(data.instance(1000)); //  	seeds.add(da

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?