📄 discretize.java
字号:
}
/**
* Sets which attributes are to be Discretized (only numeric
* attributes among the selection will be Discretized).
*
* @param rangeList a string representing the list of attributes. Since
* the string will typically come from a user, attributes are indexed from
* 1. <br>
* eg: first-3,5,6-last
* @exception IllegalArgumentException if an invalid range list is supplied
*/
public void setAttributeIndices(String rangeList) {
m_DiscretizeCols.setRanges(rangeList);
}
/**
* Sets which attributes are to be Discretized (only numeric
* attributes among the selection will be Discretized).
*
* @param attributes an array containing indexes of attributes to Discretize.
* Since the array will typically come from a program, attributes are indexed
* from 0.
* @exception IllegalArgumentException if an invalid set of ranges
* is supplied
*/
public void setAttributeIndicesArray(int [] attributes) {
setAttributeIndices(Range.indicesToRangeList(attributes));
}
/**
* Gets the cut points for an attribute
*
* @param the index (from 0) of the attribute to get the cut points of
* @return an array containing the cutpoints (or null if the
* attribute requested isn't being Discretized
*/
public double [] getCutPoints(int attributeIndex) {
if (m_CutPoints == null) {
return null;
}
return m_CutPoints[attributeIndex];
}
/** Generate the cutpoints for each attribute */
protected void calculateCutPoints() {
Instances copy = null;
m_CutPoints = new double [getInputFormat().numAttributes()] [];
for(int i = getInputFormat().numAttributes() - 1; i >= 0; i--) {
if ((m_DiscretizeCols.isInRange(i)) &&
(getInputFormat().attribute(i).isNumeric())) {
// Use copy to preserve order
if (copy == null) {
copy = new Instances(getInputFormat());
}
calculateCutPointsByMDL(i, copy);
}
}
}
/**
* Set cutpoints for a single attribute using MDL.
*
* @param index the index of the attribute to set cutpoints for
*/
protected void calculateCutPointsByMDL(int index,
Instances data) {
// Sort instances
data.sort(data.attribute(index));
// Find first instances that's missing
int firstMissing = data.numInstances();
for (int i = 0; i < data.numInstances(); i++) {
if (data.instance(i).isMissing(index)) {
firstMissing = i;
break;
}
}
m_CutPoints[index] = cutPointsForSubset(data, index, 0, firstMissing);
}
/** Test using Kononenko's MDL criterion. */
private boolean KononenkosMDL(double[] priorCounts,
double[][] bestCounts,
double numInstances,
int numCutPoints) {
double distPrior, instPrior, distAfter = 0, sum, instAfter = 0;
double before, after;
int numClassesTotal;
// Number of classes occuring in the set
numClassesTotal = 0;
for (int i = 0; i < priorCounts.length; i++) {
if (priorCounts[i] > 0) {
numClassesTotal++;
}
}
// Encode distribution prior to split
distPrior = SpecialFunctions.log2Binomial(numInstances
+ numClassesTotal - 1,
numClassesTotal - 1);
// Encode instances prior to split.
instPrior = SpecialFunctions.log2Multinomial(numInstances,
priorCounts);
before = instPrior + distPrior;
// Encode distributions and instances after split.
for (int i = 0; i < bestCounts.length; i++) {
sum = Utils.sum(bestCounts[i]);
distAfter += SpecialFunctions.log2Binomial(sum + numClassesTotal - 1,
numClassesTotal - 1);
instAfter += SpecialFunctions.log2Multinomial(sum,
bestCounts[i]);
}
// Coding cost after split
after = Utils.log2(numCutPoints) + distAfter + instAfter;
// Check if split is to be accepted
return (before > after);
}
/** Test using Fayyad and Irani's MDL criterion. */
private boolean FayyadAndIranisMDL(double[] priorCounts,
double[][] bestCounts,
double numInstances,
int numCutPoints) {
double priorEntropy, entropy, gain;
double entropyLeft, entropyRight, delta;
int numClassesTotal, numClassesRight, numClassesLeft;
// Compute entropy before split.
priorEntropy = ContingencyTables.entropy(priorCounts);
// Compute entropy after split.
entropy = ContingencyTables.entropyConditionedOnRows(bestCounts);
// Compute information gain.
gain = priorEntropy - entropy;
// Number of classes occuring in the set
numClassesTotal = 0;
for (int i = 0; i < priorCounts.length; i++) {
if (priorCounts[i] > 0) {
numClassesTotal++;
}
}
// Number of classes occuring in the left subset
numClassesLeft = 0;
for (int i = 0; i < bestCounts[0].length; i++) {
if (bestCounts[0][i] > 0) {
numClassesLeft++;
}
}
// Number of classes occuring in the right subset
numClassesRight = 0;
for (int i = 0; i < bestCounts[1].length; i++) {
if (bestCounts[1][i] > 0) {
numClassesRight++;
}
}
// Entropy of the left and the right subsets
entropyLeft = ContingencyTables.entropy(bestCounts[0]);
entropyRight = ContingencyTables.entropy(bestCounts[1]);
// Compute terms for MDL formula
delta = Utils.log2(Math.pow(3, numClassesTotal) - 2) -
(((double) numClassesTotal * priorEntropy) -
(numClassesRight * entropyRight) -
(numClassesLeft * entropyLeft));
// Check if split is to be accepted
return (gain > (Utils.log2(numCutPoints) + delta) / (double)numInstances);
}
/** Selects cutpoints for sorted subset. */
private double[] cutPointsForSubset(Instances instances, int attIndex,
int first, int lastPlusOne) {
double[][] counts, bestCounts;
double[] priorCounts, left, right, cutPoints;
double currentCutPoint = -Double.MAX_VALUE, bestCutPoint = -1,
currentEntropy, bestEntropy, priorEntropy, gain;
int bestIndex = -1, numInstances = 0, numCutPoints = 0;
// Compute number of instances in set
if ((lastPlusOne - first) < 2) {
return null;
}
// Compute class counts.
counts = new double[2][instances.numClasses()];
for (int i = first; i < lastPlusOne; i++) {
numInstances += instances.instance(i).weight();
counts[1][(int)instances.instance(i).classValue()] +=
instances.instance(i).weight();
}
// Save prior counts
priorCounts = new double[instances.numClasses()];
System.arraycopy(counts[1], 0, priorCounts, 0,
instances.numClasses());
// Entropy of the full set
priorEntropy = ContingencyTables.entropy(priorCounts);
bestEntropy = priorEntropy;
// Find best entropy.
bestCounts = new double[2][instances.numClasses()];
for (int i = first; i < (lastPlusOne - 1); i++) {
counts[0][(int)instances.instance(i).classValue()] +=
instances.instance(i).weight();
counts[1][(int)instances.instance(i).classValue()] -=
instances.instance(i).weight();
if (instances.instance(i).value(attIndex) <
instances.instance(i + 1).value(attIndex)) {
currentCutPoint = (instances.instance(i).value(attIndex) +
instances.instance(i + 1).value(attIndex)) / 2.0;
currentEntropy = ContingencyTables.entropyConditionedOnRows(counts);
if (currentEntropy < bestEntropy) {
bestCutPoint = currentCutPoint;
bestEntropy = currentEntropy;
bestIndex = i;
System.arraycopy(counts[0], 0,
bestCounts[0], 0, instances.numClasses());
System.arraycopy(counts[1], 0,
bestCounts[1], 0, instances.numClasses());
}
numCutPoints++;
}
}
// Use worse encoding?
if (!m_UseBetterEncoding) {
numCutPoints = (lastPlusOne - first) - 1;
}
// Checks if gain is zero
gain = priorEntropy - bestEntropy;
if (gain <= 0) {
return null;
}
// Check if split is to be accepted
if ((m_UseKononenko && KononenkosMDL(priorCounts, bestCounts,
numInstances, numCutPoints)) ||
(!m_UseKononenko && FayyadAndIranisMDL(priorCounts, bestCounts,
numInstances, numCutPoints))) {
// Select split points for the left and right subsets
left = cutPointsForSubset(instances, attIndex, first, bestIndex + 1);
right = cutPointsForSubset(instances, attIndex,
bestIndex + 1, lastPlusOne);
// Merge cutpoints and return them
if ((left == null) && (right) == null) {
cutPoints = new double[1];
cutPoints[0] = bestCutPoint;
} else if (right == null) {
cutPoints = new double[left.length + 1];
System.arraycopy(left, 0, cutPoints, 0, left.length);
cutPoints[left.length] = bestCutPoint;
} else if (left == null) {
cutPoints = new double[1 + right.length];
cutPoints[0] = bestCutPoint;
System.arraycopy(right, 0, cutPoints, 1, right.length);
} else {
cutPoints = new double[left.length + right.length + 1];
System.arraycopy(left, 0, cutPoints, 0, left.length);
cutPoints[left.length] = bestCutPoint;
System.arraycopy(right, 0, cutPoints, left.length + 1, right.length);
}
return cutPoints;
} else
return null;
}
/**
* Set the output format. Takes the currently defined cutpoints and
* m_InputFormat and calls setOutputFormat(Instances) appropriately.
*/
protected void setOutputFormat() {
if (m_CutPoints == null) {
setOutputFormat(null);
return;
}
FastVector attributes = new FastVector(getInputFormat().numAttributes());
int classIndex = getInputFormat().classIndex();
for(int i = 0; i < getInputFormat().numAttributes(); i++) {
if ((m_DiscretizeCols.isInRange(i))
&& (getInputFormat().attribute(i).isNumeric())) {
if (!m_MakeBinary) {
FastVector attribValues = new FastVector(1);
if (m_CutPoints[i] == null) {
attribValues.addElement("'All'");
} else {
for(int j = 0; j <= m_CutPoints[i].length; j++) {
if (j == 0) {
attribValues.addElement("'(-inf-"
+ Utils.doubleToString(m_CutPoints[i][j], 6) + "]'");
} else if (j == m_CutPoints[i].length) {
attribValues.addElement("'("
+ Utils.doubleToString(m_CutPoints[i][j - 1], 6)
+ "-inf)'");
} else {
attribValues.addElement("'("
+ Utils.doubleToString(m_CutPoints[i][j - 1], 6) + "-"
+ Utils.doubleToString(m_CutPoints[i][j], 6) + "]'");
}
}
}
attributes.addElement(new Attribute(getInputFormat().
attribute(i).name(),
attribValues));
} else {
if (m_CutPoints[i] == null) {
FastVector attribValues = new FastVector(1);
attribValues.addElement("'All'");
attributes.addElement(new Attribute(getInputFormat().
attribute(i).name(),
attribValues));
} else {
if (i < getInputFormat().classIndex()) {
classIndex += m_CutPoints[i].length - 1;
}
for(int j = 0; j < m_CutPoints[i].length; j++) {
FastVector attribValues = new FastVector(2);
attribValues.addElement("'(-inf-"
+ Utils.doubleToString(m_CutPoints[i][j], 6) + "]'");
attribValues.addElement("'("
+ Utils.doubleToString(m_CutPoints[i][j], 6) + "-inf)'");
attributes.addElement(new Attribute(getInputFormat().
attribute(i).name(),
attribValues));
}
}
}
} else {
attributes.addElement(getInputFormat().attribute(i).copy());
}
}
Instances outputFormat =
new Instances(getInputFormat().relationName(), attributes, 0);
outputFormat.setClassIndex(classIndex);
setOutputFormat(outputFormat);
}
/**
* Convert a single instance over. The converted instance is added to
* the end of the output queue.
*
* @param instance the instance to convert
*/
protected void convertInstance(Instance instance) {
int index = 0;
double [] vals = new double [outputFormatPeek().numAttributes()];
// Copy and convert the values
for(int i = 0; i < getInputFormat().numAttributes(); i++) {
if (m_DiscretizeCols.isInRange(i) &&
getInputFormat().attribute(i).isNumeric()) {
int j;
double currentVal = instance.value(i);
if (m_CutPoints[i] == null) {
if (instance.isMissing(i)) {
vals[index] = Instance.missingValue();
} else {
vals[index] = 0;
}
index++;
} else {
if (!m_MakeBinary) {
if (instance.isMissing(i)) {
vals[index] = Instance.missingValue();
} else {
for (j = 0; j < m_CutPoints[i].length; j++) {
if (currentVal <= m_CutPoints[i][j]) {
break;
}
}
vals[index] = j;
}
index++;
} else {
for (j = 0; j < m_CutPoints[i].length; j++) {
if (instance.isMissing(i)) {
vals[index] = Instance.missingValue();
} else if (currentVal <= m_CutPoints[i][j]) {
vals[index] = 0;
} else {
vals[index] = 1;
}
index++;
}
}
}
} else {
vals[index] = instance.value(i);
index++;
}
}
Instance inst = null;
if (instance instanceof SparseInstance) {
inst = new SparseInstance(instance.weight(), vals);
} else {
inst = new Instance(instance.weight(), vals);
}
copyStringValues(inst, false, instance.dataset(), getInputStringIndex(),
getOutputFormat(), getOutputStringIndex());
inst.setDataset(getOutputFormat());
push(inst);
}
/**
* Main method for testing this class.
*
* @param argv should contain arguments to the filter: use -h for help
*/
public static void main(String [] argv) {
try {
if (Utils.getFlag('b', argv)) {
Filter.batchFilterFile(new Discretize(), argv);
} else {
Filter.filterFile(new Discretize(), argv);
}
} catch (Exception ex) {
System.out.println(ex.getMessage());
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -