📄 discretizefilter.java
字号:
int first, int lastPlusOne) throws Exception {
double[][] counts, bestCounts;
double[] priorCounts, left, right, cutPoints;
double currentCutPoint = -Double.MAX_VALUE, bestCutPoint = -1,
currentEntropy, bestEntropy, priorEntropy, gain;
int bestIndex = -1, numInstances = 0, numCutPoints = 0;
// Compute number of instances in set
if ((lastPlusOne - first) < 2) {
return null;
}
// Compute class counts.
counts = new double[2][instances.numClasses()];
for (int i = first; i < lastPlusOne; i++) {
numInstances += instances.instance(i).weight();
counts[1][(int)instances.instance(i).classValue()] +=
instances.instance(i).weight();
}
// Save prior counts
priorCounts = new double[instances.numClasses()];
System.arraycopy(counts[1], 0, priorCounts, 0,
instances.numClasses());
// Entropy of the full set
priorEntropy = ContingencyTables.entropy(priorCounts);
bestEntropy = priorEntropy;
// Find best entropy.
bestCounts = new double[2][instances.numClasses()];
for (int i = first; i < (lastPlusOne - 1); i++) {
counts[0][(int)instances.instance(i).classValue()] +=
instances.instance(i).weight();
counts[1][(int)instances.instance(i).classValue()] -=
instances.instance(i).weight();
if (Utils.sm(instances.instance(i).value(attIndex),
instances.instance(i + 1).value(attIndex))) {
currentCutPoint = instances.instance(i).value(attIndex); //+
//instances.instance(i + 1).value(attIndex)) / 2.0;
currentEntropy = ContingencyTables.entropyConditionedOnRows(counts);
if (Utils.sm(currentEntropy, bestEntropy)) {
bestCutPoint = currentCutPoint;
bestEntropy = currentEntropy;
bestIndex = i;
System.arraycopy(counts[0], 0,
bestCounts[0], 0, instances.numClasses());
System.arraycopy(counts[1], 0,
bestCounts[1], 0, instances.numClasses());
}
numCutPoints++;
}
}
// Use worse encoding?
if (!m_UseBetterEncoding) {
numCutPoints = (lastPlusOne - first) - 1;
}
// Checks if gain is zero
gain = priorEntropy - bestEntropy;
if (Utils.eq(gain, 0)) {
return null;
}
// Check if split is to be accepted
if ((m_UseKononenko && KononenkosMDL(priorCounts, bestCounts,
numInstances, numCutPoints)) ||
(!m_UseKononenko && FayyadAndIranisMDL(priorCounts, bestCounts,
numInstances, numCutPoints))) {
// Select split points for the left and right subsets
left = cutPointsForSubset(instances, attIndex, first, bestIndex + 1);
right = cutPointsForSubset(instances, attIndex,
bestIndex + 1, lastPlusOne);
// Merge cutpoints and return them
if ((left == null) && (right) == null) {
cutPoints = new double[1];
cutPoints[0] = bestCutPoint;
} else if (right == null) {
cutPoints = new double[left.length + 1];
System.arraycopy(left, 0, cutPoints, 0, left.length);
cutPoints[left.length] = bestCutPoint;
} else if (left == null) {
cutPoints = new double[1 + right.length];
cutPoints[0] = bestCutPoint;
System.arraycopy(right, 0, cutPoints, 1, right.length);
} else {
cutPoints = new double[left.length + right.length + 1];
System.arraycopy(left, 0, cutPoints, 0, left.length);
cutPoints[left.length] = bestCutPoint;
System.arraycopy(right, 0, cutPoints, left.length + 1, right.length);
}
return cutPoints;
} else
return null;
}
/**
* Set cutpoints for a single attribute.
*
* @param index the index of the attribute to set cutpoints for
*/
protected void calculateCutPointsByEqualWidthBinning(int index) {
// Scan for max and min values
double max = 0, min = 1, currentVal;
Instance currentInstance;
for(int i = 0; i < getInputFormat().numInstances(); i++) {
currentInstance = getInputFormat().instance(i);
if (!currentInstance.isMissing(index)) {
currentVal = currentInstance.value(index);
if (max < min) {
max = min = currentVal;
}
if (currentVal > max) {
max = currentVal;
}
if (currentVal < min) {
min = currentVal;
}
}
}
double binWidth = (max - min) / m_NumBins;
double [] cutPoints = null;
if ((m_NumBins > 1) && (binWidth > 0)) {
cutPoints = new double [m_NumBins - 1];
for(int i = 1; i < m_NumBins; i++) {
cutPoints[i - 1] = min + binWidth * i;
}
}
m_CutPoints[index] = cutPoints;
}
/**
* Set cutpoints for a single attribute.
*
* @param index the index of the attribute to set cutpoints for
*/
protected void calculateCutPointsByEqualFrequencyBinning(int index) {
// Copy data so that it can be sorted
Instances data = new Instances(getInputFormat());
// Sort input data
data.sort(index);
// Compute weight of instances without missing values
double sumOfWeights = 0;
for (int i = 0; i < data.numInstances(); i++) {
if (data.instance(i).isMissing(index)) {
break;
} else {
sumOfWeights += data.instance(i).weight();
}
}
double freq = sumOfWeights / m_NumBins;
// Compute break points
double[] cutPoints = new double[m_NumBins - 1];
double counter = 0;
int cpindex = 0;
for (int i = 0; i < data.numInstances() - 1; i++) {
// Stop if value missing
if (data.instance(i).isMissing(index)) {
break;
}
counter += data.instance(i).weight();
// Do we have a potential breakpoint?
if (data.instance(i).value(index) <
data.instance(i + 1).value(index)) {
if (counter >= freq) {
cutPoints[cpindex] = (data.instance(i).value(index) +
data.instance(i + 1).value(index)) / 2;
cpindex++;
counter = counter - freq;
}
}
}
// Did we find any cutpoints?
if (cpindex == 0) {
m_CutPoints[index] = null;
} else {
double[] cp = new double[cpindex];
for (int i = 0; i < cpindex; i++) {
cp[i] = cutPoints[i];
}
m_CutPoints[index] = cp;
}
}
/**
* Optimizes the number of bins using leave-one-out cross-validation.
*
* @param index the attribute index
*/
protected void findNumBins(int index) {
double min = Double.MAX_VALUE, max = -Double.MIN_VALUE, binWidth = 0,
entropy, bestEntropy = Double.MAX_VALUE, currentVal;
double[] distribution;
int bestNumBins = 1;
Instance currentInstance;
// Find minimum and maximum
for (int i = 0; i < getInputFormat().numInstances(); i++) {
currentInstance = getInputFormat().instance(i);
if (!currentInstance.isMissing(index)) {
currentVal = currentInstance.value(index);
if (currentVal > max) {
max = currentVal;
}
if (currentVal < min) {
min = currentVal;
}
}
}
// Find best number of bins
for (int i = 0; i < m_NumBins; i++) {
distribution = new double[i + 1];
binWidth = (max - min) / (i + 1);
// Compute distribution
for (int j = 0; j < getInputFormat().numInstances(); j++) {
currentInstance = getInputFormat().instance(j);
if (!currentInstance.isMissing(index)) {
for (int k = 0; k < i + 1; k++) {
if (currentInstance.value(index) <=
(min + (((double)k + 1) * binWidth))) {
distribution[k] += currentInstance.weight();
break;
}
}
}
}
// Compute cross-validated entropy
entropy = 0;
for (int k = 0; k < i + 1; k++) {
if (distribution[k] < 2) {
entropy = Double.MAX_VALUE;
break;
}
entropy -= distribution[k] * Math.log((distribution[k] - 1) /
binWidth);
}
// Best entropy so far?
if (entropy < bestEntropy) {
bestEntropy = entropy;
bestNumBins = i + 1;
}
}
// Compute cut points
double [] cutPoints = null;
if ((bestNumBins > 1) && (binWidth > 0)) {
cutPoints = new double [bestNumBins - 1];
for(int i = 1; i < bestNumBins; i++) {
cutPoints[i - 1] = min + binWidth * i;
}
}
m_CutPoints[index] = cutPoints;
}
/**
* Set the output format. Takes the currently defined cutpoints and
* m_InputFormat and calls setOutputFormat(Instances) appropriately.
*/
protected void setOutputFormat() throws Exception{
if (m_CutPoints == null) {
setOutputFormat(null);
return;
}
FastVector attributes = new FastVector(getInputFormat().numAttributes());
int classIndex = getInputFormat().classIndex();
for(int i = 0; i < getInputFormat().numAttributes(); i++) {
if ((m_DiscretizeCols.isInRange(i))
&& (getInputFormat().attribute(i).isNumeric())) {
if (!m_MakeBinary) {
FastVector attribValues = new FastVector(1);
if (m_CutPoints[i] == null) {
attribValues.addElement("'All'");
} else {
for(int j = 0; j <= m_CutPoints[i].length; j++) {
if (j == 0) {
attribValues.addElement("'(-inf-"
+ Utils.doubleToString(m_CutPoints[i][j], 6) + "]'");
} else if (j == m_CutPoints[i].length) {
attribValues.addElement("'("
+ Utils.doubleToString(m_CutPoints[i][j - 1], 6)
+ "-inf)'");
} else {
attribValues.addElement("'("
+ Utils.doubleToString(m_CutPoints[i][j - 1], 6) + "-"
+ Utils.doubleToString(m_CutPoints[i][j], 6) + "]'");
}
}
}
attributes.addElement(new Attribute(getInputFormat().
attribute(i).name(),
attribValues));
} else {
if (m_CutPoints[i] == null) {
FastVector attribValues = new FastVector(1);
attribValues.addElement("'All'");
attributes.addElement(new Attribute(getInputFormat().
attribute(i).name(),
attribValues));
} else {
if (i < getInputFormat().classIndex()) {
classIndex += m_CutPoints[i].length - 1;
}
for(int j = 0; j < m_CutPoints[i].length; j++) {
FastVector attribValues = new FastVector(2);
attribValues.addElement("'(-inf-"
+ Utils.doubleToString(m_CutPoints[i][j], 6) + "]'");
attribValues.addElement("'("
+ Utils.doubleToString(m_CutPoints[i][j], 6) + "-inf)'");
attributes.addElement(new Attribute(getInputFormat().
attribute(i).name(),
attribValues));
}
}
}
} else {
attributes.addElement(getInputFormat().attribute(i).copy());
}
}
Instances outputFormat =
new Instances(getInputFormat().relationName(), attributes, 0);
outputFormat.setClassIndex(classIndex);
setOutputFormat(outputFormat);
}
/**
* Convert a single instance over. The converted instance is added to
* the end of the output queue.
*
* @param instance the instance to convert
*/
protected void convertInstance(Instance instance) {
int index = 0;
double [] vals = new double [outputFormatPeek().numAttributes()];
// Copy and convert the values
for(int i = 0; i < getInputFormat().numAttributes(); i++) {
if (m_DiscretizeCols.isInRange(i) &&
getInputFormat().attribute(i).isNumeric()) {
int j;
double currentVal = instance.value(i);
if (m_CutPoints[i] == null) {
if (instance.isMissing(i)) {
vals[index] = Instance.missingValue();
} else {
vals[index] = 0;
}
index++;
} else {
if (!m_MakeBinary) {
if (instance.isMissing(i)) {
vals[index] = Instance.missingValue();
} else {
for (j = 0; j < m_CutPoints[i].length; j++) {
if (currentVal <= m_CutPoints[i][j]) {
break;
}
}
vals[index] = j;
}
index++;
} else {
for (j = 0; j < m_CutPoints[i].length; j++) {
if (instance.isMissing(i)) {
vals[index] = Instance.missingValue();
} else if (currentVal <= m_CutPoints[i][j]) {
vals[index] = 0;
} else {
vals[index] = 1;
}
index++;
}
}
}
} else {
vals[index] = instance.value(i);
index++;
}
}
Instance inst = null;
if (instance instanceof SparseInstance) {
inst = new SparseInstance(instance.weight(), vals);
} else {
inst = new Instance(instance.weight(), vals);
}
copyStringValues(inst, false, instance.dataset(), getInputStringIndex(),
getOutputFormat(), getOutputStringIndex());
inst.setDataset(getOutputFormat());
push(inst);
}
/**
* Main method for testing this class.
*
* @param argv should contain arguments to the filter: use -h for help
*/
public static void main(String [] argv) {
try {
if (Utils.getFlag('b', argv)) {
Filter.batchFilterFile(new DiscretizeFilter(), argv);
} else {
Filter.filterFile(new DiscretizeFilter(), argv);
}
} catch (Exception ex) {
log.error(ex.getMessage());
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -