📄 densitytreefactory.cpp
字号:
#include "DensityTreeFactory.hpp"#include "DensityTreeInternal.hpp"#include "DensityTreeLeaf.hpp"using namespace indii::ml::aux;DensityTreeFactory::DensityTreeFactory(const unsigned int P, const double rho, const DensityTreeSplitStrategy strategy) : sampleThreshold(sqrt(P)), depthThreshold(0.25*log(P)/log(2.0)), rho(rho), strategy(strategy) { /* pre-condition */ assert (rho >= 0.0 && rho <= 1.0); assert (sampleThreshold >= 2); assert (depthThreshold >= 1); //}DensityTreeFactory::DensityTreeFactory(const unsigned int sampleThreshold, const unsigned int depthThreshold, const double rho, const DensityTreeSplitStrategy strategy) : sampleThreshold(sampleThreshold), depthThreshold(depthThreshold), rho(rho), strategy(strategy) { /* pre-condition */ assert (rho >= 0.0 && rho <= 1.0); assert (sampleThreshold >= 2); assert (depthThreshold >= 1); //}DensityTreeFactory::~DensityTreeFactory() { //}DensityTreeNode* DensityTreeFactory::create(DiracMixturePdf& p) const { if (p.getNumComponents() <= 1) { /* degenerate case */ return NULL; } else { /* bounds at extreme points */ Bounds bounds = bound(p); BoundedDiracMixturePdf b = { p, &bounds.lower, &bounds.upper }; /* create root node */ return create(b); }}DensityTreeNode* DensityTreeFactory::create(BoundedDiracMixturePdf& b, const double mixDensity, const unsigned int depth) const { DensityTreeNode* node; double density, volume = 1.0; /* calculate volume at node */ volume = norm_1(*b.upper - *b.lower); assert (volume > 0.0); /* calculate density at node */ if (depth == 0) { // if root node density = b.p.getTotalWeight() / volume; } else { density = (1.0 - rho) * b.p.getTotalWeight() / volume + rho * mixDensity; } /* determine type of node */ if (b.p.getNumComponents() >= sampleThreshold && depth + 1 < depthThreshold) { /* split weighted sample set */ SplitDiracMixturePdf s = split(b); vector leftUpper(*b.upper); vector rightLower(*b.lower); leftUpper(s.index) = s.value; rightLower(s.index) = s.value; BoundedDiracMixturePdf leftBranch = { s.left, b.lower, &leftUpper }; BoundedDiracMixturePdf rightBranch = { s.right, &rightLower, b.upper }; DensityTreeNode* left = create(leftBranch, density, depth + 1); DensityTreeNode* right = create(rightBranch, density, depth + 1); /* make internal node */ node = new DensityTreeInternal(*b.lower, *b.upper, s.index, s.value, left, right); } else { /* make leaf node */ node = new DensityTreeLeaf(*b.lower, *b.upper, density*volume, volume); } return node;}DensityTreeFactory::SplitDiracMixturePdf DensityTreeFactory::split( BoundedDiracMixturePdf& b) const { switch (strategy) { case SPLIT_VARIANCE: return splitVariance(b.p); case SPLIT_LENGTH: return splitLength(b); default: return splitRandom(b); }}DensityTreeFactory::SplitDiracMixturePdf DensityTreeFactory::splitVariance( DiracMixturePdf& p) { /* pre-condition */ assert (p.getNumComponents() >= 2); // variance must exist const vector& mu = p.getExpectation(); vector sigma(mu.size()); // only need variance, not whole covariance /* calculate variance */ DiracMixturePdf::weighted_component_const_iterator iter, end; sigma.clear(); iter = p.getComponents().begin(); end = p.getComponents().end(); while (iter != end) { noalias(sigma) += iter->w * element_prod(iter->x, iter->x); iter++; } sigma /= p.getTotalWeight(); noalias(sigma) -= element_prod(mu, mu); /* find dimension of highest variance */ unsigned int index = 0, i; for (i = 1; i < sigma.size(); i++) { if (sigma(i) > sigma(index)) { index = i; } } return split(p, index, mu(index));}DensityTreeFactory::SplitDiracMixturePdf DensityTreeFactory::splitLength( BoundedDiracMixturePdf& b) { Bounds bounds = bound(b.p); const vector length(bounds.upper - bounds.lower); unsigned int index = 0, i; for (i = 1; i < length.size(); i++) { if (length(i) > length(index)) { index = i; } } return split(b.p, index, (bounds.upper(index) + bounds.lower(index)) / 2.0);}DensityTreeFactory::SplitDiracMixturePdf DensityTreeFactory::splitRandom( BoundedDiracMixturePdf& b) { double random = Random::uniform(0, b.p.getDimensions()); unsigned int index = static_cast<unsigned int>(floor(random)); /** * @todo Really only need to calculate bounds on chosen dimension. */ Bounds bounds = bound(b.p); return split(b.p, index, (bounds.upper(index) + bounds.lower(index)) / 2.0);}DensityTreeFactory::SplitDiracMixturePdf DensityTreeFactory::split( DiracMixturePdf& p, const unsigned int index, const double value) { /* pre-condition */ assert (index < p.getDimensions()); const unsigned int N = p.getDimensions(); DiracMixturePdf left(N), right(N); DiracMixturePdf::weighted_component_const_iterator iter, end; iter = p.getComponents().begin(); end = p.getComponents().end(); while (iter != end) { if (iter->x(index) < value) { left.addComponent(*iter); } else { right.addComponent(*iter); } iter++; } SplitDiracMixturePdf result = { left, right, index, value }; return result;}DensityTreeFactory::Bounds DensityTreeFactory::bound(DiracMixturePdf& p) { const unsigned int N = p.getDimensions(); unsigned int i; vector lower(N), upper(N); DiracMixturePdf::weighted_component_const_iterator iter, end; iter = p.getComponents().begin(); end = p.getComponents().end(); assert (iter != end); noalias(lower) = iter->x; noalias(upper) = lower; iter++; while (iter != end) { for (i = 0; i < N; i++) { if (iter->x(i) < lower(i)) { lower(i) = iter->x(i); } else if (iter->x(i) > upper(i)) { upper(i) = iter->x(i); } } iter++; } Bounds result = { lower, upper }; return result;}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -