📄 catdist.java
字号:
set_preferred_category(singleCat);
}
/** Specifies a single category to prefer if it is ever involved in a tie. If
* a tie occurs which does not involve the preferred category, the first
* category (but never unknown) will be chosen.
* @param cat The index of the category to be preferred.
*/
public void set_preferred_category(int cat) {
if (cat < Globals.UNKNOWN_CATEGORY_VAL ||
cat > Globals.UNKNOWN_CATEGORY_VAL+schema.num_label_values())
Error.fatalErr("CatDist::set_preferred_category: specified category "
+cat+ " is out of range");
// set up an ordering vector. Rank the preferred category first
// (0), then give lower numbers in order to the other categories
// from first to last. Assign the unknown category the lowest value.
int val = 0;
tiebreakingOrder[cat] = val++;
for(int i=Globals.FIRST_CATEGORY_VAL ; i < tiebreakingOrder.length ; i++)
if (i != cat)
tiebreakingOrder[i] = val++;
if (cat != Globals.UNKNOWN_CATEGORY_VAL)
tiebreakingOrder[Globals.UNKNOWN_CATEGORY_VAL] = val++;
MLJ.ASSERT(val == schema.num_label_values() + 1, "CatDist::set_preferred_category: val == schema.num_label_values() + 1");
// if (Globals.DBG)
// check_tiebreaking_order(get_tiebreaking_order());
}
/** Builds a tie breaking order from a weight distribution. If there is a tie
* among weights, the first one will have a better (i.e. lower) tie breaking
* rank.
* @return The tie breaking order of ranks.
* @param weightDistribution The distribution of weights for label
* categories.
*/
static public int[] tiebreaking_order(double[] weightDistribution) {
double[] dist =(double[]) weightDistribution.clone();
// if(Globals.DBG)
// MLJ.ASSERT(dist.min() >= 0 || MLJ.approx_equal(dist.min(), 0.0),
// "CatDist::tiebreaking_order: Minimum distribution < 0.");
int[] order = new int[dist.length];
if (0 == Globals.UNKNOWN_CATEGORY_VAL &&
MLJ.approx_equal(dist[0], 0.0))
dist[0] = -1;
int nextIndex = 0;
for(int i = 0 ; i < order.length ; i++) {
IntRef highestIndex = new IntRef(0);
MLJArray.max(highestIndex,dist);
// if(Globals.DBG)
// MLJ.ASSERT(order[highestIndex] == Globals.INT_MAX,
// "CatDist::tiebreaking_order: order[highestIndex] != Globals.INT_MAX.");
order[highestIndex.value] = nextIndex++;
dist[highestIndex.value] = -1;
}
MLJ.ASSERT(nextIndex == order.length, "CatDist::tiebreaking_order: nextIndex == order.length");
return order;
}
/** Returns the best category according to the weight distribution. If a loss
* matrix is defined, the distribution will be multiplied by the loss matrix to
* produce a vector of expected losses. The best category is the one with the
* smallest expected loss.
* @return An AugCategory containing information about the best category found.
*/
public AugCategory best_category() {
// having no values at all is an error
MLJ.ASSERT(dist.length > 0, "CatDist::best_category: dist.length > 0");
double bestProb = -1;
int bestCat = -1;
// If a loss matrix is defined, multiply it by the scoring vector
// to obtain the loss vector.
if (schema.has_loss_matrix()) {
double[][] lossMatrix = schema.get_loss_matrix();
double[] lossVector = new double[schema.num_label_values() + 1];
multiply_losses(lossMatrix, dist, lossVector);
// these are LOSSES, so pick the smallest one
bestProb = Double.MAX_VALUE;
for(int i=0 ; i < lossVector.length ; i++) {
if (MLJ.approx_equal(lossVector[i], bestProb)) {
if (tiebreakingOrder[i] < tiebreakingOrder[bestCat]) {
bestProb = lossVector[i];
bestCat = i;
}
}
else if (lossVector[i] < bestProb) {
bestProb = lossVector[i];
bestCat = i;
}
}
logOptions.LOG(4, "Probs: " +scoresToString() + ". Loss vector: "
+lossVector+ ". Picked " +bestCat+ '\n');
}
// Otherwise, just pick the highest probability in the distribution.
// In the event of a tie, prefer the category with the LOWER
// tiebreakingOrder.
else {
// MLJ.ASSERT(dist.low() == Globals.UNKNOWN_CATEGORY_VAL,"CatDist::best_category: dist.low() == Globals.UNKNOWN_CATEGORY_VAL");
for(int i=0 ; i < dist.length ; i++) {
if (MLJ.approx_equal(dist[i], bestProb)) {
// we have a tie.
if (tiebreakingOrder[i] < tiebreakingOrder[bestCat]) {
bestProb = dist[i];
bestCat = i;
}
}
else if (dist[i] > bestProb) {
// always pick this category--its the best so far
bestProb = dist[i];
bestCat = i;
}
}
logOptions.LOG(4, "Probs: " +scoresToString() + ". Picked " +bestCat+ '\n');
if (bestCat == Globals.UNKNOWN_CATEGORY_VAL &&
!GlobalOptions.allowUnknownPredictions) {
Error.err("CatDist::best_category: attempting to predict "
+ "UNKNOWN without a loss matrix set. Set "
+ "ALLOW_UNKNOWN_PREDICTIONS to Yes to deactivate this "
+ "error.");
Error.err("Probabilities: " +dist+ '\n');
Error.err("Tie breaking order: " +tiebreakingOrder+ '\n');
Error.fatalErr("");
}
}
if (bestCat == Globals.UNKNOWN_CATEGORY_VAL)
return Globals.UNKNOWN_AUG_CATEGORY;
else
return new AugCategory(bestCat,schema.nominal_label_info() .get_value(bestCat));
}
/** Sets the distribution scores for the current distribution.
* @param fCounts The frequency counts of categories found.
* @param cType Type of correction to perform. Range is CatDist.none,
* CatDist.laplace, CatDist.evidence.
* @param cParam Correction parameter. Must be equal to or greater than 0.
*/
public void set_scores(double[] fCounts,
int cType, double cParam) {
if (fCounts.length != dist.length)
Error.fatalErr("CatDist::set_scores: size of frequency counts array ("
+fCounts.length+ ") does not match number of categories "
+ "in data (including unknown) (" +dist.length+ "). "
+ "It is possible you are using the wrong version of "
+ "the CatDist constructor");
int numLabelVals = schema.num_label_values();
// compute the total sum of the counts. Negative frequency counts
// are not permitted, but slightly negative ones (negative only by
// error) will be clamped to zero.
double total = 0;
for(int i=0 ; i<fCounts.length ; i++) {
DoubleRef val =new DoubleRef(fCounts[i]);
MLJ.clamp_above(val, 0.0, "CatDist::CatDist: negative frequency counts "
+ "are not permitted" , fCounts.length);
total += val.value;
}
// If all counts are zero, make an all-even distribution,
// but give the unknown class zero weight.
if (MLJ.approx_equal(total, 0.0)) {
double evenProb = 1.0/ (dist.length -1);
dist[0] = 0.0;
for(int i=1 ; i<dist.length ; i++)
dist[i] = evenProb;
total = 1.0;
}
// compute probabilities. The method depends on the correction type
else {
switch (cType) {
// frequency counts: normalize the counts.
case none: {
MLJ.ASSERT(!MLJ.approx_equal(total, 0.0) , "CatDist::set_scores: !MLJ.approx_equal(total, 0.0)");
for(int i=1 ; i<dist.length ; i++)
dist[i] = fCounts[i]/ total;
}
break;
// Laplace correction: each count is equal to
// (fCount + cParam) / (total + cParam)
// zero cParam means use 1/total as the correction factor.
case laplace: {
if (cParam < 0.0)
Error.fatalErr("CatDist::CatDist: negative correction parameter "
+ "(cParam) values are not permitted for laplace correction");
MLJ.verify_strictly_greater(total + cParam, 0, "CatDist::CatDist: "
+ " total + cParam too clost to zero");
double finalCorrection =(cParam == 0.0) ?(1.0 / total) : cParam;
for(int i=1 ; i<dist.length ; i++) {
double divisor = total + numLabelVals * finalCorrection;
if (MLJ.approx_equal(divisor, 0.0))
Error.fatalErr("CatDist::CatDist: divisor too close to zero");
dist[i] =(fCounts[i] + finalCorrection)/ divisor;
}
}
break;
// Evidence projection algorithm:
case evidence: {
if (cParam <= 0.0)
Error.fatalErr("CatDist::CatDist: negative or zero correction parameter "
+ "(cParam) values are not permitted for evidence projection");
// copy fCounts into dist
// fCounts may have a different bound than dist, so we can't
// use operator=.
for(int i=0 ; i<fCounts.length ; i++)
dist[i] = fCounts[i];
// correct
apply_evidence_projection(dist, cParam, true);
}
break;
default:
MLJ.Abort();
}
// Assign unknown the remainder of the distribution
double newTotal = 0;
for(int i=1 ; i<dist.length ; i++)
newTotal += dist[i];
dist[0] =(1.0 - newTotal);
// If unknown is near zero, pin it to zero and renormalize
// the other probabilities. This will prevent the unknown
// from picking up probability mass due to numerical errors
if (MLJ.approx_equal(dist[0], 0.0)) {
dist[0] = 0;
for(int i=1 ; i<dist.length ; i++)
dist[i] /= newTotal;
}
// WARNING: Do not attempt to pin near-zero and near-one
// values here--this will can cause the distribution not
// to sum to 1.0!
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -