📄 balltreeclass.cc
字号:
// ball in another tree. Returns the index in this tree of the closer
// child.
BallTree::index BallTree::closer(BallTree::index myLeft, BallTree::index myRight, const BallTree& otherTree,
BallTree::index otherRoot) const
{
double dist_sq_l = 0, dist_sq_r = 0;
for(int i=0; i<dims; i++) {
dist_sq_l += (otherTree.center(otherRoot)[i] - center(myLeft)[i]) *
(otherTree.center(otherRoot)[i] - center(myLeft)[i]);
dist_sq_r += (otherTree.center(otherRoot)[i] - center(myRight)[i]) *
(otherTree.center(otherRoot)[i] - center(myRight)[i]);
}
if (dist_sq_l < dist_sq_r)
return myLeft;
else
return myRight;
}
//
// Perform a *slight* adjustment of the tree: move the points by delta, but
// don't reform the whole tree; just fix up the statistics.
//
void BallTree::movePoints(double* delta)
{
index i;
for (i=leafFirst(root());i<=leafLast(root());i++)
for (unsigned int k=0;k<dims;k++) // first adjust locations by delta
centers[dims*i+k] += delta[ getIndexOf(i)*dims + k ];
for (i=num_points-1; i != 0; i--) // then recompute stats of
calcStats(i); // parent nodes
calcStats(root()); // and finally root node
}
// Assumes newWeights is the right size (num_points)
void BallTree::changeWeights(const double *newWeights) {
for(index i=num_points, j=0; i<num_points*2; i++, j++)
weights[i] = newWeights[ getIndexOf(i) ];
for (index i=num_points-1; i != 0; i--)
calcStats(i);
calcStats(root());
}
/////////////////////// k nearest neighbors functions ///////////////////////
// returns distance squared
double BallTree::minDist(index myBall, const double* point) const
{
double dist = 0, tmp;
for(index i=0; i<dims; i++) {
tmp = fabs(center(myBall)[i] - point[i]) - range(myBall)[i];
if(tmp >= 0)
dist += tmp*tmp;
}
return dist;
}
double BallTree::maxDist(index myBall, const double* point) const
{
double dist = 0, tmp;
for(index i=0; i<dims; i++) {
tmp = fabs(center(myBall)[i] - point[i]) + range(myBall)[i];
dist += tmp*tmp;
}
return dist;
}
typedef std::multimap<double, BallTree::index> myMap;
// nns is a K x N matrix
// dists is a 1 x N vector
// points is a D x N matrix
void BallTree::kNearestNeighbors(index *nns, double *dists, const double *points,
int N, int k)
const
{
myMap m;
int leavesDone = 0;
index lastBall;
for(index target = 0; target < N*dims; target += dims, ++leavesDone) {
int nnsSoFar = 0;
double leastDist = maxDist(root(), points+target);
m.insert(myMap::value_type(minDist(root(), points+target), root()));
// examining points in order of min dist means that when you see a
// point, it is the closest point left
while(! m.empty() && nnsSoFar < k) {
index current = (*m.begin()).second;
m.erase(m.begin());
if(isLeaf(current)) {
// since the nodes are sorted by minDist, a leaf at the front of
// the pq must be the next nearest neighbor
nns[leavesDone*k + nnsSoFar++] = current;
lastBall = current;
} else { // not a leaf
// push both children
m.insert(myMap::value_type(minDist(left_child[current], points+target),
left_child[current]));
m.insert(myMap::value_type(minDist(right_child[current], points+target),
right_child[current]));
}
bool keepInnerPruning = true;
myMap::iterator it = m.begin(), last;
// get rid of ineligible balls and find the min distance
while(it != m.end()) {
current = (*it).second;
double max = maxDist(current, points+target);
if(max < leastDist && Npts(current) >= k - nnsSoFar)
leastDist = max;
// if((*it).first > leastDist) {
// // we see the points in order of minDist, so once we see one
// // that's too big, the rest will also be too big
// m.erase(it, m.end());
// break;
// }
myMap::iterator last = it++;
if(keepInnerPruning) {
if(it != m.end() && max < (*it).first && Npts(current) < k - nnsSoFar) {
// if the closest ball doesn't have too many points and all of
// them are nearer than any other ball, include them all
// note "<" not "<=" so that *dist will be right
lastBall = current;
for(index i=leafFirst(current); i <= leafLast(current); i++)
nns[leavesDone*k + nnsSoFar++] = i;
m.erase(last);
} else {
keepInnerPruning = false;
}
}
} // end pruning
} // end single nearest neighbor
// clear out the remaining points
m.clear();
dists[leavesDone] = sqrt(maxDist(lastBall, points+target));
index i;
for(i=leavesDone*k; i < leavesDone*k+nnsSoFar; i++)
nns[i] = getIndexOf(nns[i]);
for(i=leavesDone*k+nnsSoFar; i<(leavesDone+1)*k; i++)
nns[i] = NO_CHILD;
} // end all nearest neighbors
}
/////////////////////////////// matlab functions ////////////////////////////
#ifdef MEX
// Constructor that doesn't initialize members, so that they can be
// set by the loadFromMatlab and createInMatlab functions.
BallTree::BallTree() : next(1) {}
// Load the arrays already allocated in matlab from the given
// structure.
BallTree::BallTree(const mxArray* structure)
{
dims = (unsigned int) mxGetScalar(mxGetField(structure,0,"D")); // get the dimensions
num_points = (unsigned long) mxGetScalar(mxGetField(structure,0,"N")); //
centers = mxGetPr(mxGetField(structure,0,"centers"));
ranges = mxGetPr(mxGetField(structure,0,"ranges"));
weights = mxGetPr(mxGetField(structure,0,"weights"));
lowest_leaf = (unsigned long*) mxGetData(mxGetField(structure,0,"lower"));
highest_leaf= (unsigned long*) mxGetData(mxGetField(structure,0,"upper"));
left_child = (unsigned long*) mxGetData(mxGetField(structure,0,"leftch"));
right_child = (unsigned long*) mxGetData(mxGetField(structure,0,"rightch"));
permutation = (unsigned long*) mxGetData(mxGetField(structure,0,"perm"));
next = 1; // unimportant
}
// Create new matlab arrays and put them in the given structure.
mxArray* BallTree::createInMatlab(const mxArray* _pointsMatrix, const mxArray* _weightsMatrix)
{
mxArray* structure;
structure = matlabMakeStruct(_pointsMatrix,_weightsMatrix);
BallTree bt(structure);
if (bt.Npts() > 0) bt.buildTree();
return structure;
}
// Create new matlab arrays and put them in the given structure.
mxArray* BallTree::matlabMakeStruct(const mxArray* _pointsMatrix, const mxArray* _weightsMatrix)
{
mxArray* structure;
unsigned long i, j;
double *_points, *_weights;
// get fields from input arguments
unsigned int Nd = mxGetM(_pointsMatrix);
unsigned long Np = mxGetN(_pointsMatrix);
_points = (double*)mxGetData(_pointsMatrix);
_weights = (double*)mxGetData(_weightsMatrix);
// create structure, populate it, and get handles to the arrays
structure = mxCreateStructMatrix(1, 1, nfields, FIELD_NAMES);
mxSetField(structure, 0, "D", mxCreateDoubleScalar((double) Nd));
mxSetField(structure, 0, "N", mxCreateDoubleScalar((double) Np));
mxSetField(structure, 0, "centers", mxCreateDoubleMatrix(Nd, 2*Np, mxREAL));
mxSetField(structure, 0, "ranges", mxCreateDoubleMatrix(Nd, 2*Np, mxREAL));
mxSetField(structure, 0, "weights", mxCreateDoubleMatrix(1, 2*Np, mxREAL));
mxSetField(structure, 0, "lower", mxCreateNumericMatrix(1, 2*Np, mxUINT32_CLASS, mxREAL));
mxSetField(structure, 0, "upper", mxCreateNumericMatrix(1, 2*Np, mxUINT32_CLASS, mxREAL));
mxSetField(structure, 0, "leftch", mxCreateNumericMatrix(1, 2*Np, mxUINT32_CLASS, mxREAL));
mxSetField(structure, 0, "rightch", mxCreateNumericMatrix(1, 2*Np, mxUINT32_CLASS, mxREAL));
mxSetField(structure, 0, "perm", mxCreateNumericMatrix(1, 2*Np, mxUINT32_CLASS, mxREAL));
// initialize arrays
double* centers = (double *) mxGetData(mxGetField(structure, 0, "centers"));
double* weights = (double *) mxGetData(mxGetField(structure, 0, "weights"));
for (j=0,i=Nd*Np; j<Nd*Np; i++,j++)
centers[i] = _points[j];
for (j=0,i=Np; j<Np; i++,j++)
weights[i] = _weights[j];
return structure;
}
#endif
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -