📄 balltreedensityclass.cc
字号:
} else if (Npts(dRoot)*atTree.Npts(aRoot)<=DirectSize){ // DIRECT EVALUATION
llGradDirect(dRoot,atTree,aRoot,gradWRT);
} else {
close = atTree.closer( atTree.left(aRoot), atTree.right(aRoot), *this, left(dRoot));
if (left(dRoot) != NO_CHILD && close != NO_CHILD)
llGradRecurse(left(dRoot),atTree,close,tolGrad,gradWRT);
far = (close == atTree.left(aRoot)) ? atTree.right(aRoot) : atTree.left(aRoot);
if (left(dRoot) != NO_CHILD && far != NO_CHILD)
llGradRecurse(left(dRoot),atTree,far,tolGrad,gradWRT);
close = atTree.closer( atTree.left(aRoot), atTree.right(aRoot), *this, right(dRoot));
if (right(dRoot) != NO_CHILD && close != NO_CHILD)
llGradRecurse(right(dRoot),atTree,close,tolGrad,gradWRT);
far = (close == atTree.left(aRoot)) ? atTree.right(aRoot) : atTree.left(aRoot);
if (right(dRoot) != NO_CHILD && far != NO_CHILD)
llGradRecurse(right(dRoot),atTree,far,tolGrad,gradWRT);
}
}
////////////////////////////////////////////////////////////////////////////////////
// L = sum_i wi log p(yi) = sum_i wi log[ sum_j wj K(yi-xj) ]
// => d(log L)/dxj[k] = - sum_i wi 1/p(yi) wj K'(xj-yi)
// d(log L)/dyi[k] = wi 1/p(yi) sum_j wj K'(xj-yi) (same K')
//
////////////////////////////////////////////////////////////////////////////////////
void BallTreeDensity::llGrad(const BallTree& locations, double* _gradD, double* _gradA, double tolEval, double tolGrad, Gradient gradWRT) const
{
BallTree::index j;
unsigned long k;
gradD = _gradD; gradA = _gradA;
min = new double[locations.Ndim()]; max = new double[locations.Ndim()];
pMin = new double[2*locations.Npts()];
pMax = new double[2*locations.Npts()];
for (j=0;j<2*locations.Npts();j++) pMin[j] = pMax[j] = 0;
#ifdef NEWVERSION
pAdd = new double*[1]; pAdd[0] = new double[2*locations.Npts()];
pErr = new double[2*locations.Npts()];
for (j=0;j<2*locations.Npts();j++) pAdd[0][j] = pErr[j] = 0;
#endif
evaluate(root(), locations, locations.root(), 2*tolEval);
#ifdef NEWVERSION
pushDownAll(locations);
#endif
if (this == &locations) { // fix leave-one-out normalization
for (j=leafFirst(root()); j<=leafLast(root()); j++)
pMax[j] /= (1-weight(j)); pMin[j] /= (1-weight(j));
}
if(gradWRT == WRTWeight)
llGradWRecurse(root(),locations,locations.root(), tolGrad*tolGrad);
else
llGradRecurse(root(),locations,locations.root(), tolGrad*tolGrad, gradWRT);
if (this == &locations) { // fix leave-one-out normalization
for (j=leafFirst(root()); j<=leafLast(root()); j++) {
unsigned long Nj = Ndim() * getIndexOf(j);
for (k=0;k<Ndim();k++) {
if (gradD) gradD[Nj+k] /= (1-weight(j));
if (gradA) gradA[Nj+k] /= (1-weight(j));
} } }
delete[] min; delete[] max;
delete[] pMax; delete[] pMin;
#ifdef NEWVERSION
delete[] pAdd[0]; delete[] pAdd; delete[] pErr;
#endif
}
////////////////////////////////////////////////////////////////////////////////////
// Gradient wrt WEIGHT
// DIRECT VERSION:
// Just iterate over the N^2 indices; faster than recursion for small N.
////////////////////////////////////////////////////////////////////////////////////
void BallTreeDensity::llGradWDirect(BallTree::index dRoot, const BallTree& atTree,
BallTree::index aRoot) const
{
BallTree::index i,j;
for (i=atTree.leafFirst(aRoot);i<=atTree.leafLast(aRoot);i++) {
for (j=leafFirst(dRoot);j<=leafLast(dRoot);j++) {
dKdX_p(j,atTree,i,true,WRTWeight); // use "true" to signal leaf evaluation
if (gradD)
gradD[getIndexOf(j)] -= atTree.weight(i) * (max[0]+min[0])/2;
if (gradA)
gradA[atTree.getIndexOf(i)] += weight(j) * (max[0]+min[0])/2;
}
}
}
////////////////////////////////////////////////////////////////////////////////////
// Gradient wrt WEIGHT
// RECURSIVE VERSION:
// Try to find approximations to speed things up.
////////////////////////////////////////////////////////////////////////////////////
void BallTreeDensity::llGradWRecurse(BallTree::index dRoot,const BallTree& atTree,
BallTree::index aRoot, double tolGrad) const
{
BallTree::index i,j,close,far;
dKdX_p(dRoot,atTree,aRoot,false,WRTWeight); // "false" signals maybe not leaf nodes
double norm = (max[0]-min[0]) * (max[0]-min[0]);
if (norm <= tolGrad) {
if (gradD) for (j=leafFirst(dRoot);j<=leafLast(dRoot);j++) {
gradD[getIndexOf(j)] -= atTree.weight(aRoot) * (max[0]+min[0])/2;
}
if (gradA) for (i=atTree.leafFirst(aRoot);i<=atTree.leafLast(aRoot);i++) {
gradA[atTree.getIndexOf(i)] += weight(dRoot) * (max[0]+min[0])/2;
}
} else if (Npts(dRoot)*atTree.Npts(aRoot)<=100){ // DIRECT EVALUATION
llGradWDirect(dRoot,atTree,aRoot);
} else {
close = atTree.closer( atTree.left(aRoot), atTree.right(aRoot), *this, left(dRoot));
if (left(dRoot) != NO_CHILD && close != NO_CHILD)
llGradWRecurse(left(dRoot),atTree,close,tolGrad);
far = (close == atTree.left(aRoot)) ? atTree.right(aRoot) : atTree.left(aRoot);
if (left(dRoot) != NO_CHILD && far != NO_CHILD)
llGradWRecurse(left(dRoot),atTree,far,tolGrad);
close = atTree.closer( atTree.left(aRoot), atTree.right(aRoot), *this, right(dRoot));
if (right(dRoot) != NO_CHILD && close != NO_CHILD)
llGradWRecurse(right(dRoot),atTree,close,tolGrad);
far = (close == atTree.left(aRoot)) ? atTree.right(aRoot) : atTree.left(aRoot);
if (right(dRoot) != NO_CHILD && far != NO_CHILD)
llGradWRecurse(right(dRoot),atTree,far,tolGrad);
}
}
///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
//
// CONSTRUCTION METHODS
//
///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
#ifdef MEX
// Load the arrays already allocated in matlab from the given
// structure.
BallTreeDensity::BallTreeDensity(const mxArray* structure) : BallTree(structure) {
means = mxGetPr(mxGetField(structure,0,"means"));
bandwidth = (double*) mxGetData(mxGetField(structure,0,"bandwidth"));
type = (BallTreeDensity::KernelType) mxGetScalar(mxGetField(structure,0,"type"));
if (mxGetN(mxGetField(structure,0,"bandwidth")) == 6*num_points) {
multibandwidth = 1;
bandwidthMax = bandwidth + 2*num_points*dims; // not all the same =>
bandwidthMin = bandwidthMax + 2*num_points*dims; // track min/max vals
} else { // all the same => min = max
multibandwidth = 0; // = any leaf node
bandwidthMax = bandwidthMin = bandwidth + num_points*dims;
}
}
// Create new matlab arrays and put them in the given structure
mxArray* BallTreeDensity::createInMatlab(const mxArray* _pointsMatrix, const mxArray* _weightsMatrix,
const mxArray* _bwMatrix,BallTreeDensity::KernelType _type)
{
mxArray* structure = matlabMakeStruct(_pointsMatrix, _weightsMatrix,_bwMatrix,_type);
BallTreeDensity dens(structure);
if (dens.Npts() > 0) dens.buildTree();
return structure;
}
// Create new matlab arrays and put them in the given structure.
mxArray* BallTreeDensity::matlabMakeStruct(const mxArray* _pointsMatrix, const mxArray* _weightsMatrix,
const mxArray* _bwMatrix,BallTreeDensity::KernelType _type)
{
unsigned long i,j;
mxArray* structure = BallTree::matlabMakeStruct(_pointsMatrix, _weightsMatrix);
unsigned long Nd = (unsigned long) mxGetScalar(mxGetField(structure,0,"D"));
unsigned long Np = (unsigned long) mxGetScalar(mxGetField(structure,0,"N"));
mxAddField(structure, "means");
mxSetField(structure, 0, "means", mxCreateDoubleMatrix(Nd, 2*Np, mxREAL));
mxAddField(structure, "bandwidth");
if (mxGetN(_bwMatrix) == 1)
mxSetField(structure, 0, "bandwidth", mxCreateDoubleMatrix(Nd, 2*Np, mxREAL));
else
mxSetField(structure, 0, "bandwidth", mxCreateDoubleMatrix(Nd, 6*Np, mxREAL));
mxAddField(structure, "type");
mxSetField(structure, 0, "type", mxCreateDoubleScalar((double)_type));
// initialize arrays
double* means = (double *) mxGetData(mxGetField(structure, 0, "means"));
double* points = (double *) mxGetData(mxGetField(structure, 0, "centers"));
for (j=0,i=Nd*Np; j<Nd*Np; i++,j++)
means[i] = points[i];
double* bw = (double *) mxGetData(mxGetField(structure, 0, "bandwidth"));
double* bwIn = (double *) mxGetData(_bwMatrix);
if (mxGetN(_bwMatrix) == 1) {
for (j=0,i=Nd*Np; j<Nd*Np; i++,j++)
bw[i] = bwIn[j%Nd];
} else {
double *bwMax, *bwMin; bwMax = bw + 2*Np*Nd; bwMin = bwMax + 2*Np*Nd;
for (j=0,i=Nd*Np; j<Nd*Np; i++,j++)
bwMax[i] = bwMin[i] = bw[i] = bwIn[j];
}
return structure;
}
#endif
// returns true on success, false on failure
bool BallTreeDensity::updateBW(const double* newBWs, index N)
{
if((N == num_points && multibandwidth == 0) ||
(N == 1 && multibandwidth == 1)) {
// mexPrintf("multibandwidth=%d, num_points=%d, N=%d\n", multibandwidth, num_points, N);
return false;
}
index i,j;
// pointers all stay the same, just copy data over
if (N == 1) {
for (j=0,i=dims*num_points; j<dims*num_points; i++,j++)
bandwidth[i] = newBWs[j%dims];
} else {
double *bwMax, *bwMin;
bwMax = bandwidth + 2*num_points*dims;
bwMin = bwMax + 2*num_points*dims;
for (j=0,i=dims*num_points; j<dims*num_points; i++,j++)
bwMax[i] = bwMin[i] = bandwidth[i] = newBWs[j];
}
// calculate bandwidths for non-leaf nodes
for (i=num_points-1; i != 0; i--)
calcStats(i);
calcStats(root());
return true;
}
void BallTreeDensity::calcStats(BallTree::index root)
{
BallTree::calcStats(root);
BallTree::index Ni, NiL, NiR;
double wtL,wtR,wtT;
unsigned int k;
BallTree::index leftI = left(root), rightI=right(root); // get children indices
if (!validIndex(leftI) || !validIndex(rightI)) return; // nothing to do if this
// isn't a parent node
Ni = dims*root; NiL = dims*leftI; NiR = dims*rightI;
wtL = weight(leftI); wtR = weight(rightI); wtT = wtL + wtR + DBL_EPSILON;
wtL /= wtT; wtR /= wtT;
if (!bwUniform()) {
for(k = 0; k < dims; k++) {
bandwidthMax[Ni+k] = (bandwidthMax[NiL+k] > bandwidthMax[NiR+k])
? bandwidthMax[NiL+k] : bandwidthMax[NiR+k];
bandwidthMin[Ni+k] = (bandwidthMin[NiL+k] < bandwidthMin[NiR+k])
? bandwidthMin[NiL+k] : bandwidthMin[NiR+k];
} }
switch(type) {
case Gaussian:
for(unsigned int k=0; k < dims; k++) {
means[Ni+k] = wtL * means[NiL+k] + wtR * means[NiR+k];
bandwidth[Ni+k] = wtL* (bandwidth[NiL+k] + means[NiL+k]*means[NiL+k]) +
wtR* (bandwidth[NiR+k] + means[NiR+k]*means[NiR+k]) -
means[Ni+k]*means[Ni+k];
}; break;
case Laplacian:
for(unsigned int k=0; k < dims; k++) {
means[Ni+k] = wtL * means[NiL+k] + wtR * means[NiR+k];
bandwidth[Ni+k] = wtL* (2*bandwidth[NiL+k]*bandwidth[NiL+k] + means[NiL+k]*means[NiL+k]) +
wtR* (2*bandwidth[NiR+k]*bandwidth[NiR+k] + means[NiR+k]*means[NiR+k]) -
means[Ni+k]*means[Ni+k]; // compute in terms of variance
bandwidth[Ni+k] = sqrt(.5*bandwidth[Ni+k]); // then convert back to normal BW rep.
}; break;
case Epanetchnikov:
for(unsigned int k=0; k < dims; k++) {
means[Ni+k] = wtL * means[NiL+k] + wtR * means[NiR+k];
bandwidth[Ni+k] = wtL* (.2*bandwidth[NiL+k]*bandwidth[NiL+k] + means[NiL+k]*means[NiL+k]) +
wtR* (.2*bandwidth[NiR+k]*bandwidth[NiR+k] + means[NiR+k]*means[NiR+k]) -
means[Ni+k]*means[Ni+k]; // compute in terms of variance
bandwidth[Ni+k] = sqrt(5*bandwidth[Ni+k]); // then convert back to normal BW rep.
}; break;
}
}
// Swap the ith leaf with the jth leaf.
void BallTreeDensity::swap(unsigned long i, unsigned long j)
{
if (i==j) return;
BallTree::swap(i,j);
i *= dims; j *= dims;
for(unsigned int k=0; k<dims; i++,j++,k++) {
double tmp;
tmp = means[i]; means[i] = means[j]; means[j] = tmp;
tmp = bandwidth[i]; bandwidth[i] = bandwidth[j]; bandwidth[j] = tmp;
if (!bwUniform()) {
tmp = bandwidthMax[i];bandwidthMax[i]=bandwidthMax[j];bandwidthMax[j]=tmp;
tmp = bandwidthMin[i];bandwidthMin[i]=bandwidthMin[j];bandwidthMin[j]=tmp;
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -