📄 balltreeclass.cc
字号:
//////////////////////////////////////////////////////////////////////////////////////
// BallTreeClass -- class definitions for a BallTree (actually KD-tree)
// object, primarily for use in matlab MEX files.
//
// See BallTree.h for the class definition.
//
//////////////////////////////////////////////////////////////////////////////////////
//
// Written by Alex Ihler and Mike Mandel
// Copyright (C) 2003 Alexander Ihler; distributable under GPL -- see README.txt
//
//////////////////////////////////////////////////////////////////////////////////////
#define MEX
#include <math.h>
#include "BallTree.h"
#include <utility>
#include <map>
const char* BallTree::FIELD_NAMES[] = {"D", "N", "centers", "ranges", "weights",
"lower", "upper", "leftch", "rightch", "perm"};
const int BallTree::nfields = 10;
// Given the leaves, build the rest of the tree from the top down.
// Split the leaves along the most spread coordinate, build two balls
// out of those, and then build a ball around those two children.
void BallTree::buildBall(BallTree::index low, BallTree::index high, BallTree::index root)
{
// special case for N=1 trees
if(low == high) {
lowest_leaf[root] = low;
highest_leaf[root] = high;
left_child[root] = low;
// point right child to the same as left for calc stats, and then
// point it to the correct NO_CHILD afterwards. kinda kludgey
right_child[root] = high;
calcStats(root);
right_child[root] = NO_CHILD;
return;
}
BallTree::index coord, split, left, right;
coord = most_spread_coord(low, high); // find dimension of widest spread
// split the current leaves into two groups, to build balls on them.
// Chose the most spread coordinate to split them on, and make sure
// there are the same number of points in each (+-1 for round off
// error).
split = (low + high) / 2;
select(coord, split, low, high);
// an alternative is to use partition, but that doesn't deal well
// with repeated numbers and it doesn't split into balanced sets.
// split = partition(coord, low, high);
// if the left sub-tree is just one leaf, don't make a new non-leaf
// node for it, just point left_idx directly to the leaf itself.
if(split <= low) left = low;
else left = next++;
// same for the right
if(split+1 >= high) right = high;
else right = next++;
lowest_leaf[root] = low;
highest_leaf[root] = high;
left_child[root] = left;
right_child[root] = right;
// build sub-trees if necessary
if(left != low) buildBall(low, split, left);
if(right != high) buildBall(split+1, high, right);
calcStats(root);
}
// Find the dimension along which the leaves between low and high
// inclusive have the greatest variance
unsigned long BallTree::most_spread_coord(BallTree::index low, BallTree::index high) const
{
BallTree::index dimension, point, max_dim;
double mean, variance, max_variance;
max_variance = 0;
max_dim = 0;
for(dimension = 0; dimension<dims; dimension++) {
mean = 0;
for(point = dims*low + dimension; point < dims*high; point += dims)
mean += centers[point];
mean /= (high - low);
variance = 0;
for(point = dims*low + dimension; point < dims*high; point += dims)
variance += (centers[point] - mean) * (centers[point] - mean);
if(variance > max_variance) {
max_variance = variance;
max_dim = dimension;
}
}
return max_dim;
}
// straight from CLR, the unrandomized partition algorithm for
// quicksort. Partitions the leaves from low to high inclusive around
// a random pivot in the given dimension. Does not affect non-leaf
// nodes, but does relabel the leaves from low to high.
unsigned long BallTree::partition(unsigned long dimension, unsigned long low,
unsigned long high)
{
unsigned long pivot;
pivot = low; // not randomized, could set pivot to a random element
while(low < high) {
while(centers[dims*high + dimension] >= centers[dims*pivot + dimension])
high--;
while(centers[dims*low + dimension] < centers[dims*pivot + dimension])
low++;
swap(low, high);
pivot = high;
}
return high;
}
// Function to partition the data into two (equal-sized or near as possible)
// sets, one of which is uniformly greater than the other in the given
// dimension.
void BallTree::select(unsigned long dimension, unsigned long position,
unsigned long low, unsigned long high)
{
unsigned long m,r,i;
while (low < high) {
r = (low + high)/2;
swap(r,low);
m = low;
for (i=low+1; i<=high; i++) {
if (centers[dimension+dims*i] < centers[dimension+dims*low]) {
m++;
swap(m,i);
}
}
swap(low,m);
if (m <= position) low=m+1;
if (m >= position) high=m-1;
}
}
// Swap the ith leaf with the jth leaf. Actually, only swap the
// weights, permutation, and centers, so only for swapping
// leaves. Will not swap ranges correctly and will not swap children
// correctly.
void BallTree::swap(unsigned long i, unsigned long j)
{
unsigned long k;
double tmp;
if (i==j) return;
// swap weights
tmp = weights[i]; weights[i] = weights[j]; weights[j] = tmp;
// swap perm
k = permutation[i]; permutation[i] = permutation[j]; permutation[j] = k;
// swap centers
i *= dims; j *= dims;
for(k=0; k<dims; i++,j++,k++) {
tmp = centers[i]; centers[i] = centers[j]; centers[j] = tmp;
}
}
//
// Calculate the statistics of level "root" based on the statistics of
// its left and right children.
//
void BallTree::calcStats(BallTree::index root)
{
BallTree::index Ni, NiL, NiR;
index d;
BallTree::index leftI = left(root), rightI=right(root); // get children indices
if (!validIndex(leftI) || !validIndex(rightI)) return; // nothing to do if this
// isn't a parent node
// figure out the center and ranges of this ball based on it's children
double max, min;
for(d=0; d<dims; d++) {
if (center(leftI)[d] + range(leftI)[d] > center(rightI)[d] + range(rightI)[d])
max = center(leftI)[d] + range(leftI)[d];
else
max = center(rightI)[d] + range(rightI)[d];
if (center(leftI)[d] - range(leftI)[d] < center(rightI)[d] - range(rightI)[d])
min = center(leftI)[d] - range(leftI)[d];
else
min = center(rightI)[d] - range(rightI)[d];
centers[root*dims+d] = (max+min) / 2;
ranges[root*dims+d] = (max-min) / 2;
}
// if the left ball is the same as the right ball (should only
// happen when calling the function directly with the same argument
// twice), don't count the weight twice
if(leftI != rightI)
weights[root] = weights[leftI] + weights[rightI];
else
weights[root] = weights[leftI];
}
// Public method to build the tree, just calls the private method with
// the proper starting arguments.
void BallTree::buildTree()
{
BallTree::index i,j;
for (j=0, i=num_points; j<num_points; i++,j++) {
for(index k=0; k<dims; k++)
ranges[i*dims+k] = 0;
lowest_leaf[i] = highest_leaf[i] = i;
left_child[i] = i;
right_child[i] = NO_CHILD;
permutation[i] = j;
}
next = 1;
buildBall(num_points, 2*num_points - 1, 0);
}
// Figure out which of two children in this tree is closest to a given
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -