⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 budgetedlearningcurveresultproducer.java

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
    public void doRun(int run) throws Exception {	int numExtraKeys;	if(m_IsFraction)	    numExtraKeys = 5;	else numExtraKeys = 4;		if (getRawOutput()) {	    if (m_ZipDest == null) {		m_ZipDest = new OutputZipper(m_OutputFile);	    }	}	if (m_Instances == null) {	    throw new Exception("No Instances set");	}	if (m_ResultListener == null) {	    throw new Exception("No ResultListener set");	}	//Make local copy of data	//Initialize dataset - based on initial ablation level	//Run classifier	//Get queries	//Update dataset with result of query		if(m_SplitEvaluator instanceof FeatureCostSensitiveClassifierSplitEvaluator){	    ((FeatureCostSensitiveClassifierSplitEvaluator)m_SplitEvaluator).setFeatureCosts(m_FeatureCosts);	}		int numFeatures = numFeatures();	// Randomize on a copy of the original dataset	Instances runInstances = new Instances(m_Instances);	runInstances.randomize(new Random(run));	if (runInstances.classAttribute().isNominal()) {	    runInstances.stratify(m_NumFolds);	}	for (int fold = 0; fold < m_NumFolds; fold++) {//For each fold	    m_Cost = 0.0;//initialize cost counter	    Instances fullTrain = runInstances.trainCV(m_NumFolds, fold);	    // Randomly shuffle stratified training set for fold	    fullTrain.randomize(new Random(fold));	    Instances train = initializeData(fullTrain);	    boolean [][]queryMatrix = new boolean[train.numInstances()][numFeatures];	    //initially queryMatrix is set to all false	    boolean firstPoint = true;	    int prevSize = 0;	    Instances test = runInstances.testCV(m_NumFolds, fold);	    int pointNum = 0;	    // For each subsample size	    if (m_PlotPoints != null) {		m_CurrentSize = plotPoint(0);	    }	    else if (m_LowerSize == 0) {		m_CurrentSize = stepSize(m_StepSize);	    } else {		m_CurrentSize = lowerSize(m_LowerSize);	    }	    	    //maxQueries should not exceed total number of queries possible in current fold	    int maxQueries = maxNumQueries();	    int trainSize = fullTrain.numInstances()*numFeatures();	    if(maxQueries > trainSize) maxQueries = trainSize;	    	    //System.out.println(train);	    	    while (m_CurrentSize <= maxQueries) {		long selectionTimeStart, selectionTimeElapsed;		// Add in some fields to the key like run and fold number, dataset name		Object [] seKey = m_SplitEvaluator.getKey();		Object [] key = new Object [seKey.length + numExtraKeys];		key[0] = Utils.backQuoteChars(m_Instances.relationName());		key[1] = "" + run;		key[2] = "" + (fold + 1);		key[3] = "" + m_CurrentSize;		if(m_IsFraction) key[4] = "" + m_PlotPoints[pointNum];		System.arraycopy(seKey, 0, key, numExtraKeys, seKey.length);		if (m_ResultListener.isResultRequired(this, key)) {		    try {			if(m_IsFraction)			    System.out.println("Run:" + run + " Fold:" + fold + " Size:" + m_CurrentSize + " Fraction:" + m_PlotPoints[pointNum]);			else			    System.out.println("Run:" + run + " Fold:" + fold + " Size:" + m_CurrentSize);						if(firstPoint){//the first training set is always randomly selected			    firstPoint = false;			    selectionTimeStart = System.currentTimeMillis();			    makeRandomQueries(train,fullTrain,m_CurrentSize - prevSize,queryMatrix,run*1000+fold);			    selectionTimeElapsed = System.currentTimeMillis() - selectionTimeStart;			}else{			    //use current classifier to actively select instance-feature pairs 			    //acquire features-values and add to the training set			    selectionTimeStart = System.currentTimeMillis(); 			    makeQueries(train,fullTrain,m_CurrentSize - prevSize,queryMatrix,run*1000+fold);			    selectionTimeElapsed = System.currentTimeMillis() - selectionTimeStart;			}						Object [] seResults = m_SplitEvaluator.getResult(train, test);			Object [] results = new Object [seResults.length + 3];			results[0] = getTimestamp();			results[1] = new Double(selectionTimeElapsed / 1000.0);			results[2] = new Double(m_Cost);			System.arraycopy(seResults, 0, results, 3,					 seResults.length);			if (m_debugOutput) {			    String resultName = (""+run+"."+(fold+1)+"."+ m_CurrentSize + "." 						 + Utils.backQuoteChars(runInstances.relationName())						 +"."						 +m_SplitEvaluator.toString()).replace(' ','_');			    resultName = Utils.removeSubstring(resultName, 							       "weka.classifiers.");			    resultName = Utils.removeSubstring(resultName, 							       "weka.filters.");			    resultName = Utils.removeSubstring(resultName, 							       "weka.attributeSelection.");			    m_ZipDest.zipit(m_SplitEvaluator.getRawResultOutput(), resultName);			}			m_ResultListener.acceptResult(this, key, results);		    } catch (Exception ex) {			// Save the train and test datasets for debugging purposes?			throw ex;		    }		}		prevSize = m_CurrentSize;		if (m_PlotPoints != null) {		    pointNum ++;		    m_CurrentSize = plotPoint(pointNum);		}		else {		    m_CurrentSize += stepSize(m_StepSize);		}				//System.out.println(train);	    }	}    }            //Create initial training set with only the class labels    protected Instances initializeData(Instances fulltrain){	int numInstances = fulltrain.numInstances();	int numAtts = fulltrain.numAttributes();	int classIndex = fulltrain.classIndex();	Instances train = new Instances(fulltrain, numInstances);	for(int i=0; i<numInstances; i++){	    Instance newInst = new Instance(numAtts);	    newInst.setValue(classIndex,(fulltrain.instance(i)).classValue());	    train.add(newInst);	}	return train;    }        /**     * Use current classifier to actively select specified number of instances-feature queries     * to be made. Update training set and the query matrix     *     * @param train instances with missing feature-values     * @param fulltrain instances with all feature-values     * @param num number of queries to make     * @param queryMatrix matrix to track available queries     * @param seed random seed needed if the learner is not a BudgetedLearner      */    protected void makeQueries(Instances train, Instances fulltrain, int num, boolean [][]queryMatrix, int seed) throws Exception{	Classifier classifier;	try{	    classifier = ((ClassifierSplitEvaluator)m_SplitEvaluator).getClassifier();	}catch (Exception ex){	    throw new Exception("Budgeted learning is only implemented for evaluators of classifiers.");	}		if(classifier instanceof BudgetedLearner){	    //get instance-feature pairs picked by the classifier	    Pair []queries = ((BudgetedLearner)classifier).selectInstancesForFeatures(train, num, queryMatrix);	    if(queries.length!=num) throw new Exception("Incorrect number of queries made!");	    transferQueries(train, fulltrain, queries, queryMatrix);	}else{//randomly pick examples from local pool	    makeRandomQueries(train, fulltrain, num, queryMatrix, seed);	}    }        //Randomly select queries to make    protected void makeRandomQueries(Instances train, Instances fulltrain, int num, boolean [][]queryMatrix, int seed) throws Exception{	int numInstances = fulltrain.numInstances();	int numFeatures = numFeatures();	//create a list of query pairs	ArrayList allQueries = new ArrayList();	for(int i=0; i<numInstances; i++)	    for(int j=0; j<numFeatures; j++)		if(!queryMatrix[i][j]) allQueries.add(new Pair(i,j)); 		//randomly select num queries	Pair []queries = new Pair[num];	Random random = new Random(seed);	System.out.println("Making random queries ...");	System.out.print(allQueries.size()+" - "+num+" = ");	for(int i=0; i<num; i++){	    int index = random.nextInt(allQueries.size());	    queries[i] = (Pair) allQueries.get(index);	    allQueries.remove(index);	}	System.out.println(allQueries.size());	transferQueries(train, fulltrain, queries, queryMatrix);    }        //Transfer requested feature-value to training set     protected void transferQueries(Instances train, Instances fulltrain, Pair []queries, boolean [][]queryMatrix){	int instanceIndex, featureIndex;	for(int i=0; i<queries.length; i++){	    instanceIndex = (int) queries[i].first;	    featureIndex = (int) queries[i].second;	    if(queryMatrix[instanceIndex][featureIndex]){		System.err.println("Query tracking failure!");	    }else{		(train.instance(instanceIndex)).setValue(featureIndex,(fulltrain.instance(instanceIndex)).value(featureIndex));		m_Cost += m_FeatureCosts[featureIndex];		System.out.println("Query for ("+instanceIndex+","+featureIndex+") = "+ (train.instance(instanceIndex)).value(featureIndex) + "\tCost = "+m_FeatureCosts[featureIndex]+"\tTotalCost = "+m_Cost);		queryMatrix[instanceIndex][featureIndex] = true;	    }	}    }        /**     * Add new instances to the given set of instances.     *     * @param data given instances     * @param newData set of instances to add to given instances     */    protected void addInstances(Instances data, Instances newData){	for(int i=0; i<newData.numInstances(); i++)	    data.add(newData.instance(i));    }        /** Determines if the points specified are fractions of the total number of examples */    protected boolean setIsFraction(){	if (m_PlotPoints != null){	    if(!isInteger(m_PlotPoints[0]))//if the first point is not an integer		m_IsFraction = true;	    else		m_IsFraction = false;	}//  else{//  	    if(!isInteger(m_StepSize))//if the step size is not an integer//  		m_IsFraction = true;//  	    else//  		m_IsFraction = false; //  	}	return m_IsFraction;    }        /** Return the number of instance-feature queries for the ith point on the     * curve for plotPoints as specified.     */    protected int plotPoint(int i) {	// If i beyond number of given plot points return a value greater than maximum queries	if (i >= m_PlotPoints.length)	    return maxNumQueries() + 1;	double point = m_PlotPoints[i];	// If plot point is an integer (other than a non-initial 1)	// treat it as a specific number of queries	if (isInteger(point) && !(Utils.eq(point, 1.0) && i!=0))	    return (int)point;	else	    // Otherwise, treat it as a percentage of the full set	    return (int)Math.round(point * maxNumQueries());    }    /** Return true if the given double represents an integer value */    protected static boolean isInteger(double val) {	return Utils.eq(Math.floor(val), Math.ceil(val));    }    /**     * Gets the names of each of the columns produced for a single run.     * This method should really be static.     *     * @return an array containing the name of each column     */    public String [] getKeyNames() {	String [] keyNames = m_SplitEvaluator.getKeyNames();	// Add in the names of our extra key fields	int numExtraKeys;	if(m_IsFraction)	    numExtraKeys = 5;	else numExtraKeys = 4;	String [] newKeyNames = new String [keyNames.length + numExtraKeys];	newKeyNames[0] = DATASET_FIELD_NAME;	newKeyNames[1] = RUN_FIELD_NAME;	newKeyNames[2] = FOLD_FIELD_NAME;	newKeyNames[3] = STEP_FIELD_NAME;	if(m_IsFraction) newKeyNames[4] = FRACTION_FIELD_NAME;	System.arraycopy(keyNames, 0, newKeyNames, numExtraKeys, keyNames.length);	return newKeyNames;    }    /**     * Gets the data types of each of the columns produced for a single run.     * This method should really be static.     *     * @return an array containing objects of the type of each column. The      * objects should be Strings, or Doubles.     */    public Object [] getKeyTypes() {	Object [] keyTypes = m_SplitEvaluator.getKeyTypes();	int numExtraKeys;	if(m_IsFraction)	    numExtraKeys = 5;	else numExtraKeys = 4;	// Add in the types of our extra fields	Object [] newKeyTypes = new String [keyTypes.length + numExtraKeys];	newKeyTypes[0] = new String();	newKeyTypes[1] = new String();	newKeyTypes[2] = new String();	newKeyTypes[3] = new String();	if(m_IsFraction) newKeyTypes[4] = new String();	System.arraycopy(keyTypes, 0, newKeyTypes, numExtraKeys, keyTypes.length);	return newKeyTypes;    }    /**     * Gets the names of each of the columns produced for a single run.     * This method should really be static.     *     * @return an array containing the name of each column     */    public String [] getResultNames() {	String [] resultNames = m_SplitEvaluator.getResultNames();	// Add in the names of our extra Result fields	String [] newResultNames = new String [resultNames.length + 3];	newResultNames[0] = TIMESTAMP_FIELD_NAME;	newResultNames[1] = SELECTION_TIME_FIELD_NAME;	newResultNames[2] = COST_FIELD_NAME;	System.arraycopy(resultNames, 0, newResultNames, 3, resultNames.length);	return newResultNames;    }    /**     * Gets the data types of each of the columns produced for a single run.     * This method should really be static.     *     * @return an array containing objects of the type of each column. The      * objects should be Strings, or Doubles.     */    public Object [] getResultTypes() {	Object [] resultTypes = m_SplitEvaluator.getResultTypes();	// Add in the types of our extra Result fields	Object [] newResultTypes = new Object [resultTypes.length + 3];	newResultTypes[0] = new Double(0);	newResultTypes[1] = new Double(0);	newResultTypes[2] = new Double(0);	System.arraycopy(resultTypes, 0, newResultTypes, 3, resultTypes.length);	return newResultTypes;    }    /**     * Gets a description of the internal settings of the result     * producer, sufficient for distinguishing a ResultProducer    * instance from another with different settings (ignoring     * those settings set through this interface). For example,     * a cross-validation ResultProducer may have a setting for the     * number of folds. For a given state, the results produced should     * be compatible. Typically if a ResultProducer is an OptionHandler,     * this string will represent the command line arguments required     * to set the ResultProducer to that state.     *     * @return the description of the ResultProducer state, or null     * if no state is defined     */    public String getCompatibilityState() {	String result = "-X " + m_NumFolds + " -S " + getStepSize() + 	    " -L " + getLowerSize() + " -U " + getUpperSize() + " ";	if (m_SplitEvaluator == null) {	    result += "<null SplitEvaluator>";	} else {	    result += "-W " + m_SplitEvaluator.getClass().getName();	}	return result + " --";    }    /**     * Returns the tip text for this property     * @return tip text for this property suitable for     * displaying in the explorer/experimenter gui     */    public String outputFileTipText() {	return "Set the destination for saving raw output. If the rawOutput "	    +"option is selected, then output from the splitEvaluator for "	    +"individual folds is saved. If the destination is a directory, "	    +"then each output is saved to an individual gzip file; if the "	    +"destination is a file, then each output is saved as an entry "	    +"in a zip file.";    }    /**     * Get the value of OutputFile.     *     * @return Value of OutputFile.     */    public File getOutputFile() {    	return m_OutputFile;    }      /**     * Set the value of OutputFile.     *     * @param newOutputFile Value to assign to OutputFile.     */    public void setOutputFile(File newOutputFile) {    	m_OutputFile = newOutputFile;    }      /**     * Returns the tip text for this property     * @return tip text for this property suitable for     * displaying in the explorer/experimenter gui     */    public String numFoldsTipText() {	return "Number of folds to use in cross validation.";    }    /**     * Get the value of NumFolds.     *     * @return Value of NumFolds.     */    public int getNumFolds() {    	return m_NumFolds;    }      /**     * Set the value of NumFolds.     *     * @param newNumFolds Value to assign to NumFolds.     */    public void setNumFolds(int newNumFolds) {    	m_NumFolds = newNumFolds;    }    /**     * Returns the tip text for this property     * @return tip text for this property suitable for     * displaying in the explorer/experimenter gui     */    public String lowerSizeTipText() {	return "Set the minimum number of instances in a training set. Setting zero "            + "here will actually use <stepSize> number of instances at the first "

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -