⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 tdidt.cpp

📁 orange源码 数据挖掘技术
💻 CPP
📖 第 1 页 / 共 2 页
字号:
    PDistribution decision = vote(node, refexam, branchWeights);
    if (decision)
      return decision->highestProbValue(exam);
  }

  // couldn't classify, so we'll return something a priori
  return findNodeValue(node, refexam);
}


PDistribution TTreeClassifier::classDistribution(const TExample &exam)
{
  checkProperty(descender);
  return classDistribution(tree, !domain || (exam.domain == domain) ? exam : TExample(domain, exam));
}


PDistribution TTreeClassifier::classDistribution(PTreeNode node, const TExample &exam)
{ PDiscDistribution branchWeights;
  node = descender->call(node, exam, branchWeights);

  if (!branchWeights) {
    if (node->nodeClassifier)
      return node->nodeClassifier->classDistribution(exam);
  }
  else 
    return vote(node, exam, branchWeights);

  return CLONE(TDistribution, findNodeDistribution(node, exam));
}


PDistribution TTreeClassifier::vote(PTreeNode node, const TExample &exam, PDiscDistribution branchWeights)
{
  PDistribution res = TDistribution::create(classVar);
  TDistribution &ures = res.getReference();
  TDiscDistribution::const_iterator bdi(branchWeights->begin()), bde(branchWeights->end());
  TTreeNodeList::const_iterator bi(node->branches->begin());
  for(; bdi!=bde; bdi++, bi++)
    if (*bdi && *bi) {
      PDistribution subDistr = classDistribution(*bi, exam);
      if (subDistr) {
        subDistr->normalize();
        subDistr->operator *= (*bdi);
        ures += subDistr;
      }
    }
  ures.normalize();
  return res;
}


void TTreeClassifier::predictionAndDistribution(const TExample &exam, TValue &val, PDistribution &distr)
{
  checkProperty(descender);
 
  TExample convex = (exam.domain != domain) ? TExample(domain, exam) : TExample();
  const TExample &refexam = (exam.domain != domain) ? convex : exam;
  PDiscDistribution splitDecision;
  PTreeNode node = descender->call(tree, refexam, splitDecision);
  if (!splitDecision) {
    if (node->nodeClassifier)
      node->nodeClassifier->predictionAndDistribution(refexam, val, distr);
    else
      distr = CLONE(TDistribution, findNodeDistribution(node, refexam));
  }
  else {
    distr = vote(node, refexam, splitDecision);
    val = distr->highestProbValue(exam);
  }
}



PTreeNode TTreeDescender_UnknownToNode::operator()(PTreeNode node, const TExample &ex, PDiscDistribution &distr)
{ 
  while (node->branchSelector && node->branches) {
    TValue val = node->branchSelector->call(ex);
    int nBranches = node->branches->size()-1;
    if (val.isSpecial() || (val.intV<0) || (val.intV>=nBranches) || !node->branches->at(val.intV))
      break;
    else
      node = node->branches->at(val.intV);
  }

  distr = PDiscDistribution();
  return node;
}


PTreeNode TTreeDescender_UnknownToBranch::operator()(PTreeNode node, const TExample &ex, PDiscDistribution &distr)
{ while (node->branchSelector && node->branches) {
    TValue val = node->branchSelector->call(ex);
    int nBranches = node->branches->size()-1;
    if (val.isSpecial() || (val.intV<0) || (val.intV>=nBranches) || !node->branches->at(val.intV))
      node = node->branches->back();
    else
      node = node->branches->at(val.intV);
  }

  distr = PDiscDistribution();
  return node;
}


int randomNonNull(const PTreeNodeList &branches, const int &roff)
{ int nonull = 0;
  TTreeNodeList::const_iterator ni(branches->begin()), ne(branches->end());
  for (; ni!=ne; ni++)
    if (*ni)
      nonull++;
  
  if (!nonull)
    return -1;

  for(ni = branches->begin(), nonull = roff % (nonull+1); nonull; )
    if (*(ni++))
      nonull--;

  return (ni-1)-branches->begin();
}


PTreeNode TTreeDescender_UnknownToCommonBranch::operator()(PTreeNode node, const TExample &ex, PDiscDistribution &distr)
{ while (node->branchSelector && node->branches) {
    TValue val = node->branchSelector->call(ex);
    int ind = val.isSpecial() ? -1 : val.intV;

    if ((ind<0) || (ind>=int(node->branches->size())))
      ind = node->branchSizes ? node->branchSizes->highestProbIntIndex(ex) : -1;

    if ((ind<0) || !node->branches->at(ind)) {
      ind = randomNonNull(node->branches, ex.sumValues());
      if (ind<0)
        break;
    }
    node = node->branches->at(ind);
  }

  distr = PDiscDistribution();
  return node;
}



PTreeNode TTreeDescender_UnknownToCommonSelector::operator()(PTreeNode node, const TExample &ex, PDiscDistribution &distr)
{ while (node->branchSelector && node->branches) {
    TValue val = node->branchSelector->call(ex);
    int ind;
    if (val.isSpecial()) {
      TDiscDistribution *valdistr = val.svalV.AS(TDiscDistribution);
      ind = valdistr ? valdistr->highestProbIntIndex(ex) : -1;
    }
    else
      ind = val.intV<int(node->branches->size()) ? val.intV : -1;

    if ((ind<0) || !node->branches->at(ind)) {
      ind = randomNonNull(node->branches, ex.sumValues());
      if (ind<0)
        break;
    }

    node = node->branches->at(ind);
  }

  distr = PDiscDistribution();
  return node;
}



PTreeNode TTreeDescender_UnknownMergeAsBranchSizes::operator()(PTreeNode node, const TExample &ex, PDiscDistribution &distr)
{ while (node->branchSelector && node->branches) {
    TValue val = node->branchSelector->call(ex);
    if (val.isSpecial() || (val.intV<0) || (val.intV>=int(node->branches->size())) || (!node->branches->at(val.intV))) {
      distr = node->branchSizes;
      return node;
    }
    else
      node = node->branches->at(val.intV);
  }

  distr = PDiscDistribution();
  return node;
}



PTreeNode TTreeDescender_UnknownMergeAsSelector::operator()(PTreeNode node, const TExample &ex, PDiscDistribution &distr)
{ while (node->branchSelector && node->branches) {
    TValue val = node->branchSelector->call(ex);
    if (val.isSpecial() || (val.intV<0) || (val.intV>=int(node->branches->size())) || (!node->branches->at(val.intV))) {
      if (val.svalV && val.svalV.is_derived_from(TDiscDistribution))
        distr = val.svalV;
      else
        distr = PDiscDistribution();
      return node;
    }
    else
      node = node->branches->at(val.intV);
  }

  distr = PDiscDistribution();
  return node;
}






PTreeNode TTreePruner_SameMajority::operator()(PTreeNode root)
{ vector<bool> tmp;
  return operator()(root, tmp);
}


/* Argument 'bestValues' gives values that are majority values for the subtree.
   While iterating through branches, an intersection is computed. Branches may
   return different sizes of 'bestValues'; the intersection is as long as the
   shortest of the reported bestValues. */

PTreeNode TTreePruner_SameMajority::operator()(PTreeNode node, vector<bool> &bestValues)
{ 
  PTreeNode newNode = CLONE(TTreeNode, node);

  if (node->branchSelector) {
    newNode->branches = mlnew TTreeNodeList();
    int notfirst = 0;
    PITERATE(TTreeNodeList, bi, node->branches)
      if (*bi) {
        if (notfirst++) {
          vector<bool> subBest;
          newNode->branches->push_back(operator()(*bi, subBest));

          if (subBest.size() < bestValues.size())
            bestValues.erase(bestValues.begin() + (bestValues.size() - subBest.size()), bestValues.end());
          for(vector<bool>::iterator bvi(bestValues.begin()), bve(bestValues.end()), sbi(subBest.begin());
              bvi!=bve; bvi++, sbi++)
            *bvi = *bvi && *sbi;
        }
        else
          newNode->branches->push_back(operator()(*bi, bestValues));
      }
      else
        newNode->branches->push_back(PTreeNode());

    vector<bool>::iterator pi(bestValues.begin());
    for( ; (pi!=bestValues.end()) && !*pi; pi++);
    if (pi!=bestValues.end()) {
      newNode->branches = PTreeNodeList();
      newNode->branchDescriptions = PStringList();
      newNode->branchSelector = PClassifier();
      newNode->branchSizes = PDiscDistribution();
    }
  }

  else {
    TDefaultClassifier *maj = node->nodeClassifier.AS(TDefaultClassifier);
    if (maj) {
      TDiscDistribution *ddist = maj->defaultDistribution.AS(TDiscDistribution);
      if (ddist) {
        float bestF = -1;
        TDiscDistribution::const_iterator bi(ddist->begin()), bb=bi, be(ddist->end());
        for(; bi!=be; bi++)
          if (*bi>=bestF) {
            bb = bi;
            bestF = *bb;
          }
            
        // The loop runs to one before the last; the last is always true
        for(bi = ddist->begin(); bi!=bb; bi++)
          bestValues.push_back(*bi==bestF);
        bestValues.push_back(true);
      }
    }
  }

  return newNode;
}

  
TTreePruner_m::TTreePruner_m(const float &am)
: m(am)
{}

PTreeNode TTreePruner_m::operator()(PTreeNode root)
{ if (m<0.0)
    raiseError("'m' should be positive");
  
  PDistribution dist;
  if (root->distribution)
    dist = root->distribution;
  else if (root->contingency && root->contingency->classes)
    dist = root->contingency->classes;
  else
    raiseError("the node does not store class distribution (check your flags for TreeLearner)");


  TDiscDistribution *ddist = dist.AS(TDiscDistribution);
  if (ddist) {
    vector<float> m_by_p;
    const float mba = m/ddist->abs;
    PITERATE(TDiscDistribution, di, ddist)
      m_by_p.push_back(*di*mba);

    PTreeNode prunned;
    operator()(root, m_by_p, prunned);  
    return prunned;
  }

  TContDistribution *cdist = dist.AS(TContDistribution);
  if (cdist) {
    PTreeNode prunned;
    operator()(root, cdist->error() * m, prunned);
    return prunned;
  }

  raiseError("class distribution of unknown type (neither discrete nor continuous)");
  return PTreeNode();
}


float TTreePruner_m::estimateError(const PTreeNode &node, const vector<float> &m_by_p) const
{ 
  const TDiscDistribution *dist;
  if (node->distribution)
    dist = node->distribution.AS(TDiscDistribution);
  else if (node->contingency)
    dist = node->contingency->classes.AS(TDiscDistribution);
  else
    raiseError("the node does not store class distribution (check your flags for TreeLearner)");

  if (!dist)
    raiseError("invalid class distribution (DiscDistribution expected)");

  if ((dist->abs < 1e-10) || (dist->abs+m < 1e-10))
    return 0.0;

  float maxe = 0.0;
  vector<float>::const_iterator mi(m_by_p.begin());
  for(TDiscDistribution::const_iterator di(dist->begin()), de(dist->end()); di!=de; di++, mi++) {
    float thise = *di + *mi;
    if (thise>maxe)
      maxe = thise;
  }

  return 1.0 - maxe/(dist->abs+m);
}


float TTreePruner_m::estimateError(const PTreeNode &node, const float &m_by_se) const
{ const TContDistribution *dist;
  if (node->distribution)
    dist = node->distribution.AS(TContDistribution);
  else if (node->contingency)
    dist = node->contingency->classes.AS(TContDistribution);
  else
    raiseError("the node does not store class distribution (check your flags for TreeLearner)");
  if (!dist)
    raiseError("invalid class distribution (ContDistribution expected)");

  if ((dist->abs==0.0) || (dist->abs+m==0.0))
    return 0.0;

  return (dist->abs*dist->error() + m_by_se) / (dist->abs+m);
}

#ifdef _MSC_VER
#pragma optimize("p", off)
#endif

template<class T>
float TTreePruner_m::operator()(PTreeNode node, const T &m_by_p, PTreeNode &newNode) const
{ 
  newNode = CLONE(TTreeNode, node);

  if (node->branchSelector) {
    newNode->branches = mlnew TTreeNodeList(node->branches->size());

    float sumerr = 0, sumweights = 0;
    TDiscDistribution::const_iterator bwi (node->branchSizes->begin());
    TTreeNodeList::iterator oi (node->branches->begin()), oe (node->branches->end());
    TTreeNodeList::iterator bi (newNode->branches->begin());
    for (; oi!=oe; oi++, bi++, bwi++)
      if (*oi) {
        sumerr += *bwi * operator()(*oi, m_by_p, *bi);
        sumweights += *bwi;
      }

    const float staticError = estimateError(node, m_by_p);
    const float backupError = sumerr/sumweights;

    if (staticError<backupError) {
      newNode->branches = PTreeNodeList();
      newNode->branchDescriptions = PStringList();
      newNode->branchSelector = PClassifier();
      newNode->branchSizes = PDiscDistribution();
      return staticError;
    }
    else
      return backupError;
  }

  else
    return estimateError(node, m_by_p);
}

#ifdef _MSC_VER
#pragma optimize("p", on)
#endif

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -