📄 vfi.java
字号:
} return options; } /** * Generates the classifier. * * @param instances set of instances serving as training data * @exception Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances instances) throws Exception { if (!m_weightByConfidence) { TINY = 0.0; } if (instances.classIndex() == -1) { throw new Exception("No class attribute assigned"); } if (!instances.classAttribute().isNominal()) { throw new UnsupportedClassTypeException("VFI: class attribute needs to be nominal!"); } instances = new Instances(instances); instances.deleteWithMissingClass(); m_ClassIndex = instances.classIndex(); m_NumClasses = instances.numClasses(); m_globalCounts = new double [m_NumClasses]; m_maxEntrop = Math.log(m_NumClasses) / Math.log(2); m_Instances = new Instances(instances, 0); // Copy the structure for ref m_intervalBounds = new double[instances.numAttributes()][2+(2*m_NumClasses)]; for (int j = 0; j < instances.numAttributes(); j++) { boolean alt = false; for (int i = 0; i < m_NumClasses*2+2; i++) { if (i == 0) { m_intervalBounds[j][i] = Double.NEGATIVE_INFINITY; } else if (i == m_NumClasses*2+1) { m_intervalBounds[j][i] = Double.POSITIVE_INFINITY; } else { if (alt) { m_intervalBounds[j][i] = Double.NEGATIVE_INFINITY; alt = false; } else { m_intervalBounds[j][i] = Double.POSITIVE_INFINITY; alt = true; } } } } // find upper and lower bounds for numeric attributes for (int j = 0; j < instances.numAttributes(); j++) { if (j != m_ClassIndex && instances.attribute(j).isNumeric()) { for (int i = 0; i < instances.numInstances(); i++) { Instance inst = instances.instance(i); if (!inst.isMissing(j)) { if (inst.value(j) < m_intervalBounds[j][((int)inst.classValue()*2+1)]) { m_intervalBounds[j][((int)inst.classValue()*2+1)] = inst.value(j); } if (inst.value(j) > m_intervalBounds[j][((int)inst.classValue()*2+2)]) { m_intervalBounds[j][((int)inst.classValue()*2+2)] = inst.value(j); } } } } } m_counts = new double [instances.numAttributes()][][]; // sort intervals for (int i = 0 ; i < instances.numAttributes(); i++) { if (instances.attribute(i).isNumeric()) { int [] sortedIntervals = Utils.sort(m_intervalBounds[i]); // remove any duplicate bounds int count = 1; for (int j = 1; j < sortedIntervals.length; j++) { if (m_intervalBounds[i][sortedIntervals[j]] != m_intervalBounds[i][sortedIntervals[j-1]]) { count++; } } double [] reordered = new double [count]; count = 1; reordered[0] = m_intervalBounds[i][sortedIntervals[0]]; for (int j = 1; j < sortedIntervals.length; j++) { if (m_intervalBounds[i][sortedIntervals[j]] != m_intervalBounds[i][sortedIntervals[j-1]]) { reordered[count] = m_intervalBounds[i][sortedIntervals[j]]; count++; } } m_intervalBounds[i] = reordered; m_counts[i] = new double [count][m_NumClasses]; } else if (i != m_ClassIndex) { // nominal attribute m_counts[i] = new double [instances.attribute(i).numValues()][m_NumClasses]; } } // collect class counts for (int i = 0; i < instances.numInstances(); i++) { Instance inst = instances.instance(i); m_globalCounts[(int)instances.instance(i).classValue()] += inst.weight(); for (int j = 0; j < instances.numAttributes(); j++) { if (!inst.isMissing(j) && j != m_ClassIndex) { if (instances.attribute(j).isNumeric()) { double val = inst.value(j); int k; boolean ok = false; for (k = m_intervalBounds[j].length-1; k >= 0; k--) { if (val > m_intervalBounds[j][k]) { ok = true; m_counts[j][k][(int)inst.classValue()] += inst.weight(); break; } else if (val == m_intervalBounds[j][k]) { ok = true; m_counts[j][k][(int)inst.classValue()] += (inst.weight() / 2.0); m_counts[j][k-1][(int)inst.classValue()] += (inst.weight() / 2.0);; break; } } } else { // nominal attribute m_counts[j][(int)inst.value(j)][(int)inst.classValue()] += inst.weight();; } } } } } /** * Returns a description of this classifier. * * @return a description of this classifier as a string. */ public String toString() { if (m_Instances == null) { return "FVI: Classifier not built yet!"; } StringBuffer sb = new StringBuffer("Voting feature intervals classifier\n"); /* Output the intervals and class counts for each attribute */ /* for (int i = 0; i < m_Instances.numAttributes(); i++) { if (i != m_ClassIndex) { sb.append("\n"+m_Instances.attribute(i).name()+" :\n"); if (m_Instances.attribute(i).isNumeric()) { for (int j = 0; j < m_intervalBounds[i].length; j++) { sb.append(m_intervalBounds[i][j]).append("\n"); if (j != m_intervalBounds[i].length-1) { for (int k = 0; k < m_NumClasses; k++) { sb.append(m_counts[i][j][k]+" "); } } sb.append("\n"); } } else { for (int j = 0; j < m_Instances.attribute(i).numValues(); j++) { sb.append(m_Instances.attribute(i).value(j)).append("\n"); for (int k = 0; k < m_NumClasses; k++) { sb.append(m_counts[i][j][k]+" "); } sb.append("\n"); } } } } */ return sb.toString(); } /** * Classifies the given test instance. * * @param instance the instance to be classified * @return the predicted class for the instance * @exception Exception if the instance can't be classified */ public double [] distributionForInstance(Instance instance) throws Exception { double [] dist = new double[m_NumClasses]; double [] temp = new double[m_NumClasses]; double totalWeight = 0.0; double weight = 1.0; for (int i = 0; i < instance.numAttributes(); i++) { if (i != m_ClassIndex && !instance.isMissing(i)) { double val = instance.value(i); boolean ok = false; if (instance.attribute(i).isNumeric()) { int k; for (k = m_intervalBounds[i].length-1; k >= 0; k--) { if (val > m_intervalBounds[i][k]) { for (int j = 0; j < m_NumClasses; j++) { if (m_globalCounts[j] > 0) { temp[j] = ((m_counts[i][k][j]+TINY) / (m_globalCounts[j]+TINY)); } } ok = true; break; } else if (val == m_intervalBounds[i][k]) { for (int j = 0; j < m_NumClasses; j++) { if (m_globalCounts[j] > 0) { temp[j] = ((m_counts[i][k][j] + m_counts[i][k-1][j]) / 2.0) + TINY; temp[j] /= (m_globalCounts[j]+TINY); } } ok = true; break; } } if (!ok) { throw new Exception("This shouldn't happen"); } } else { // nominal attribute ok = true; for (int j = 0; j < m_NumClasses; j++) { if (m_globalCounts[j] > 0) { temp[j] = ((m_counts[i][(int)val][j]+TINY) / (m_globalCounts[j]+TINY)); } } } Utils.normalize(temp); if (m_weightByConfidence) { weight = weka.core.ContingencyTables.entropy(temp); weight = Math.pow(weight, m_bias); if (weight < 1.0) { weight = 1.0; } } for (int j = 0; j < m_NumClasses; j++) { dist[j] += (temp[j] * weight); } } } Utils.normalize(dist); return dist; } /** * Main method for testing this class. * * @param args should contain command line arguments for evaluation * (see Evaluation). */ public static void main(String [] args) { try { System.out.println(Evaluation.evaluateModel(new VFI(), args)); } catch (Exception e) { e.printStackTrace(); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -