⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 serialexecutorloop.cpp

📁 粗糙集应用软件
💻 CPP
📖 第 1 页 / 共 2 页
字号:
	SerialExecutorLoop *self = const_cast(SerialExecutorLoop *, this);

	message.Notify("Initializing sampling scheme...");

	// Initialize sampler.
	if (!self->InitializeSamplingScheme(*table, masked, rng))
		return NULL;

	// Get desired output type.
	Id outputtype = GetOutputType();

	// The structure to output.
	Handle<Structure> output;

	Vector(float) accuracies;
	Vector(float) rocareas;
	Vector(float) rocstderrs;

	accuracies.reserve(GetN());
	rocareas.reserve(GetN());
	rocstderrs.reserve(GetN());

	int i, j;

	// Do training/testing loop.
	for (i = 0; i < GetN(); i++) {

		if (!message.Progress("Executing iteration " + String::Format(i + 1) + "...", i, GetN())) {
			StaticCleanUp(*self, outputtype);
			return NULL;
		}

		message.Notify("Sampling tables...");

		// Initialize training/testing tables before sampling. Since we don't know if
		// the input structures get physically modified by any of the algorithms in the
		// pipeline, we have to operate on duplicates for each iteration.
		Handle<DecisionTable> training = dynamic_cast(DecisionTable *, table->Duplicate());
		Handle<DecisionTable> testing  = dynamic_cast(DecisionTable *, table->Duplicate());

		// Sample training/testing tables.
		if (!SampleTables(i, rng, *training, *testing, masked)) {
			Message::Error("Failed to sample tables.");
			StaticCleanUp(*self, outputtype);
			return NULL;
		}

		Vector(String) parameters_training_copy = parameters_training;
		Vector(String) parameters_testing_copy  = parameters_testing;

		// Replace all iteration no. "macros" with index of current iteration. Try to be case-tolerant.
		for (j = 0; j < parameters_training_copy.size(); j++) {
			parameters_training_copy[j].Replace("#iteration#", String::Format(i));
			parameters_training_copy[j].Replace("#ITERATION#", String::Format(i));
		}
		for (j = 0; j < parameters_testing_copy.size(); j++) {
			parameters_testing_copy[j].Replace("#iteration#", String::Format(i));
			parameters_testing_copy[j].Replace("#ITERATION#", String::Format(i));
		}

		// Set up training pipeline. (Executed by separate object.)
		SerialExecutor pipeline_training;

		// We want the originating decision table so that we get the dictionary as well.
		pipeline_training.SetOutputType(DECISIONTABLE);

		if (!message.Progress("Executing iteration " + String::Format(i + 1) + ", training pipeline...", i, GetN())) {
			StaticCleanUp(*self, outputtype);
			return NULL;
		}

		// Execute training pipeline.
		Handle <DecisionTable> parenttable = dynamic_cast(DecisionTable *, pipeline_training.ExecuteCommands(*training, algorithms_training, parameters_training_copy, stream, false));

		if (parenttable == NULL) {
			Message::Error("Failed to execute training pipeline.");
			StaticCleanUp(*self, outputtype);
			return NULL;
		}

		// Get the last rule set derived from the output decision table.
		Handle<Rules> rules = StaticGetRules(*parenttable);

		if (rules == NULL) {
			Message::Error("Unable to locate rule set generated in training pipeline.");
			StaticCleanUp(*self, outputtype);
			return NULL;
		}

		// Temporarily set rules so that they can be employed in the testing pipeline.
		self->SetRules(rules);

		// Set up a temporary project so that FindParent method will function.
		Handle<Project> project = Creator::Project();

		if (!ProjectManager::InsertProject(project.GetPointer())) {
			StaticCleanUp(*self, outputtype);
			return NULL;
		}

		if (!project->AppendChild(parenttable.GetPointer())) {
			StaticCleanUp(*self, outputtype);
			StaticCleanUp(*project, *parenttable);
			return NULL;
		}

		// Set up training pipeline. (Executed by this object.)
		self->SetOutputType(BATCHCLASSIFICATION);

		if (!message.Progress("Executing iteration " + String::Format(i + 1) + ", testing pipeline...", i, GetN())) {
			StaticCleanUp(*self, outputtype);
			return NULL;
		}

		// Execute testing pipeline.
		Handle<BatchClassification> batchclassification = dynamic_cast(BatchClassification *, SerialExecutor::ExecuteCommands(*testing, algorithms_testing, parameters_testing_copy, stream));

		if (batchclassification == NULL) {
			Message::Error("Failed to execute testing pipeline.");
			StaticCleanUp(*self, outputtype);
			StaticCleanUp(*project, *parenttable);
			return NULL;
		}

		// Display matrix contents in log file.
		if (!SaveLogEntry(stream, *batchclassification)) {
			StaticCleanUp(*self, outputtype);
			StaticCleanUp(*project, *parenttable);
			return NULL;
		}

		// Collect statistics for this loop.
		accuracies.push_back(batchclassification->GetConfusionMatrix().GetDiagonalRatio());

		if (batchclassification->GetROCArea() != Undefined::Float())
			rocareas.push_back(batchclassification->GetROCArea());

		if (batchclassification->GetROCStandardError() != Undefined::Float())
			rocstderrs.push_back(batchclassification->GetROCStandardError());

		// Try to adhere to specified output type.
		if (batchclassification->IsA(outputtype))
			output = batchclassification.GetPointer();
		else if (rules->IsA(outputtype))
			output = rules.GetPointer();
		else if (parenttable->IsA(outputtype))
			output = parenttable.GetPointer();
		else if (structure.IsA(outputtype))
			output = &structure;

		// Clean up temporary project stuff.
		StaticCleanUp(*project, *parenttable);

	}

	// Clean up.
	StaticCleanUp(*self, outputtype);

	// Display statistics in log file.
	if (!SaveLogStatistics(stream, accuracies, "Accuracy"))
		return NULL;

	if (!rocareas.empty() && !SaveLogStatistics(stream, rocareas, "ROC.AUC"))
		return NULL;

	if (!rocstderrs.empty() && !SaveLogStatistics(stream, rocstderrs, "ROC.AUC.SE"))
		return NULL;

	if (output == NULL) {
		if (outputtype == Undefined::Id())
			output = &structure;
		else
			Message::Error("Unable to return a structure of type " + IdHolder::GetClassname(GetOutputType()) + ".");
	}

	return output.Release();

}

//-------------------------------------------------------------------
// Methods inherited from SerialExecutor.
//===================================================================

//-------------------------------------------------------------------
// Method........: SetSpecialParameters
// Author........: Aleksander 豩rn
// Date..........:
// Description...: Used in the testing pipeline.
//                 If the algorithm is a batch classifier, then set
//                 rules produced in training pipeline.
// Comments......:
// Revisions.....:
//===================================================================

bool
SerialExecutorLoop::SetSpecialParameters(Algorithm &algorithm, const String &parameters) const {

	if (!algorithm.IsA(BATCHCLASSIFIER))
		return true;

	// Cast to verified type.
	Handle<BatchClassifier> batchclassifier = dynamic_cast(BatchClassifier *, &algorithm);

	// Set rules.
	if (!batchclassifier->SetRules(GetRules().GetPointer()))
		return false;

	// Set parameters again, if some of them are for the embedded classifier.
	algorithm.SetParameters(parameters);

	return true;

}

//-------------------------------------------------------------------
// New virtual methods.
//===================================================================

//-------------------------------------------------------------------
// Method........: SplitCommands
// Author........: Aleksander 豩rn
// Date..........:
// Description...: Splits the input commands/parameters and returns
//                 (in-place) two command/parameter sets.
// Comments......:
// Revisions.....:
//===================================================================

bool
SerialExecutorLoop::SplitCommands(int index, const Algorithm::Handles &algorithms,  const Vector(String) &parameters, Algorithm::Handles &algorithms1, Vector(String) &parameters1, Algorithm::Handles &algorithms2, Vector(String) &parameters2) const {

	// Clear vectors.
	algorithms1.erase(algorithms1.begin(), algorithms1.end());
	algorithms2.erase(algorithms2.begin(), algorithms2.end());
	parameters1.erase(parameters1.begin(), parameters1.end());
	parameters2.erase(parameters2.begin(), parameters2.end());

	// Check input.
	if (algorithms.size() != parameters.size())
		return false;

	// Check index validity.
	if (index < 0 || index > algorithms.size())
		return false;

	int i;

	// Do split.
	for (i = 0; i < algorithms.size(); i++) {
		if (i < index) {
			algorithms1.push_back(algorithms[i]);
			parameters1.push_back(parameters[i]);
		}
		else {
			algorithms2.push_back(algorithms[i]);
			parameters2.push_back(parameters[i]);
		}
	}

	return true;

}

//-------------------------------------------------------------------
// Method........: InitializeSamplingScheme
// Author........: Aleksander 豩rn
// Date..........:
// Description...: Initializes internals used for partitioning
//                 data into training and testing sets.
//
//                 Default implementation does nothing.
// Comments......:
// Revisions.....:
//===================================================================

bool
SerialExecutorLoop::InitializeSamplingScheme(const DecisionTable &/*table*/, bool /*masked*/, const RNG &/*rng*/) {
	return true;
}

//-------------------------------------------------------------------
// Method........: SaveLogEntry
// Author........: Aleksander 豩rn
// Date..........:
// Description...:
// Comments......:
// Revisions.....: A

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -