📄 id3.java
字号:
* @throws NoSupportForMissingValuesException if instance has missing values */ public double[] distributionForInstance(Instance instance) throws NoSupportForMissingValuesException { if (instance.hasMissingValue()) { throw new NoSupportForMissingValuesException("Id3: no missing values, " + "please."); } if (m_Attribute == null) { return m_Distribution; } else { return m_Successors[(int) instance.value(m_Attribute)]. distributionForInstance(instance); } } /** * Prints the decision tree using the private toString method from below. * * @return a textual description of the classifier */ public String toString() { if ((m_Distribution == null) && (m_Successors == null)) { return "Id3: No model built yet."; } return "Id3\n\n" + toString(0); } /** * Computes information gain for an attribute. * * @param data the data for which info gain is to be computed * @param att the attribute * @return the information gain for the given attribute and data * @throws Exception if computation fails */ private double computeInfoGain(Instances data, Attribute att) throws Exception { double infoGain = computeEntropy(data); Instances[] splitData = splitData(data, att); for (int j = 0; j < att.numValues(); j++) { if (splitData[j].numInstances() > 0) { infoGain -= ((double) splitData[j].numInstances() / (double) data.numInstances()) * computeEntropy(splitData[j]); } } return infoGain; } /** * Computes the entropy of a dataset. * * @param data the data for which entropy is to be computed * @return the entropy of the data's class distribution * @throws Exception if computation fails */ private double computeEntropy(Instances data) throws Exception { double [] classCounts = new double[data.numClasses()]; Enumeration instEnum = data.enumerateInstances(); while (instEnum.hasMoreElements()) { Instance inst = (Instance) instEnum.nextElement(); classCounts[(int) inst.classValue()]++; } double entropy = 0; for (int j = 0; j < data.numClasses(); j++) { if (classCounts[j] > 0) { entropy -= classCounts[j] * Utils.log2(classCounts[j]); } } entropy /= (double) data.numInstances(); return entropy + Utils.log2(data.numInstances()); } /** * Splits a dataset according to the values of a nominal attribute. * * @param data the data which is to be split * @param att the attribute to be used for splitting * @return the sets of instances produced by the split */ private Instances[] splitData(Instances data, Attribute att) { Instances[] splitData = new Instances[att.numValues()]; for (int j = 0; j < att.numValues(); j++) { splitData[j] = new Instances(data, data.numInstances()); } Enumeration instEnum = data.enumerateInstances(); while (instEnum.hasMoreElements()) { Instance inst = (Instance) instEnum.nextElement(); splitData[(int) inst.value(att)].add(inst); } for (int i = 0; i < splitData.length; i++) { splitData[i].compactify(); } return splitData; } /** * Outputs a tree at a certain level. * * @param level the level at which the tree is to be printed * @return the tree as string at the given level */ private String toString(int level) { StringBuffer text = new StringBuffer(); if (m_Attribute == null) { if (Instance.isMissingValue(m_ClassValue)) { text.append(": null"); } else { text.append(": " + m_ClassAttribute.value((int) m_ClassValue)); } } else { for (int j = 0; j < m_Attribute.numValues(); j++) { text.append("\n"); for (int i = 0; i < level; i++) { text.append("| "); } text.append(m_Attribute.name() + " = " + m_Attribute.value(j)); text.append(m_Successors[j].toString(level + 1)); } } return text.toString(); } /** * Adds this tree recursively to the buffer. * * @param id the unqiue id for the method * @param buffer the buffer to add the source code to * @return the last ID being used * @throws Exception if something goes wrong */ protected int toSource(int id, StringBuffer buffer) throws Exception { int result; int i; int newID; StringBuffer[] subBuffers; buffer.append("\n"); buffer.append(" protected static double node" + id + "(Object[] i) {\n"); // leaf? if (m_Attribute == null) { result = id; if (Double.isNaN(m_ClassValue)) buffer.append(" return Double.NaN;"); else buffer.append(" return " + m_ClassValue + ";"); if (m_ClassAttribute != null) buffer.append(" // " + m_ClassAttribute.value((int) m_ClassValue)); buffer.append("\n"); buffer.append(" }\n"); } else { buffer.append(" // " + m_Attribute.name() + "\n"); // subtree calls subBuffers = new StringBuffer[m_Attribute.numValues()]; newID = id; for (i = 0; i < m_Attribute.numValues(); i++) { newID++; buffer.append(" "); if (i > 0) buffer.append("else "); buffer.append("if (((String) i[" + m_Attribute.index() + "]).equals(\"" + m_Attribute.value(i) + "\"))\n"); buffer.append(" return node" + newID + "(i);\n"); subBuffers[i] = new StringBuffer(); newID = m_Successors[i].toSource(newID, subBuffers[i]); } buffer.append(" else\n"); buffer.append(" throw new IllegalArgumentException(\"Value '\" + i[" + m_Attribute.index() + "] + \"' is not allowed!\");\n"); buffer.append(" }\n"); // output subtree code for (i = 0; i < m_Attribute.numValues(); i++) { buffer.append(subBuffers[i].toString()); } subBuffers = null; result = newID; } return result; } /** * Returns a string that describes the classifier as source. The * classifier will be contained in a class with the given name (there may * be auxiliary classes), * and will contain a method with the signature: * <pre><code> * public static double classify(Object[] i); * </code></pre> * where the array <code>i</code> contains elements that are either * Double, String, with missing values represented as null. The generated * code is public domain and comes with no warranty. <br/> * Note: works only if class attribute is the last attribute in the dataset. * * @param className the name that should be given to the source class. * @return the object source described by a string * @throws Exception if the souce can't be computed */ public String toSource(String className) throws Exception { StringBuffer result; int id; result = new StringBuffer(); result.append("class " + className + " {\n"); result.append(" public static double classify(Object[] i) {\n"); id = 0; result.append(" return node" + id + "(i);\n"); result.append(" }\n"); toSource(id, result); result.append("}\n"); return result.toString(); } /** * Returns the revision string. * * @return the revision */ public String getRevision() { return RevisionUtils.extract("$Revision: 1.23 $"); } /** * Main method. * * @param args the options for the classifier */ public static void main(String[] args) { runClassifier(new Id3(), args); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -