📄 budgetedlearningcurveresultproducer.java
字号:
public void doRun(int run) throws Exception { int numExtraKeys; if(m_IsFraction) numExtraKeys = 5; else numExtraKeys = 4; if (getRawOutput()) { if (m_ZipDest == null) { m_ZipDest = new OutputZipper(m_OutputFile); } } if (m_Instances == null) { throw new Exception("No Instances set"); } if (m_ResultListener == null) { throw new Exception("No ResultListener set"); } //Make local copy of data //Initialize dataset - based on initial ablation level //Run classifier //Get queries //Update dataset with result of query if(m_SplitEvaluator instanceof FeatureCostSensitiveClassifierSplitEvaluator){ ((FeatureCostSensitiveClassifierSplitEvaluator)m_SplitEvaluator).setFeatureCosts(m_FeatureCosts); } int numFeatures = numFeatures(); // Randomize on a copy of the original dataset Instances runInstances = new Instances(m_Instances); runInstances.randomize(new Random(run)); if (runInstances.classAttribute().isNominal()) { runInstances.stratify(m_NumFolds); } for (int fold = 0; fold < m_NumFolds; fold++) {//For each fold m_Cost = 0.0;//initialize cost counter Instances fullTrain = runInstances.trainCV(m_NumFolds, fold); // Randomly shuffle stratified training set for fold fullTrain.randomize(new Random(fold)); Instances train = initializeData(fullTrain); boolean [][]queryMatrix = new boolean[train.numInstances()][numFeatures]; //initially queryMatrix is set to all false boolean firstPoint = true; int prevSize = 0; Instances test = runInstances.testCV(m_NumFolds, fold); int pointNum = 0; // For each subsample size if (m_PlotPoints != null) { m_CurrentSize = plotPoint(0); } else if (m_LowerSize == 0) { m_CurrentSize = stepSize(m_StepSize); } else { m_CurrentSize = lowerSize(m_LowerSize); } //maxQueries should not exceed total number of queries possible in current fold int maxQueries = maxNumQueries(); int trainSize = fullTrain.numInstances()*numFeatures(); if(maxQueries > trainSize) maxQueries = trainSize; //System.out.println(train); while (m_CurrentSize <= maxQueries) { long selectionTimeStart, selectionTimeElapsed; // Add in some fields to the key like run and fold number, dataset name Object [] seKey = m_SplitEvaluator.getKey(); Object [] key = new Object [seKey.length + numExtraKeys]; key[0] = Utils.backQuoteChars(m_Instances.relationName()); key[1] = "" + run; key[2] = "" + (fold + 1); key[3] = "" + m_CurrentSize; if(m_IsFraction) key[4] = "" + m_PlotPoints[pointNum]; System.arraycopy(seKey, 0, key, numExtraKeys, seKey.length); if (m_ResultListener.isResultRequired(this, key)) { try { if(m_IsFraction) System.out.println("Run:" + run + " Fold:" + fold + " Size:" + m_CurrentSize + " Fraction:" + m_PlotPoints[pointNum]); else System.out.println("Run:" + run + " Fold:" + fold + " Size:" + m_CurrentSize); if(firstPoint){//the first training set is always randomly selected firstPoint = false; selectionTimeStart = System.currentTimeMillis(); makeRandomQueries(train,fullTrain,m_CurrentSize - prevSize,queryMatrix,run*1000+fold); selectionTimeElapsed = System.currentTimeMillis() - selectionTimeStart; }else{ //use current classifier to actively select instance-feature pairs //acquire features-values and add to the training set selectionTimeStart = System.currentTimeMillis(); makeQueries(train,fullTrain,m_CurrentSize - prevSize,queryMatrix,run*1000+fold); selectionTimeElapsed = System.currentTimeMillis() - selectionTimeStart; } Object [] seResults = m_SplitEvaluator.getResult(train, test); Object [] results = new Object [seResults.length + 3]; results[0] = getTimestamp(); results[1] = new Double(selectionTimeElapsed / 1000.0); results[2] = new Double(m_Cost); System.arraycopy(seResults, 0, results, 3, seResults.length); if (m_debugOutput) { String resultName = (""+run+"."+(fold+1)+"."+ m_CurrentSize + "." + Utils.backQuoteChars(runInstances.relationName()) +"." +m_SplitEvaluator.toString()).replace(' ','_'); resultName = Utils.removeSubstring(resultName, "weka.classifiers."); resultName = Utils.removeSubstring(resultName, "weka.filters."); resultName = Utils.removeSubstring(resultName, "weka.attributeSelection."); m_ZipDest.zipit(m_SplitEvaluator.getRawResultOutput(), resultName); } m_ResultListener.acceptResult(this, key, results); } catch (Exception ex) { // Save the train and test datasets for debugging purposes? throw ex; } } prevSize = m_CurrentSize; if (m_PlotPoints != null) { pointNum ++; m_CurrentSize = plotPoint(pointNum); } else { m_CurrentSize += stepSize(m_StepSize); } //System.out.println(train); } } } //Create initial training set with only the class labels protected Instances initializeData(Instances fulltrain){ int numInstances = fulltrain.numInstances(); int numAtts = fulltrain.numAttributes(); int classIndex = fulltrain.classIndex(); Instances train = new Instances(fulltrain, numInstances); for(int i=0; i<numInstances; i++){ Instance newInst = new Instance(numAtts); newInst.setValue(classIndex,(fulltrain.instance(i)).classValue()); train.add(newInst); } return train; } /** * Use current classifier to actively select specified number of instances-feature queries * to be made. Update training set and the query matrix * * @param train instances with missing feature-values * @param fulltrain instances with all feature-values * @param num number of queries to make * @param queryMatrix matrix to track available queries * @param seed random seed needed if the learner is not a BudgetedLearner */ protected void makeQueries(Instances train, Instances fulltrain, int num, boolean [][]queryMatrix, int seed) throws Exception{ Classifier classifier; try{ classifier = ((ClassifierSplitEvaluator)m_SplitEvaluator).getClassifier(); }catch (Exception ex){ throw new Exception("Budgeted learning is only implemented for evaluators of classifiers."); } if(classifier instanceof BudgetedLearner){ //get instance-feature pairs picked by the classifier Pair []queries = ((BudgetedLearner)classifier).selectInstancesForFeatures(train, num, queryMatrix); if(queries.length!=num) throw new Exception("Incorrect number of queries made!"); transferQueries(train, fulltrain, queries, queryMatrix); }else{//randomly pick examples from local pool makeRandomQueries(train, fulltrain, num, queryMatrix, seed); } } //Randomly select queries to make protected void makeRandomQueries(Instances train, Instances fulltrain, int num, boolean [][]queryMatrix, int seed) throws Exception{ int numInstances = fulltrain.numInstances(); int numFeatures = numFeatures(); //create a list of query pairs ArrayList allQueries = new ArrayList(); for(int i=0; i<numInstances; i++) for(int j=0; j<numFeatures; j++) if(!queryMatrix[i][j]) allQueries.add(new Pair(i,j)); //randomly select num queries Pair []queries = new Pair[num]; Random random = new Random(seed); System.out.println("Making random queries ..."); System.out.print(allQueries.size()+" - "+num+" = "); for(int i=0; i<num; i++){ int index = random.nextInt(allQueries.size()); queries[i] = (Pair) allQueries.get(index); allQueries.remove(index); } System.out.println(allQueries.size()); transferQueries(train, fulltrain, queries, queryMatrix); } //Transfer requested feature-value to training set protected void transferQueries(Instances train, Instances fulltrain, Pair []queries, boolean [][]queryMatrix){ int instanceIndex, featureIndex; for(int i=0; i<queries.length; i++){ instanceIndex = (int) queries[i].first; featureIndex = (int) queries[i].second; if(queryMatrix[instanceIndex][featureIndex]){ System.err.println("Query tracking failure!"); }else{ (train.instance(instanceIndex)).setValue(featureIndex,(fulltrain.instance(instanceIndex)).value(featureIndex)); m_Cost += m_FeatureCosts[featureIndex]; System.out.println("Query for ("+instanceIndex+","+featureIndex+") = "+ (train.instance(instanceIndex)).value(featureIndex) + "\tCost = "+m_FeatureCosts[featureIndex]+"\tTotalCost = "+m_Cost); queryMatrix[instanceIndex][featureIndex] = true; } } } /** * Add new instances to the given set of instances. * * @param data given instances * @param newData set of instances to add to given instances */ protected void addInstances(Instances data, Instances newData){ for(int i=0; i<newData.numInstances(); i++) data.add(newData.instance(i)); } /** Determines if the points specified are fractions of the total number of examples */ protected boolean setIsFraction(){ if (m_PlotPoints != null){ if(!isInteger(m_PlotPoints[0]))//if the first point is not an integer m_IsFraction = true; else m_IsFraction = false; }// else{// if(!isInteger(m_StepSize))//if the step size is not an integer// m_IsFraction = true;// else// m_IsFraction = false; // } return m_IsFraction; } /** Return the number of instance-feature queries for the ith point on the * curve for plotPoints as specified. */ protected int plotPoint(int i) { // If i beyond number of given plot points return a value greater than maximum queries if (i >= m_PlotPoints.length) return maxNumQueries() + 1; double point = m_PlotPoints[i]; // If plot point is an integer (other than a non-initial 1) // treat it as a specific number of queries if (isInteger(point) && !(Utils.eq(point, 1.0) && i!=0)) return (int)point; else // Otherwise, treat it as a percentage of the full set return (int)Math.round(point * maxNumQueries()); } /** Return true if the given double represents an integer value */ protected static boolean isInteger(double val) { return Utils.eq(Math.floor(val), Math.ceil(val)); } /** * Gets the names of each of the columns produced for a single run. * This method should really be static. * * @return an array containing the name of each column */ public String [] getKeyNames() { String [] keyNames = m_SplitEvaluator.getKeyNames(); // Add in the names of our extra key fields int numExtraKeys; if(m_IsFraction) numExtraKeys = 5; else numExtraKeys = 4; String [] newKeyNames = new String [keyNames.length + numExtraKeys]; newKeyNames[0] = DATASET_FIELD_NAME; newKeyNames[1] = RUN_FIELD_NAME; newKeyNames[2] = FOLD_FIELD_NAME; newKeyNames[3] = STEP_FIELD_NAME; if(m_IsFraction) newKeyNames[4] = FRACTION_FIELD_NAME; System.arraycopy(keyNames, 0, newKeyNames, numExtraKeys, keyNames.length); return newKeyNames; } /** * Gets the data types of each of the columns produced for a single run. * This method should really be static. * * @return an array containing objects of the type of each column. The * objects should be Strings, or Doubles. */ public Object [] getKeyTypes() { Object [] keyTypes = m_SplitEvaluator.getKeyTypes(); int numExtraKeys; if(m_IsFraction) numExtraKeys = 5; else numExtraKeys = 4; // Add in the types of our extra fields Object [] newKeyTypes = new String [keyTypes.length + numExtraKeys]; newKeyTypes[0] = new String(); newKeyTypes[1] = new String(); newKeyTypes[2] = new String(); newKeyTypes[3] = new String(); if(m_IsFraction) newKeyTypes[4] = new String(); System.arraycopy(keyTypes, 0, newKeyTypes, numExtraKeys, keyTypes.length); return newKeyTypes; } /** * Gets the names of each of the columns produced for a single run. * This method should really be static. * * @return an array containing the name of each column */ public String [] getResultNames() { String [] resultNames = m_SplitEvaluator.getResultNames(); // Add in the names of our extra Result fields String [] newResultNames = new String [resultNames.length + 3]; newResultNames[0] = TIMESTAMP_FIELD_NAME; newResultNames[1] = SELECTION_TIME_FIELD_NAME; newResultNames[2] = COST_FIELD_NAME; System.arraycopy(resultNames, 0, newResultNames, 3, resultNames.length); return newResultNames; } /** * Gets the data types of each of the columns produced for a single run. * This method should really be static. * * @return an array containing objects of the type of each column. The * objects should be Strings, or Doubles. */ public Object [] getResultTypes() { Object [] resultTypes = m_SplitEvaluator.getResultTypes(); // Add in the types of our extra Result fields Object [] newResultTypes = new Object [resultTypes.length + 3]; newResultTypes[0] = new Double(0); newResultTypes[1] = new Double(0); newResultTypes[2] = new Double(0); System.arraycopy(resultTypes, 0, newResultTypes, 3, resultTypes.length); return newResultTypes; } /** * Gets a description of the internal settings of the result * producer, sufficient for distinguishing a ResultProducer * instance from another with different settings (ignoring * those settings set through this interface). For example, * a cross-validation ResultProducer may have a setting for the * number of folds. For a given state, the results produced should * be compatible. Typically if a ResultProducer is an OptionHandler, * this string will represent the command line arguments required * to set the ResultProducer to that state. * * @return the description of the ResultProducer state, or null * if no state is defined */ public String getCompatibilityState() { String result = "-X " + m_NumFolds + " -S " + getStepSize() + " -L " + getLowerSize() + " -U " + getUpperSize() + " "; if (m_SplitEvaluator == null) { result += "<null SplitEvaluator>"; } else { result += "-W " + m_SplitEvaluator.getClass().getName(); } return result + " --"; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String outputFileTipText() { return "Set the destination for saving raw output. If the rawOutput " +"option is selected, then output from the splitEvaluator for " +"individual folds is saved. If the destination is a directory, " +"then each output is saved to an individual gzip file; if the " +"destination is a file, then each output is saved as an entry " +"in a zip file."; } /** * Get the value of OutputFile. * * @return Value of OutputFile. */ public File getOutputFile() { return m_OutputFile; } /** * Set the value of OutputFile. * * @param newOutputFile Value to assign to OutputFile. */ public void setOutputFile(File newOutputFile) { m_OutputFile = newOutputFile; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String numFoldsTipText() { return "Number of folds to use in cross validation."; } /** * Get the value of NumFolds. * * @return Value of NumFolds. */ public int getNumFolds() { return m_NumFolds; } /** * Set the value of NumFolds. * * @param newNumFolds Value to assign to NumFolds. */ public void setNumFolds(int newNumFolds) { m_NumFolds = newNumFolds; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String lowerSizeTipText() { return "Set the minimum number of instances in a training set. Setting zero " + "here will actually use <stepSize> number of instances at the first "
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -