⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 balltreeclass.cc

📁 Non-parametric density estimation
💻 CC
📖 第 1 页 / 共 2 页
字号:
//////////////////////////////////////////////////////////////////////////////////////
// BallTreeClass  --  class definitions for a BallTree (actually KD-tree) 
//                    object, primarily for use in matlab MEX files.
//
// See BallTree.h for the class definition.
//
//////////////////////////////////////////////////////////////////////////////////////
//
// Written by Alex Ihler and Mike Mandel
// Copyright (C) 2003 Alexander Ihler; distributable under GPL -- see README.txt
//
//////////////////////////////////////////////////////////////////////////////////////

#define MEX
#include <math.h>
#include "BallTree.h"
#include <utility>
#include <map>

const char* BallTree::FIELD_NAMES[] = {"D", "N", "centers", "ranges", "weights",
            "lower", "upper", "leftch", "rightch", "perm"};
const int BallTree::nfields = 10;

// Given the leaves, build the rest of the tree from the top down.
// Split the leaves along the most spread coordinate, build two balls
// out of those, and then build a ball around those two children.
void BallTree::buildBall(BallTree::index low, BallTree::index high, BallTree::index root)
{
  // special case for N=1 trees
  if(low == high) {
    lowest_leaf[root] = low;
    highest_leaf[root] = high;
    left_child[root] = low;

    // point right child to the same as left for calc stats, and then
    // point it to the correct NO_CHILD afterwards.  kinda kludgey
    right_child[root] = high;
    calcStats(root);
    right_child[root] = NO_CHILD;

    return;
  }

  BallTree::index coord, split, left, right;
  coord = most_spread_coord(low, high);    // find dimension of widest spread
  
  // split the current leaves into two groups, to build balls on them.
  // Chose the most spread coordinate to split them on, and make sure
  // there are the same number of points in each (+-1 for round off
  // error).
  split = (low + high) / 2;
  select(coord, split, low, high);

  // an alternative is to use partition, but that doesn't deal well
  // with repeated numbers and it doesn't split into balanced sets.
//   split = partition(coord, low, high);

  // if the left sub-tree is just one leaf, don't make a new non-leaf
  // node for it, just point left_idx directly to the leaf itself.
  if(split <= low)    left = low;
  else                left = next++;

  // same for the right
  if(split+1 >= high) right = high;
  else                right = next++;

  lowest_leaf[root]  = low;
  highest_leaf[root] = high;
  left_child[root]   = left;
  right_child[root]  = right;

  // build sub-trees if necessary
  if(left != low)    buildBall(low, split, left);
  if(right != high)  buildBall(split+1, high, right);

  calcStats(root);
}

// Find the dimension along which the leaves between low and high
// inclusive have the greatest variance
unsigned long BallTree::most_spread_coord(BallTree::index low, BallTree::index high) const
{
  BallTree::index dimension, point, max_dim;
  double mean, variance, max_variance;

  max_variance = 0;
  max_dim = 0;

  for(dimension = 0; dimension<dims; dimension++) {
    mean = 0;
    for(point = dims*low + dimension; point < dims*high; point += dims)
      mean += centers[point];
    mean /= (high - low);

    variance = 0;
    for(point = dims*low + dimension; point < dims*high; point += dims)
      variance += (centers[point] - mean) * (centers[point] - mean);
    if(variance > max_variance) {
      max_variance = variance;
      max_dim = dimension;
    }
  }

  return max_dim;
}


// straight from CLR, the unrandomized partition algorithm for
// quicksort.  Partitions the leaves from low to high inclusive around
// a random pivot in the given dimension.  Does not affect non-leaf
// nodes, but does relabel the leaves from low to high.
unsigned long BallTree::partition(unsigned long dimension, unsigned long low, 
				  unsigned long high) 
{
  unsigned long pivot;

  pivot = low;  // not randomized, could set pivot to a random element

  while(low < high) {
    while(centers[dims*high + dimension] >= centers[dims*pivot + dimension])
      high--;
    while(centers[dims*low + dimension] < centers[dims*pivot + dimension])
      low++;
    
    swap(low, high);
    pivot = high;
  }

  return high;
}


// Function to partition the data into two (equal-sized or near as possible)
//   sets, one of which is uniformly greater than the other in the given
//   dimension.
void BallTree::select(unsigned long dimension, unsigned long position,
		      unsigned long low, unsigned long high)
{
  unsigned long m,r,i;
  
  while (low < high) {
    r = (low + high)/2; 
    swap(r,low);
    m = low;
    for (i=low+1; i<=high; i++) {
      if (centers[dimension+dims*i] < centers[dimension+dims*low]) {
        m++;
        swap(m,i);
      } 
    }
    swap(low,m);
    if (m <= position) low=m+1;
    if (m >= position) high=m-1;
  }    
}


// Swap the ith leaf with the jth leaf.  Actually, only swap the
// weights, permutation, and centers, so only for swapping
// leaves. Will not swap ranges correctly and will not swap children
// correctly.
void BallTree::swap(unsigned long i, unsigned long j) 
{
  unsigned long k;
  double tmp;

  if (i==j) return;

  // swap weights
  tmp = weights[i];    weights[i] = weights[j];          weights[j] = tmp;

  // swap perm
  k = permutation[i];  permutation[i] = permutation[j];  permutation[j] = k;

  // swap centers
  i *= dims;   j *= dims;
  for(k=0; k<dims; i++,j++,k++) {
    tmp = centers[i];   centers[i]  = centers[j];   centers[j]  = tmp;
  }
}

//
// Calculate the statistics of level "root" based on the statistics of
//   its left and right children.
//
void BallTree::calcStats(BallTree::index root)
{
  BallTree::index Ni, NiL, NiR;
  index d;

  BallTree::index leftI = left(root), rightI=right(root);   // get children indices 
  if (!validIndex(leftI) || !validIndex(rightI)) return;    // nothing to do if this
                                                            //   isn't a parent node

  // figure out the center and ranges of this ball based on it's children
  double max, min;
  for(d=0; d<dims; d++) {
    if (center(leftI)[d] + range(leftI)[d] > center(rightI)[d] + range(rightI)[d])
      max = center(leftI)[d] + range(leftI)[d];
    else
      max = center(rightI)[d] + range(rightI)[d];

    if (center(leftI)[d] - range(leftI)[d] < center(rightI)[d] - range(rightI)[d])
      min = center(leftI)[d] - range(leftI)[d];
    else
      min = center(rightI)[d] - range(rightI)[d];

    centers[root*dims+d] = (max+min) / 2;
    ranges[root*dims+d] = (max-min) / 2;
  }    
  
  // if the left ball is the same as the right ball (should only
  // happen when calling the function directly with the same argument
  // twice), don't count the weight twice
  if(leftI != rightI)
    weights[root] = weights[leftI] + weights[rightI];
  else
    weights[root] = weights[leftI];
}
  

// Public method to build the tree, just calls the private method with
// the proper starting arguments.
void BallTree::buildTree()
{
  BallTree::index i,j;
  for (j=0, i=num_points; j<num_points; i++,j++) {
    for(index k=0; k<dims; k++)
      ranges[i*dims+k] = 0;

    lowest_leaf[i] = highest_leaf[i] = i; 
    left_child[i] = i; 
    right_child[i] = NO_CHILD;
    permutation[i] = j;
  }
  next = 1;

  buildBall(num_points, 2*num_points - 1, 0);
}

// Figure out which of two children in this tree is closest to a given

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -