📄 catdist.java
字号:
/** Sets the distribution scores for the given distribution.
* @param unknownProb The probability weight for the unknown category.
* @param aDist The distribution weights. This will be altered by
* this method. Must have a length equal to the
* number of categories.
*/
public void set_scores(DoubleRef unknownProb, double[] aDist) {
MLJ.clamp_to_range(unknownProb, 0, 1, "CatDist::CatDist: probability "
+ "selected for unknown values must be in range "
+ "[0.0, 1.0]");
if (aDist.length != schema.num_label_values())
Error.fatalErr("CatDist::CatDist: size of distribution array (" +
aDist.length + ") does not match number of categories in data "
+ "(" + schema.num_label_values() + ")");
// normalize probabilities
double remainder = 1.0 - unknownProb.value;
double total = 0;
for(int i=0 ; i<schema.num_label_values() ; i++)
total += aDist[i];
// Again, don't divide by zero.
if (!MLJ.approx_equal(total, 0.0))
for(int i=1 ; i<schema.num_label_values() + 1 ; i++) {
dist[i] = aDist[i-1] * remainder/ total;
}
else
for(int i=1 ; i<dist.length ; i++)
dist[i] = 0.0;
dist[0] = unknownProb.value;
}
/** Sets the tiebreaking order to the default values.
*/
public void set_default_tiebreaking() {
int val = 0;
for(int i=Globals.FIRST_CATEGORY_VAL ; i<tiebreakingOrder.length ; i++)
tiebreakingOrder[i] = val++;
tiebreakingOrder[Globals.UNKNOWN_CATEGORY_VAL] = val++;
// ASSERT(val == schema.num_label_values() + 1,"CatDist::set_default_tiebreaking: val == schema.num_label_values() + 1");
// DBG(check_tiebreaking_order(get_tiebreaking_order()));
}
/** Returns the distribution scores.
* @return The distribution of scores.
*/
public double[] get_scores() {
return dist;
}
/** Multiplies the given distribution by the given loss matrix to produce
* a vector of expected losses.
* @param lossMatrix The loss matrix.
* @param probDist The probability distribution for which loss is to be
* calculated. Must have the same bounds as the number
* of columns in the loss matrix.
* @param lossVector Contains the vector of expected losses. Will be changed by
* this function. Must have the same bounds as the number
* of columns in the loss matrix.
*/
public static void multiply_losses(double[][] lossMatrix,
double[] probDist,
double[] lossVector) {
// check: lossVector and probDist should have the same bounds.
// these should equal the firstCol and highCol of the lossMatrix
// ASSERT(probDist.low() == lossVector.low());
// ASSERT(lossVector.low() == lossMatrix.start_col());
// ASSERT(probDist.high() == lossVector.high());
// ASSERT(lossVector.high() == lossMatrix.high_col());
// check: lossMatrix should have one fewer rows than columns
// ASSERT(lossMatrix.start_row() == lossMatrix.start_col() + 1);
// ASSERT(lossMatrix.high_row() == lossMatrix.high_col());
// now multiply
for(int lvIndex = 0 ; lvIndex <= lossVector.length ; lvIndex++) {
lossVector[lvIndex] = 0;
// ignore unknowns in the probability distribution
for(int pdIndex =1 ; pdIndex <= probDist.length ; pdIndex++)
lossVector[lvIndex] += probDist[pdIndex] * lossMatrix[pdIndex][lvIndex];
}
}
/** Applies the evidence projection algorithm. This function is used by
* CatDist's auto-correction mode, and also inside NaiveBayesCat when it turns
* on evidence projection.
* @param counts Counts is an array of frequency counts. It is
* assumed that the sum of these counts is the
* total weight of information(i.e. total number
* of instances). This total weight is scaled
* by eviFactor.
* @param eviFactor Factor for computing the total evidence available.
* @param firstIsUnknown Setting firstIsUnknown to TRUE will cause the first
* value in the counts array to be treated as
* "unknown"-- it will not participate in the
* projection algorithm but will reduce
* probability weight given to the other counts.
* The counts array is adjusted in-place to
* become a normalized array of corrected
* probabilities.
*/
public static void apply_evidence_projection(double[] counts, double eviFactor,
boolean firstIsUnknown) {
// If firstIsUnknown is set, we must have at least one element in
// counts
if (firstIsUnknown && counts.length == 0)
Error.fatalErr("apply_evidence_projection: you have selected the first "
+ "count to be UNKNOWN, but the counts array has size 0");
double[] projProbs = new double[counts.length];
double total = MLJArray.sum(counts);
double totalKnown = total;
if (firstIsUnknown)
totalKnown -= counts[0];
// Compute total evidence available. This is computed as
// log(1+total*eviFactor).
double logTotalInfo = MLJ.log_bin((total * eviFactor) + 1.0);
// Apply evidence projection; we convert to logs, project
// infinite evidence onto the maximum evidence available
// (logTotalInfo), then convert back to probabilities.
double projTotal = 0;
int start =(firstIsUnknown) ? 1 : 0;
for(int i=start ; i<counts.length ; i++) {
// project the single probability
projProbs[i] = single_evidence_projection(counts[i],
total,
logTotalInfo);
// Add to the new sum (for normalization)
projTotal += projProbs[i];
}
// renormalize. Leave out original unknown probability mass.
double reNormFactor = totalKnown/ total;
MLJ.verify_strictly_greater(projTotal, 0, "apply_evidence_projection: "
+ "projection total must be non-negative");
for(int i =start ; i<counts.length ; i++)
counts[i] = projProbs[i] * reNormFactor/ projTotal;
// If firstIsUnknown is set, give it mass
if (firstIsUnknown)
counts[0] = counts[0]/ total;
}
/** Returns a single, unnormalized, evidence projection of a count based on
* the max evidence available.
* @return An evidence projection of the category count.
* @param count The count of a particular category.
* @param total The total count of all categories.
* @param maxEvidence Projection factor.
*/
public static double single_evidence_projection(double count, double total, double maxEvidence) {
double normProb = count/ total;
// Pin values which are near 1.0 or 0.0
if (MLJ.approx_equal(normProb, 0.0))
normProb = 0.0;
else if (MLJ.approx_equal(normProb, 1.0))
normProb = 1.0;
// compute evidence and weight. For finite evidence,
// weight is simply 1.0. For infinite evidence, weight is 0.0.
// We can use == comparison to 0.0 because we pinned the value
// above.
double evidence;
double weight;
if (normProb == 0.0) {
evidence = 1.0;
weight = 0.0;
}
else {
evidence = -MLJ.log_bin(normProb);
weight = 1.0;
}
// compute the projected probability using the weight
// This is a 1D projection in homogenous coordinates.
// Infinity will be projected onto logTotalInfo.
if (MLJ.approx_equal(evidence + maxEvidence * weight, 0.0))
Error.fatalErr("single_evidence_projection: divisor too close to zero");
double projEvi = maxEvidence * evidence/ (evidence + maxEvidence*weight);
// Take exponential to get the probability
return Math.pow(2, -projEvi);
}
/** Sets the tiebreaking order to the given order.
* @param order The new tiebreaking order. The length of the array should
* be the same as the number of categories.
*/
public void set_tiebreaking_order(int[] order) {
// if(order.low() != tiebreakingOrder.low() ||
// order.high() != tiebreakingOrder.high())
// Error.fatalErr("CatDist::set_tiebreaking_order: the given array's bounds "
// +"("+order.low()+" - "+order.high()+") are incorrect. "
// +"Bounds should be ("+tiebreakingOrder.low()+" - "
// +tiebreakingOrder.high()+")");
tiebreakingOrder = order;
check_tiebreaking_order(get_tiebreaking_order());
}
/** Checks if the current tiebreaking order is the same as the given
* tiebreaking order. If it is not, and error message is displayed.
* @param order The order to be compared to.
*/
public void check_tiebreaking_order(int[] order) {
// sort the array
int[] sortedOrder =(int[]) order.clone();
Arrays.sort(sortedOrder);
boolean bad = false;
for(int i = 0 ; i < order.length && !bad ; i++)
if (i != sortedOrder[i])
bad = true;
if (bad)
Error.fatalErr("CatDist::check_tiebreaking_order: Tiebreaking order "
+order+ " has bad form");
}
/** Returns the tiebreaking order.
* @return The tiebreaking order.
*/
public int[] get_tiebreaking_order() {
return tiebreakingOrder;
}
/** Testing code for the CatDist class.
* @param args Command line arguments.
*/
public static void main(String[] args) {
InstanceList IL = new InstanceList(args[0]);
Schema SC = IL.get_schema();
CatDist CD = new CatDist(SC,1);
CD.set_scores(new double[SC.num_attr()],CatDist.none,0);
System.out.println("Done.");
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -