⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 ensembleselection.java

📁 代码是一个分类器的实现,其中使用了部分weka的源代码。可以将项目导入eclipse运行
💻 JAVA
📖 第 1 页 / 共 4 页
字号:
	System.out.println("" + modelPerformance[i] + " "	    + modelNamesArray[modelIndexes[i]]);      }    }    // We're now ready to build our array of the models which were chosen    // and there associated weights.    m_chosen_models = new EnsembleSelectionLibraryModel[chosenModels];    m_chosen_model_weights = new int[chosenModels];        libraryIndex = 0;    // chosenIndex indexes over the models which were chosen by    // EnsembleSelection    // (those which have non-zero weight).    int chosenIndex = 0;    iter = m_library.getModels().iterator();    while (iter.hasNext()) {      int weightOfModel = modelWeights[libraryIndex++];            EnsembleSelectionLibraryModel model = (EnsembleSelectionLibraryModel) iter      .next();            if (weightOfModel > 0) {	// If the model was chosen at least once, add it to our array	// of chosen models and weights.	m_chosen_models[chosenIndex] = model;	m_chosen_model_weights[chosenIndex] = weightOfModel;	// Note that the EnsembleSelectionLibraryModel may not be	// "loaded" -	// that is, its classifier(s) may be null pointers. That's okay	// -	// we'll "rehydrate" them later, if and when we need to.	++chosenIndex;      }    }  }    /**   * Calculates the class membership probabilities for the given test instance.   *   * @param instance the instance to be classified   * @return predicted class probability distribution   * @throws Exception if instance could not be classified   * successfully   */  public double[] distributionForInstance(Instance instance) throws Exception {    String stringInstance = instance.toString();    double cachedPreds[][] = null;        if (m_cachedPredictions != null) {      // If we have any cached predictions (i.e., if cachePredictions was      // called), look for a cached set of predictions for this instance.      if (m_cachedPredictions.containsKey(stringInstance)) {	cachedPreds = (double[][]) m_cachedPredictions.get(stringInstance);      }    }    double[] prediction = new double[instance.numClasses()];    for (int i = 0; i < prediction.length; ++i) {      prediction[i] = 0.0;    }        // Now do a weighted average of the predictions of each of our models.    for (int i = 0; i < m_chosen_models.length; ++i) {      double[] predictionForThisModel = null;      if (cachedPreds == null) {	// If there are no predictions cached, we'll load the model's	// classifier(s) in to memory and get the predictions.	m_chosen_models[i].rehydrateModel(m_workingDirectory.getAbsolutePath());	predictionForThisModel = m_chosen_models[i].getAveragePrediction(instance);	// We could release the model here to save memory, but we assume	// that there is enough available since we're not using the	// prediction caching functionality. If we load and release a	// model	// every time we need to get a prediction for an instance, it	// can be	// prohibitively slow.      } else {	// If it's cached, just get it from the array of cached preds	// for this instance.	predictionForThisModel = cachedPreds[i];      }      // We have encountered a bug where MultilayerPerceptron returns a      // null      // prediction array. If that happens, we just don't count that model      // in      // our ensemble prediction.      if (predictionForThisModel != null) {	// Okay, the model returned a valid prediction array, so we'll	// add the appropriate fraction of this model's prediction.	for (int j = 0; j < prediction.length; ++j) {	  prediction[j] += m_chosen_model_weights[i] * predictionForThisModel[j] / m_total_weight;	}      }    }    // normalize to add up to 1.    if (instance.classAttribute().isNominal()) {      if (Utils.sum(prediction) > 0)	Utils.normalize(prediction);    }    return prediction;  }    /**   * This function tests whether or not a given path is appropriate for being   * the working directory. Specifically, we care that we can write to the   * path and that it doesn't point to a "non-directory" file handle.   *    * @param dir		the directory to test   * @return 		true if the directory is valid   */  private boolean validWorkingDirectory(String dir) {        boolean valid = false;        File f = new File((dir));        if (f.exists()) {      if (f.isDirectory() && f.canWrite())	valid = true;    } else {      if (f.canWrite())	valid = true;    }        return valid;      }    /**   * This method tries to find a reasonable path name for the ensemble working   * directory where models and files will be stored.   *    *    * @return true if m_workingDirectory now has a valid file name   */  public static String getDefaultWorkingDirectory() {        String defaultDirectory = new String("");        boolean success = false;        int i = 1;        while (i < MAX_DEFAULT_DIRECTORIES && !success) {            File f = new File(System.getProperty("user.home"), "Ensemble-" + i);            if (!f.exists() && f.getParentFile().canWrite()) {	defaultDirectory = f.getPath();	success = true;      }      i++;          }        if (!success) {      defaultDirectory = new String("");      // should we print an error or something?    }        return defaultDirectory;  }    /**   * Output a representation of this classifier   *    * @return	a string representation of the classifier   */  public String toString() {    // We just print out the models which were selected, and the number    // of times each was selected.    String result = new String();    if (m_chosen_models != null) {      for (int i = 0; i < m_chosen_models.length; ++i) {	result += m_chosen_model_weights[i];	result += " " + m_chosen_models[i].getStringRepresentation()	+ "\n";      }    } else {      result = "No models selected.";    }    return result;  }    /**   * Cache predictions for the individual base classifiers in the ensemble   * with respect to the given dataset. This is used so that when testing a   * large ensemble on a test set, we don't have to keep the models in memory.   *    * @param test 	The instances for which to cache predictions.   * @throws Exception 	if somethng goes wrong   */  private void cachePredictions(Instances test) throws Exception {    m_cachedPredictions = new HashMap();    Evaluation evalModel = null;    Instances originalInstances = null;    // If the verbose flag is set, we'll also print out the performances of    // all the individual models w.r.t. this test set while we're at it.    boolean printModelPerformances = getVerboseOutput();    if (printModelPerformances) {      // To get performances, we need to keep the class attribute.      originalInstances = new Instances(test);    }        // For each model, we'll go through the dataset and get predictions.    // The idea is we want to only have one model in memory at a time, so    // we'll    // load one model in to memory, get all its predictions, and add them to    // the    // hash map. Then we can release it from memory and move on to the next.    for (int i = 0; i < m_chosen_models.length; ++i) {      if (printModelPerformances) {	// If we're going to print predictions, we need to make a new	// Evaluation object.	evalModel = new Evaluation(originalInstances);      }            Date startTime = new Date();            // Load the model in to memory.      m_chosen_models[i].rehydrateModel(m_workingDirectory.getAbsolutePath());      // Now loop through all the instances and get the model's      // predictions.      for (int j = 0; j < test.numInstances(); ++j) {	Instance currentInstance = test.instance(j);	// When we're looking for a cached prediction later, we'll only	// have the non-class attributes, so we set the class missing	// here	// in order to make the string match up properly.	currentInstance.setClassMissing();	String stringInstance = currentInstance.toString();		// When we come in here with the first model, the instance will	// not	// yet be part of the map.	if (!m_cachedPredictions.containsKey(stringInstance)) {	  // The instance isn't in the map yet, so add it.	  // For each instance, we store a two-dimensional array - the	  // first	  // index is over all the models in the ensemble, and the	  // second	  // index is over the (i.e., typical prediction array).	  int predSize = test.classAttribute().isNumeric() ? 1 : test	      .classAttribute().numValues();	  double predictionArray[][] = new double[m_chosen_models.length][predSize];	  m_cachedPredictions.put(stringInstance, predictionArray);	}	// Get the array from the map which is associated with this	// instance	double predictions[][] = (double[][]) m_cachedPredictions	.get(stringInstance);	// And add our model's prediction for it.	predictions[i] = m_chosen_models[i].getAveragePrediction(test	    .instance(j));		if (printModelPerformances) {	  evalModel.evaluateModelOnceAndRecordPrediction(	      predictions[i], originalInstances.instance(j));	}      }      // Now we're done with model #i, so we can release it.      m_chosen_models[i].releaseModel();            Date endTime = new Date();      long diff = endTime.getTime() - startTime.getTime();            if (m_Debug)	System.out.println("Test time for "	    + m_chosen_models[i].getStringRepresentation()	    + " was: " + diff);            if (printModelPerformances) {	String output = new String(m_chosen_models[i]	                                           .getStringRepresentation()	                                           + ": ");	output += "\tRMSE:" + evalModel.rootMeanSquaredError();	output += "\tACC:" + evalModel.pctCorrect();	if (test.numClasses() == 2) {	  // For multiclass problems, we could print these too, but	  // it's	  // not clear which class we should use in that case... so	  // instead	  // we only print these metrics for binary classification	  // problems.	  output += "\tROC:" + evalModel.areaUnderROC(1);	  output += "\tPREC:" + evalModel.precision(1);	  output += "\tFSCR:" + evalModel.fMeasure(1);	}	System.out.println(output);      }    }  }    /**   * Return the technical information.  There is actually another   * paper that describes our current method of CV for this classifier   * TODO: Cite Technical report when published   *    * @return the technical information about this class   */  public TechnicalInformation getTechnicalInformation() {        TechnicalInformation result;        result = new TechnicalInformation(Type.INPROCEEDINGS);    result.setValue(Field.AUTHOR, "Rich Caruana, Alex Niculescu, Geoff Crew, and Alex Ksikes");    result.setValue(Field.TITLE, "Ensemble Selection from Libraries of Models");    result.setValue(Field.BOOKTITLE, "21st International Conference on Machine Learning");    result.setValue(Field.YEAR, "2004");        return result;  }    /**   * Executes the classifier from commandline.   *    * @param argv   *            should contain the following arguments: -t training file [-T   *            test file] [-c class index]   */  public static void main(String[] argv) {        try {            String options[] = (String[]) argv.clone();            // do we get the input from XML instead of normal parameters?      String xml = Utils.getOption("xml", options);      if (!xml.equals(""))	options = new XMLOptions(xml).toArray();            String trainFileName = Utils.getOption('t', options);      String objectInputFileName = Utils.getOption('l', options);      String testFileName = Utils.getOption('T', options);            if (testFileName.length() != 0 && objectInputFileName.length() != 0	  && trainFileName.length() == 0) {		System.out.println("Caching predictions");		EnsembleSelection classifier = null;		BufferedReader testReader = new BufferedReader(new FileReader(	    testFileName));		// Set up the Instances Object	Instances test;	int classIndex = -1;	String classIndexString = Utils.getOption('c', options);	if (classIndexString.length() != 0) {	  classIndex = Integer.parseInt(classIndexString);	}		test = new Instances(testReader, 1);	if (classIndex != -1) {	  test.setClassIndex(classIndex - 1);	} else {	  test.setClassIndex(test.numAttributes() - 1);	}	if (classIndex > test.numAttributes()) {	  throw new Exception("Index of class attribute too large.");	}		while (test.readInstance(testReader)) {	  	}	testReader.close();		// Now yoink the EnsembleSelection Object from the fileSystem		InputStream is = new FileInputStream(objectInputFileName);	if (objectInputFileName.endsWith(".gz")) {	  is = new GZIPInputStream(is);	}		// load from KOML?	if (!(objectInputFileName.endsWith("UpdateableClassifier.koml") && KOML	    .isPresent())) {	  ObjectInputStream objectInputStream = new ObjectInputStream(	      is);	  classifier = (EnsembleSelection) objectInputStream	  .readObject();	  objectInputStream.close();	} else {	  BufferedInputStream xmlInputStream = new BufferedInputStream(	      is);	  classifier = (EnsembleSelection) KOML.read(xmlInputStream);	  xmlInputStream.close();	}		String workingDir = Utils.getOption('W', argv);	if (!workingDir.equals("")) {	  classifier.setWorkingDirectory(new File(workingDir));	}		classifier.setDebug(Utils.getFlag('D', argv));	classifier.setVerboseOutput(Utils.getFlag('O', argv));		classifier.cachePredictions(test);		// Now we write the model back out to the file system.	String objectOutputFileName = objectInputFileName;	OutputStream os = new FileOutputStream(objectOutputFileName);	// binary	if (!(objectOutputFileName.endsWith(".xml") || (objectOutputFileName	    .endsWith(".koml") && KOML.isPresent()))) {	  if (objectOutputFileName.endsWith(".gz")) {	    os = new GZIPOutputStream(os);	  }	  ObjectOutputStream objectOutputStream = new ObjectOutputStream(	      os);	  objectOutputStream.writeObject(classifier);	  objectOutputStream.flush();	  objectOutputStream.close();	}	// KOML/XML	else {	  BufferedOutputStream xmlOutputStream = new BufferedOutputStream(	      os);	  if (objectOutputFileName.endsWith(".xml")) {	    XMLSerialization xmlSerial = new XMLClassifier();	    xmlSerial.write(xmlOutputStream, classifier);	  } else	    // whether KOML is present has already been checked	    // if not present -> ".koml" is interpreted as binary - see	    // above	    if (objectOutputFileName.endsWith(".koml")) {	      KOML.write(xmlOutputStream, classifier);	    }	  xmlOutputStream.close();	}	      }            System.out.println(Evaluation.evaluateModel(	  new EnsembleSelection(), argv));          } catch (Exception e) {      if (    (e.getMessage() != null) 	   && (e.getMessage().indexOf("General options") == -1) )	e.printStackTrace();      else	System.err.println(e.getMessage());    }  }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -