📄 decisiontable.java
字号:
} /** * Gets the number of folds for cross validation * * @return the number of cross validation folds */ public int getCrossVal() { return m_CVFolds; } /** * Sets the number of non improving decision tables to consider * before abandoning the search. * * @param stale the number of nodes */ public void setMaxStale(int stale) { m_maxStale = stale; } /** * Gets the number of non improving decision tables * * @return the number of non improving decision tables */ public int getMaxStale() { return m_maxStale; } /** * Sets whether IBk should be used instead of the majority class * * @param ibk true if IBk is to be used */ public void setUseIBk(boolean ibk) { m_useIBk = ibk; } /** * Gets whether IBk is being used instead of the majority class * * @return true if IBk is being used */ public boolean getUseIBk() { return m_useIBk; } /** * Sets whether rules are to be printed * * @param rules true if rules are to be printed */ public void setDisplayRules(boolean rules) { m_displayRules = rules; } /** * Gets whether rules are being printed * * @return true if rules are being printed */ public boolean getDisplayRules() { return m_displayRules; } /** * Parses the options for this object. * * Valid options are: <p> * * -S num <br> * Number of fully expanded non improving subsets to consider * before terminating a best first search. * (Default = 5) <p> * * -X num <br> * Use cross validation to evaluate features. Use number of folds = 1 for * leave one out CV. (Default = leave one out CV) <p> * * -I <br> * Use nearest neighbour instead of global table majority. <p> * * -R <br> * Prints the decision table. <p> * * @param options the list of options as an array of strings * @exception Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { String optionString; resetOptions(); optionString = Utils.getOption('X',options); if (optionString.length() != 0) { m_CVFolds = Integer.parseInt(optionString); } optionString = Utils.getOption('S',options); if (optionString.length() != 0) { m_maxStale = Integer.parseInt(optionString); } m_useIBk = Utils.getFlag('I',options); m_displayRules = Utils.getFlag('R',options); } /** * Gets the current settings of the classifier. * * @return an array of strings suitable for passing to setOptions */ public String [] getOptions() { String [] options = new String [7]; int current = 0; options[current++] = "-X"; options[current++] = "" + m_CVFolds; options[current++] = "-S"; options[current++] = "" + m_maxStale; if (m_useIBk) { options[current++] = "-I"; } if (m_displayRules) { options[current++] = "-R"; } while (current < options.length) { options[current++] = ""; } return options; } /** * Generates the classifier. * * @param data set of instances serving as training data * @exception Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances data) throws Exception { int i; m_rr = new Random(1); m_theInstances = new Instances(data); m_theInstances.deleteWithMissingClass(); if (m_theInstances.numInstances() == 0) { throw new Exception("No training instances without missing class!"); } if (m_theInstances.checkForStringAttributes()) { throw new Exception("Can't handle string attributes!"); } m_disTransform = new DiscretizeFilter(); if (m_theInstances.classAttribute().isNumeric()) { m_classIsNominal = false; // use binned discretisation if the class is numeric m_disTransform.setUseMDL(false); m_disTransform.setBins(10); m_disTransform.setInvertSelection(true); // Discretize all attributes EXCEPT the class String rangeList = ""; rangeList+=(m_theInstances.classIndex()+1); //System.out.println("The class col: "+m_theInstances.classIndex()); m_disTransform.setAttributeIndices(rangeList); } else { m_disTransform.setUseBetterEncoding(true); m_classIsNominal = true; } m_disTransform.setInputFormat(m_theInstances); m_theInstances = Filter.useFilter(m_theInstances, m_disTransform); m_numAttributes = m_theInstances.numAttributes(); m_numInstances = m_theInstances.numInstances(); m_majority = m_theInstances.meanOrMode(m_theInstances.classAttribute()); best_first(); // reduce instances to selected features m_delTransform = new AttributeFilter(); m_delTransform.setInvertSelection(true); // set features to keep m_delTransform.setAttributeIndicesArray(m_decisionFeatures); m_delTransform.setInputFormat(m_theInstances); m_theInstances = Filter.useFilter(m_theInstances, m_delTransform); // reset the number of attributes m_numAttributes = m_theInstances.numAttributes(); // create hash table m_entries = new Hashtable((int)(m_theInstances.numInstances() * 1.5)); // insert instances into the hash table for (i=0;i<m_numInstances;i++) { Instance inst = m_theInstances.instance(i); insertIntoTable(inst, null); } // Replace the global table majority with nearest neighbour? if (m_useIBk) { m_ibk = new IBk(); m_ibk.buildClassifier(m_theInstances); } // Save memory m_theInstances = new Instances(m_theInstances, 0); } /** * Calculates the class membership probabilities for the given * test instance. * * @param instance the instance to be classified * @return predicted class probability distribution * @exception Exception if distribution can't be computed */ public double [] distributionForInstance(Instance instance) throws Exception { hashKey thekey; double [] tempDist; double [] normDist; m_disTransform.input(instance); m_disTransform.batchFinished(); instance = m_disTransform.output(); m_delTransform.input(instance); m_delTransform.batchFinished(); instance = m_delTransform.output(); thekey = new hashKey(instance, instance.numAttributes()); // if this one is not in the table if ((tempDist = (double [])m_entries.get(thekey)) == null) { if (m_useIBk) { tempDist = m_ibk.distributionForInstance(instance); } else { if (!m_classIsNominal) { tempDist = new double[1]; tempDist[0] = m_majority; } else { tempDist = new double [m_theInstances.classAttribute().numValues()]; tempDist[(int)m_majority] = 1.0; } } } else { if (!m_classIsNominal) { normDist = new double[1]; normDist[0] = (tempDist[0] / tempDist[1]); tempDist = normDist; } else { // normalise distribution normDist = new double [tempDist.length]; System.arraycopy(tempDist,0,normDist,0,tempDist.length); Utils.normalize(normDist); tempDist = normDist; } } return tempDist; } /** * Returns a string description of the features selected * * @return a string of features */ public String printFeatures() { int i; String s = ""; for (i=0;i<m_decisionFeatures.length;i++) { if (i==0) { s = ""+(m_decisionFeatures[i]+1); } else { s += ","+(m_decisionFeatures[i]+1); } } return s; } /** * Returns the number of rules * @return the number of rules */ public double measureNumRules() { return m_entries.size(); } /** * Returns an enumeration of the additional measure names * @return an enumeration of the measure names */ public Enumeration enumerateMeasures() { Vector newVector = new Vector(1); newVector.addElement("measureNumRules"); return newVector.elements(); } /** * Returns the value of the named measure * @param measureName the name of the measure to query for its value * @return the value of the named measure * @exception IllegalArgumentException if the named measure is not supported */ public double getMeasure(String additionalMeasureName) { if (additionalMeasureName.compareTo("measureNumRules") == 0) { return measureNumRules(); } else { throw new IllegalArgumentException(additionalMeasureName + " not supported (DecisionTable)"); } } /** * Returns a description of the classifier. * * @return a description of the classifier as a string. */ public String toString() { if (m_entries == null) { return "Decision Table: No model built yet."; } else { StringBuffer text = new StringBuffer(); text.append("Decision Table:"+ "\n\nNumber of training instances: "+m_numInstances+ "\nNumber of Rules : "+m_entries.size()+"\n"); if (m_useIBk) { text.append("Non matches covered by IB1.\n"); } else { text.append("Non matches covered by Majority class.\n"); } text.append("Best first search for feature set,\nterminated after "+ m_maxStale+" non improving subsets.\n"); text.append("Evaluation (for feature selection): CV "); if (m_CVFolds > 1) { text.append("("+m_CVFolds+" fold) "); } else { text.append("(leave one out) "); } text.append("\nFeature set: "+printFeatures()); if (m_displayRules) { // find out the max column width int maxColWidth = 0; for (int i=0;i<m_theInstances.numAttributes();i++) { if (m_theInstances.attribute(i).name().length() > maxColWidth) { maxColWidth = m_theInstances.attribute(i).name().length(); } if (m_classIsNominal || (i != m_theInstances.classIndex())) { Enumeration e = m_theInstances.attribute(i).enumerateValues(); while (e.hasMoreElements()) { String ss = (String)e.nextElement(); if (ss.length() > maxColWidth) { maxColWidth = ss.length(); } } } } text.append("\n\nRules:\n"); StringBuffer tm = new StringBuffer(); for (int i=0;i<m_theInstances.numAttributes();i++) { if (m_theInstances.classIndex() != i) { int d = maxColWidth - m_theInstances.attribute(i).name().length(); tm.append(m_theInstances.attribute(i).name()); for (int j=0;j<d+1;j++) { tm.append(" "); } } } tm.append(m_theInstances.attribute(m_theInstances.classIndex()).name()+" "); for (int i=0;i<tm.length()+10;i++) { text.append("="); } text.append("\n"); text.append(tm); text.append("\n"); for (int i=0;i<tm.length()+10;i++) { text.append("="); } text.append("\n"); Enumeration e = m_entries.keys(); while (e.hasMoreElements()) { hashKey tt = (hashKey)e.nextElement(); text.append(tt.toString(m_theInstances,maxColWidth)); double [] ClassDist = (double []) m_entries.get(tt); if (m_classIsNominal) { int m = Utils.maxIndex(ClassDist); try { text.append(m_theInstances.classAttribute().value(m)+"\n"); } catch (Exception ee) { System.out.println(ee.getMessage()); } } else { text.append((ClassDist[0] / ClassDist[1])+"\n"); } } for (int i=0;i<tm.length()+10;i++) { text.append("="); } text.append("\n"); text.append("\n"); } return text.toString(); } } /** * Main method for testing this class. * * @param argv the command-line options */ public static void main(String [] argv) { Classifier scheme; try { scheme = new DecisionTable(); System.out.println(Evaluation.evaluateModel(scheme,argv)); } catch (Exception e) { e.printStackTrace(); System.out.println(e.getMessage()); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -