⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 batchweightlearner.java

📁 著名的开源仿真软件yale
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
	Model              bestModel          = null;	Model              currentModel       = null;	for(int i=0; i < lambda.length; i++) {                  // for each lambda	    LogService.logMessage("BatchWeightLearner '" + getName() + "': start learning for lambda = " + lambda[i] +				  " ...", LogService.TASK);	    setExponentialTrainingExampleWeights (exampleSet, lambda[i]);  // * set the example weights	    //// learn model and estimate its performance on the training set:	    currentModel = learner.learn (exampleSet);	    if (learner.canEstimatePerformance()) {		currentPerformance = learner.getEstimatedPerformance();	    } else {		throw new FatalException("BatchWeightLearner '" + getName() + 					 "': The enclosed learner must be able " +					 "to estimate its performance based on the training set. The enclosed learner " +					 learner.getName() + "is not able to do this.");	    }	    // currentError = (currentPerformance.get(0)).getValue();           // RK/2003/04/30: old version	    currentError = (currentPerformance.getMainCriterion()).getValue();  // RK/2003/04/30: new version	    //// if current performance is better than best performance found so far, store it	    if (currentError < bestError) {		bestLambda = lambda[i];		bestError = currentError;		bestModel = currentModel;		performanceEstimation = currentPerformance;	    }	}	LogService.logMessage ("BatchWeightLearner '" + this.getName() + "':  selected weighting factor lambda = " +			       bestLambda + " for current batch = " + exampleSet.getLastBatch() +			       " (estimated classification error: " + bestError + ").", LogService.TASK);	return bestModel;    }        /** This method learns a classification model and returns it (along with a performance estimation     *  (xi-alpha-estimation of the classification error)).     */    private Model  learnWithSigmoidalWeighting (ExampleSet inputSet)  throws OperatorException {	BatchedExampleSet  exampleSet         = (BatchedExampleSet) inputSet;	Learner            learner            = (Learner)(getLearner());	PerformanceVector  currentPerformance = null;	double             bestA              = 0.0;	double             bestB              = 1.0;	double             bestError          = 1.0;	double             currentError       = 1.0;	Model              bestModel          = null;	Model              currentModel       = null;	// --- determine sensible values for the parameters a, b of the sigmoidal weighting function ---	// sigmoidal function:  w(a,b) := (tanh ((x-a)/b) + 1) / 2  (w(a,b) is in [0,1])	a = new double[4];    // a[ai] = "center" of the sigmoidal "switch"	b = new double[5];    // b[bi] = "width" of the sigmoidal "switch"	//	a[0] = exampleSet.getLastBatch();	a[1] = exampleSet.getLastBatch() * 0.85;	a[2] = exampleSet.getLastBatch() * 0.68;	a[3] = exampleSet.getLastBatch() * 0.50;	//	b[0] = exampleSet.getLastBatch() * 0.01;	b[1] = exampleSet.getLastBatch() * 0.10;	b[2] = exampleSet.getLastBatch() * 0.30;	b[3] = exampleSet.getLastBatch() * 0.60;	b[4] = exampleSet.getLastBatch() * 5.00;	// --- brute force test of all combinations of parameters a, b and selection of best combination ---	for(int ai=0; ai < a.length; ai++) {                  // for value of parameter a	    for(int bi=0; bi < b.length; bi++) {                  // for value of parameter b		LogService.logMessage("BatchWeightLearner '" + getName() + "': start learning for a = " + a[ai] +				      ", b = " + b[bi] + " ...", LogService.TASK);		setSigmoidalTrainingExampleWeights (exampleSet, a[ai], b[bi]);  // * set the example weights		// learn model and estimate its performance on the training set:		currentModel = learner.learn (exampleSet);		if (learner.canEstimatePerformance()) {		    currentPerformance = learner.getEstimatedPerformance();		} else {		    throw new FatalException("BatchWeightLearner '" + getName() + "': The enclosed learner must be able " +					     "to estimate its performance based on the training set. The enclosed learner " +					     learner.getName() + "is not able to do this.");		}		// currentError = (currentPerformance.get(0)).getValue();           // RK/2003/04/30: old version		currentError = (currentPerformance.getMainCriterion()).getValue();  // RK/2003/04/30: new version		// if current performance is better than best performance found so far, store it		if (currentError < bestError) {		    bestA     = a[ai];		    bestB     = b[bi];		    bestError = currentError;		    bestModel = currentModel;		    performanceEstimation = currentPerformance;		}	    }	}	LogService.logMessage ("BatchWeightLearner '" + this.getName() + "':  selected weighting parameters a = " +			       bestA + ", b = " + bestB + " for current batch = " + exampleSet.getLastBatch() +			       " (estimaed classification error: " + bestError + ").", LogService.TASK);	return bestModel;    }        /** This method learns a classification model and returns it (along with a performance estimation     *  (xi-alpha-estimation of the classification error)).     *  This methods sets the weights of all training examples (= all examples in past batches) to one     *  (if the error the model learned on the currently last batch makes on the current batch is smaller     *  than twice the error of this model on the currently last batch) or zero (otherwise).     */    private Model  learnWithZeroOneWeighting (ExampleSet inputSet)  throws OperatorException {	BatchedExampleSet  exampleSet          = (BatchedExampleSet) inputSet;	Attribute          batchIndexAttribute = exampleSet.getBatchIndexAttribute();	int                lastBatch           = exampleSet.getLastBatch();	int                currentBatch        = 0;	Learner            learner             = (Learner)(getLearner());	// ===== learn initial model learned on currently last batch only =====	BatchedExampleSet  currentExampleSet   = new BatchedExampleSet (exampleSet, batchIndexAttribute, lastBatch, lastBatch);	Model              initialModel        = learner.learn(currentExampleSet);  // initial model learned on currently last batch only	Model              finalModel          = null;                      // final model learned on all batches selected for training	// ===== determine error of the initial model on the currently last batch =====	IOContainer        evalOutput       = evaluate (initialModel, currentExampleSet);                      // apply & evaluate model	PerformanceVector  performance      = (PerformanceVector)evalOutput.getInput(PerformanceVector.class); // get performance results	double             errorOnLastBatch = performance.getMainCriterion().getValue();	double             errorOnCurrentBatch = 1.0;	// ===== apply model and determine batch weights (for each batch: current test set = current batch) =====	double[]           batchWeight           = new double[lastBatch];   // local weight of each batch	int                noOfSelectedBatches   = 0;                       // number of batches with weight set to one	String             selectedBatchesOutput = "";	for (int b=0; b < lastBatch; b++) {	    currentExampleSet = new BatchedExampleSet (exampleSet, batchIndexAttribute, b, b);	    evalOutput  = evaluate (initialModel, currentExampleSet);                       // apply & evaluate model	    performance = (PerformanceVector)evalOutput.getInput(PerformanceVector.class);  // get performance results	    // errorOnCurrentBatch = performance.get(0).getValue();  // or:  performance.getValue("classification_error")  // RK/2003/04/30: old	    errorOnCurrentBatch = performance.getMainCriterion().getValue();	    LogService.logMessage ("BatchWeightLearner '" + getName() + "': error on current batch = " +				   errorOnCurrentBatch + "  vs.  error on last batch = " + errorOnLastBatch +				   " (batch " + b + " vs. " + lastBatch + ")", LogService.TASK);	    if ((errorOnCurrentBatch < 5.0 * errorOnLastBatch) || (errorOnCurrentBatch < 0.2)) {		// \__ RK/2002/09/24: ERROR: in frist condition use ... < 2.0 * EstimatedError (not true error, wg. risk of over-fitting)		batchWeight[b] = 1.0;		noOfSelectedBatches++;		selectedBatchesOutput += b + " ";	    } else {		batchWeight[b] = 0.0; 	    }	}	LogService.logMessage ("BatchWeightLearner '" + getName() + "': selection of the following batches for " +			       "training at the currently last batch " + lastBatch + ": " + selectedBatchesOutput,			       LogService.TASK);	// ===== set new example weights =====        currentExampleSet = new BatchedExampleSet (exampleSet, batchIndexAttribute, 0, lastBatch);	ExampleReader  exampleIterator = currentExampleSet.getExampleReader();	Example        currentExample  = null;	while (exampleIterator.hasNext()) {	    currentExample = exampleIterator.next();	    currentBatch = (int) currentExample.getValue (batchIndexAttribute);	    if (currentBatch < exampleSet.getLastBatch()) {   // only change weight of examples in past batches		try {		    currentExample.setWeight (batchWeight[currentBatch]);		} catch (MethodNotSupportedException e) {		    throw new FatalException("BatchWeightLearner '" + getName() + "': The example set passed to " +					     "this operator must contain examples with weights.\n" + e.getMessage());		}	    }	    // else if (currentBatch == exampleSet.getLastBatch()) {	    //   try {	    //	    currentExample.setWeight (1.0);	    //	 } catch (MethodNotSupportedException e) {	    //	    LogService.logFatalException ("BatchWeightLearner '" + getName() + "': The example set passed to " +	    //					  "this operator must contain examples with weights.", e);	    //	 }	    // }	}	// ===== re-learn model =====	finalModel = learner.learn(currentExampleSet);	LogService.logMessage ("BatchWeightLearner '" + this.getName() + "':  " + noOfSelectedBatches + 			       " batches selected for re-training for currently last batch " + 			       exampleSet.getLastBatch(), LogService.TASK);	return finalModel;    }    public boolean canEstimatePerformance() { return true; }        /** returns an object of the class <tt>EstimatedPerformance</tt> containing the xi-alpha-performance     *  estimates of the learned mySVM model, if ...     */    public PerformanceVector  getEstimatedPerformance() {        return performanceEstimation;    }    /** learns a model from a given example set (i.e. given in the input of the operator (<code>IOContainer</code>)).     *  The model is stored under the name <tt>model_file</tt> (if specified).     */    public IOObject[] apply()  throws OperatorException {                             // from super class Learner	ExampleSet exampleSet = (BatchedExampleSet)getInput(ExampleSet.class); 	if (exampleSet==null) {	    throw new FatalException("BatchWeightLearner '"+getName()+"': No input example set!"); 	} 	if (exampleSet.getNumberOfAttributes()==0) {	    throw new FatalException("BatchWeightLearner '"+getName()+"': Input example set has no attributes");	}	LogService.logMessage("BatchWeightLearner '" + getName() + "': Start learning...", LogService.TASK);		Model model = learn(exampleSet);		try {	    if (modelFile != null) model.writeModel(getExperiment().resolveFileName(modelFile));	} catch (IOException e) {	    LogService.logMessage("BatchWeightLearner '" + getName() + "' Can't write model file: " + modelFile, LogService.ERROR);	}		PerformanceVector perfVector = getEstimatedPerformance();	if (perfVector == null) {	    return new IOObject[] { model };	} else {	    return new IOObject[] { model, perfVector };	}    }        /** Liefert zur&uuml;ck, welche EingabeTypen dieser Operator bearbeiten kann und wie das Eingabearray aussehen mu&szlig;. */    public Class[]  getInputClasses() { return INPUT_CLASSES; }    /** Liefert zur&uuml;ck, welche AusgabeTypen dieser Operator liefert und wie das Ausgabearray aussehen wird. */    public Class[]  getOutputClasses() {                                                           // from super class Learner	return canEstimatePerformance() ? 	    new Class[] { Model.class, PerformanceVector.class } : 	    new Class[] { Model.class };    }    // /** Sets the index of the class to use as "positive" (+1), e.g.    // pass to enclosed learner ??  TMP / TO DO  ??    //  *  setPositiveLabelIndex(attribute.mapString("positive"))    //  */    // public void setPositiveLabelIndex(int index) {    //	 this.positiveLabelIndex = index;    // }    public List getParameterTypes() {	List types = super.getParameterTypes();	types.add(new ParameterTypeString("model_file",					  "If this parameter is set, the model is written to a file."));	types.add(new ParameterTypeDouble("minimal_weight", 					  "Threshold, under which example weights are considered equal to zero ... ",					  0.0, 1.0, 0.0001));	types.add(new ParameterTypeCategory("weighting_function", "[PENDING]", WEIGHTING_FUNCTION_TYPE_NAME, EXPONENTIAL_FUNCTION));	types.add(new ParameterTypeCategory("weighting_scheme", "[PENDING]", WEIGHTING_SCHEME_TYPE_NAME, GLOBAL_EXAMPLE_WEIGHTING));	return types;    }    /** Returns the minimum number of innner operators, which is equal to 1 for this operator. */    public int getMinNumberOfInnerOperators() { return 1; }    /** Returns the maximum number of innner operators, which is equal to the maximum <tt>Integer</tt> value of Java for this operator. */    public int getMaxNumberOfInnerOperators() { return Integer.MAX_VALUE; }    /** Returns the number of inner steps of this operator, which is the number of runs times the      *  sum of the number of inner steps of all inner operators.     */    public int getNumberOfSteps() {  	// TO DO:  return (sum number of steps of all inner operators) * (number of runs) ?	return 1;    }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -