⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 kde.hpp

📁 dysii is a C++ library for distributed probabilistic inference and learning in large-scale dynamical
💻 HPP
📖 第 1 页 / 共 2 页
字号:
        targetNode.difference(queryNode, x);        if (K(N(x)) > 0.0) {          if (queryNode.isInternal()) {            if (targetNode.isInternal()) {              /* split both query and target nodes */              queryNodes.push(queryNode.getLeft());              targetNodes.push(targetNode.getLeft());                        queryNodes.push(queryNode.getLeft());              targetNodes.push(targetNode.getRight());              queryNodes.push(queryNode.getRight());              targetNodes.push(targetNode.getLeft());              queryNodes.push(queryNode.getRight());              targetNodes.push(targetNode.getRight());                    } else {              /* split query node only */              queryNodes.push(queryNode.getLeft());              targetNodes.push(&targetNode);                        queryNodes.push(queryNode.getRight());              targetNodes.push(&targetNode);            }          } else {            /* split target node only */            queryNodes.push(&queryNode);            targetNodes.push(targetNode.getLeft());                    queryNodes.push(&queryNode);            targetNodes.push(targetNode.getRight());          }        }      }    }        if (normalise) {      result /= p.getTotalWeight();    }  }    return result;}template <class NT, class KT, class PT>indii::ml::aux::matrix indii::ml::aux::distributedDualTreeDensity(    PT& queryTree, PT& targetTree, const matrix& ws,    const NT& N, const KT& K, const bool normalise) {  boost::mpi::communicator world;  const unsigned int size = world.size();  unsigned int i;    matrix result(dualTreeDensity(queryTree, targetTree, ws, N, K, false));  rotate(*queryTree.getData());  rotate(queryTree);  rotate(result);  for (i = 1; i < size; i++) {    noalias(result) += dualTreeDensity(queryTree, targetTree, ws, N, K,        false);    rotate(*queryTree.getData());    rotate(queryTree);    rotate(result);  }  if (normalise) {    result /= targetTree.getData()->getDistributedTotalWeight();  }  return result;}template <class NT, class KT, class PT>indii::ml::aux::matrix indii::ml::aux::selfTreeDensity(PT& tree,    const matrix& ws, const NT& N, const KT& K, const bool normalise) {  /* pre-condition */  assert (ws.size2() == tree.getData()->getSize());    DiracMixturePdf& p = *tree.getData();  PartitionTreeNode* root = tree.getRoot();  matrix result(ws.size1(), p.getSize());  result.clear();    if (root != NULL) {    std::stack<PartitionTreeNode*> queryNodes, targetNodes;    std::stack<bool> doCrosses; // for query equals target tree optimisations    vector x(p.getDimensions());    unsigned int i, j;    double w, d;    bool doCross;    queryNodes.push(root);    targetNodes.push(root);    doCrosses.push(false);        while (!queryNodes.empty()) {      PartitionTreeNode& queryNode = *queryNodes.top();      queryNodes.pop();      PartitionTreeNode& targetNode = *targetNodes.top();      targetNodes.pop();      doCross = doCrosses.top();      doCrosses.pop();      if (queryNode.isLeaf() && targetNode.isLeaf()) {        i = queryNode.getIndex();        j = targetNode.getIndex();        noalias(x) = p.get(i) - p.get(j);        d = K(N(x));        if (doCross) {          noalias(column(result,j)) += d * column(ws,i);        }        noalias(column(result,i)) += d * column(ws,j);      } else if (queryNode.isLeaf() && targetNode.isPrune()) {        i = queryNode.getIndex();        const std::vector<unsigned int>& js = targetNode.getIndices();        for (j = 0; j < js.size(); j++) {          noalias(x) = p.get(i) - p.get(js[j]);          d = K(N(x));          if (doCross) {            noalias(column(result,js[j])) += d * column(ws,i);          }          noalias(column(result,i)) += d * column(ws,js[j]);        }      } else if (queryNode.isPrune() && targetNode.isLeaf()) {        const std::vector<unsigned int>& is = queryNode.getIndices();        j = targetNode.getIndex();        for (i = 0; i < is.size(); i++) {          noalias(x) = p.get(is[i]) - p.get(j);          d = K(N(x));          if (doCross) {            noalias(column(result,j)) += d * column(ws,is[i]);          }          noalias(column(result,is[i])) += d * column(ws,j);        }      } else if (queryNode.isPrune() && targetNode.isPrune()) {        const std::vector<unsigned int>& is = queryNode.getIndices();        const std::vector<unsigned int>& js = targetNode.getIndices();        for (i = 0; i < is.size(); i++) {          for (j = 0; j < js.size(); j++) {            noalias(x) = p.get(is[i]) - p.get(js[j]);            d = K(N(x));            if (doCross) {              noalias(column(result,js[j])) += d * column(ws,is[i]);            }            noalias(column(result,is[i])) += d * column(ws,js[j]);          }        }      } else {        /* should we recurse? */        targetNode.difference(queryNode, x);        if (K(N(x)) > 0.0) {          if (queryNode.isInternal()) {            if (targetNode.isInternal()) {              /* split both query and target nodes */              queryNodes.push(queryNode.getLeft());              targetNodes.push(targetNode.getLeft());              doCrosses.push(doCross);                        queryNodes.push(queryNode.getLeft());              targetNodes.push(targetNode.getRight());              if (&queryNode == &targetNode) {                /* symmetric, so just double left-right evaluation */                doCrosses.push(true);              } else {                /* asymmetric, so evaluate right-left separately */                doCrosses.push(doCross);                queryNodes.push(queryNode.getRight());                targetNodes.push(targetNode.getLeft());                doCrosses.push(doCross);              }                            queryNodes.push(queryNode.getRight());              targetNodes.push(targetNode.getRight());                      doCrosses.push(doCross);            } else {              /* split query node only */              queryNodes.push(queryNode.getLeft());              targetNodes.push(&targetNode);              doCrosses.push(doCross);                        queryNodes.push(queryNode.getRight());              targetNodes.push(&targetNode);              doCrosses.push(doCross);            }          } else {            /* split target node only */            queryNodes.push(&queryNode);            targetNodes.push(targetNode.getLeft());            doCrosses.push(doCross);                    queryNodes.push(&queryNode);            targetNodes.push(targetNode.getRight());            doCrosses.push(doCross);          }        }      }    }        if (normalise) {      result /= p.getTotalWeight();    }  }    return result;}template <class NT, class KT, class PT>indii::ml::aux::matrix indii::ml::aux::distributedSelfTreeDensity(    PT& tree, const matrix& ws, const NT& N, const KT& K,    const bool normalise) {  boost::mpi::communicator world;  const unsigned int size = world.size();    matrix result(selfTreeDensity(tree, ws, N, K, false));    if (size > 1) {    /* cross densities */    unsigned int crosses = (size - 1) / 2;    bool leftover = (size - 1) % 2 > 0;    unsigned int i;      matrix ws2(ws);    matrix result2(result.size1(), result.size2());    result2.clear();      PT* tree2 = dynamic_cast<PT*>(tree.clone());    DiracMixturePdf q(*tree.getData());    tree2->setData(&q);        for (i = 0; i < crosses; i++) {      rotate(*tree2->getData());      rotate(*tree2);      rotate(result2);      rotate(ws2);      crossTreeDensity(tree, *tree2, ws, ws2, N, K, result, result2,          false, false);    }    if (leftover) {      rotate(*tree2->getData());      rotate(*tree2);      rotate(ws2);      noalias(result) += dualTreeDensity(tree, *tree2, ws2, N, K, false);    }    /* return results to original node */    rotate(result2, size - i);    noalias(result) += result2;    delete tree2;  }  if (normalise) {    result /= tree.getData()->getDistributedTotalWeight();  }  return result;}template <class NT, class KT, class PT>void indii::ml::aux::crossTreeDensity(    PT& tree1, PT& tree2, const matrix& ws1,    const matrix& ws2, const NT& N, const KT& K, matrix& result1,    matrix& result2, const bool clear, const bool normalise) {  /* pre-condition */  assert (ws1.size2() == tree1.getData()->getSize());  assert (ws2.size2() == tree2.getData()->getSize());  assert (result1.size2() == tree1.getData()->getSize());  assert (result2.size2() == tree2.getData()->getSize());  assert (result1.size1() == ws1.size1());  assert (result2.size1() == ws2.size1());  assert (tree1.getData()->getDimensions() ==      tree2.getData()->getDimensions());    DiracMixturePdf& p1 = *tree1.getData();  DiracMixturePdf& p2 = *tree2.getData();  PartitionTreeNode* root1 = tree1.getRoot();  PartitionTreeNode* root2 = tree2.getRoot();  if (clear) {    result1.clear();    result2.clear();  }    if (root1 != NULL && root2 != NULL) {    std::stack<PartitionTreeNode*> nodes1, nodes2;    vector x(p1.getDimensions());    unsigned int i, j;    double w, d;    nodes1.push(root1);    nodes2.push(root2);        while (!nodes1.empty()) {      PartitionTreeNode& node1 = *nodes1.top();      nodes1.pop();      PartitionTreeNode& node2 = *nodes2.top();      nodes2.pop();      if (node1.isLeaf() && node2.isLeaf()) {        i = node1.getIndex();        j = node2.getIndex();        noalias(x) = p1.get(i) - p2.get(j);        d = K(N(x));        noalias(column(result1,i)) += d * column(ws2,j);        noalias(column(result2,j)) += d * column(ws1,i);      } else if (node1.isLeaf() && node2.isPrune()) {        i = node1.getIndex();        const std::vector<unsigned int>& js = node2.getIndices();        for (j = 0; j < js.size(); j++) {          noalias(x) = p1.get(i) - p2.get(js[j]);          d = K(N(x));          noalias(column(result1,i)) += d * column(ws2,js[j]);          noalias(column(result2,js[j])) += d * column(ws1,i);        }      } else if (node1.isPrune() && node2.isLeaf()) {        const std::vector<unsigned int>& is = node1.getIndices();        j = node2.getIndex();        for (i = 0; i < is.size(); i++) {          noalias(x) = p1.get(is[i]) - p2.get(j);          d = K(N(x));          noalias(column(result1,is[i])) += d * column(ws2,j);          noalias(column(result2,j)) += d * column(ws1,is[i]);        }      } else if (node1.isPrune() && node2.isPrune()) {        const std::vector<unsigned int>& is = node1.getIndices();        const std::vector<unsigned int>& js = node2.getIndices();        for (i = 0; i < is.size(); i++) {          for (j = 0; j < js.size(); j++) {            noalias(x) = p1.get(is[i]) - p2.get(js[j]);            d = K(N(x));            noalias(column(result1,is[i])) += d * column(ws2,js[j]);            noalias(column(result2,js[j])) += d * column(ws1,is[i]);          }        }      } else {        /* should we recurse? */        node2.difference(node1, x);        if (K(N(x)) > 0.0) {          if (node1.isInternal()) {            if (node2.isInternal()) {              /* split both query and target nodes */              nodes1.push(node1.getLeft());              nodes2.push(node2.getLeft());                        nodes1.push(node1.getLeft());              nodes2.push(node2.getRight());              nodes1.push(node1.getRight());              nodes2.push(node2.getLeft());                            nodes1.push(node1.getRight());              nodes2.push(node2.getRight());                    } else {              /* split query node only */              nodes1.push(node1.getLeft());              nodes2.push(&node2);                        nodes1.push(node1.getRight());              nodes2.push(&node2);            }          } else {            /* split target node only */            nodes1.push(&node1);            nodes2.push(node2.getLeft());                    nodes1.push(&node1);            nodes2.push(node2.getRight());          }        }      }    }        if (normalise) {      result1 /= p1.getTotalWeight();      result2 /= p2.getTotalWeight();    }  }}#endif

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -