📄 decisiontable.java
字号:
* (default weka.attributeSelection.BestFirst)</pre> * * <pre> -X <number of folds> * Use cross validation to evaluate features. * Use number of folds = 1 for leave one out CV. * (Default = leave one out CV)</pre> * * <pre> -I * Use nearest neighbour instead of global table majority.</pre> * * <pre> -R * Display decision table rules. * </pre> * <!-- options-end --> * * @param options the list of options as an array of strings * @throws Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { String optionString; resetOptions(); optionString = Utils.getOption('X',options); if (optionString.length() != 0) { m_CVFolds = Integer.parseInt(optionString); } m_useIBk = Utils.getFlag('I',options); m_displayRules = Utils.getFlag('R',options); optionString = Utils.getOption('E', options); if (optionString.length() != 0) { if (optionString.equals("acc")) { setEvaluationMeasure(new SelectedTag(EVAL_ACCURACY, TAGS_EVALUATION)); } else if (optionString.equals("rmse")) { setEvaluationMeasure(new SelectedTag(EVAL_RMSE, TAGS_EVALUATION)); } else if (optionString.equals("mae")) { setEvaluationMeasure(new SelectedTag(EVAL_MAE, TAGS_EVALUATION)); } else if (optionString.equals("auc")) { setEvaluationMeasure(new SelectedTag(EVAL_AUC, TAGS_EVALUATION)); } else { throw new IllegalArgumentException("Invalid evaluation measure"); } } String searchString = Utils.getOption('S', options); if (searchString.length() == 0) searchString = weka.attributeSelection.BestFirst.class.getName(); String [] searchSpec = Utils.splitOptions(searchString); if (searchSpec.length == 0) { throw new IllegalArgumentException("Invalid search specification string"); } String searchName = searchSpec[0]; searchSpec[0] = ""; setSearch(ASSearch.forName(searchName, searchSpec)); } /** * Gets the current settings of the classifier. * * @return an array of strings suitable for passing to setOptions */ public String [] getOptions() { String [] options = new String [9]; int current = 0; options[current++] = "-X"; options[current++] = "" + m_CVFolds; if (m_evaluationMeasure != EVAL_DEFAULT) { options[current++] = "-E"; switch (m_evaluationMeasure) { case EVAL_ACCURACY: options[current++] = "acc"; break; case EVAL_RMSE: options[current++] = "rmse"; break; case EVAL_MAE: options[current++] = "mae"; break; case EVAL_AUC: options[current++] = "auc"; break; } } if (m_useIBk) { options[current++] = "-I"; } if (m_displayRules) { options[current++] = "-R"; } options[current++] = "-S"; options[current++] = "" + getSearchSpec(); while (current < options.length) { options[current++] = ""; } return options; } /** * Gets the search specification string, which contains the class name of * the search method and any options to it * * @return the search string. */ protected String getSearchSpec() { ASSearch s = getSearch(); if (s instanceof OptionHandler) { return s.getClass().getName() + " " + Utils.joinOptions(((OptionHandler)s).getOptions()); } return s.getClass().getName(); } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); // attributes result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.NUMERIC_ATTRIBUTES); result.enable(Capability.DATE_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); // class result.enable(Capability.NOMINAL_CLASS); if (m_evaluationMeasure != EVAL_ACCURACY && m_evaluationMeasure != EVAL_AUC) { result.enable(Capability.NUMERIC_CLASS); result.enable(Capability.DATE_CLASS); } result.enable(Capability.MISSING_CLASS_VALUES); return result; } /** * Sets up a dummy subset evaluator that basically just delegates * evaluation to the estimatePerformance method in DecisionTable */ protected void setUpEvaluator() { m_evaluator = new SubsetEvaluator () { public void buildEvaluator(Instances data) throws Exception { } public double evaluateSubset(BitSet subset) throws Exception { int fc = 0; for (int jj = 0;jj < m_numAttributes; jj++) { if (subset.get(jj)) { fc++; } } return estimatePerformance(subset, fc); } }; } protected boolean m_saveMemory = true; /** * Generates the classifier. * * @param data set of instances serving as training data * @throws Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances data) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(data); // remove instances with missing class m_theInstances = new Instances(data); m_theInstances.deleteWithMissingClass(); m_rr = new Random(1); if (m_theInstances.classAttribute().isNominal()) {// Set up class priors m_classPriorCounts = new double [data.classAttribute().numValues()]; Arrays.fill(m_classPriorCounts, 1.0); for (int i = 0; i <data.numInstances(); i++) { Instance curr = data.instance(i); m_classPriorCounts[(int)curr.classValue()] += curr.weight(); } m_classPriors = m_classPriorCounts.clone(); Utils.normalize(m_classPriors); } setUpEvaluator(); if (m_theInstances.classAttribute().isNumeric()) { m_disTransform = new weka.filters.unsupervised.attribute.Discretize(); m_classIsNominal = false; // use binned discretisation if the class is numeric ((weka.filters.unsupervised.attribute.Discretize)m_disTransform). setBins(10); ((weka.filters.unsupervised.attribute.Discretize)m_disTransform). setInvertSelection(true); // Discretize all attributes EXCEPT the class String rangeList = ""; rangeList+=(m_theInstances.classIndex()+1); //System.out.println("The class col: "+m_theInstances.classIndex()); ((weka.filters.unsupervised.attribute.Discretize)m_disTransform). setAttributeIndices(rangeList); } else { m_disTransform = new weka.filters.supervised.attribute.Discretize(); ((weka.filters.supervised.attribute.Discretize)m_disTransform).setUseBetterEncoding(true); m_classIsNominal = true; } m_disTransform.setInputFormat(m_theInstances); m_theInstances = Filter.useFilter(m_theInstances, m_disTransform); m_numAttributes = m_theInstances.numAttributes(); m_numInstances = m_theInstances.numInstances(); m_majority = m_theInstances.meanOrMode(m_theInstances.classAttribute()); // Perform the search int [] selected = m_search.search(m_evaluator, m_theInstances); m_decisionFeatures = new int [selected.length+1]; System.arraycopy(selected, 0, m_decisionFeatures, 0, selected.length); m_decisionFeatures[m_decisionFeatures.length-1] = m_theInstances.classIndex(); // reduce instances to selected features m_delTransform = new Remove(); m_delTransform.setInvertSelection(true); // set features to keep m_delTransform.setAttributeIndicesArray(m_decisionFeatures); m_delTransform.setInputFormat(m_theInstances); m_dtInstances = Filter.useFilter(m_theInstances, m_delTransform); // reset the number of attributes m_numAttributes = m_dtInstances.numAttributes(); // create hash table m_entries = new Hashtable((int)(m_dtInstances.numInstances() * 1.5)); // insert instances into the hash table for (int i = 0; i < m_numInstances; i++) { Instance inst = m_dtInstances.instance(i); insertIntoTable(inst, null); } // Replace the global table majority with nearest neighbour? if (m_useIBk) { m_ibk = new IBk(); m_ibk.buildClassifier(m_theInstances); } // Save memory if (m_saveMemory) { m_theInstances = new Instances(m_theInstances, 0); m_dtInstances = new Instances(m_dtInstances, 0); } } /** * Calculates the class membership probabilities for the given * test instance. * * @param instance the instance to be classified * @return predicted class probability distribution * @throws Exception if distribution can't be computed */ public double [] distributionForInstance(Instance instance) throws Exception { hashKey thekey; double [] tempDist; double [] normDist; m_disTransform.input(instance); m_disTransform.batchFinished(); instance = m_disTransform.output(); m_delTransform.input(instance); m_delTransform.batchFinished(); instance = m_delTransform.output(); thekey = new hashKey(instance, instance.numAttributes(), false); // if this one is not in the table if ((tempDist = (double [])m_entries.get(thekey)) == null) { if (m_useIBk) { tempDist = m_ibk.distributionForInstance(instance); } else { if (!m_classIsNominal) { tempDist = new double[1]; tempDist[0] = m_majority; } else { tempDist = m_classPriors.clone(); /*tempDist = new double [m_theInstances.classAttribute().numValues()]; tempDist[(int)m_majority] = 1.0; */ } } } else { if (!m_classIsNominal) { normDist = new double[1]; normDist[0] = (tempDist[0] / tempDist[1]); tempDist = normDist; } else { // normalise distribution normDist = new double [tempDist.length]; System.arraycopy(tempDist,0,normDist,0,tempDist.length); Utils.normalize(normDist); tempDist = normDist; } } return tempDist; } /** * Returns a string description of the features selected * * @return a string of features */ public String printFeatures() { int i; String s = ""; for (i=0;i<m_decisionFeatures.length;i++) { if (i==0) { s = ""+(m_decisionFeatures[i]+1); } else { s += ","+(m_decisionFeatures[i]+1); } } return s; } /** * Returns the number of rules * @return the number of rules */ public double measureNumRules() { return m_entries.size(); } /** * Returns an enumeration of the additional measure names * @return an enumeration of the measure names */ public Enumeration enumerateMeasures() { Vector newVector = new Vector(1); newVector.addElement("measureNumRules"); return newVector.elements(); } /** * Returns the value of the named measure * @param additionalMeasureName the name of the measure to query for its value * @return the value of the named measure * @throws IllegalArgumentException if the named measure is not supported */ public double getMeasure(String additionalMeasureName) { if (additionalMeasureName.compareToIgnoreCase("measureNumRules") == 0) { return measureNumRules(); } else { throw new IllegalArgumentException(additionalMeasureName + " not supported (DecisionTable)"); } } /** * Returns a description of the classifier. * * @return a description of the classifier as a string. */ public String toString() { if (m_entries == null) { return "Decision Table: No model built yet."; } else { StringBuffer text = new StringBuffer(); text.append("Decision Table:"+ "\n\nNumber of training instances: "+m_numInstances+ "\nNumber of Rules : "+m_entries.size()+"\n"); if (m_useIBk) { text.append("Non matches covered by IB1.\n"); } else { text.append("Non matches covered by Majority class.\n"); } text.append(m_search.toString()); /*text.append("Best first search for feature set,\nterminated after "+ m_maxStale+" non improving subsets.\n"); */ text.append("Evaluation (for feature selection): CV "); if (m_CVFolds > 1) { text.append("("+m_CVFolds+" fold) "); } else { text.append("(leave one out) "); } text.append("\nFeature set: "+printFeatures()); if (m_displayRules) { // find out the max column width int maxColWidth = 0; for (int i=0;i<m_dtInstances.numAttributes();i++) { if (m_dtInstances.attribute(i).name().length() > maxColWidth) { maxColWidth = m_dtInstances.attribute(i).name().length(); } if (m_classIsNominal || (i != m_dtInstances.classIndex())) { Enumeration e = m_dtInstances.attribute(i).enumerateValues(); while (e.hasMoreElements()) { String ss = (String)e.nextElement(); if (ss.length() > maxColWidth) { maxColWidth = ss.length(); } } } } text.append("\n\nRules:\n"); StringBuffer tm = new StringBuffer(); for (int i=0;i<m_dtInstances.numAttributes();i++) { if (m_dtInstances.classIndex() != i) { int d = maxColWidth - m_dtInstances.attribute(i).name().length(); tm.append(m_dtInstances.attribute(i).name()); for (int j=0;j<d+1;j++) { tm.append(" "); } } } tm.append(m_dtInstances.attribute(m_dtInstances.classIndex()).name()+" "); for (int i=0;i<tm.length()+10;i++) { text.append("="); } text.append("\n"); text.append(tm); text.append("\n"); for (int i=0;i<tm.length()+10;i++) { text.append("="); } text.append("\n"); Enumeration e = m_entries.keys(); while (e.hasMoreElements()) { hashKey tt = (hashKey)e.nextElement(); text.append(tt.toString(m_dtInstances,maxColWidth)); double [] ClassDist = (double []) m_entries.get(tt); if (m_classIsNominal) { int m = Utils.maxIndex(ClassDist); try { text.append(m_dtInstances.classAttribute().value(m)+"\n"); } catch (Exception ee) { System.out.println(ee.getMessage()); } } else { text.append((ClassDist[0] / ClassDist[1])+"\n"); } } for (int i=0;i<tm.length()+10;i++) { text.append("="); } text.append("\n"); text.append("\n"); } return text.toString(); } } /** * Main method for testing this class. * * @param argv the command-line options */ public static void main(String [] argv) { runClassifier(new DecisionTable(), argv); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -