⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 activefeatureacquisitioncvresultproducer.java

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
/* *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* *    ActiveFeatureAcquisitionCVResultProducer.java *    Copyright (C) 2003 Prem Melville * */package weka.experiment;import java.util.*;import java.io.*;import weka.classifiers.*;import weka.core.*;/** * Does an N-fold cross-validation, but generates a learning curve by * also varying the number of training examples. Creates a split that * uses increasingly larger fractions of the full training set from * the N fold but always using the same N-fold test set for * testing. If this is applied to an active learner, then the training * examples are selected actively by the learner from the pool of * unlabeled examples. If this is not used with an active learner, it * should produce the same results as the * LearningCurveCrossValidationResultProducer. * * @author Prem Melville (melville@cs.utexas.edu)  */public class ActiveFeatureAcquisitionCVResultProducer     implements ResultProducer, OptionHandler, AdditionalMeasureProducer {        /** Indices for range of non-local features (to be acquired) */    protected int m_NonLocalStartIndex = 15;     protected int m_NonLocalEndIndex = 39;        /** Ablational level */    protected double m_AblationLevel = -1;        /** Select features for ablation randomly */    protected boolean m_RandomAblation = false;        /** The dataset of interest */    protected Instances m_Instances;    /** The ResultListener to send results to */    protected ResultListener m_ResultListener = new CSVResultListener();    /** The number of folds in the cross-validation */    protected int m_NumFolds = 10;    /** Save raw output of split evaluators --- for debugging purposes */    protected boolean m_debugOutput = false;    /** The output zipper to use for saving raw splitEvaluator output */    protected OutputZipper m_ZipDest = null;    /** The destination output file/directory for raw output */    protected File m_OutputFile = new File(					   new File(System.getProperty("user.dir")), 					   "splitEvalutorOut.zip");    /** The SplitEvaluator used to generate results */    protected SplitEvaluator m_SplitEvaluator = new ClassifierSplitEvaluator();    /** The names of any additional measures to look for in SplitEvaluators */    protected String [] m_AdditionalMeasures = null;    /**      * The minimum number of instances to use. If this is zero, the first     * step will contain m_StepSize instances      */    protected int m_LowerSize = 0;      /**     * The maximum number of instances to use. -1 indicates no maximum      * (other than the total number of instances)     */    protected int m_UpperSize = -1;    /** The number of instances to add at each step */    protected int m_StepSize = 1;    /** The specific points to plot, either integers representing specific numbers of training examples,     * or decimal fractions representing percentages of the full training set*/    protected double[] m_PlotPoints;    /** The current dataset size during stepping */    protected int m_CurrentSize = 0;    /* The name of the key field containing the dataset name */    public static String DATASET_FIELD_NAME = "Dataset";    /* The name of the key field containing the run number */    public static String RUN_FIELD_NAME = "Run";    /* The name of the key field containing the fold number */    public static String FOLD_FIELD_NAME = "Fold";    /* The name of the result field containing the timestamp */    public static String TIMESTAMP_FIELD_NAME = "Date_time";    /* The name of the key field containing the learning rate step number */    public static String STEP_FIELD_NAME = "Total_instances";    /* The name of the key field containing the fraction of total instances used */    public static String FRACTION_FIELD_NAME = "Fraction_instances";    /* Indicates whether fractions or actual number of instances have been specified */    protected boolean m_IsFraction = false;        /**     * Returns a string describing this result producer     * @return a description of the result producer suitable for     * displaying in the explorer/experimenter gui     */    public String globalInfo() {	return "Performs a learning-curve cross validation run using a supplied "	    +"split evaluator. Trains on increasing subsets of the training data for each split, "	    +"repeatedly testing on the test set for that split after training on subsets of various sizes.";    }            /**     * Get the value of m_RandomAblation.     * @return value of m_RandomAblation.     */    public boolean getRandomAblation() {	return m_RandomAblation;    }        /**     * Set the value of m_RandomAblation.     * @param v  Value to assign to m_RandomAblation.     */    public void setRandomAblation(boolean  v) {	m_RandomAblation = v;    }        public void setAblationLevel(double abLevel){	m_AblationLevel = abLevel;    }    public double getAblationLevel(){	return m_AblationLevel;    }        /**     * Sets the dataset that results will be obtained for.     *     * @param instances a value of type 'Instances'.     */    public void setInstances(Instances instances) {    	m_Instances = instances;    }    /**     * Sets the object to send results of each run to.     *     * @param listener a value of type 'ResultListener'     */    public void setResultListener(ResultListener listener) {	m_ResultListener = listener;    }        /**     * Set a list of method names for additional measures to look for     * in SplitEvaluators. This could contain many measures (of which only a     * subset may be produceable by the current SplitEvaluator) if an experiment     * is the type that iterates over a set of properties.     * @param additionalMeasures an array of measure names, null if none     */    public void setAdditionalMeasures(String [] additionalMeasures) {	m_AdditionalMeasures = additionalMeasures;	if (m_SplitEvaluator != null) {	    System.err.println("LearningCurveCrossValidationResultProducer: setting additional "			       +"measures for "			       +"split evaluator");	    m_SplitEvaluator.setAdditionalMeasures(m_AdditionalMeasures);	}    }    /**     * Returns an enumeration of any additional measure names that might be     * in the SplitEvaluator     * @return an enumeration of the measure names     */    public Enumeration enumerateMeasures() {	Vector newVector = new Vector();	if (m_SplitEvaluator instanceof AdditionalMeasureProducer) {	    Enumeration en = ((AdditionalMeasureProducer)m_SplitEvaluator).		enumerateMeasures();	    while (en.hasMoreElements()) {		String mname = (String)en.nextElement();		newVector.addElement(mname);	    }	}	return newVector.elements();    }      /**     * Returns the value of the named measure     * @param measureName the name of the measure to query for its value     * @return the value of the named measure     * @exception IllegalArgumentException if the named measure is not supported     */    public double getMeasure(String additionalMeasureName) {	if (m_SplitEvaluator instanceof AdditionalMeasureProducer) {	    return ((AdditionalMeasureProducer)m_SplitEvaluator).		getMeasure(additionalMeasureName);	} else {	    throw new IllegalArgumentException("LearningCurveCrossValidationResultProducer: "					       +"Can't return value for : "+additionalMeasureName					       +". "+m_SplitEvaluator.getClass().getName()+" "					       +"is not an AdditionalMeasureProducer");	}    }      /**     * Gets a Double representing the current date and time.     * eg: 1:46pm on 20/5/1999 -> 19990520.1346     *     * @return a value of type Double     */    public static Double getTimestamp() {	Calendar now = Calendar.getInstance(TimeZone.getTimeZone("UTC"));	double timestamp = now.get(Calendar.YEAR) * 10000	    + (now.get(Calendar.MONTH) + 1) * 100	    + now.get(Calendar.DAY_OF_MONTH)	    + now.get(Calendar.HOUR_OF_DAY) / 100.0	    + now.get(Calendar.MINUTE) / 10000.0;	return new Double(timestamp);    }      /**     * Prepare to generate results.     *     * @exception Exception if an error occurs during preprocessing.     */    public void preProcess() throws Exception {	if (m_SplitEvaluator == null) {	    throw new Exception("No SplitEvalutor set");	}	if (m_ResultListener == null) {	    throw new Exception("No ResultListener set");	}	m_ResultListener.preProcess(this);    }      /**     * Perform any postprocessing. When this method is called, it indicates     * that no more requests to generate results for the current experiment     * will be sent.     *     * @exception Exception if an error occurs     */    public void postProcess() throws Exception {	m_ResultListener.postProcess(this);	if (m_debugOutput) {	    if (m_ZipDest != null) {		m_ZipDest.finished();		m_ZipDest = null;	    }	}    }      /**     * Gets the keys for a specified run number. Different run     * numbers correspond to different randomizations of the data. Keys     * produced should be sent to the current ResultListener     *     * @param run the run number to get keys for.     * @exception Exception if a problem occurs while getting the keys     */    public void doRunKeys(int run) throws Exception {	int numExtraKeys;	if(m_IsFraction)	    numExtraKeys = 5;	else numExtraKeys = 4;		if (m_Instances == null) {	    throw new Exception("No Instances set");	}	if (m_ResultListener == null) {	    throw new Exception("No ResultListener set");	}	for (int fold = 0; fold < m_NumFolds; fold++) {	    int pointNum = 0;	    // For each subsample size	    if (m_PlotPoints != null) {		m_CurrentSize = plotPoint(0);	    }	    else if (m_LowerSize == 0) {		m_CurrentSize = m_StepSize;	    } else {		m_CurrentSize = m_LowerSize;	    }	    while (m_CurrentSize <= maxTrainSize()) {		// Add in some fields to the key like run and fold number, dataset name		Object [] seKey = m_SplitEvaluator.getKey();		Object [] key = new Object [seKey.length + numExtraKeys];		key[0] = Utils.backQuoteChars(m_Instances.relationName());		key[1] = "" + run;		key[2] = "" + (fold + 1);		key[3] = "" + m_CurrentSize;		if(m_IsFraction) key[4] = "" + m_PlotPoints[pointNum];		System.arraycopy(seKey, 0, key, numExtraKeys, seKey.length);		if (m_ResultListener.isResultRequired(this, key)) {		    try {			m_ResultListener.acceptResult(this, key, null);		    } catch (Exception ex) {			// Save the train and test datasets for debugging purposes?			throw ex;		    }		}		if (m_PlotPoints != null) {		    pointNum ++;		    m_CurrentSize = plotPoint(pointNum);		}		else {		    m_CurrentSize += m_StepSize;		}	    }	}    }    /**      * Get the maximum size of the training set based on  upperSize limit     * or maximum training set size from the n-fold CV      */    protected int maxTrainSize() {	if (m_UpperSize == -1 || m_PlotPoints != null)	    return (int)(m_Instances.numInstances()*(1 - 1/((double)m_NumFolds)));	else return m_UpperSize;    }    /**     * Gets the results for a specified run number. Different run     * numbers correspond to different randomizations of the data. Results     * produced should be sent to the current ResultListener     *     * @param run the run number to get results for.     * @exception Exception if a problem occurs while getting the results     */    public void doRun(int run) throws Exception {	int numExtraKeys;	if(m_IsFraction)	    numExtraKeys = 5;	else numExtraKeys = 4;		if (getRawOutput()) {	    if (m_ZipDest == null) {		m_ZipDest = new OutputZipper(m_OutputFile);	    }	}	if (m_Instances == null) {	    throw new Exception("No Instances set");	}	if (m_ResultListener == null) {	    throw new Exception("No ResultListener set");	}	if(m_AblationLevel>=0 && !m_RandomAblation){	    int numAtts = m_Instances.numAttributes()-1;	    m_NonLocalStartIndex = (int) (m_AblationLevel * numAtts);	    if(m_NonLocalStartIndex<1) m_NonLocalStartIndex=1;//atleast one feature should be available to begin with	    m_NonLocalEndIndex = numAtts-1;//exclude the class attribute	    assert (numAtts==m_Instances.classIndex());	}	// Randomize on a copy of the original dataset	Instances runInstances = new Instances(m_Instances);	runInstances.randomize(new Random(run));	if (runInstances.classAttribute().isNominal()) {	    runInstances.stratify(m_NumFolds);	}	for (int fold = 0; fold < m_NumFolds; fold++) {//For each fold	    Instances fullTrain = runInstances.trainCV(m_NumFolds, fold);	    // Randomly shuffle stratified training set for fold: added by Sugato	    fullTrain.randomize(new Random(fold));	    	    Instances global = new Instances(fullTrain,0);//only global instances	    Instances local = new Instances(fullTrain);//only local instance	    HashMap localToGlobal = new HashMap();//map local instances to global instances	    //only copy header and allocate space for training set	    Instances train=new Instances(fullTrain,fullTrain.numInstances());	    boolean firstPoint = true;	    int prevSize = 0;	    Instances test = runInstances.testCV(m_NumFolds, fold);	    int pointNum = 0;	    	    ablateFeatures(local, run*100+fold);//replace non local features with missing	    //associate local instances with global instances	    for(int index=0; index<fullTrain.numInstances(); index++){		localToGlobal.put(local.instance(index),fullTrain.instance(index));	    }	    	    // For each subsample size	    if (m_PlotPoints != null) {		m_CurrentSize = plotPoint(0);	    }	    else if (m_LowerSize == 0) {		m_CurrentSize = m_StepSize;	    } else {

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -