⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 ensembleselectionlibrarymodel.java

📁 代码是一个分类器的实现,其中使用了部分weka的源代码。可以将项目导入eclipse运行
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
/* *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* *    EnsembleSelection.java *    Copyright (C) 2006 David Michael * */package weka.classifiers.meta.ensembleSelection;import weka.classifiers.Classifier;import weka.classifiers.EnsembleLibraryModel;import weka.core.Instance;import weka.core.Instances;import weka.core.Utils;import java.io.File;import java.io.FileInputStream;import java.io.FileOutputStream;import java.io.IOException;import java.io.ObjectInputStream;import java.io.ObjectOutput;import java.io.ObjectOutputStream;import java.io.Serializable;import java.io.UnsupportedEncodingException;import java.util.Date;import java.util.zip.Adler32;/** * This class represents a library model that is used for EnsembleSelection. At * this level the concept of cross validation is abstracted away. This class * keeps track of the performance statistics and bookkeeping information for its * "model type" accross all the CV folds. By "model type", I mean the * combination of both the Classifier type (e.g. J48), and its set of parameters * (e.g. -C 0.5 -X 1 -Y 5). So for example, if you are using 5 fold cross * validaiton, this model will keep an array of classifiers[] of length 5 and * will keep track of their performances accordingly. This class also has * methods to deal with serializing all of this information into the .elm file * that will represent this model. * <p/> * Also it is worth mentioning that another important function of this class is * to track all of the dataset information that was used to create this model. * This is because we want to protect users from doing foreseeably bad things. * e.g., trying to build an ensemble for a dataset with models that were trained * on the wrong partitioning of the dataset. This could lead to artificially high * performance due to the fact that instances used for the test set to gauge * performance could have accidentally been used to train the base classifiers. * So in a nutshell, we are preventing people from unintentionally "cheating" by * enforcing that the seed, #folds, validation ration, and the checksum of the  * Instances.toString() method ALL match exactly.  Otherwise we throw an  * exception. *  * @author  Robert Jung (mrbobjung@gmail.com) * @version $Revision: 1.1 $  */public class EnsembleSelectionLibraryModel  extends EnsembleLibraryModel  implements Serializable {    /**   * This is the serialVersionUID that SHOULD stay the same so that future   * modified versions of this class will be backwards compatible with older   * model versions.   */  private static final long serialVersionUID = -6426075459862947640L;    /** The default file extension for ensemble library models */  public static final String FILE_EXTENSION = ".elm";    /** the models */  private Classifier[] m_models = null;    /** The seed that was used to create this model */  private int m_seed;    /**   * The checksum of the instances.arff object that was used to create this   * model   */  private String m_checksum;    /** The validation ratio that was used to create this model */  private double m_validationRatio;    /**   * The number of folds, or number of CV models that was used to create this   * "model"   */  private int m_folds;    /**   * The .elm file name that this model should be saved/loaded to/from   */  private String m_fileName;    /**   * The debug flag as propagated from the main EnsembleSelection class.   */  public transient boolean m_Debug = true;    /**   * the validation predictions of this model. First index for the instance.   * third is for the class (we use distributionForInstance).   */  private double[][] m_validationPredictions = null; // = new double[0][0];    /**   * Default Constructor   */  public EnsembleSelectionLibraryModel() {  }    /**   * Constructor for LibaryModel   *    * @param classifier		the classifier to use   * @param seed		the random seed value   * @param checksum		the checksum   * @param validationRatio	the ration to use   * @param folds		the number of folds to use   */  public EnsembleSelectionLibraryModel(Classifier classifier, int seed,      String checksum, double validationRatio, int folds) {        super(classifier);        m_seed = seed;    m_checksum = checksum;    m_validationRatio = validationRatio;    m_models = null;    m_folds = folds;  }    /**   * This is used to propagate the m_Debug flag of the EnsembleSelection   * classifier to this class. There are things we would want to print out   * here also.   *    * @param debug	if true additional information is output   */  public void setDebug(boolean debug) {    m_Debug = debug;  }    /**   * Returns the average of the prediction of the models across all folds.   *    * @param instance	the instance to get predictions for   * @return		the average prediction   * @throws Exception	if something goes wrong   */  public double[] getAveragePrediction(Instance instance) throws Exception {        // Return the average prediction from all classifiers that make up    // this model.    double average[] = new double[instance.numClasses()];    for (int i = 0; i < m_folds; ++i) {      // Some models alter the instance (MultiLayerPerceptron), so we need      // to copy it.      Instance temp_instance = (Instance) instance.copy();      double[] pred = getFoldPrediction(temp_instance, i);      if (pred == null) {	// Some models have bugs whereby they can return a null	// prediction	// array (again, MultiLayerPerceptron). We return null, and this	// should be handled above in EnsembleSelection.	System.err.println("Null validation predictions given: "	    + getStringRepresentation());	return null;      }      if (i == 0) {	// The first time through the loop, just use the first returned	// prediction array. Just a simple optimization.	average = pred;      } else {	// For the rest, add the prediction to the average array.	for (int j = 0; j < pred.length; ++j) {	  average[j] += pred[j];	}      }    }    if (instance.classAttribute().isNominal()) {      // Normalize predictions for classes to add up to 1.      Utils.normalize(average);    } else {      average[0] /= m_folds;    }    return average;  }    /**   * Basic Constructor   *    * @param classifier	the classifier to use   */  public EnsembleSelectionLibraryModel(Classifier classifier) {    super(classifier);  }    /**   * Returns prediction of the classifier for the specified fold.   *    * @param instance   *            instance for which to make a prediction.   * @param fold   *            fold number of the classifier to use.   * @return the prediction for the classes   * @throws Exception if prediction fails   */  public double[] getFoldPrediction(Instance instance, int fold)    throws Exception {        return m_models[fold].distributionForInstance(instance);  }    /**   * Creates the model. If there are n folds, it constructs n classifiers   * using the current Classifier class and options. If the model has already   * been created or loaded, starts fresh.   *    * @param data		the data to work with   * @param hillclimbData	the data for hillclimbing   * @param dataDirectoryName	the directory to use   * @param algorithm		the type of algorithm   * @throws Exception		if something goeds wrong   */  public void createModel(Instances[] data, Instances[] hillclimbData,      String dataDirectoryName, int algorithm) throws Exception {        String modelFileName = getFileName(getStringRepresentation());        File modelFile = new File(dataDirectoryName, modelFileName);        String relativePath = (new File(dataDirectoryName)).getName()    + File.separatorChar + modelFileName;    // if (m_Debug) System.out.println("setting relative path to:    // "+relativePath);    setFileName(relativePath);        if (!modelFile.exists()) {            Date startTime = new Date();            String lockFileName = EnsembleSelectionLibraryModel      .getFileName(getStringRepresentation());      lockFileName = lockFileName.substring(0, lockFileName.length() - 3)      + "LCK";      File lockFile = new File(dataDirectoryName, lockFileName);            if (lockFile.exists()) {	if (m_Debug)	  System.out.println("Detected lock file.  Skipping: "	      + lockFileName);	throw new Exception("Lock File Detected: " + lockFile.getName());	      } else { // if (algorithm ==	// EnsembleSelection.ALGORITHM_BUILD_LIBRARY) {	// This lock file lets other computers that might be sharing the	// same file	// system that this model is already being trained so they know	// to move ahead	// and train other models.		if (lockFile.createNewFile()) {	  	  if (m_Debug)	    System.out	    .println("lock file created: " + lockFileName);	  	  if (m_Debug)	    System.out.println("Creating model in locked mode: "		+ modelFile.getPath());	  	  m_models = new Classifier[m_folds];	  for (int i = 0; i < m_folds; ++i) {	    	    try {	      m_models[i] = Classifier.forName(getModelClass()		  .getName(), null);	      m_models[i].setOptions(getOptions());	    } catch (Exception e) {	      throw new Exception("Invalid Options: "		  + e.getMessage());	    }	  }	  	  try {	    for (int i = 0; i < m_folds; ++i) {	      train(data[i], i);	    }	  } catch (Exception e) {	    throw new Exception("Could not Train: "		+ e.getMessage());	  }	  	  Date endTime = new Date();	  int diff = (int) (endTime.getTime() - startTime.getTime());	  	  // We don't need the actual model for hillclimbing. To save	  // memory, release	  // it.	  	  // if (!invalidModels.contains(model)) {	  // EnsembleLibraryModel.saveModel(dataDirectory.getPath(),	  // model);	  // model.releaseModel();	  // }	  if (m_Debug)	    System.out.println("Train time for " + modelFileName		+ " was: " + diff);	  	  if (m_Debug)	    System.out	    .println("Generating validation set predictions");	  	  startTime = new Date();	  	  int total = 0;	  for (int i = 0; i < m_folds; ++i) {	    total += hillclimbData[i].numInstances();	  }	  	  m_validationPredictions = new double[total][];	  	  int preds_index = 0;	  for (int i = 0; i < m_folds; ++i) {	    for (int j = 0; j < hillclimbData[i].numInstances(); ++j) {	      Instance temp = (Instance) hillclimbData[i]	                                               .instance(j).copy();// new	      // Instance(m_hillclimbData[i].instance(j));	      // must copy the instance because SOME classifiers	      // (I'm not pointing fingers...	      // MULTILAYERPERCEPTRON)	      // change the instance!	      	      m_validationPredictions[preds_index] = getFoldPrediction(		  temp, i);	      	      if (m_validationPredictions[preds_index] == null) {		throw new Exception(		    "Null validation predictions given: "		    + getStringRepresentation());	      }	      	      ++preds_index;	    }	  }	  	  endTime = new Date();	  diff = (int) (endTime.getTime() - startTime.getTime());	  	  // if (m_Debug) System.out.println("Generated a validation	  // set array of size: "+m_validationPredictions.length);	  if (m_Debug)	    System.out	    .println("Time to create validation predictions was: "		+ diff);	  	  EnsembleSelectionLibraryModel.saveModel(dataDirectoryName,	      this);	  	  if (m_Debug)	    System.out.println("deleting lock file: "		+ lockFileName);	  lockFile.delete();	  	} else {	  	  if (m_Debug)	    System.out	    .println("Could not create lock file.  Skipping: "		+ lockFileName);	  throw new Exception(	      "Could not create lock file.  Skipping: "	      + lockFile.getName());	  	}	      }          } else {      // This branch is responsible for loading a model from a .elm file            if (m_Debug)	System.out.println("Loading model: " + modelFile.getPath());      // now we need to check to see if the model is valid, if so then      // load it      Date startTime = new Date();            EnsembleSelectionLibraryModel newModel = loadModel(modelFile	  .getPath());            if (!newModel.getStringRepresentation().equals(	  getStringRepresentation()))	throw new EnsembleModelMismatchException(	    "String representations "	    + newModel.getStringRepresentation() + " and "	    + getStringRepresentation() + " not equal");            if (!newModel.getChecksum().equals(getChecksum()))	throw new EnsembleModelMismatchException("Checksums "	    + newModel.getChecksum() + " and " + getChecksum()	    + " not equal");            if (newModel.getSeed() != getSeed())

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -