📄 serialexecutorloop.cpp
字号:
SerialExecutorLoop *self = const_cast(SerialExecutorLoop *, this);
message.Notify("Initializing sampling scheme...");
// Initialize sampler.
if (!self->InitializeSamplingScheme(*table, masked, rng))
return NULL;
// Get desired output type.
Id outputtype = GetOutputType();
// The structure to output.
Handle<Structure> output;
Vector(float) accuracies;
Vector(float) rocareas;
Vector(float) rocstderrs;
accuracies.reserve(GetN());
rocareas.reserve(GetN());
rocstderrs.reserve(GetN());
int i, j;
// Do training/testing loop.
for (i = 0; i < GetN(); i++) {
if (!message.Progress("Executing iteration " + String::Format(i + 1) + "...", i, GetN())) {
StaticCleanUp(*self, outputtype);
return NULL;
}
message.Notify("Sampling tables...");
// Initialize training/testing tables before sampling. Since we don't know if
// the input structures get physically modified by any of the algorithms in the
// pipeline, we have to operate on duplicates for each iteration.
Handle<DecisionTable> training = dynamic_cast(DecisionTable *, table->Duplicate());
Handle<DecisionTable> testing = dynamic_cast(DecisionTable *, table->Duplicate());
// Sample training/testing tables.
if (!SampleTables(i, rng, *training, *testing, masked)) {
Message::Error("Failed to sample tables.");
StaticCleanUp(*self, outputtype);
return NULL;
}
Vector(String) parameters_training_copy = parameters_training;
Vector(String) parameters_testing_copy = parameters_testing;
// Replace all iteration no. "macros" with index of current iteration. Try to be case-tolerant.
for (j = 0; j < parameters_training_copy.size(); j++) {
parameters_training_copy[j].Replace("#iteration#", String::Format(i));
parameters_training_copy[j].Replace("#ITERATION#", String::Format(i));
}
for (j = 0; j < parameters_testing_copy.size(); j++) {
parameters_testing_copy[j].Replace("#iteration#", String::Format(i));
parameters_testing_copy[j].Replace("#ITERATION#", String::Format(i));
}
// Set up training pipeline. (Executed by separate object.)
SerialExecutor pipeline_training;
// We want the originating decision table so that we get the dictionary as well.
pipeline_training.SetOutputType(DECISIONTABLE);
if (!message.Progress("Executing iteration " + String::Format(i + 1) + ", training pipeline...", i, GetN())) {
StaticCleanUp(*self, outputtype);
return NULL;
}
// Execute training pipeline.
Handle <DecisionTable> parenttable = dynamic_cast(DecisionTable *, pipeline_training.ExecuteCommands(*training, algorithms_training, parameters_training_copy, stream, false));
if (parenttable == NULL) {
Message::Error("Failed to execute training pipeline.");
StaticCleanUp(*self, outputtype);
return NULL;
}
// Get the last rule set derived from the output decision table.
Handle<Rules> rules = StaticGetRules(*parenttable);
if (rules == NULL) {
Message::Error("Unable to locate rule set generated in training pipeline.");
StaticCleanUp(*self, outputtype);
return NULL;
}
// Temporarily set rules so that they can be employed in the testing pipeline.
self->SetRules(rules);
// Set up a temporary project so that FindParent method will function.
Handle<Project> project = Creator::Project();
if (!ProjectManager::InsertProject(project.GetPointer())) {
StaticCleanUp(*self, outputtype);
return NULL;
}
if (!project->AppendChild(parenttable.GetPointer())) {
StaticCleanUp(*self, outputtype);
StaticCleanUp(*project, *parenttable);
return NULL;
}
// Set up training pipeline. (Executed by this object.)
self->SetOutputType(BATCHCLASSIFICATION);
if (!message.Progress("Executing iteration " + String::Format(i + 1) + ", testing pipeline...", i, GetN())) {
StaticCleanUp(*self, outputtype);
return NULL;
}
// Execute testing pipeline.
Handle<BatchClassification> batchclassification = dynamic_cast(BatchClassification *, SerialExecutor::ExecuteCommands(*testing, algorithms_testing, parameters_testing_copy, stream));
if (batchclassification == NULL) {
Message::Error("Failed to execute testing pipeline.");
StaticCleanUp(*self, outputtype);
StaticCleanUp(*project, *parenttable);
return NULL;
}
// Display matrix contents in log file.
if (!SaveLogEntry(stream, *batchclassification)) {
StaticCleanUp(*self, outputtype);
StaticCleanUp(*project, *parenttable);
return NULL;
}
// Collect statistics for this loop.
accuracies.push_back(batchclassification->GetConfusionMatrix().GetDiagonalRatio());
if (batchclassification->GetROCArea() != Undefined::Float())
rocareas.push_back(batchclassification->GetROCArea());
if (batchclassification->GetROCStandardError() != Undefined::Float())
rocstderrs.push_back(batchclassification->GetROCStandardError());
// Try to adhere to specified output type.
if (batchclassification->IsA(outputtype))
output = batchclassification.GetPointer();
else if (rules->IsA(outputtype))
output = rules.GetPointer();
else if (parenttable->IsA(outputtype))
output = parenttable.GetPointer();
else if (structure.IsA(outputtype))
output = &structure;
// Clean up temporary project stuff.
StaticCleanUp(*project, *parenttable);
}
// Clean up.
StaticCleanUp(*self, outputtype);
// Display statistics in log file.
if (!SaveLogStatistics(stream, accuracies, "Accuracy"))
return NULL;
if (!rocareas.empty() && !SaveLogStatistics(stream, rocareas, "ROC.AUC"))
return NULL;
if (!rocstderrs.empty() && !SaveLogStatistics(stream, rocstderrs, "ROC.AUC.SE"))
return NULL;
if (output == NULL) {
if (outputtype == Undefined::Id())
output = &structure;
else
Message::Error("Unable to return a structure of type " + IdHolder::GetClassname(GetOutputType()) + ".");
}
return output.Release();
}
//-------------------------------------------------------------------
// Methods inherited from SerialExecutor.
//===================================================================
//-------------------------------------------------------------------
// Method........: SetSpecialParameters
// Author........: Aleksander 豩rn
// Date..........:
// Description...: Used in the testing pipeline.
// If the algorithm is a batch classifier, then set
// rules produced in training pipeline.
// Comments......:
// Revisions.....:
//===================================================================
bool
SerialExecutorLoop::SetSpecialParameters(Algorithm &algorithm, const String ¶meters) const {
if (!algorithm.IsA(BATCHCLASSIFIER))
return true;
// Cast to verified type.
Handle<BatchClassifier> batchclassifier = dynamic_cast(BatchClassifier *, &algorithm);
// Set rules.
if (!batchclassifier->SetRules(GetRules().GetPointer()))
return false;
// Set parameters again, if some of them are for the embedded classifier.
algorithm.SetParameters(parameters);
return true;
}
//-------------------------------------------------------------------
// New virtual methods.
//===================================================================
//-------------------------------------------------------------------
// Method........: SplitCommands
// Author........: Aleksander 豩rn
// Date..........:
// Description...: Splits the input commands/parameters and returns
// (in-place) two command/parameter sets.
// Comments......:
// Revisions.....:
//===================================================================
bool
SerialExecutorLoop::SplitCommands(int index, const Algorithm::Handles &algorithms, const Vector(String) ¶meters, Algorithm::Handles &algorithms1, Vector(String) ¶meters1, Algorithm::Handles &algorithms2, Vector(String) ¶meters2) const {
// Clear vectors.
algorithms1.erase(algorithms1.begin(), algorithms1.end());
algorithms2.erase(algorithms2.begin(), algorithms2.end());
parameters1.erase(parameters1.begin(), parameters1.end());
parameters2.erase(parameters2.begin(), parameters2.end());
// Check input.
if (algorithms.size() != parameters.size())
return false;
// Check index validity.
if (index < 0 || index > algorithms.size())
return false;
int i;
// Do split.
for (i = 0; i < algorithms.size(); i++) {
if (i < index) {
algorithms1.push_back(algorithms[i]);
parameters1.push_back(parameters[i]);
}
else {
algorithms2.push_back(algorithms[i]);
parameters2.push_back(parameters[i]);
}
}
return true;
}
//-------------------------------------------------------------------
// Method........: InitializeSamplingScheme
// Author........: Aleksander 豩rn
// Date..........:
// Description...: Initializes internals used for partitioning
// data into training and testing sets.
//
// Default implementation does nothing.
// Comments......:
// Revisions.....:
//===================================================================
bool
SerialExecutorLoop::InitializeSamplingScheme(const DecisionTable &/*table*/, bool /*masked*/, const RNG &/*rng*/) {
return true;
}
//-------------------------------------------------------------------
// Method........: SaveLogEntry
// Author........: Aleksander 豩rn
// Date..........:
// Description...:
// Comments......:
// Revisions.....: A
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -