📄 alternatingtree.java
字号:
package jboost.atree;import java.io.BufferedReader;import java.io.FileNotFoundException;import java.io.FileReader;import java.io.IOException;import java.io.Serializable;import java.util.ArrayList;import java.util.Date;import java.util.HashMap;import java.util.Iterator;import java.util.Map;import java.util.Set;import java.util.Stack;import java.util.StringTokenizer;import java.util.TreeMap;import java.util.TreeSet;import java.util.Vector;import jboost.ComplexLearner;import jboost.WritablePredictor;import jboost.booster.Booster;import jboost.booster.Prediction;import jboost.examples.AttributeDescription;import jboost.examples.ExampleDescription;import jboost.examples.Instance;import jboost.learner.IncompAttException;import jboost.learner.SplitterBuilder;import jboost.learner.Summary;/** An Alternating Tree classifier * * @author Nigel Duffy */public class AlternatingTree implements WritablePredictor,Serializable { /** The default constructor * * @param the root node of the tree */ public AlternatingTree(PredictorNode r) { root=r; } /** A null constructor */ public AlternatingTree() { root=null; } /** Make a prediction */ public Prediction predict(Instance instance) throws IncompAttException { Prediction retval=root.predict(instance); return(retval); } /** Make a prediction */ public Prediction predict(Instance instance, int numIters) throws IncompAttException { //return predict(instance); return orderPredict(instance, numIters); } /** Make a iteration orderd prediction */ public Prediction orderPredict(Instance instance, int numIters) throws IncompAttException { Prediction retval=root.orderPredict(instance, numIters); return(retval); } /** Generate a textual explanation of the prediction */ public String explain(Instance instance) throws IncompAttException { return root.explain(instance); } /** Make a classification * I am not sure that this function should exist. */ public int classify(Instance instance) { return(-1); } /** * Convert to a human readable form. */ public String toString() { String retval=new String("----------------------------------------------\n"); retval+=root; retval+="----------------------------------------------\n"; return(retval); } /** Converts this AlternatingTree to java */ public String toJava(String cname, String fname, String specFileName, ExampleDescription exampleDescription) throws FileNotFoundException, IOException { String code = ""; tokenMap = new HashMap(); tokenList = new Vector(); numTokens = 0; maxTextAttr = 0; maxAttr = 0; realAttrs = new TreeSet(); discreteAttrs = new TreeSet(); textAttrs = new TreeMap(); String fname_int = fname + "_int"; code += "" + "/**\n" + " This class provides static methods for evaluating a jboost-trained\n" + " classifier on new data. This part of the code can stand by itself.\n" + "\n"; if (specFileName != null) { code += "" + " In addition, this class includes a main which, when run, reads data\n" + " from standard input of the same form as that used during training,\n" + " and outputs corresponding predictions to standard output. This\n" + " part of the code requires other jboost classes.\n" + "\n"; } code += "" + " This classifier was automatically generated by jboost on\n" + " " + (new Date()) + ".\n" + "**/\n" + "\n" + "import java.util.*;\n" + "\n" + "public class " + cname + "{\n" + "\n" + " /**\n" + " Evaluates this classifier on an example represented by an array\n" + " of Objects and returns an array of scores, one for each class.\n" + " Finite attributes must be represented by an Integer specifying\n" + " the index of the chosen value. Text attributes are given by a\n" + " String. Number attributes are represented by a Double. In all\n" + " cases, an undefined attribute is indicated by a null pointer.\n" + " @param at an array of Objects corresponding to the attributes\n" + " specified in the spec file on which this classifier\n" + " was trained. Specifically, these objects are:\n" + "<pre>\n" + " * index attr.type data.type name\n" + " * ------------------------------------------\n"; AttributeDescription[] ad = exampleDescription.getAttributes(); for (int i = 0; i < ad.length; i++) { String s; String key = ""; String t = ad[i].getType(); if (t.equals("number")) { s = "number Double "; } else if (t.equals("text")) { s = "text String "; } else if (t.equals("finite")) { s = "finite Integer "; for (int j = 0; j < ad[i].getNoOfValues(); j++) key += (j == 0 ? " * key: " : " * ") + padInteger(j, 5) + " = " + ad[i].getAttributeValue(j) + "\n"; } else { System.err.println("Warning: unrecognized type for attribute " + i + ": " + t); s = "??? ??? "; } code += " * " + padInteger(i, 5) + " " + s + " " + ad[i].getAttributeName() + "\n" + key; } String label_code = ""; label_code += "" + "</pre>\n" + " @return an array of scores correpsonding to the classes:\n" + "<pre>\n" + " * index class name\n" + " * ------------------------\n"; AttributeDescription la = exampleDescription.getLabelDescription(); for (int j = 0; j < la.getNoOfValues(); j++) label_code += " * " + padInteger(j, 5) + " " + la.getAttributeValue(j) + "\n"; code += label_code + "</pre>\n" + " **/\n" + " static public double[] predict(Object[] at) {\n" + " attr = at;\n" + " int i,j,a,n;\n" + " StringTokenizer st;\n" + " String s;\n" + " Object v;\n" + " Enumeration e;\n" + " for (i = 0; i < num_text_attr; i++) {\n" + " a = text_attr[i];\n" + " if (!defined_attr(a))\n" + " continue;\n" + " Arrays.fill(tokens[a], false);\n" + " try {\n" + " s = (String) attr[a];\n" + " }\n" + " catch (ClassCastException ex) {\n" + " throw new IllegalArgumentException\n" + " (\"Expected attribute \" + a + \" to be of type String\");\n" + " }\n" + " st = new StringTokenizer(s);\n" + " n = st.countTokens();\n" + " String[] words = new String[n];\n" + " for (j = 0; j < n; j++)\n" + " words[j] = st.nextToken();\n" + " for (j = 0; j < text_patterns[i].length; j++) {\n" + " setPattern(words, text_patterns[i][j]);\n" + " while(moreTokens())\n" + " if ((v = hash.get(nextToken())) != null)\n" + " tokens[a][((Integer) v).intValue()] = true;\n" + " }\n" + " }\n" + " return " + fname_int + "();\n" + " }\n" + "\n" + " /**\n" + " Evaluates this classifier on an example represented by an array\n" + " of Strings and returns an array of scores, one for each class.\n" + " These Strings represent the values of the attributes similar\n" + " to their representation in a data file. Null pointers can be\n" + " passed for undefined attributes.\n" + "\n" + " @param at an array of Objects corresponding to the attributes\n" + " specified in the spec file on which this classifier\n" + " was trained. Specifically, these objects are:\n" + "<pre>\n" + " * index attr.type name\n" + " * ------------------------------------------\n"; for (int i = 0; i < ad.length; i++) { String s; String key = ""; String t = ad[i].getType(); if (t.equals("number")) { s = "number "; } else if (t.equals("text")) { s = "text "; } else if (t.equals("finite")) { s = "finite "; for (int j = 0; j < ad[i].getNoOfValues(); j++) key += (j == 0 ? " * values: " : " * ") + ad[i].getAttributeValue(j) + "\n"; } else { System.err.println("Warning: unrecognized type for attribute " + i + ": " + t); s = "??? "; } code += " * " + padInteger(i, 5) + " " + s + ad[i].getAttributeName() + "\n" + key; } code += label_code + "</pre>\n" + " **/\n" + " static public double[] " + fname + "(String[] as) {\n" + " int j, a;\n" + " Object v;\n" + " Object[] attr = new Object[as.length];\n" + "\n" + " for (j = 0; j < real_attr.length; j++) {\n" + " a = real_attr[j];\n" + " try{ \n" + " attr[a] = (as[a] == null || as[a].trim().equals(\"\")\n" + " ? null\n" + " : (new Double(as[a])));\n" + " }\n" + " catch (NumberFormatException e) {\n" + " throw new IllegalArgumentException\n" + " (\"Expected attribute \" + a + \" to contain a String parsable as a double\");\n" + " }\n" + " }\n" + " for (j = 0; j < text_attr.length; j++) {\n" + " a = text_attr[j];\n" + " attr[a] = as[a];\n" + " }\n" + " for (j = 0; j < disc_attr.length; j++) {\n" + " String s = null;\n" + " a = disc_attr[j];\n" + " if (as[a] == null || (s = as[a].trim()).equals(\"\"))\n" + " attr[a] = null;\n" + " else if ((v = disc_val_map[j].get(s)) == null) {\n" + " throw new IllegalArgumentException\n" + " (\"Illegal value for attribute \" + a + \":\" + s);\n" + " } else\n" + " attr[a] = ((Integer) v);\n" + " }\n" + "\n" + " return " + fname + "(attr);\n" + " }\n" + "\n" + " static private double[] " + fname_int + "() {\n" + " reset_pred();\n" + makeCode(root, " ") + "\n" + " return finalize_pred();\n" + " }\n"; code += " static private String[] keys = {\n"; for (int i = 0; i < numTokens; i++) code += " \"" + checkChar((String)tokenList.get(i)) + "\",\n"; code += " };\n" + " static private final int num_keys = " + numTokens + ";\n" + " static private boolean[][] tokens = new boolean[" + (maxTextAttr+1) + "][];\n" + " static private int text_attr[] = {"; for (Iterator i = textAttrs.keySet().iterator(); i.hasNext(); ) code += ((Integer) i.next()) + ","; code += " };\n" + " static private final int num_text_attr = " + textAttrs.size() + ";\n" + " static private boolean[][][] text_patterns = {\n"; for (Iterator i = textAttrs.keySet().iterator(); i.hasNext();) { code += " {\n"; for (Iterator j = ((Set) textAttrs.get(i.next())).iterator(); j.hasNext();) { String p = (String) j.next(); code += " {"; int l = p.length(); for (int k = 0; k < l; k++) code += (p.charAt(k) == '1' ? "true," : "false,"); code += "},\n"; } code += " },\n"; } code += " };\n" + " static private int real_attr[] = {"; for (Iterator i = realAttrs.iterator(); i.hasNext(); ) code += ((Integer) i.next()) + ","; code += " };\n" + " static private int disc_attr[] = {"; for (Iterator i = discreteAttrs.iterator(); i.hasNext(); ) code += ((Integer) i.next()) + ","; code += " };\n" + " static private Object[] attr;\n" + " static private Map hash = null;\n" + " static private Map[] disc_val_map = null;\n" + " static private String[][] disc_attr_vals = {\n"; for (Iterator i = discreteAttrs.iterator(); i.hasNext(); ) { AttributeDescription a = exampleDescription. getAttributeDescription(((Integer) i.next()).intValue()); code += " {\n"; for (int j = 0; j < a.getNoOfValues(); j++) code += " \"" + checkChar(a.getAttributeValue(j)) + "\",\n"; code += " },\n"; } code += "" + " };\n" + "\n" + " static {\n" + " disc_val_map = new Map[disc_attr.length];\n" + " for (int i = 0; i < disc_attr.length; i++) {\n" + " disc_val_map[i] = new TreeMap();\n" + " for (int j = 0; j < disc_attr_vals[i].length; j++)\n" + " disc_val_map[i].put(disc_attr_vals[i][j], new Integer(j));\n" + " }\n" + " }\n" + "\n" + " static {\n" + " if (hash == null) {\n" + " hash = new HashMap();\n" + " for (int i = 0; i < num_keys; i++)\n" + " hash.put(keys[i], new Integer(i));\n" + " }\n" + " }\n" + "\n"
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -