📄 batchweightlearner.java
字号:
Model bestModel = null; Model currentModel = null; for(int i=0; i < lambda.length; i++) { // for each lambda LogService.logMessage("BatchWeightLearner '" + getName() + "': start learning for lambda = " + lambda[i] + " ...", LogService.TASK); setExponentialTrainingExampleWeights (exampleSet, lambda[i]); // * set the example weights //// learn model and estimate its performance on the training set: currentModel = learner.learn (exampleSet); if (learner.canEstimatePerformance()) { currentPerformance = learner.getEstimatedPerformance(); } else { throw new FatalException("BatchWeightLearner '" + getName() + "': The enclosed learner must be able " + "to estimate its performance based on the training set. The enclosed learner " + learner.getName() + "is not able to do this."); } // currentError = (currentPerformance.get(0)).getValue(); // RK/2003/04/30: old version currentError = (currentPerformance.getMainCriterion()).getValue(); // RK/2003/04/30: new version //// if current performance is better than best performance found so far, store it if (currentError < bestError) { bestLambda = lambda[i]; bestError = currentError; bestModel = currentModel; performanceEstimation = currentPerformance; } } LogService.logMessage ("BatchWeightLearner '" + this.getName() + "': selected weighting factor lambda = " + bestLambda + " for current batch = " + exampleSet.getLastBatch() + " (estimated classification error: " + bestError + ").", LogService.TASK); return bestModel; } /** This method learns a classification model and returns it (along with a performance estimation * (xi-alpha-estimation of the classification error)). */ private Model learnWithSigmoidalWeighting (ExampleSet inputSet) throws OperatorException { BatchedExampleSet exampleSet = (BatchedExampleSet) inputSet; Learner learner = (Learner)(getLearner()); PerformanceVector currentPerformance = null; double bestA = 0.0; double bestB = 1.0; double bestError = 1.0; double currentError = 1.0; Model bestModel = null; Model currentModel = null; // --- determine sensible values for the parameters a, b of the sigmoidal weighting function --- // sigmoidal function: w(a,b) := (tanh ((x-a)/b) + 1) / 2 (w(a,b) is in [0,1]) a = new double[4]; // a[ai] = "center" of the sigmoidal "switch" b = new double[5]; // b[bi] = "width" of the sigmoidal "switch" // a[0] = exampleSet.getLastBatch(); a[1] = exampleSet.getLastBatch() * 0.85; a[2] = exampleSet.getLastBatch() * 0.68; a[3] = exampleSet.getLastBatch() * 0.50; // b[0] = exampleSet.getLastBatch() * 0.01; b[1] = exampleSet.getLastBatch() * 0.10; b[2] = exampleSet.getLastBatch() * 0.30; b[3] = exampleSet.getLastBatch() * 0.60; b[4] = exampleSet.getLastBatch() * 5.00; // --- brute force test of all combinations of parameters a, b and selection of best combination --- for(int ai=0; ai < a.length; ai++) { // for value of parameter a for(int bi=0; bi < b.length; bi++) { // for value of parameter b LogService.logMessage("BatchWeightLearner '" + getName() + "': start learning for a = " + a[ai] + ", b = " + b[bi] + " ...", LogService.TASK); setSigmoidalTrainingExampleWeights (exampleSet, a[ai], b[bi]); // * set the example weights // learn model and estimate its performance on the training set: currentModel = learner.learn (exampleSet); if (learner.canEstimatePerformance()) { currentPerformance = learner.getEstimatedPerformance(); } else { throw new FatalException("BatchWeightLearner '" + getName() + "': The enclosed learner must be able " + "to estimate its performance based on the training set. The enclosed learner " + learner.getName() + "is not able to do this."); } // currentError = (currentPerformance.get(0)).getValue(); // RK/2003/04/30: old version currentError = (currentPerformance.getMainCriterion()).getValue(); // RK/2003/04/30: new version // if current performance is better than best performance found so far, store it if (currentError < bestError) { bestA = a[ai]; bestB = b[bi]; bestError = currentError; bestModel = currentModel; performanceEstimation = currentPerformance; } } } LogService.logMessage ("BatchWeightLearner '" + this.getName() + "': selected weighting parameters a = " + bestA + ", b = " + bestB + " for current batch = " + exampleSet.getLastBatch() + " (estimaed classification error: " + bestError + ").", LogService.TASK); return bestModel; } /** This method learns a classification model and returns it (along with a performance estimation * (xi-alpha-estimation of the classification error)). * This methods sets the weights of all training examples (= all examples in past batches) to one * (if the error the model learned on the currently last batch makes on the current batch is smaller * than twice the error of this model on the currently last batch) or zero (otherwise). */ private Model learnWithZeroOneWeighting (ExampleSet inputSet) throws OperatorException { BatchedExampleSet exampleSet = (BatchedExampleSet) inputSet; Attribute batchIndexAttribute = exampleSet.getBatchIndexAttribute(); int lastBatch = exampleSet.getLastBatch(); int currentBatch = 0; Learner learner = (Learner)(getLearner()); // ===== learn initial model learned on currently last batch only ===== BatchedExampleSet currentExampleSet = new BatchedExampleSet (exampleSet, batchIndexAttribute, lastBatch, lastBatch); Model initialModel = learner.learn(currentExampleSet); // initial model learned on currently last batch only Model finalModel = null; // final model learned on all batches selected for training // ===== determine error of the initial model on the currently last batch ===== IOContainer evalOutput = evaluate (initialModel, currentExampleSet); // apply & evaluate model PerformanceVector performance = (PerformanceVector)evalOutput.getInput(PerformanceVector.class); // get performance results double errorOnLastBatch = performance.getMainCriterion().getValue(); double errorOnCurrentBatch = 1.0; // ===== apply model and determine batch weights (for each batch: current test set = current batch) ===== double[] batchWeight = new double[lastBatch]; // local weight of each batch int noOfSelectedBatches = 0; // number of batches with weight set to one String selectedBatchesOutput = ""; for (int b=0; b < lastBatch; b++) { currentExampleSet = new BatchedExampleSet (exampleSet, batchIndexAttribute, b, b); evalOutput = evaluate (initialModel, currentExampleSet); // apply & evaluate model performance = (PerformanceVector)evalOutput.getInput(PerformanceVector.class); // get performance results // errorOnCurrentBatch = performance.get(0).getValue(); // or: performance.getValue("classification_error") // RK/2003/04/30: old errorOnCurrentBatch = performance.getMainCriterion().getValue(); LogService.logMessage ("BatchWeightLearner '" + getName() + "': error on current batch = " + errorOnCurrentBatch + " vs. error on last batch = " + errorOnLastBatch + " (batch " + b + " vs. " + lastBatch + ")", LogService.TASK); if ((errorOnCurrentBatch < 5.0 * errorOnLastBatch) || (errorOnCurrentBatch < 0.2)) { // \__ RK/2002/09/24: ERROR: in frist condition use ... < 2.0 * EstimatedError (not true error, wg. risk of over-fitting) batchWeight[b] = 1.0; noOfSelectedBatches++; selectedBatchesOutput += b + " "; } else { batchWeight[b] = 0.0; } } LogService.logMessage ("BatchWeightLearner '" + getName() + "': selection of the following batches for " + "training at the currently last batch " + lastBatch + ": " + selectedBatchesOutput, LogService.TASK); // ===== set new example weights ===== currentExampleSet = new BatchedExampleSet (exampleSet, batchIndexAttribute, 0, lastBatch); ExampleReader exampleIterator = currentExampleSet.getExampleReader(); Example currentExample = null; while (exampleIterator.hasNext()) { currentExample = exampleIterator.next(); currentBatch = (int) currentExample.getValue (batchIndexAttribute); if (currentBatch < exampleSet.getLastBatch()) { // only change weight of examples in past batches try { currentExample.setWeight (batchWeight[currentBatch]); } catch (MethodNotSupportedException e) { throw new FatalException("BatchWeightLearner '" + getName() + "': The example set passed to " + "this operator must contain examples with weights.\n" + e.getMessage()); } } // else if (currentBatch == exampleSet.getLastBatch()) { // try { // currentExample.setWeight (1.0); // } catch (MethodNotSupportedException e) { // LogService.logFatalException ("BatchWeightLearner '" + getName() + "': The example set passed to " + // "this operator must contain examples with weights.", e); // } // } } // ===== re-learn model ===== finalModel = learner.learn(currentExampleSet); LogService.logMessage ("BatchWeightLearner '" + this.getName() + "': " + noOfSelectedBatches + " batches selected for re-training for currently last batch " + exampleSet.getLastBatch(), LogService.TASK); return finalModel; } public boolean canEstimatePerformance() { return true; } /** returns an object of the class <tt>EstimatedPerformance</tt> containing the xi-alpha-performance * estimates of the learned mySVM model, if ... */ public PerformanceVector getEstimatedPerformance() { return performanceEstimation; } /** learns a model from a given example set (i.e. given in the input of the operator (<code>IOContainer</code>)). * The model is stored under the name <tt>model_file</tt> (if specified). */ public IOObject[] apply() throws OperatorException { // from super class Learner ExampleSet exampleSet = (BatchedExampleSet)getInput(ExampleSet.class); if (exampleSet==null) { throw new FatalException("BatchWeightLearner '"+getName()+"': No input example set!"); } if (exampleSet.getNumberOfAttributes()==0) { throw new FatalException("BatchWeightLearner '"+getName()+"': Input example set has no attributes"); } LogService.logMessage("BatchWeightLearner '" + getName() + "': Start learning...", LogService.TASK); Model model = learn(exampleSet); try { if (modelFile != null) model.writeModel(getExperiment().resolveFileName(modelFile)); } catch (IOException e) { LogService.logMessage("BatchWeightLearner '" + getName() + "' Can't write model file: " + modelFile, LogService.ERROR); } PerformanceVector perfVector = getEstimatedPerformance(); if (perfVector == null) { return new IOObject[] { model }; } else { return new IOObject[] { model, perfVector }; } } /** Liefert zurück, welche EingabeTypen dieser Operator bearbeiten kann und wie das Eingabearray aussehen muß. */ public Class[] getInputClasses() { return INPUT_CLASSES; } /** Liefert zurück, welche AusgabeTypen dieser Operator liefert und wie das Ausgabearray aussehen wird. */ public Class[] getOutputClasses() { // from super class Learner return canEstimatePerformance() ? new Class[] { Model.class, PerformanceVector.class } : new Class[] { Model.class }; } // /** Sets the index of the class to use as "positive" (+1), e.g. // pass to enclosed learner ?? TMP / TO DO ?? // * setPositiveLabelIndex(attribute.mapString("positive")) // */ // public void setPositiveLabelIndex(int index) { // this.positiveLabelIndex = index; // } public List getParameterTypes() { List types = super.getParameterTypes(); types.add(new ParameterTypeString("model_file", "If this parameter is set, the model is written to a file.")); types.add(new ParameterTypeDouble("minimal_weight", "Threshold, under which example weights are considered equal to zero ... ", 0.0, 1.0, 0.0001)); types.add(new ParameterTypeCategory("weighting_function", "[PENDING]", WEIGHTING_FUNCTION_TYPE_NAME, EXPONENTIAL_FUNCTION)); types.add(new ParameterTypeCategory("weighting_scheme", "[PENDING]", WEIGHTING_SCHEME_TYPE_NAME, GLOBAL_EXAMPLE_WEIGHTING)); return types; } /** Returns the minimum number of innner operators, which is equal to 1 for this operator. */ public int getMinNumberOfInnerOperators() { return 1; } /** Returns the maximum number of innner operators, which is equal to the maximum <tt>Integer</tt> value of Java for this operator. */ public int getMaxNumberOfInnerOperators() { return Integer.MAX_VALUE; } /** Returns the number of inner steps of this operator, which is the number of runs times the * sum of the number of inner steps of all inner operators. */ public int getNumberOfSteps() { // TO DO: return (sum number of steps of all inner operators) * (number of runs) ? return 1; }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -