📄 ensembleselectionlibrary.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * EnsembleSelectionLibrary.java * Copyright (C) 2006 Robert Jung * */package weka.classifiers.meta.ensembleSelection;import weka.classifiers.Classifier;import weka.classifiers.EnsembleLibrary;import weka.classifiers.EnsembleLibraryModel;import weka.classifiers.meta.EnsembleSelection;import weka.core.Instances;import java.beans.PropertyChangeListener;import java.beans.PropertyChangeSupport;import java.io.File;import java.io.FileWriter;import java.io.InputStream;import java.io.Serializable;import java.io.UnsupportedEncodingException;import java.text.DateFormat;import java.text.SimpleDateFormat;import java.util.Date;import java.util.HashSet;import java.util.Iterator;import java.util.Set;import java.util.TreeSet;import java.util.zip.Adler32;/** * This class represents an ensemble library. That is a * collection of models that will be combined via the * ensemble selection algorithm. This class is responsible for * tracking all of the unique model specifications in the current * library and trainined them when asked. There are also methods * to save/load library model list files. * * @author Robert Jung * @author David Michael * @version $Revision: 1.1 $ */public class EnsembleSelectionLibrary extends EnsembleLibrary implements Serializable { /** for serialization */ private static final long serialVersionUID = -6444026512552917835L; /** the working ensemble library directory. */ private File m_workingDirectory; /** tha name of the model list file storing the list of * models currently being used by the model library */ private String m_modelListFile = null; /** the training data used to build the library. One per fold.*/ private Instances[] m_trainingData; /** the test data used for hillclimbing. One per fold. */ private Instances[] m_hillclimbData; /** the predictions of each model. Built by trainAll. First index is * for the model. Second is for the instance. third is for the class * (we use distributionForInstance). */ private double[][][] m_predictions; /** the random seed used to partition the training data into * validation and training folds */ private int m_seed; /** the number of folds */ private int m_folds; /** the ratio of validation data used to train the model */ private double m_validationRatio; /** A helper class for notifying listeners when working directory changes */ private transient PropertyChangeSupport m_workingDirectoryPropertySupport = new PropertyChangeSupport(this); /** Whether we should print debug messages. */ public transient boolean m_Debug = true; /** * Creates a default libary. Library should be associated with * */ public EnsembleSelectionLibrary() { super(); m_workingDirectory = new File(EnsembleSelection.getDefaultWorkingDirectory()); } /** * Creates a default libary. Library should be associated with * a working directory * * @param dir the working directory form the ensemble library * @param seed the seed value * @param folds the number of folds * @param validationRatio the ratio to use */ public EnsembleSelectionLibrary(String dir, int seed, int folds, double validationRatio) { super(); if (dir != null) m_workingDirectory = new File(dir); m_seed = seed; m_folds = folds; m_validationRatio = validationRatio; } /** * This constructor will create a library from a model * list file given by the file name argument * * @param libraryFileName the library filename */ public EnsembleSelectionLibrary(String libraryFileName) { super(); File libraryFile = new File(libraryFileName); try { EnsembleLibrary.loadLibrary(libraryFile, this); } catch (Exception e) { System.err.println("Could not load specified library file: "+libraryFileName); } } /** * This constructor will create a library from the given XML stream. * * @param stream the XML library stream */ public EnsembleSelectionLibrary(InputStream stream) { super(); try { EnsembleLibrary.loadLibrary(stream, this); } catch (Exception e) { System.err.println("Could not load library from XML stream: " + e); } } /** * Set debug flag for the library and all its models. The debug flag * determines whether we print debugging information to stdout. * * @param debug if true debug mode is on */ public void setDebug(boolean debug) { m_Debug = debug; Iterator it = getModels().iterator(); while (it.hasNext()) { ((EnsembleSelectionLibraryModel)it.next()).setDebug(m_Debug); } } /** * Sets the validation-set ratio. This is the portion of the * training set that is set aside for hillclimbing. Note that * this value is ignored if we are doing cross-validation * (indicated by the number of folds being > 1). * * @param validationRatio the new ratio */ public void setValidationRatio(double validationRatio) { m_validationRatio = validationRatio; } /** * Set the number of folds for cross validation. If the number * of folds is > 1, the validation ratio is ignored. * * @param numFolds the number of folds to use */ public void setNumFolds(int numFolds) { m_folds = numFolds; } /** * This method will iterate through the TreeMap of models and * train all models that do not currently exist (are not * yet trained). * <p/> * Returns the data set which should be used for hillclimbing. * <p/> * If training a model fails then an error will * be sent to stdout and that model will be removed from the * TreeMap. FIXME Should we maybe raise an exception instead? * * @param data the data to work on * @param directory the working directory * @param algorithm the type of algorithm * @return the data that should be used for hillclimbing * @throws Exception if something goes wrong */ public Instances trainAll(Instances data, String directory, int algorithm) throws Exception { createWorkingDirectory(directory); //craete the directory if it doesn't already exist String dataDirectoryName = getDataDirectoryName(data); File dataDirectory = new File(directory, dataDirectoryName); if (!dataDirectory.exists()) { dataDirectory.mkdirs(); } //Now create a record of all the models trained. This will be a .mlf //flat file with a file name based on the time/date of training //DateFormat formatter = new SimpleDateFormat("yyyy.MM.dd.HH.mm"); //String dateString = formatter.format(new Date()); //Go ahead and save in both formats just in case: DateFormat formatter = new SimpleDateFormat("yyyy.MM.dd.HH.mm"); String modelListFileName = formatter.format(new Date())+"_"+size()+"_models.mlf"; //String modelListFileName = dataDirectory.getName()+".mlf"; File modelListFile = new File(dataDirectory.getPath(), modelListFileName); EnsembleLibrary.saveLibrary(modelListFile, this, null); //modelListFileName = dataDirectory.getName()+".model.xml"; modelListFileName = formatter.format(new Date())+"_"+size()+"_models.model.xml"; modelListFile = new File(dataDirectory.getPath(), modelListFileName); EnsembleLibrary.saveLibrary(modelListFile, this, null); //log the instances used just in case we need to know... String arf = data.toString(); FileWriter f = new FileWriter(new File(dataDirectory.getPath(), dataDirectory.getName()+".arff")); f.write(arf); f.close(); // m_trainingData will contain the datasets used for training models for each fold. m_trainingData = new Instances[m_folds]; // m_hillclimbData will contain the dataset which we will use for hillclimbing - // m_hillclimbData[i] should be disjoint from m_trainingData[i]. m_hillclimbData = new Instances[m_folds]; // validationSet is all of the hillclimbing data from all folds, in the same // order as it is in m_hillclimbData Instances validationSet; if (m_folds > 1) { validationSet = new Instances(data, data.numInstances()); //make a new set //with the same capacity and header as data. //instances may come from CV functions in //different order, so we'll make sure the //validation set's order matches that of //the concatenated testCV sets for (int i=0; i < m_folds; ++i) { m_trainingData[i] = data.trainCV(m_folds, i); m_hillclimbData[i] = data.testCV(m_folds, i); } // If we're doing "embedded CV" we can hillclimb on // the entire training set, so we just put all of the hillclimbData // from all folds in to validationSet (making sure it's in the appropriate // order). for (int i=0; i < m_folds; ++i) { for (int j=0; j < m_hillclimbData[i].numInstances(); ++j) { validationSet.add(m_hillclimbData[i].instance(j)); } } } else { // Otherwise, we're not doing CV, we're just using a validation set. // Partition the data set in to a training set and a hillclimb set // based on the m_validationRatio. int validation_size = (int)(data.numInstances() * m_validationRatio); m_trainingData[0] = new Instances(data, 0, data.numInstances() - validation_size); m_hillclimbData[0] = new Instances(data, data.numInstances() - validation_size, validation_size); validationSet = m_hillclimbData[0];
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -