📄 tdidt.cpp
字号:
PDistribution decision = vote(node, refexam, branchWeights);
if (decision)
return decision->highestProbValue(exam);
}
// couldn't classify, so we'll return something a priori
return findNodeValue(node, refexam);
}
PDistribution TTreeClassifier::classDistribution(const TExample &exam)
{
checkProperty(descender);
return classDistribution(tree, !domain || (exam.domain == domain) ? exam : TExample(domain, exam));
}
PDistribution TTreeClassifier::classDistribution(PTreeNode node, const TExample &exam)
{ PDiscDistribution branchWeights;
node = descender->call(node, exam, branchWeights);
if (!branchWeights) {
if (node->nodeClassifier)
return node->nodeClassifier->classDistribution(exam);
}
else
return vote(node, exam, branchWeights);
return CLONE(TDistribution, findNodeDistribution(node, exam));
}
PDistribution TTreeClassifier::vote(PTreeNode node, const TExample &exam, PDiscDistribution branchWeights)
{
PDistribution res = TDistribution::create(classVar);
TDistribution &ures = res.getReference();
TDiscDistribution::const_iterator bdi(branchWeights->begin()), bde(branchWeights->end());
TTreeNodeList::const_iterator bi(node->branches->begin());
for(; bdi!=bde; bdi++, bi++)
if (*bdi && *bi) {
PDistribution subDistr = classDistribution(*bi, exam);
if (subDistr) {
subDistr->normalize();
subDistr->operator *= (*bdi);
ures += subDistr;
}
}
ures.normalize();
return res;
}
void TTreeClassifier::predictionAndDistribution(const TExample &exam, TValue &val, PDistribution &distr)
{
checkProperty(descender);
TExample convex = (exam.domain != domain) ? TExample(domain, exam) : TExample();
const TExample &refexam = (exam.domain != domain) ? convex : exam;
PDiscDistribution splitDecision;
PTreeNode node = descender->call(tree, refexam, splitDecision);
if (!splitDecision) {
if (node->nodeClassifier)
node->nodeClassifier->predictionAndDistribution(refexam, val, distr);
else
distr = CLONE(TDistribution, findNodeDistribution(node, refexam));
}
else {
distr = vote(node, refexam, splitDecision);
val = distr->highestProbValue(exam);
}
}
PTreeNode TTreeDescender_UnknownToNode::operator()(PTreeNode node, const TExample &ex, PDiscDistribution &distr)
{
while (node->branchSelector && node->branches) {
TValue val = node->branchSelector->call(ex);
int nBranches = node->branches->size()-1;
if (val.isSpecial() || (val.intV<0) || (val.intV>=nBranches) || !node->branches->at(val.intV))
break;
else
node = node->branches->at(val.intV);
}
distr = PDiscDistribution();
return node;
}
PTreeNode TTreeDescender_UnknownToBranch::operator()(PTreeNode node, const TExample &ex, PDiscDistribution &distr)
{ while (node->branchSelector && node->branches) {
TValue val = node->branchSelector->call(ex);
int nBranches = node->branches->size()-1;
if (val.isSpecial() || (val.intV<0) || (val.intV>=nBranches) || !node->branches->at(val.intV))
node = node->branches->back();
else
node = node->branches->at(val.intV);
}
distr = PDiscDistribution();
return node;
}
int randomNonNull(const PTreeNodeList &branches, const int &roff)
{ int nonull = 0;
TTreeNodeList::const_iterator ni(branches->begin()), ne(branches->end());
for (; ni!=ne; ni++)
if (*ni)
nonull++;
if (!nonull)
return -1;
for(ni = branches->begin(), nonull = roff % (nonull+1); nonull; )
if (*(ni++))
nonull--;
return (ni-1)-branches->begin();
}
PTreeNode TTreeDescender_UnknownToCommonBranch::operator()(PTreeNode node, const TExample &ex, PDiscDistribution &distr)
{ while (node->branchSelector && node->branches) {
TValue val = node->branchSelector->call(ex);
int ind = val.isSpecial() ? -1 : val.intV;
if ((ind<0) || (ind>=int(node->branches->size())))
ind = node->branchSizes ? node->branchSizes->highestProbIntIndex(ex) : -1;
if ((ind<0) || !node->branches->at(ind)) {
ind = randomNonNull(node->branches, ex.sumValues());
if (ind<0)
break;
}
node = node->branches->at(ind);
}
distr = PDiscDistribution();
return node;
}
PTreeNode TTreeDescender_UnknownToCommonSelector::operator()(PTreeNode node, const TExample &ex, PDiscDistribution &distr)
{ while (node->branchSelector && node->branches) {
TValue val = node->branchSelector->call(ex);
int ind;
if (val.isSpecial()) {
TDiscDistribution *valdistr = val.svalV.AS(TDiscDistribution);
ind = valdistr ? valdistr->highestProbIntIndex(ex) : -1;
}
else
ind = val.intV<int(node->branches->size()) ? val.intV : -1;
if ((ind<0) || !node->branches->at(ind)) {
ind = randomNonNull(node->branches, ex.sumValues());
if (ind<0)
break;
}
node = node->branches->at(ind);
}
distr = PDiscDistribution();
return node;
}
PTreeNode TTreeDescender_UnknownMergeAsBranchSizes::operator()(PTreeNode node, const TExample &ex, PDiscDistribution &distr)
{ while (node->branchSelector && node->branches) {
TValue val = node->branchSelector->call(ex);
if (val.isSpecial() || (val.intV<0) || (val.intV>=int(node->branches->size())) || (!node->branches->at(val.intV))) {
distr = node->branchSizes;
return node;
}
else
node = node->branches->at(val.intV);
}
distr = PDiscDistribution();
return node;
}
PTreeNode TTreeDescender_UnknownMergeAsSelector::operator()(PTreeNode node, const TExample &ex, PDiscDistribution &distr)
{ while (node->branchSelector && node->branches) {
TValue val = node->branchSelector->call(ex);
if (val.isSpecial() || (val.intV<0) || (val.intV>=int(node->branches->size())) || (!node->branches->at(val.intV))) {
if (val.svalV && val.svalV.is_derived_from(TDiscDistribution))
distr = val.svalV;
else
distr = PDiscDistribution();
return node;
}
else
node = node->branches->at(val.intV);
}
distr = PDiscDistribution();
return node;
}
PTreeNode TTreePruner_SameMajority::operator()(PTreeNode root)
{ vector<bool> tmp;
return operator()(root, tmp);
}
/* Argument 'bestValues' gives values that are majority values for the subtree.
While iterating through branches, an intersection is computed. Branches may
return different sizes of 'bestValues'; the intersection is as long as the
shortest of the reported bestValues. */
PTreeNode TTreePruner_SameMajority::operator()(PTreeNode node, vector<bool> &bestValues)
{
PTreeNode newNode = CLONE(TTreeNode, node);
if (node->branchSelector) {
newNode->branches = mlnew TTreeNodeList();
int notfirst = 0;
PITERATE(TTreeNodeList, bi, node->branches)
if (*bi) {
if (notfirst++) {
vector<bool> subBest;
newNode->branches->push_back(operator()(*bi, subBest));
if (subBest.size() < bestValues.size())
bestValues.erase(bestValues.begin() + (bestValues.size() - subBest.size()), bestValues.end());
for(vector<bool>::iterator bvi(bestValues.begin()), bve(bestValues.end()), sbi(subBest.begin());
bvi!=bve; bvi++, sbi++)
*bvi = *bvi && *sbi;
}
else
newNode->branches->push_back(operator()(*bi, bestValues));
}
else
newNode->branches->push_back(PTreeNode());
vector<bool>::iterator pi(bestValues.begin());
for( ; (pi!=bestValues.end()) && !*pi; pi++);
if (pi!=bestValues.end()) {
newNode->branches = PTreeNodeList();
newNode->branchDescriptions = PStringList();
newNode->branchSelector = PClassifier();
newNode->branchSizes = PDiscDistribution();
}
}
else {
TDefaultClassifier *maj = node->nodeClassifier.AS(TDefaultClassifier);
if (maj) {
TDiscDistribution *ddist = maj->defaultDistribution.AS(TDiscDistribution);
if (ddist) {
float bestF = -1;
TDiscDistribution::const_iterator bi(ddist->begin()), bb=bi, be(ddist->end());
for(; bi!=be; bi++)
if (*bi>=bestF) {
bb = bi;
bestF = *bb;
}
// The loop runs to one before the last; the last is always true
for(bi = ddist->begin(); bi!=bb; bi++)
bestValues.push_back(*bi==bestF);
bestValues.push_back(true);
}
}
}
return newNode;
}
TTreePruner_m::TTreePruner_m(const float &am)
: m(am)
{}
PTreeNode TTreePruner_m::operator()(PTreeNode root)
{ if (m<0.0)
raiseError("'m' should be positive");
PDistribution dist;
if (root->distribution)
dist = root->distribution;
else if (root->contingency && root->contingency->classes)
dist = root->contingency->classes;
else
raiseError("the node does not store class distribution (check your flags for TreeLearner)");
TDiscDistribution *ddist = dist.AS(TDiscDistribution);
if (ddist) {
vector<float> m_by_p;
const float mba = m/ddist->abs;
PITERATE(TDiscDistribution, di, ddist)
m_by_p.push_back(*di*mba);
PTreeNode prunned;
operator()(root, m_by_p, prunned);
return prunned;
}
TContDistribution *cdist = dist.AS(TContDistribution);
if (cdist) {
PTreeNode prunned;
operator()(root, cdist->error() * m, prunned);
return prunned;
}
raiseError("class distribution of unknown type (neither discrete nor continuous)");
return PTreeNode();
}
float TTreePruner_m::estimateError(const PTreeNode &node, const vector<float> &m_by_p) const
{
const TDiscDistribution *dist;
if (node->distribution)
dist = node->distribution.AS(TDiscDistribution);
else if (node->contingency)
dist = node->contingency->classes.AS(TDiscDistribution);
else
raiseError("the node does not store class distribution (check your flags for TreeLearner)");
if (!dist)
raiseError("invalid class distribution (DiscDistribution expected)");
if ((dist->abs < 1e-10) || (dist->abs+m < 1e-10))
return 0.0;
float maxe = 0.0;
vector<float>::const_iterator mi(m_by_p.begin());
for(TDiscDistribution::const_iterator di(dist->begin()), de(dist->end()); di!=de; di++, mi++) {
float thise = *di + *mi;
if (thise>maxe)
maxe = thise;
}
return 1.0 - maxe/(dist->abs+m);
}
float TTreePruner_m::estimateError(const PTreeNode &node, const float &m_by_se) const
{ const TContDistribution *dist;
if (node->distribution)
dist = node->distribution.AS(TContDistribution);
else if (node->contingency)
dist = node->contingency->classes.AS(TContDistribution);
else
raiseError("the node does not store class distribution (check your flags for TreeLearner)");
if (!dist)
raiseError("invalid class distribution (ContDistribution expected)");
if ((dist->abs==0.0) || (dist->abs+m==0.0))
return 0.0;
return (dist->abs*dist->error() + m_by_se) / (dist->abs+m);
}
#ifdef _MSC_VER
#pragma optimize("p", off)
#endif
template<class T>
float TTreePruner_m::operator()(PTreeNode node, const T &m_by_p, PTreeNode &newNode) const
{
newNode = CLONE(TTreeNode, node);
if (node->branchSelector) {
newNode->branches = mlnew TTreeNodeList(node->branches->size());
float sumerr = 0, sumweights = 0;
TDiscDistribution::const_iterator bwi (node->branchSizes->begin());
TTreeNodeList::iterator oi (node->branches->begin()), oe (node->branches->end());
TTreeNodeList::iterator bi (newNode->branches->begin());
for (; oi!=oe; oi++, bi++, bwi++)
if (*oi) {
sumerr += *bwi * operator()(*oi, m_by_p, *bi);
sumweights += *bwi;
}
const float staticError = estimateError(node, m_by_p);
const float backupError = sumerr/sumweights;
if (staticError<backupError) {
newNode->branches = PTreeNodeList();
newNode->branchDescriptions = PStringList();
newNode->branchSelector = PClassifier();
newNode->branchSizes = PDiscDistribution();
return staticError;
}
else
return backupError;
}
else
return estimateError(node, m_by_p);
}
#ifdef _MSC_VER
#pragma optimize("p", on)
#endif
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -