📄 calibrationcurve.cpp
字号:
//-------------------------------------------------------------------
// Author........: Aleksander 豩rn
// Date..........:
// Description...:
// Revisions.....:
//===================================================================
#include <stdafx.h> // Precompiled headers.
#include <copyright.h>
#include <kernel/structures/calibrationcurve.h>
#include <kernel/structures/classification.h>
#include <kernel/utilities/mathkit.h>
#include <kernel/utilities/systemkit.h>
#include <kernel/utilities/permuter.h>
#include <kernel/utilities/iokit.h>
#include <kernel/basic/vector.h>
#include <kernel/basic/message.h>
#include <kernel/system/fstream.h>
#include <common/configuration.h>
//-------------------------------------------------------------------
// Methods for class CalibrationCurve.
//===================================================================
//-------------------------------------------------------------------
// Constructors/destructor.
//===================================================================
CalibrationCurve::CalibrationCurve() {
decision_attribute_ = Undefined::Integer();
decision_class_ = Undefined::Integer();
}
CalibrationCurve::CalibrationCurve(const CalibrationCurve &in) {
targets_ = in.targets_;
outputs_ = in.outputs_;
indices_ = in.indices_;
targets_mean_ = in.targets_mean_;
outputs_mean_ = in.outputs_mean_;
indices_summed_ = in.indices_summed_;
group_sizes_ = in.group_sizes_;
decision_attribute_ = in.decision_attribute_;
decision_class_ = in.decision_class_;
}
CalibrationCurve::~CalibrationCurve() {
}
//-------------------------------------------------------------------
// Methods inherited from Identifier.
//===================================================================
IMPLEMENTIDMETHODS(CalibrationCurve, CALIBRATIONCURVE, BinaryOutcomeCurve)
//-------------------------------------------------------------------
// Methods inherited from Persistent.
//===================================================================
//-------------------------------------------------------------------
// Method........: Load
// Author........: Aleksander 豩rn
// Date..........:
// Description...:
// Comments......: BinaryOutcomeCurve::Load method not compatible with
// CalibrationCurve::Save method.
// Revisions.....:
//===================================================================
bool
CalibrationCurve::Load(ifstream &stream) {
return BinaryOutcomeCurve::Load(stream);
}
//-------------------------------------------------------------------
// Method........: Save
// Author........: Aleksander 豩rn
// Date..........:
// Description...:
// Comments......:
// Revisions.....:
//===================================================================
bool
CalibrationCurve::Save(ofstream &stream) const {
return Save(stream, Undefined::Id(), Undefined::String());
}
//-------------------------------------------------------------------
// Methods inherited from Structure.
//===================================================================
//-------------------------------------------------------------------
// Method........: Duplicate
// Author........: Aleksander 豩rn
// Date..........:
// Description...:
// Comments......:
// Revisions.....:
//===================================================================
Structure *
CalibrationCurve::Duplicate() const {
return new CalibrationCurve(*this);
}
//-------------------------------------------------------------------
// Method........: Clear
// Author........: Aleksander 豩rn
// Date..........:
// Description...:
// Comments......:
// Revisions.....:
//===================================================================
void
CalibrationCurve::Clear() {
// Erase data.
targets_.erase(targets_.begin(), targets_.end());
outputs_.erase(outputs_.begin(), outputs_.end());
indices_.erase(indices_.begin(), indices_.end());
targets_mean_.erase(targets_mean_.begin(), targets_mean_.end());
outputs_mean_.erase(outputs_mean_.begin(), outputs_mean_.end());
indices_summed_.erase(indices_summed_.begin(), indices_summed_.end());
group_sizes_.erase(group_sizes_.begin(), group_sizes_.end());
// Invalidate member variables.
decision_attribute_ = Undefined::Integer();
decision_class_ = Undefined::Integer();
}
//-------------------------------------------------------------------
// Methods inherited from BinaryOutcomeCurve
//===================================================================
//-------------------------------------------------------------------
// Method........: Create
// Author........: Aleksander 豩rn
// Date..........:
// Description...:
// Comments......:
// Revisions.....:
//===================================================================
bool
CalibrationCurve::Create(const Vector(int) &targets, const Vector(float) &outputs) {
return Create(targets, outputs, 1);
}
//-------------------------------------------------------------------
// Local methods.
//===================================================================
//-------------------------------------------------------------------
// Method........: Create
// Author........: Aleksander 豩rn
// Date..........:
// Description...:
// Comments......: Index of object assumed equal to index of
// classification.
// Revisions.....:
//===================================================================
bool
CalibrationCurve::Create(const Vector(ICPair) &pairs, int decision_class, int no_groups, bool progress) {
Message message;
int no_pairs = pairs.size();
// Is vector empty?
if (no_pairs == 0) {
Message::Error("No classifications to process.");
return false;
}
Vector(int) targets; targets.reserve(no_pairs);
Vector(float) outputs; outputs.reserve(no_pairs);
Vector(int) indices; indices.reserve(no_pairs);
int i;
// Create (actual class, model output) vectors.
for (i = 0; i < no_pairs; i++) {
// Notify user of progress?
if (progress) {
if (!message.Progress("Generating calibration data...", i, no_pairs))
return false;
}
// Skip invalid predictions.
if (pairs[i].second == NULL) {
Message::Warning("Invalid or no prediction value for object " + String::Format(i + 1) + ", skipped pair.", false);
continue;
}
// Binarize.
targets.push_back((pairs[i].first == decision_class) ? 1 : 0);
outputs.push_back(pairs[i].second->GetBinaryOutcomeCoefficient(decision_class));
indices.push_back(i);
}
return Create(targets, outputs, indices, decision_class, no_groups, progress);
}
//-------------------------------------------------------------------
// Method........: Create
// Author........: Aleksander 豩rn
// Date..........:
// Description...:
// Comments......: Index of object assumed equal to index of pair.
// Revisions.....:
//===================================================================
bool
CalibrationCurve::Create(const Vector(int) &targets, const Vector(float) &outputs, int decision_class, int no_groups, bool progress) {
Vector(int) indices;
indices.reserve(targets.size());
int i;
// Create index vector.
for (i = 0; i < targets.size(); i++)
indices.push_back(i);
return Create(targets, outputs, indices, decision_class, no_groups, progress);
}
//-------------------------------------------------------------------
// Method........: Create
// Author........: Aleksander 豩rn
// Date..........:
// Description...: Creates points on a calibration curve
// from a list of pairs (a, b), where:
//
// a = actual/target class, binarized wrt the
// specified decision class.
// b = model's output (predicted value) for the
// specified decision class.
//
// Comments......: a in {0, 1}, 0 <= b <= 1, b = Pr(a = 1).
// Revisions.....:
//===================================================================
bool
CalibrationCurve::Create(const Vector(int) &targets, const Vector(float) &outputs, const Vector(int) &indices, int decision_class, int no_groups, bool progress) {
Message message;
// Verify dimensions.
if (targets.size() != outputs.size())
return false;
if (targets.size() != indices.size())
return false;
int no_pairs = targets.size();
// Is vector empty?
if (no_pairs == 0) {
Message::Error("No pairs to process.");
return false;
}
// Valid number of groups?
if (no_groups <= 0) {
Message::Error("Invalid number of groups.");
return false;
}
if (no_groups > no_pairs) {
Message::Warning("Too many groups, using maximum.", false);
no_groups = no_pairs;
}
// Compute average group size.
float average_size = (no_pairs == no_groups) ? 1.0f : static_cast(float, no_pairs) / no_groups;
if (average_size < 5.0)
Message::Warning("Average group size (" + String::Format(average_size) + ") is small.", false);
int i, j;
// Verify binarity.
for (i = 0; i < no_pairs; i++) {
if (targets[i] != 0 && targets[i] != 1) {
Message::Error("Element in target vector is not 0 or 1.");
return false;
}
}
// Initialize data vectors.
targets_ = targets;
outputs_ = outputs;
indices_ = indices;
targets_mean_.erase(targets_mean_.begin(), targets_mean_.end());
outputs_mean_.erase(outputs_mean_.begin(), outputs_mean_.end());
indices_summed_.erase(indices_summed_.begin(), indices_summed_.end());
group_sizes_.erase(group_sizes_.begin(), group_sizes_.end());
targets_mean_.reserve(no_groups);
outputs_mean_.reserve(no_groups);
indices_summed_.reserve(no_groups);
group_sizes_.reserve(no_groups);
Vector(int) permutation;
Permuter<float> permuter;
// Sort vectors by model output.
permuter.Permute(outputs_, permutation);
MathKit::Permute(targets_, permutation);
MathKit::Permute(outputs_, permutation);
MathKit::Permute(indices_, permutation);
// Compute group sums.
for (i = 0; i < no_groups; i++) {
// Notify user of progress?
if (progress) {
if (!message.Progress("Computing calibration group sums...", i, no_groups))
return false;
}
// Determine start and end of group.
int start = MathKit::Round(i * average_size);
int size = MathKit::Round((i + 1) * average_size) - start;
int target_sum = 0;
float output_sum = 0;
int index_sum = 0;
// Compute group sums.
for (j = 0; j < size; j++) {
target_sum += targets_[start + j];
output_sum += outputs_[start + j];
index_sum += indices_[start + j];
}
indices_summed_.push_back(index_sum);
group_sizes_.push_back(size);
// Avoid numerical quirks.
if (size == 1) {
targets_mean_.push_back(static_cast(float, target_sum));
outputs_mean_.push_back(output_sum);
}
else {
targets_mean_.push_back(static_cast(float, target_sum) / size);
outputs_mean_.push_back(output_sum / size);
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -