📄 decisiontree.h
字号:
#ifndef DECISIONTREE_H#define DECISIONTREE_H#include <list>#include <iostream>#include "Range.h"#include "VarSet.h"#include <math.h>// Gaussian random variable func, from NRCfloat gasdev();#ifndef PI#define PI 3.14159265#endifnamespace DT {class Vertex;class Branch;class Leaf;class Multinomial;class BinGaussian;};using namespace DT;class DecisionTree{ int var; int maxVal; list<int> inputs; list<const Leaf*> leafList; Vertex* currVertex; Branch* currBranch; Leaf* currLeaf; Multinomial* currMultinomial; Vertex* headVertex; double currMean; double currSD; double currProbMissing;public: DecisionTree(int us, int numVals) :var(us), maxVal(numVals), currVertex(NULL), currBranch(NULL), currLeaf(NULL), headVertex(NULL) { /* NOP */ } ~DecisionTree(); double getProb(double state, double* allStates) const; double getLogProb(double state, double* allStates) const; double sample(double *allStates) const; // Get a list of all values on which this variable is split list<double> getSplits(int var) const; list<const Leaf*> DecisionTree::getLeafList() const { return leafList; } int getRange() const { return maxVal; } void addInput(int inputVar) { inputs.push_back(inputVar); } void beginVertex(int splitVar); void endVertex(); void beginBranch(); void endBranch(); void endValues(list<Range> values); void beginMultinomial(); void endProbs(double* probs); void endMean(double mean) { currMean = mean; } void endSD(double SD) { currSD = SD; } void endProbMissing(double probMissing) { currProbMissing = probMissing; } void endBinGaussian();//private: const Leaf* getLeaf(double* allStates) const; friend class Potential; friend class Distribution;};namespace DT {class Vertex{ int split; list<Branch*> children; Branch* parent; Vertex(int splitVar, Branch* par) :split(splitVar), parent(par) { /* NOP */ } // Delete children in destructor ~Vertex(); void addChild(Branch* child) { children.push_back(child); } int getSplitVar() const { return split; } Branch* getParent() { return parent; } void getSplits(int var, list<double>& splits); Leaf* getLeaf(double* allStates); friend class DecisionTree; friend class Branch;};class Branch{ list<Range> values; Vertex* parent; Vertex* childVertex; Leaf* childLeaf; Branch(Vertex* p) :parent(p), childVertex(NULL), childLeaf(NULL) { /* NOP */ } // Delete children in destructor (but not parent!) ~Branch(); void setValues(list<Range> newVals) { values = newVals; } void setLeaf(Leaf* leaf) { childLeaf = leaf; } void setVertex(Vertex* vertex) { childVertex = vertex; } Vertex* getParent() { return parent; } bool inRange(double val); Leaf* getLeaf(double* allStates); void getSplits(int var, list<double>& splits); friend class DecisionTree; friend class Vertex;};class Leaf{public: virtual double getProb(double val) const = 0; virtual double getLogProb(double val) const = 0; virtual double sample() const = 0;};class Multinomial : public Leaf{ int maxVal; double* probs;public: Multinomial(int numVals) :maxVal(numVals), probs(0) { /* NOP */ }; virtual ~Multinomial() { delete probs; } void setProbs(double* newProbs) { probs = newProbs; } double getProb(double val) const { return probs[(int)val]; } double getLogProb(double val) const { return log(probs[(int)val]); } double sample() const { double r = (double)rand()/RAND_MAX; double totalProb = 0.0; int i; for (i = 0; totalProb < r && i < maxVal; i++) { totalProb += probs[i]; } return (i - 1); }};class BinGaussian : public Leaf{public: double mean; double SD; double probMissing;public: BinGaussian(double newMean, double newSD, double newProbMissing) : mean(newMean), SD(newSD), probMissing(newProbMissing) { /* NOP */ } double getProb(double val) const { // Cache probs static double lastVal = -HUGE; static double lastProb = 0; if (val == VarSet::UNKNOWN) { // DEBUG if (probMissing == 0.0) { cout << "binGaussian returning zero prob! (probMissing)\n"; } return probMissing; } else if (val == lastVal) { // DEBUG if (lastProb == 0.0) { cout << "binGaussian returning zero prob! (lastVal)\n"; } return lastProb; } else { double prob = 1.0/(SD * sqrt(2*PI)); prob *= exp(-0.5*(val - mean)*(val - mean)/(SD*SD)); prob *= (1.0 - probMissing); lastVal = val; lastProb = prob; // DEBUG if (prob == 0.0) { cout << "binGaussian returning zero prob!\n"; } return prob; } } double getLogProb(double val) const { // Cache probs static double lastVal = -HUGE; static double lastProb = 0; if (val == VarSet::UNKNOWN) { return log(probMissing); } else if (val == lastVal) { return lastProb; } else { double lp = -log(SD * sqrt(2*PI)); lp += (-0.5*(val - mean)*(val - mean)/(SD*SD)); lp += log(1.0 - probMissing); lastVal = val; lastProb = lp; // DEBUG if (isnan(lp)) { cout << "binGaussian returning nan log-prob!\n"; } return lp; } } double sample() const { double p = (double)rand()/RAND_MAX; if (p < probMissing) { return VarSet::UNKNOWN; } else { return SD*gasdev() + mean; } }};}; /* End namespace DT */#endif // ndef DECISIONTREE_H
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -