📄 kde.hpp
字号:
targetNode.difference(queryNode, x); if (K(N(x)) > 0.0) { if (queryNode.isInternal()) { if (targetNode.isInternal()) { /* split both query and target nodes */ queryNodes.push(queryNode.getLeft()); targetNodes.push(targetNode.getLeft()); queryNodes.push(queryNode.getLeft()); targetNodes.push(targetNode.getRight()); queryNodes.push(queryNode.getRight()); targetNodes.push(targetNode.getLeft()); queryNodes.push(queryNode.getRight()); targetNodes.push(targetNode.getRight()); } else { /* split query node only */ queryNodes.push(queryNode.getLeft()); targetNodes.push(&targetNode); queryNodes.push(queryNode.getRight()); targetNodes.push(&targetNode); } } else { /* split target node only */ queryNodes.push(&queryNode); targetNodes.push(targetNode.getLeft()); queryNodes.push(&queryNode); targetNodes.push(targetNode.getRight()); } } } } if (normalise) { result /= p.getTotalWeight(); } } return result;}template <class NT, class KT, class PT>indii::ml::aux::matrix indii::ml::aux::distributedDualTreeDensity( PT& queryTree, PT& targetTree, const matrix& ws, const NT& N, const KT& K, const bool normalise) { boost::mpi::communicator world; const unsigned int size = world.size(); unsigned int i; matrix result(dualTreeDensity(queryTree, targetTree, ws, N, K, false)); rotate(*queryTree.getData()); rotate(queryTree); rotate(result); for (i = 1; i < size; i++) { noalias(result) += dualTreeDensity(queryTree, targetTree, ws, N, K, false); rotate(*queryTree.getData()); rotate(queryTree); rotate(result); } if (normalise) { result /= targetTree.getData()->getDistributedTotalWeight(); } return result;}template <class NT, class KT, class PT>indii::ml::aux::matrix indii::ml::aux::selfTreeDensity(PT& tree, const matrix& ws, const NT& N, const KT& K, const bool normalise) { /* pre-condition */ assert (ws.size2() == tree.getData()->getSize()); DiracMixturePdf& p = *tree.getData(); PartitionTreeNode* root = tree.getRoot(); matrix result(ws.size1(), p.getSize()); result.clear(); if (root != NULL) { std::stack<PartitionTreeNode*> queryNodes, targetNodes; std::stack<bool> doCrosses; // for query equals target tree optimisations vector x(p.getDimensions()); unsigned int i, j; double w, d; bool doCross; queryNodes.push(root); targetNodes.push(root); doCrosses.push(false); while (!queryNodes.empty()) { PartitionTreeNode& queryNode = *queryNodes.top(); queryNodes.pop(); PartitionTreeNode& targetNode = *targetNodes.top(); targetNodes.pop(); doCross = doCrosses.top(); doCrosses.pop(); if (queryNode.isLeaf() && targetNode.isLeaf()) { i = queryNode.getIndex(); j = targetNode.getIndex(); noalias(x) = p.get(i) - p.get(j); d = K(N(x)); if (doCross) { noalias(column(result,j)) += d * column(ws,i); } noalias(column(result,i)) += d * column(ws,j); } else if (queryNode.isLeaf() && targetNode.isPrune()) { i = queryNode.getIndex(); const std::vector<unsigned int>& js = targetNode.getIndices(); for (j = 0; j < js.size(); j++) { noalias(x) = p.get(i) - p.get(js[j]); d = K(N(x)); if (doCross) { noalias(column(result,js[j])) += d * column(ws,i); } noalias(column(result,i)) += d * column(ws,js[j]); } } else if (queryNode.isPrune() && targetNode.isLeaf()) { const std::vector<unsigned int>& is = queryNode.getIndices(); j = targetNode.getIndex(); for (i = 0; i < is.size(); i++) { noalias(x) = p.get(is[i]) - p.get(j); d = K(N(x)); if (doCross) { noalias(column(result,j)) += d * column(ws,is[i]); } noalias(column(result,is[i])) += d * column(ws,j); } } else if (queryNode.isPrune() && targetNode.isPrune()) { const std::vector<unsigned int>& is = queryNode.getIndices(); const std::vector<unsigned int>& js = targetNode.getIndices(); for (i = 0; i < is.size(); i++) { for (j = 0; j < js.size(); j++) { noalias(x) = p.get(is[i]) - p.get(js[j]); d = K(N(x)); if (doCross) { noalias(column(result,js[j])) += d * column(ws,is[i]); } noalias(column(result,is[i])) += d * column(ws,js[j]); } } } else { /* should we recurse? */ targetNode.difference(queryNode, x); if (K(N(x)) > 0.0) { if (queryNode.isInternal()) { if (targetNode.isInternal()) { /* split both query and target nodes */ queryNodes.push(queryNode.getLeft()); targetNodes.push(targetNode.getLeft()); doCrosses.push(doCross); queryNodes.push(queryNode.getLeft()); targetNodes.push(targetNode.getRight()); if (&queryNode == &targetNode) { /* symmetric, so just double left-right evaluation */ doCrosses.push(true); } else { /* asymmetric, so evaluate right-left separately */ doCrosses.push(doCross); queryNodes.push(queryNode.getRight()); targetNodes.push(targetNode.getLeft()); doCrosses.push(doCross); } queryNodes.push(queryNode.getRight()); targetNodes.push(targetNode.getRight()); doCrosses.push(doCross); } else { /* split query node only */ queryNodes.push(queryNode.getLeft()); targetNodes.push(&targetNode); doCrosses.push(doCross); queryNodes.push(queryNode.getRight()); targetNodes.push(&targetNode); doCrosses.push(doCross); } } else { /* split target node only */ queryNodes.push(&queryNode); targetNodes.push(targetNode.getLeft()); doCrosses.push(doCross); queryNodes.push(&queryNode); targetNodes.push(targetNode.getRight()); doCrosses.push(doCross); } } } } if (normalise) { result /= p.getTotalWeight(); } } return result;}template <class NT, class KT, class PT>indii::ml::aux::matrix indii::ml::aux::distributedSelfTreeDensity( PT& tree, const matrix& ws, const NT& N, const KT& K, const bool normalise) { boost::mpi::communicator world; const unsigned int size = world.size(); matrix result(selfTreeDensity(tree, ws, N, K, false)); if (size > 1) { /* cross densities */ unsigned int crosses = (size - 1) / 2; bool leftover = (size - 1) % 2 > 0; unsigned int i; matrix ws2(ws); matrix result2(result.size1(), result.size2()); result2.clear(); PT* tree2 = dynamic_cast<PT*>(tree.clone()); DiracMixturePdf q(*tree.getData()); tree2->setData(&q); for (i = 0; i < crosses; i++) { rotate(*tree2->getData()); rotate(*tree2); rotate(result2); rotate(ws2); crossTreeDensity(tree, *tree2, ws, ws2, N, K, result, result2, false, false); } if (leftover) { rotate(*tree2->getData()); rotate(*tree2); rotate(ws2); noalias(result) += dualTreeDensity(tree, *tree2, ws2, N, K, false); } /* return results to original node */ rotate(result2, size - i); noalias(result) += result2; delete tree2; } if (normalise) { result /= tree.getData()->getDistributedTotalWeight(); } return result;}template <class NT, class KT, class PT>void indii::ml::aux::crossTreeDensity( PT& tree1, PT& tree2, const matrix& ws1, const matrix& ws2, const NT& N, const KT& K, matrix& result1, matrix& result2, const bool clear, const bool normalise) { /* pre-condition */ assert (ws1.size2() == tree1.getData()->getSize()); assert (ws2.size2() == tree2.getData()->getSize()); assert (result1.size2() == tree1.getData()->getSize()); assert (result2.size2() == tree2.getData()->getSize()); assert (result1.size1() == ws1.size1()); assert (result2.size1() == ws2.size1()); assert (tree1.getData()->getDimensions() == tree2.getData()->getDimensions()); DiracMixturePdf& p1 = *tree1.getData(); DiracMixturePdf& p2 = *tree2.getData(); PartitionTreeNode* root1 = tree1.getRoot(); PartitionTreeNode* root2 = tree2.getRoot(); if (clear) { result1.clear(); result2.clear(); } if (root1 != NULL && root2 != NULL) { std::stack<PartitionTreeNode*> nodes1, nodes2; vector x(p1.getDimensions()); unsigned int i, j; double w, d; nodes1.push(root1); nodes2.push(root2); while (!nodes1.empty()) { PartitionTreeNode& node1 = *nodes1.top(); nodes1.pop(); PartitionTreeNode& node2 = *nodes2.top(); nodes2.pop(); if (node1.isLeaf() && node2.isLeaf()) { i = node1.getIndex(); j = node2.getIndex(); noalias(x) = p1.get(i) - p2.get(j); d = K(N(x)); noalias(column(result1,i)) += d * column(ws2,j); noalias(column(result2,j)) += d * column(ws1,i); } else if (node1.isLeaf() && node2.isPrune()) { i = node1.getIndex(); const std::vector<unsigned int>& js = node2.getIndices(); for (j = 0; j < js.size(); j++) { noalias(x) = p1.get(i) - p2.get(js[j]); d = K(N(x)); noalias(column(result1,i)) += d * column(ws2,js[j]); noalias(column(result2,js[j])) += d * column(ws1,i); } } else if (node1.isPrune() && node2.isLeaf()) { const std::vector<unsigned int>& is = node1.getIndices(); j = node2.getIndex(); for (i = 0; i < is.size(); i++) { noalias(x) = p1.get(is[i]) - p2.get(j); d = K(N(x)); noalias(column(result1,is[i])) += d * column(ws2,j); noalias(column(result2,j)) += d * column(ws1,is[i]); } } else if (node1.isPrune() && node2.isPrune()) { const std::vector<unsigned int>& is = node1.getIndices(); const std::vector<unsigned int>& js = node2.getIndices(); for (i = 0; i < is.size(); i++) { for (j = 0; j < js.size(); j++) { noalias(x) = p1.get(is[i]) - p2.get(js[j]); d = K(N(x)); noalias(column(result1,is[i])) += d * column(ws2,js[j]); noalias(column(result2,js[j])) += d * column(ws1,is[i]); } } } else { /* should we recurse? */ node2.difference(node1, x); if (K(N(x)) > 0.0) { if (node1.isInternal()) { if (node2.isInternal()) { /* split both query and target nodes */ nodes1.push(node1.getLeft()); nodes2.push(node2.getLeft()); nodes1.push(node1.getLeft()); nodes2.push(node2.getRight()); nodes1.push(node1.getRight()); nodes2.push(node2.getLeft()); nodes1.push(node1.getRight()); nodes2.push(node2.getRight()); } else { /* split query node only */ nodes1.push(node1.getLeft()); nodes2.push(&node2); nodes1.push(node1.getRight()); nodes2.push(&node2); } } else { /* split target node only */ nodes1.push(&node1); nodes2.push(node2.getLeft()); nodes1.push(&node1); nodes2.push(node2.getRight()); } } } } if (normalise) { result1 /= p1.getTotalWeight(); result2 /= p2.getTotalWeight(); } }}#endif
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -