⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 balltreeclass.cc

📁 Non-parametric density estimation
💻 CC
📖 第 1 页 / 共 2 页
字号:
// ball in another tree.  Returns the index in this tree of the closer
// child.
BallTree::index BallTree::closer(BallTree::index myLeft, BallTree::index myRight, const BallTree& otherTree,
			       BallTree::index otherRoot) const 
{
  double dist_sq_l = 0, dist_sq_r = 0;
  for(int i=0; i<dims; i++) {
    dist_sq_l += (otherTree.center(otherRoot)[i] - center(myLeft)[i]) * 
      (otherTree.center(otherRoot)[i] - center(myLeft)[i]);
    dist_sq_r += (otherTree.center(otherRoot)[i] - center(myRight)[i]) * 
      (otherTree.center(otherRoot)[i] - center(myRight)[i]);
  }

  if (dist_sq_l < dist_sq_r)
    return myLeft;
  else 
    return myRight;
}

//
// Perform a *slight* adjustment of the tree: move the points by delta, but
//   don't reform the whole tree; just fix up the statistics.
//
void BallTree::movePoints(double* delta)
{
  index i;
  for (i=leafFirst(root());i<=leafLast(root());i++)
    for (unsigned int k=0;k<dims;k++)                   // first adjust locations by delta
      centers[dims*i+k] += delta[ getIndexOf(i)*dims + k ];
  for (i=num_points-1; i != 0; i--)                     // then recompute stats of
    calcStats(i);                                       //   parent nodes
  calcStats(root());                                    //   and finally root node
}

// Assumes newWeights is the right size (num_points)
void BallTree::changeWeights(const double *newWeights) {

  for(index i=num_points, j=0; i<num_points*2; i++, j++)
    weights[i] = newWeights[ getIndexOf(i) ];

  for (index i=num_points-1; i != 0; i--)
    calcStats(i);
  calcStats(root());
}



/////////////////////// k nearest neighbors functions ///////////////////////

// returns distance squared
double BallTree::minDist(index myBall, const double* point) const
{
  double dist = 0, tmp;

  for(index i=0; i<dims; i++) {
    tmp = fabs(center(myBall)[i] - point[i]) - range(myBall)[i];
    if(tmp >= 0)
      dist += tmp*tmp;
  }
  
  return dist;
}
double BallTree::maxDist(index myBall, const double* point) const
{
  double dist = 0, tmp;

  for(index i=0; i<dims; i++) {
    tmp = fabs(center(myBall)[i] - point[i]) + range(myBall)[i];
    dist += tmp*tmp;
  }
  
  return dist;
}


typedef std::multimap<double, BallTree::index> myMap;

// nns is a K x N matrix
// dists is a 1 x N vector
// points is a D x N matrix
void BallTree::kNearestNeighbors(index *nns, double *dists, const double *points, 
				 int N, int k) 
  const
{
  myMap m;
  int leavesDone = 0;
  index lastBall;

  for(index target = 0; target < N*dims; target += dims, ++leavesDone) {  

    int nnsSoFar = 0;
    double leastDist = maxDist(root(), points+target);

    m.insert(myMap::value_type(minDist(root(), points+target), root()));
    // examining points in order of min dist means that when you see a
    // point, it is the closest point left
    
    while(! m.empty() && nnsSoFar < k) {
      index current = (*m.begin()).second;
      m.erase(m.begin());
      
      if(isLeaf(current)) {
	// since the nodes are sorted by minDist, a leaf at the front of
	// the pq must be the next nearest neighbor
	nns[leavesDone*k + nnsSoFar++] = current;
	lastBall = current;
      } else {  // not a leaf
	// push both children
	m.insert(myMap::value_type(minDist(left_child[current], points+target), 
				   left_child[current]));
	m.insert(myMap::value_type(minDist(right_child[current], points+target), 
				   right_child[current]));
      }
      
      bool keepInnerPruning = true;
      myMap::iterator it = m.begin(), last;
      // get rid of ineligible balls and find the min distance
      while(it != m.end()) {
	current = (*it).second;
	double max = maxDist(current, points+target);
	if(max < leastDist && Npts(current) >= k - nnsSoFar)
	  leastDist = max;
	
//	if((*it).first > leastDist) {
//	  // we see the points in order of minDist, so once we see one
//	  // that's too big, the rest will also be too big
//	  m.erase(it, m.end());
//	  break;
//	}
	
	myMap::iterator last = it++;
	if(keepInnerPruning) {
	  if(it != m.end() && max < (*it).first && Npts(current) < k - nnsSoFar) {
	    // if the closest ball doesn't have too many points and all of
	    // them are nearer than any other ball, include them all
	    //   note "<" not "<=" so that *dist will be right
	    lastBall = current;
	    for(index i=leafFirst(current); i <= leafLast(current); i++)
	      nns[leavesDone*k + nnsSoFar++] = i;
	    m.erase(last);
	  } else {
	    keepInnerPruning = false;
	  }
	}
      } // end pruning
    } // end single nearest neighbor

    // clear out the remaining points
    m.clear();

    dists[leavesDone] = sqrt(maxDist(lastBall, points+target));
    
    index i;
    for(i=leavesDone*k; i < leavesDone*k+nnsSoFar; i++)
      nns[i] = getIndexOf(nns[i]);
    for(i=leavesDone*k+nnsSoFar; i<(leavesDone+1)*k; i++)
      nns[i] = NO_CHILD;
  } // end all nearest neighbors
}

/////////////////////////////// matlab functions ////////////////////////////

#ifdef MEX

// Constructor that doesn't initialize members, so that they can be
// set by the loadFromMatlab and createInMatlab functions.
BallTree::BallTree() : next(1) {}

// Load the arrays already allocated in matlab from the given
// structure.
BallTree::BallTree(const mxArray* structure) 
{
  dims       = (unsigned int) mxGetScalar(mxGetField(structure,0,"D")); // get the dimensions
  num_points = (unsigned long) mxGetScalar(mxGetField(structure,0,"N")); //
  
  centers = mxGetPr(mxGetField(structure,0,"centers"));
  ranges  = mxGetPr(mxGetField(structure,0,"ranges"));
  weights = mxGetPr(mxGetField(structure,0,"weights"));

  lowest_leaf = (unsigned long*) mxGetData(mxGetField(structure,0,"lower"));
  highest_leaf= (unsigned long*) mxGetData(mxGetField(structure,0,"upper"));
  left_child  = (unsigned long*) mxGetData(mxGetField(structure,0,"leftch"));
  right_child = (unsigned long*) mxGetData(mxGetField(structure,0,"rightch"));
  permutation = (unsigned long*) mxGetData(mxGetField(structure,0,"perm"));

  next = 1;    // unimportant
}

// Create new matlab arrays and put them in the given structure.
mxArray* BallTree::createInMatlab(const mxArray* _pointsMatrix, const mxArray* _weightsMatrix)
{
  mxArray* structure;
  structure = matlabMakeStruct(_pointsMatrix,_weightsMatrix);
  BallTree bt(structure);
  if (bt.Npts() > 0) bt.buildTree();

  return structure;
}

// Create new matlab arrays and put them in the given structure.
mxArray* BallTree::matlabMakeStruct(const mxArray* _pointsMatrix, const mxArray* _weightsMatrix)
{
  mxArray* structure;
  unsigned long i, j;
  double *_points, *_weights;
  
  // get fields from input arguments
  unsigned int Nd = mxGetM(_pointsMatrix);
  unsigned long Np = mxGetN(_pointsMatrix);
  _points  = (double*)mxGetData(_pointsMatrix);
  _weights = (double*)mxGetData(_weightsMatrix);

  // create structure, populate it, and get handles to the arrays
  structure = mxCreateStructMatrix(1, 1, nfields, FIELD_NAMES);
  
  mxSetField(structure, 0, "D",       mxCreateDoubleScalar((double) Nd));
  mxSetField(structure, 0, "N",       mxCreateDoubleScalar((double) Np));

  mxSetField(structure, 0, "centers", mxCreateDoubleMatrix(Nd, 2*Np, mxREAL));
  mxSetField(structure, 0, "ranges",  mxCreateDoubleMatrix(Nd, 2*Np, mxREAL));
  mxSetField(structure, 0, "weights", mxCreateDoubleMatrix(1, 2*Np, mxREAL));

  mxSetField(structure, 0, "lower",   mxCreateNumericMatrix(1, 2*Np, mxUINT32_CLASS, mxREAL));
  mxSetField(structure, 0, "upper",   mxCreateNumericMatrix(1, 2*Np, mxUINT32_CLASS, mxREAL));
  mxSetField(structure, 0, "leftch",  mxCreateNumericMatrix(1, 2*Np, mxUINT32_CLASS, mxREAL));
  mxSetField(structure, 0, "rightch", mxCreateNumericMatrix(1, 2*Np, mxUINT32_CLASS, mxREAL));
  mxSetField(structure, 0, "perm",    mxCreateNumericMatrix(1, 2*Np, mxUINT32_CLASS, mxREAL));

  // initialize arrays
  double* centers = (double *) mxGetData(mxGetField(structure, 0, "centers"));
  double* weights = (double *) mxGetData(mxGetField(structure, 0, "weights"));
  for (j=0,i=Nd*Np; j<Nd*Np; i++,j++)
    centers[i] = _points[j];
  for (j=0,i=Np; j<Np; i++,j++)
    weights[i] = _weights[j];

  return structure;
}

#endif

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -