📄 batchweightlearner.java
字号:
lambda = new double[] {0.01, 0.1, 0.2, 0.4, 0.6, 1.0, 2.0, 4.0}; // TMP // later: read from config file a = b = null; // --- create an inner model applier (tmp: mySVM applier) --- // TMP // RK/2002/05/25: temporary hack // SVMLearner svmLearner = (SVMLearner) getLearner(); // TMP // RK/2002/05/25: temporary hack modelApplierAndPerformanceEvaluator = getModelApplierAndPerformanceEvaluator(); // RK/2002/06/09: ... // String modelApplierName = this.getName() + ".innerModelApplier"; // TMP // RK/2002/05/25: temporary hack // OperatorParams modelApplierParameters = new OperatorParams(); // TMP // RK/2002/05/25: temporary hack // modelApplier = new SVMApplier(); // TMP // RK/2002/05/25: temporary hack; better specified in experiment config.! // modelApplier.setName(modelApplierName); // TMP // RK/2002/05/25: temporary hack; better specified in experiment config.! // // modelApplier.setParent(this); // TMP // RK/2002/05/25: temporary hack; better specified in experiment config.! // modelApplier.setParent(getLearner()); // TMP // RK/2002/06/06: temporary hack; better specified in experiment config.! // modelApplierParameters.setName (modelApplierName); // // //// set parent for parentlookup of SVM parameters: // // modelApplierParameters.setParent (this.getOperatorParameters()); // ?? // modelApplierParameters.setAttribute ("parentlookup", (getLearner()).getParentLookup()+1); // RK/2002/06/06: temporary hack // //// OR: // //// copy mySVM kernel parameters and global parameters from SVMLearner to inner SVMApplier: // for (int i = 0; i < SVMLearner.KERNEL_PARAMETER.length; i++) { // String param = svmLearner.getParameter(SVMLearner.KERNEL_PARAMETER[i]); // if (param != null) // ParameterService.setParameter(modelApplier, SVMLearner.KERNEL_PARAMETER[i], param); // ?? OK ?? ... // } // for (int i = 0; i < SVMLearner.PARAMETER.length; i++) { // String param = svmLearner.getParameter(SVMLearner.PARAMETER[i]); // if (param != null) // ParameterService.setParameter(modelApplier, SVMLearner.PARAMETER[i], param); // ?? OK ?? ... // } //// // // ParameterService.registerOperator (modelApplierParameters); // RK/2002/05/25: temporary hack // // String batchEvaluatorName = this.getName() + ".innerBatchEvaluator"; // TMP // RK/2002/05/25: temporary hack // OperatorParams batchEvaluatorParameters = new OperatorParams(); // TMP // RK/2002/05/25: temporary hack // batchEvaluator = new PerformanceEvaluator(); // TMP // RK/2002/05/25: temporary hack // batchEvaluator.setName(batchEvaluatorName); // TMP // RK/2002/05/25: temporary hack // batchEvaluator.setParent(this); // TMP // RK/2002/05/25: temporary hack // batchEvaluator.add("classification_error"); // TMP // RK/2002/05/25: temporary hack // batchEvaluatorParameters.setName(batchEvaluatorName); // TMP // RK/2002/05/25: temporary hack // // ... pass or set input (when calling this operator) & retrieve output (afterwards) ? ... // ParameterService.registerOperator(batchEvaluatorParameters); // TMP // RK/2002/05/25: temporary hack } /** compute example weight for each example in the given batch as <code>weight(batch) := e^(-lambda * batch)</code>. * Note: <tt>batch</tt> is <i>not</i> the absolute batch number, but the difference between the currently last batch * number and the number of the batch, whose examples are to be weighted, i.e. the number of batches the examples * lay in the past (= before the currently last batch). */ private double computeGlobalExponentialWeight (double lambda, int batch) { return (java.lang.Math.exp(-lambda * ((double)batch))); } /** compute example weight for each example in the given batch as <code>weight(batch) := e^(-lambda * batch)</code>. * Note: <tt>batch</tt> is <i>not</i> the absolute batch number, but the difference between the currently last batch * number and the number of the batch, whose examples are to be weighted, i.e. the number of batches the examples * lay in the past (= before the currently last batch). */ private double computeGlobalExponentialWeight (double lambda, double batch) { return (java.lang.Math.exp(-lambda * batch)); } /** compute example weight for each example in the given batch as <code>weight(a,b) := (tanh((x-a)/b) + 1) / 2</code>. * Note: <tt>batch</tt> is the absolute batch number, <i>not</i> the difference between the currently last batch * number and the number of the batch, whose examples are to be weighted, i.e. the number of batches the examples * lay in the past (= before the currently last batch). */ private double computeGlobalSigmoidalWeight (double a, double b, int batch) { return ((edu.udo.cs.yale.tools.MathFunctions.tanh(((double)batch - a) / b) + 1.0) / 2.0); } /** compute example weight for each example in the given batch as <code>weight(a,b) := (tanh((x-a)/b) + 1) / 2</code>. * Note: <tt>batch</tt> is the absolute batch number, <i>not</i> the difference between the currently last batch * number and the number of the batch, whose examples are to be weighted, i.e. the number of batches the examples * lay in the past (= before the currently last batch). */ private double computeGlobalSigmoidalWeight (double a, double b, double batch) { return ((edu.udo.cs.yale.tools.MathFunctions.tanh((batch - a) / b) + 1.0) / 2.0); } // /** ... */ // private double computeLocalWeight (double lambda, double batch) { // return 1.0; // ... // } // /** ... */ // private double computeCombinedWeight (double lambda, double batch) { // return 1.0; // ... // } // /** compute example weight for each example in the given batch using the weighting scheme defined by the parameter // * <tt>weighting_scheme</tt> of this operator. // */ // private double computeWeight (double lambda, double batch) { // switch (weightingScheme) { // case GLOBAL_EXAMPLE_WEIGHTING: return computeGlobalExponentialWeight(lambda,batch); // case LOCAL_EXAMPLE_WEIGHTING: return computeLocalWeight(lambda,batch); // case COMBINED_GLOBAL_AND_LOCAL_EXAMPLE_WEIGHTING: return computeCombinedWeight(lambda,batch); // default: return computeGlobalExponentialWeight(lambda,batch); // } // // return computeGlobalWeight(lambda,batch); // this line should never be reached, but satisfies the compiler ;^) // } /** applies the inner learner (= first encapsulated inner operator). */ // protected IOContainer learn(ExampleSet trainingSet) throws OperatorException { // return learnResult = getLearner().apply(getInput().append(new IOObject[] { trainingSet })); // } /** applies the inner applier and evaluator (= second encapsulated inner operator). */ protected IOContainer evaluate (Model learnResult, ExampleSet testSet) throws OperatorException { if (learnResult == null) { throw new FatalException("Wrong use of BatchWeightLearner.evaluate(Model,ExampleSet): " + "learned model is null."); } IOContainer input = new IOContainer(new IOObject[] { learnResult, testSet }); IOContainer result = getModelApplierAndPerformanceEvaluator().apply(input); return result; } /** set weights of all training examples (= all examples in past batches) */ private void setExponentialTrainingExampleWeights (BatchedExampleSet exampleSet, double lambda) throws OperatorException { ExampleReader exampleIterator = exampleSet.getExampleReader(); Example currentExample = null; Attribute batchIndexAttribute = exampleSet.getBatchIndexAttribute(); double currentBatch = 0.0; double weight = 0.0; // --- global weighting loop --- while (exampleIterator.hasNext()) { currentExample = exampleIterator.next(); currentBatch = currentExample.getValue (batchIndexAttribute); if (currentBatch < ((double)exampleSet.getLastBatch())) { // only change weight of examples in past batches //// ... ^- bei '<=' automatisch Gewicht = 1 ? sonst unbedingt so setzen !! potentieller Fehler !! ... // // weight = computeWeight (lambda, (exampleSet.getLastBatch() - currentBatch)); // RK/2002/05/25: old weight = computeGlobalExponentialWeight (lambda, (exampleSet.getLastBatch() - currentBatch)); // RK/2002/05/25: new if (weight < minimalWeight) { weight = 0.0; } // do not consider weights below the treshold try { currentExample.setWeight (weight); } catch (MethodNotSupportedException e) { throw new FatalException("BatchWeightLearner '" + getName() + "': The example set passed to " + "this operator must contain examples with weights.", e); } } } if (weightingScheme == GLOBAL_EXAMPLE_WEIGHTING) return; // --- local or combined weighting --- localizeTrainingExampleWeights (exampleSet); } /** set weights of all training examples (= all examples in past batches); this method requires that global * example weights have already been set before and hence may only be called from * <code>setExponentialTrainingExampleWeights</code> or <code>setSigmoidialTrainingExampleWeights</code>. */ private void localizeTrainingExampleWeights (BatchedExampleSet exampleSet) throws OperatorException { // --- evaluate the batches for local and combined weighting --- Model learnedModel = ((Learner) getLearner()).learn(exampleSet); // learn a first model (with global weighting only) BatchedExampleSet currentExampleSet; // for each batch: current test set = current batch double[] localWeight = new double[exampleSet.getLastBatch()]; // local weight of each batch double sumOfWeights = 0.0; for (int b=0; b < exampleSet.getLastBatch(); b++) { currentExampleSet = new BatchedExampleSet (exampleSet, exampleSet.getBatchIndexAttribute(), b, b); IOContainer evalOutput = evaluate (learnedModel, currentExampleSet); // apply & evaluate model PerformanceVector performance = (PerformanceVector)evalOutput.getInput(PerformanceVector.class); // get performance results //// localWeight[b] = 1.0 - Math.min(1.0,2.0*performance.get(0).getValue()); // or: getValue("classification_error") if (performance.getMainCriterion().getValue() <= 0.1) localWeight[b] = 1.0; else if (performance.getMainCriterion().getValue() >= 0.3) localWeight[b] = 0.0; else localWeight[b] = 1.0 - 5.0*(performance.getMainCriterion().getValue()-0.1); sumOfWeights += localWeight[b]; } // --- compute normalized local weights --- for (int b=0; b < exampleSet.getLastBatch(); b++) { if (sumOfWeights == 0.0) { localWeight[b] = 1.0; } else { localWeight[b] /= sumOfWeights; } } // --- set local and combined weights respectively --- Attribute batchIndexAttribute = exampleSet.getBatchIndexAttribute(); ExampleReader exampleIterator = exampleSet.getExampleReader(); Example currentExample = null; int currentBatch = 0; double weight = 0.0; while (exampleIterator.hasNext()) { currentExample = exampleIterator.next(); currentBatch = (int) currentExample.getValue (batchIndexAttribute); if (currentBatch < exampleSet.getLastBatch()) { // only change weight of examples in past batches if (weightingScheme == LOCAL_EXAMPLE_WEIGHTING) { weight = localWeight[currentBatch]; } else { // COMBINED_GLOBAL_AND_LOCAL_EXAMPLE_WEIGHTING weight = currentExample.getWeight() * localWeight[currentBatch]; } if (weight < minimalWeight) { weight = 0.0; } // do not consider weights below the treshold try { currentExample.setWeight (weight); } catch (MethodNotSupportedException e) { throw new FatalException("BatchWeightLearner '" + getName() + "': The example set passed to " + "this operator must contain examples with weights.", e); } } } } /** set weights of all training examples (= all examples in past batches) */ private void setSigmoidalTrainingExampleWeights (BatchedExampleSet exampleSet, double a, double b) throws OperatorException { ExampleReader exampleIterator = exampleSet.getExampleReader(); Example currentExample = null; Attribute batchIndexAttribute = exampleSet.getBatchIndexAttribute(); double currentBatch = 0.0; double weight = 0.0; // --- global weighting loop --- while (exampleIterator.hasNext()) { currentExample = exampleIterator.next(); currentBatch = currentExample.getValue (batchIndexAttribute); if (currentBatch < ((double)exampleSet.getLastBatch())) { // only change weight of examples in past batches weight = computeGlobalSigmoidalWeight (a, b, currentBatch); if (weight < minimalWeight) { weight = 0.0; } // do not consider weights below the treshold try { currentExample.setWeight (weight); } catch (MethodNotSupportedException e) { throw new FatalException("BatchWeightLearner '" + getName() + "': The example set passed to " + "this operator must contain examples with weights.", e); } } } if (weightingScheme == GLOBAL_EXAMPLE_WEIGHTING) return; // --- local or combined weighting --- localizeTrainingExampleWeights (exampleSet); } /** This method learns a classification model and returns it (along with a performance estimation * (xi-alpha-estimation of the classification error)). */ public Model learn (ExampleSet inputSet) throws OperatorException { // abstract method in super classes Learner and BatchLearner LogService.logMessage ("BatchWeightLearner '" + getName() + "': start learning...", LogService.TASK); LogService.logMessage (" Weighting function: " + WEIGHTING_FUNCTION_TYPE_NAME[weightingFunction] + "\n", LogService.TASK); LogService.logMessage (" Weighting scheme: " + WEIGHTING_SCHEME_TYPE_NAME[weightingScheme] + "\n", LogService.TASK); performanceEstimation = null; if (weightingFunction == EXPONENTIAL_FUNCTION) { return learnWithExponentialWeighting (inputSet); } else if (weightingFunction == SIGMOIDAL_FUNCTION) { return learnWithSigmoidalWeighting (inputSet); } else { // weightingFunction == ZERO_ONE_FUNCTION return learnWithZeroOneWeighting (inputSet); } } /** This method learns a classification model and returns it (along with a performance estimation * (xi-alpha-estimation of the classification error)). */ private Model learnWithExponentialWeighting (ExampleSet inputSet) throws OperatorException { BatchedExampleSet exampleSet = (BatchedExampleSet) inputSet; Learner learner = (Learner)(getLearner()); PerformanceVector currentPerformance = null; double bestLambda = lambda[0]; double bestError = 1.0; double currentError = 1.0;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -