decisiontable.java

来自「Java 编写的多种数据挖掘算法 包括聚类、分类、预处理等」· Java 代码 · 共 1,531 行 · 第 1/3 页

JAVA
1,531
字号
  /**   * Gets the number of folds for cross validation   *   * @return the number of cross validation folds   */  public int getCrossVal() {    return m_CVFolds;  }  /**   * Returns the tip text for this property   * @return tip text for this property suitable for   * displaying in the explorer/experimenter gui   */  public String maxStaleTipText() {    return "Sets the number of non improving decision tables to consider "       + "before abandoning the search.";  }  /**   * Sets the number of non improving decision tables to consider   * before abandoning the search.   *   * @param stale the number of nodes   */  public void setMaxStale(int stale) {    m_maxStale = stale;  }  /**   * Gets the number of non improving decision tables   *   * @return the number of non improving decision tables   */  public int getMaxStale() {    return m_maxStale;  }  /**   * Returns the tip text for this property   * @return tip text for this property suitable for   * displaying in the explorer/experimenter gui   */  public String useIBkTipText() {    return "Sets whether IBk should be used instead of the majority class.";  }  /**   * Sets whether IBk should be used instead of the majority class   *   * @param ibk true if IBk is to be used   */  public void setUseIBk(boolean ibk) {    m_useIBk = ibk;  }    /**   * Gets whether IBk is being used instead of the majority class   *   * @return true if IBk is being used   */  public boolean getUseIBk() {    return m_useIBk;  }  /**   * Returns the tip text for this property   * @return tip text for this property suitable for   * displaying in the explorer/experimenter gui   */  public String displayRulesTipText() {    return "Sets whether rules are to be printed.";  }  /**   * Sets whether rules are to be printed   *   * @param rules true if rules are to be printed   */  public void setDisplayRules(boolean rules) {    m_displayRules = rules;  }    /**   * Gets whether rules are being printed   *   * @return true if rules are being printed   */  public boolean getDisplayRules() {    return m_displayRules;  }  /**   * Parses the options for this object. <p/>   *   <!-- options-start -->   * Valid options are: <p/>   *    * <pre> -S &lt;number of non improving nodes&gt;   *  Number of fully expanded non improving subsets to consider   *  before terminating a best first search.   *  Use in conjunction with -B. (Default = 5)</pre>   *    * <pre> -X &lt;number of folds&gt;   *  Use cross validation to evaluate features.   *  Use number of folds = 1 for leave one out CV.   *  (Default = leave one out CV)</pre>   *    * <pre> -I   *  Use nearest neighbour instead of global table majority.</pre>   *    * <pre> -R   *  Display decision table rules.   * </pre>   *    <!-- options-end -->   *   * @param options the list of options as an array of strings   * @throws Exception if an option is not supported   */  public void setOptions(String[] options) throws Exception {    String optionString;    resetOptions();    optionString = Utils.getOption('X',options);    if (optionString.length() != 0) {      m_CVFolds = Integer.parseInt(optionString);    }    optionString = Utils.getOption('S',options);    if (optionString.length() != 0) {      m_maxStale = Integer.parseInt(optionString);    }    m_useIBk = Utils.getFlag('I',options);    m_displayRules = Utils.getFlag('R',options);  }  /**   * Gets the current settings of the classifier.   *   * @return an array of strings suitable for passing to setOptions   */  public String [] getOptions() {    String [] options = new String [7];    int current = 0;    options[current++] = "-X"; options[current++] = "" + m_CVFolds;    options[current++] = "-S"; options[current++] = "" + m_maxStale;    if (m_useIBk) {      options[current++] = "-I";    }    if (m_displayRules) {      options[current++] = "-R";    }    while (current < options.length) {      options[current++] = "";    }    return options;  }  /**   * Returns default capabilities of the classifier.   *   * @return      the capabilities of this classifier   */  public Capabilities getCapabilities() {    Capabilities result = super.getCapabilities();    // attributes    result.enable(Capability.NOMINAL_ATTRIBUTES);    result.enable(Capability.NUMERIC_ATTRIBUTES);    result.enable(Capability.DATE_ATTRIBUTES);    result.enable(Capability.MISSING_VALUES);    // class    result.enable(Capability.NOMINAL_CLASS);    result.enable(Capability.NUMERIC_CLASS);    result.enable(Capability.DATE_CLASS);    result.enable(Capability.MISSING_CLASS_VALUES);        return result;  }    /**   * Generates the classifier.   *   * @param data set of instances serving as training data    * @throws Exception if the classifier has not been generated successfully   */  public void buildClassifier(Instances data) throws Exception {    // can classifier handle the data?    getCapabilities().testWithFail(data);    // remove instances with missing class    m_theInstances = new Instances(data);    m_theInstances.deleteWithMissingClass();        m_rr = new Random(1);    if (m_theInstances.classAttribute().isNumeric()) {      m_disTransform = new weka.filters.unsupervised.attribute.Discretize();      m_classIsNominal = false;            // use binned discretisation if the class is numeric      ((weka.filters.unsupervised.attribute.Discretize)m_disTransform).	setBins(10);      ((weka.filters.unsupervised.attribute.Discretize)m_disTransform).	setInvertSelection(true);            // Discretize all attributes EXCEPT the class       String rangeList = "";      rangeList+=(m_theInstances.classIndex()+1);      //System.out.println("The class col: "+m_theInstances.classIndex());            ((weka.filters.unsupervised.attribute.Discretize)m_disTransform).	setAttributeIndices(rangeList);    } else {      m_disTransform = new weka.filters.supervised.attribute.Discretize();      ((weka.filters.supervised.attribute.Discretize)m_disTransform).setUseBetterEncoding(true);      m_classIsNominal = true;    }    m_disTransform.setInputFormat(m_theInstances);    m_theInstances = Filter.useFilter(m_theInstances, m_disTransform);        m_numAttributes = m_theInstances.numAttributes();    m_numInstances = m_theInstances.numInstances();    m_majority = m_theInstances.meanOrMode(m_theInstances.classAttribute());        best_first();        // reduce instances to selected features    m_delTransform = new Remove();    m_delTransform.setInvertSelection(true);        // set features to keep    m_delTransform.setAttributeIndicesArray(m_decisionFeatures);     m_delTransform.setInputFormat(m_theInstances);    m_theInstances = Filter.useFilter(m_theInstances, m_delTransform);        // reset the number of attributes    m_numAttributes = m_theInstances.numAttributes();        // create hash table    m_entries = new Hashtable((int)(m_theInstances.numInstances() * 1.5));        // insert instances into the hash table    for (int i = 0; i < m_numInstances; i++) {      Instance inst = m_theInstances.instance(i);      insertIntoTable(inst, null);    }        // Replace the global table majority with nearest neighbour?    if (m_useIBk) {      m_ibk = new IBk();      m_ibk.buildClassifier(m_theInstances);    }        // Save memory    m_theInstances = new Instances(m_theInstances, 0);  }  /**   * Calculates the class membership probabilities for the given    * test instance.   *   * @param instance the instance to be classified   * @return predicted class probability distribution   * @throws Exception if distribution can't be computed   */  public double [] distributionForInstance(Instance instance)       throws Exception {    hashKey thekey;    double [] tempDist;    double [] normDist;    m_disTransform.input(instance);    m_disTransform.batchFinished();    instance = m_disTransform.output();    m_delTransform.input(instance);    m_delTransform.batchFinished();    instance = m_delTransform.output();    thekey = new hashKey(instance, instance.numAttributes(), false);        // if this one is not in the table    if ((tempDist = (double [])m_entries.get(thekey)) == null) {      if (m_useIBk) {	tempDist = m_ibk.distributionForInstance(instance);      } else {	if (!m_classIsNominal) {	  tempDist = new double[1];	  tempDist[0] = m_majority;	} else {	  tempDist = new double [m_theInstances.classAttribute().numValues()];	  tempDist[(int)m_majority] = 1.0;	}      }    } else {      if (!m_classIsNominal) {	normDist = new double[1];	normDist[0] = (tempDist[0] / tempDist[1]);	tempDist = normDist;      } else {		// normalise distribution	normDist = new double [tempDist.length];	System.arraycopy(tempDist,0,normDist,0,tempDist.length);	Utils.normalize(normDist);	tempDist = normDist;      }    }    return tempDist;  }  /**   * Returns a string description of the features selected   *   * @return a string of features   */  public String printFeatures() {    int i;    String s = "";       for (i=0;i<m_decisionFeatures.length;i++) {      if (i==0) {	s = ""+(m_decisionFeatures[i]+1);      } else {	s += ","+(m_decisionFeatures[i]+1);      }    }    return s;  }  /**   * Returns the number of rules   * @return the number of rules   */  public double measureNumRules() {    return m_entries.size();  }  /**   * Returns an enumeration of the additional measure names   * @return an enumeration of the measure names   */  public Enumeration enumerateMeasures() {    Vector newVector = new Vector(1);    newVector.addElement("measureNumRules");    return newVector.elements();  }  /**   * Returns the value of the named measure   * @param additionalMeasureName the name of the measure to query for its value   * @return the value of the named measure   * @throws IllegalArgumentException if the named measure is not supported   */  public double getMeasure(String additionalMeasureName) {    if (additionalMeasureName.compareToIgnoreCase("measureNumRules") == 0) {      return measureNumRules();    } else {      throw new IllegalArgumentException(additionalMeasureName 			  + " not supported (DecisionTable)");    }  }  /**   * Returns a description of the classifier.   *   * @return a description of the classifier as a string.   */  public String toString() {    if (m_entries == null) {      return "Decision Table: No model built yet.";    } else {      StringBuffer text = new StringBuffer();            text.append("Decision Table:"+		  "\n\nNumber of training instances: "+m_numInstances+		  "\nNumber of Rules : "+m_entries.size()+"\n");            if (m_useIBk) {	text.append("Non matches covered by IB1.\n");      } else {	text.append("Non matches covered by Majority class.\n");      }            text.append("Best first search for feature set,\nterminated after "+		  m_maxStale+" non improving subsets.\n");            text.append("Evaluation (for feature selection): CV ");      if (m_CVFolds > 1) {	text.append("("+m_CVFolds+" fold) ");      } else {	  text.append("(leave one out) ");      }      text.append("\nFeature set: "+printFeatures());            if (m_displayRules) {	// find out the max column width	int maxColWidth = 0;	for (int i=0;i<m_theInstances.numAttributes();i++) {	  if (m_theInstances.attribute(i).name().length() > maxColWidth) {	    maxColWidth = m_theInstances.attribute(i).name().length();	  }	  if (m_classIsNominal || (i != m_theInstances.classIndex())) {	    Enumeration e = m_theInstances.attribute(i).enumerateValues();	    while (e.hasMoreElements()) {	      String ss = (String)e.nextElement();	      if (ss.length() > maxColWidth) {		maxColWidth = ss.length();	      }	    }	  }	}	text.append("\n\nRules:\n");	StringBuffer tm = new StringBuffer();	for (int i=0;i<m_theInstances.numAttributes();i++) {	  if (m_theInstances.classIndex() != i) {	    int d = maxColWidth - m_theInstances.attribute(i).name().length();	    tm.append(m_theInstances.attribute(i).name());	    for (int j=0;j<d+1;j++) {	      tm.append(" ");	    }	  }	}	tm.append(m_theInstances.attribute(m_theInstances.classIndex()).name()+"  ");	for (int i=0;i<tm.length()+10;i++) {	  text.append("=");	}	text.append("\n");	text.append(tm);	text.append("\n");	for (int i=0;i<tm.length()+10;i++) {	  text.append("=");	}	text.append("\n");	Enumeration e = m_entries.keys();	while (e.hasMoreElements()) {	  hashKey tt = (hashKey)e.nextElement();	  text.append(tt.toString(m_theInstances,maxColWidth));	  double [] ClassDist = (double []) m_entries.get(tt);	  if (m_classIsNominal) {	    int m = Utils.maxIndex(ClassDist);	    try {	      text.append(m_theInstances.classAttribute().value(m)+"\n");	    } catch (Exception ee) {	      System.out.println(ee.getMessage());	    }	  } else {	    text.append((ClassDist[0] / ClassDist[1])+"\n");	  }	}		for (int i=0;i<tm.length()+10;i++) {	  text.append("=");	}	text.append("\n");	text.append("\n");      }      return text.toString();    }  }  /**   * Main method for testing this class.   *   * @param argv the command-line options   */  public static void main(String [] argv) {        Classifier scheme;        try {      scheme = new DecisionTable();      System.out.println(Evaluation.evaluateModel(scheme,argv));    }    catch (Exception e) {      e.printStackTrace();      System.out.println(e.getMessage());    }  }}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?