📄 vfi.java
字号:
* @param instances set of instances serving as training data
* @exception Exception if the classifier has not been generated successfully
*/
public void buildClassifier(Instances instances) throws Exception {
if (!m_weightByConfidence) {
TINY = 0.0;
}
if (instances.classIndex() == -1) {
throw new Exception("No class attribute assigned");
}
if (!instances.classAttribute().isNominal()) {
throw new UnsupportedClassTypeException("VFI: class attribute needs to be nominal!");
}
instances = new Instances(instances);
instances.deleteWithMissingClass();
m_ClassIndex = instances.classIndex();
m_NumClasses = instances.numClasses();
m_globalCounts = new double [m_NumClasses];
m_maxEntrop = Math.log(m_NumClasses) / Math.log(2);
m_Instances = new Instances(instances, 0); // Copy the structure for ref
m_intervalBounds =
new double[instances.numAttributes()][2+(2*m_NumClasses)];
for (int j = 0; j < instances.numAttributes(); j++) {
boolean alt = false;
for (int i = 0; i < m_NumClasses*2+2; i++) {
if (i == 0) {
m_intervalBounds[j][i] = Double.NEGATIVE_INFINITY;
} else if (i == m_NumClasses*2+1) {
m_intervalBounds[j][i] = Double.POSITIVE_INFINITY;
} else {
if (alt) {
m_intervalBounds[j][i] = Double.NEGATIVE_INFINITY;
alt = false;
} else {
m_intervalBounds[j][i] = Double.POSITIVE_INFINITY;
alt = true;
}
}
}
}
// find upper and lower bounds for numeric attributes
for (int j = 0; j < instances.numAttributes(); j++) {
if (j != m_ClassIndex && instances.attribute(j).isNumeric()) {
for (int i = 0; i < instances.numInstances(); i++) {
Instance inst = instances.instance(i);
if (!inst.isMissing(j)) {
if (inst.value(j) <
m_intervalBounds[j][((int)inst.classValue()*2+1)]) {
m_intervalBounds[j][((int)inst.classValue()*2+1)] =
inst.value(j);
}
if (inst.value(j) >
m_intervalBounds[j][((int)inst.classValue()*2+2)]) {
m_intervalBounds[j][((int)inst.classValue()*2+2)] =
inst.value(j);
}
}
}
}
}
m_counts = new double [instances.numAttributes()][][];
// sort intervals
for (int i = 0 ; i < instances.numAttributes(); i++) {
if (instances.attribute(i).isNumeric()) {
int [] sortedIntervals = Utils.sort(m_intervalBounds[i]);
// remove any duplicate bounds
int count = 1;
for (int j = 1; j < sortedIntervals.length; j++) {
if (m_intervalBounds[i][sortedIntervals[j]] !=
m_intervalBounds[i][sortedIntervals[j-1]]) {
count++;
}
}
double [] reordered = new double [count];
count = 1;
reordered[0] = m_intervalBounds[i][sortedIntervals[0]];
for (int j = 1; j < sortedIntervals.length; j++) {
if (m_intervalBounds[i][sortedIntervals[j]] !=
m_intervalBounds[i][sortedIntervals[j-1]]) {
reordered[count] = m_intervalBounds[i][sortedIntervals[j]];
count++;
}
}
m_intervalBounds[i] = reordered;
m_counts[i] = new double [count][m_NumClasses];
} else if (i != m_ClassIndex) { // nominal attribute
m_counts[i] =
new double [instances.attribute(i).numValues()][m_NumClasses];
}
}
// collect class counts
for (int i = 0; i < instances.numInstances(); i++) {
Instance inst = instances.instance(i);
m_globalCounts[(int)instances.instance(i).classValue()] += inst.weight();
for (int j = 0; j < instances.numAttributes(); j++) {
if (!inst.isMissing(j) && j != m_ClassIndex) {
if (instances.attribute(j).isNumeric()) {
double val = inst.value(j);
int k;
boolean ok = false;
for (k = m_intervalBounds[j].length-1; k >= 0; k--) {
if (val > m_intervalBounds[j][k]) {
ok = true;
m_counts[j][k][(int)inst.classValue()] += inst.weight();
break;
} else if (val == m_intervalBounds[j][k]) {
ok = true;
m_counts[j][k][(int)inst.classValue()] +=
(inst.weight() / 2.0);
m_counts[j][k-1][(int)inst.classValue()] +=
(inst.weight() / 2.0);;
break;
}
}
} else {
// nominal attribute
m_counts[j][(int)inst.value(j)][(int)inst.classValue()] +=
inst.weight();;
}
}
}
}
}
/**
* Returns a description of this classifier.
*
* @return a description of this classifier as a string.
*/
public String toString() {
if (m_Instances == null) {
return "FVI: Classifier not built yet!";
}
StringBuffer sb =
new StringBuffer("Voting feature intervals classifier\n");
/* Output the intervals and class counts for each attribute */
/* for (int i = 0; i < m_Instances.numAttributes(); i++) {
if (i != m_ClassIndex) {
sb.append("\n"+m_Instances.attribute(i).name()+" :\n");
if (m_Instances.attribute(i).isNumeric()) {
for (int j = 0; j < m_intervalBounds[i].length; j++) {
sb.append(m_intervalBounds[i][j]).append("\n");
if (j != m_intervalBounds[i].length-1) {
for (int k = 0; k < m_NumClasses; k++) {
sb.append(m_counts[i][j][k]+" ");
}
}
sb.append("\n");
}
} else {
for (int j = 0; j < m_Instances.attribute(i).numValues(); j++) {
sb.append(m_Instances.attribute(i).value(j)).append("\n");
for (int k = 0; k < m_NumClasses; k++) {
sb.append(m_counts[i][j][k]+" ");
}
sb.append("\n");
}
}
}
} */
return sb.toString();
}
/**
* Classifies the given test instance.
*
* @param instance the instance to be classified
* @return the predicted class for the instance
* @exception Exception if the instance can't be classified
*/
public double [] distributionForInstance(Instance instance)
throws Exception {
double [] dist = new double[m_NumClasses];
double [] temp = new double[m_NumClasses];
double totalWeight = 0.0;
double weight = 1.0;
for (int i = 0; i < instance.numAttributes(); i++) {
if (i != m_ClassIndex && !instance.isMissing(i)) {
double val = instance.value(i);
boolean ok = false;
if (instance.attribute(i).isNumeric()) {
int k;
for (k = m_intervalBounds[i].length-1; k >= 0; k--) {
if (val > m_intervalBounds[i][k]) {
for (int j = 0; j < m_NumClasses; j++) {
if (m_globalCounts[j] > 0) {
temp[j] = ((m_counts[i][k][j]+TINY) /
(m_globalCounts[j]+TINY));
}
}
ok = true;
break;
} else if (val == m_intervalBounds[i][k]) {
for (int j = 0; j < m_NumClasses; j++) {
if (m_globalCounts[j] > 0) {
temp[j] = ((m_counts[i][k][j] + m_counts[i][k-1][j]) / 2.0) +
TINY;
temp[j] /= (m_globalCounts[j]+TINY);
}
}
ok = true;
break;
}
}
if (!ok) {
throw new Exception("This shouldn't happen");
}
} else { // nominal attribute
ok = true;
for (int j = 0; j < m_NumClasses; j++) {
if (m_globalCounts[j] > 0) {
temp[j] = ((m_counts[i][(int)val][j]+TINY) /
(m_globalCounts[j]+TINY));
}
}
}
double sum = Utils.sum(temp);
if (sum <= 0) {
for (int j = 0; j < temp.length; j++) {
temp[j] = 1.0 / (double)temp.length;
}
} else {
Utils.normalize(temp, sum);
}
if (m_weightByConfidence) {
weight = weka.core.ContingencyTables.entropy(temp);
weight = Math.pow(weight, m_bias);
if (weight < 1.0) {
weight = 1.0;
}
}
for (int j = 0; j < m_NumClasses; j++) {
dist[j] += (temp[j] * weight);
}
}
}
double sum = Utils.sum(dist);
if (sum <= 0) {
for (int j = 0; j < dist.length; j++) {
dist[j] = 1.0 / (double)dist.length;
}
return dist;
} else {
Utils.normalize(dist, sum);
return dist;
}
}
/**
* Main method for testing this class.
*
* @param args should contain command line arguments for evaluation
* (see Evaluation).
*/
public static void main(String [] args) {
try {
System.out.println(Evaluation.evaluateModel(new VFI(), args));
} catch (Exception e) {
e.printStackTrace();
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -