⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 nec45.java

📁 NeC4.5 is a variant of C4.5 decision tree, which could generate decision trees more accurate than st
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
package nec45;/** * Description: Use NeC4.5 to generate classification trees. * * Reference:   Z.-H. Zhou and Y. Jiang. NeC4.5: neural ensemble based C4.5. IEEE *	     	Transactions on Knowledge and Data Engineering, 2004, 16(6): 770-773. * * ATTN:        This package is free for academic usage. You can run it at your own risk. *	     	For other purposes, please contact Prof. Zhi-Hua Zhou (zhouzh@nju.edu.cn). * * Requirement: To use this package, the whole WEKA environment must be available. *	        Refer: I.H. Witten and E. Frank. Data Mining: Practical Machine Learning *		Tools and Techniques with Java Implementations. Morgan Kaufmann, *		San Francisco, CA, 2000. * * Data format: Both the input and output formats are the same as those used by WEKA. * * ATTN2:       This package was developed by Mr. Ming Li (liming@ai.nju.edu.cn). There *		is a ReadMe file provided for roughly explaining the codes. But for any *		problem concerning the code, please feel free to contact with Mr. Li. *	     	Please note that since the C4.5 decision tree routine used by this *		package is its reimplementation in WEKA, the performance of the package *		would be slightly different from that reported in the paper. * */import java.io.*;import java.util.*;import weka.core.*;import weka.classifiers.*;import weka.classifiers.neural.*;import weka.classifiers.j48.*;/** The default extra data ratio is 1, which can be modified along as well as *  other arguments by calling 'setArgs'. */public class NeC45 extends Classifier{  /** The flag whether to print the classifier-building process */  private boolean m_bPrintBuildingProcess = false;  /** The C4.5 classifier trained on the preprocessed data set */  private J48 m_decisionTree;  /** The object of neural network ensemble via bagging */  private BaggingNN m_baggingNN = new BaggingNN();  /** The original training set */  private Instances m_originalDataSet;  /** The Dataset generated by N*, the neural network ensemble, to train a C4.5 decision tree */  private Instances m_trainingSet;  /** The ratio for extra data, the default value is 1 */  private int m_iMiu = 1;  /** The minimal value of each attribute */  private double[] m_dwAttMinValArr;  /**The value range of each attribute */  private double[] m_dwAttValRangeArr;  /** The random object */  Random m_rand = new Random(0);  //=========================================================================  /**   * The Constructors   */  public NeC45()  {  }  //=========================================================================  /**   * The function to change the arguments of C4.5   * @param miu: the extra data ratio   * @param NNs:  the number of neural networks in ensemble via Baggging   * @param hidUnit: the number of hidden units   * @param learningrate: the learning rate   * @param momentum: the momentum   * @param threshold: the threshold on validation set for training NN to avoid overfitting   */  public void setArgs(int miu, int NNs, int hidUnit,                      double learningrate, double momentum, int epoch, int threshold)  {    m_iMiu = miu;    m_baggingNN.setArgs(NNs,hidUnit, learningrate, momentum,epoch, threshold);  }  /**   * Call this function to get the extra data ratio   * @return: the extra data ratio, miu   */  public int getExtraDataRate()  {    return m_iMiu;  }  /**   * Call this function to get the number of hidden units   * @return: the number of hidden units   */  public int numHiddenUnits()  {    return m_baggingNN.numHiddenUnits();  }  /**   * Call this function to get the momentum for training a neural network   * @return: the momentum   */  public double getMomentum()  {    return m_baggingNN.getMomentum();  }  /**   * Call this function to get the learning rate for training a neural network   * @return: the learning rate   */  public double getLearingRate()  {    return m_baggingNN.getLearingRate();  }  /**   * Call this function to get the validation threshold specified to avoid overfitting   * @return: the validation threshold   */  public int getValidationThreshold()  {    return m_baggingNN.getValidationThreshold();  }  /**  * Call this function to get the training epochs for a neural network  * @return: the training epochs  */  public int getEpochs()  {    return m_baggingNN.getEpochs();  }  /**   * Call the function to get the flag whether to print the neural network building   * process in training or not   * @return: The printing flag. true for printing, otherwise false.   */  public boolean isPrintBuildingProcess()  {    return m_bPrintBuildingProcess;  }  /**   * Call this function to build a NeC45 Classifier   * @param i: the original training set   * @throws Exception: some exception   */  public void buildClassifier(Instances i) throws Exception  {    m_originalDataSet = i;    setBaseAndRange(i);    m_trainingSet = new Instances(i);    /* train a neural network ensemble N* from S via Bagging */    if(m_bPrintBuildingProcess)      System.out.println("train a neural network ensemble N* from S via Bagging");    m_baggingNN.buildClassifier(i);    /* process the original training set with the trained ensemble */    if(m_bPrintBuildingProcess)      System.out.println("process the original training set with the trained ensemble ");    int tsOriginalSize = m_trainingSet.numInstances();    for(int instanceId = 0; instanceId < tsOriginalSize; instanceId++)    {      Instance ins = m_trainingSet.instance(instanceId);      double category = m_baggingNN.classifyInstance(ins);      ins.setClassValue(category);    }    /* generate extra training data from the trained ensemble */    if(m_bPrintBuildingProcess && m_iMiu!=0)      System.out.println("generate extra training data from the trained ensemble ");    for(int instanceId = 0; instanceId < tsOriginalSize*m_iMiu; instanceId++)    {      Instance ins = randomGenerateInstance();    //generate a random feature vector      double category = m_baggingNN.classifyInstance(ins);      ins.setClassValue(category);      m_trainingSet.add(ins);    }    /* grow a C4.5 decision tree from the new training set */    if(m_bPrintBuildingProcess)      System.out.println("grow a C4.5 decision tree from the new training set ");    m_decisionTree = new J48();    m_decisionTree.setUnpruned(true);    m_decisionTree.buildClassifier(m_trainingSet);  }  /**   * Call this function to set the flag for printing the classifier-building process   * @param b: true for printing the process, false otherwise   */  public void setPrintBaggingProcess(boolean b)  {    m_bPrintBuildingProcess = b;  }  /**   * Call this function to classify an instance with the built NeC4.5 model   * @param ins: the instance to be classified   * @return: the class value of the instance   * @throws Exception: if the NeC4.5 model has not been trained before using this method   */  public double classifyInstance(Instance ins) throws Exception  {    if(m_decisionTree == null)      throw new Exception("NO Classification Modal has been trained!");    return m_decisionTree.classifyInstance(ins);  }  /**   * Returns a string describing the model.   * @return: a string describing the model.   */  public String toString()  {    return m_decisionTree.toString();  }  /**   * Call this function to get the measure of the decision tree size   * @return:the size of the tree   */  public double measureSize()  {    return m_decisionTree.measureTreeSize();  }  /**   * Call this function the generate the feature vector randomly   * Note: each attribute value is within the range of original dataset   * @return: randomly generated feature vector.   */  private Instance randomGenerateInstance()  {    Instance ins = new Instance(m_trainingSet.numAttributes());    ins.setDataset(m_trainingSet);    for(int j = 0; j < m_trainingSet.numAttributes(); j++)    {      if(this.m_originalDataSet.attribute(j).isNominal())      {        int ra = Math.abs(m_rand.nextInt());        int iRandval = ra % (int)(m_dwAttValRangeArr[j]+1);        ins.setValue(j, m_dwAttMinValArr[j]+iRandval);      }      else      {        double dwRandval = m_rand.nextDouble() * m_dwAttValRangeArr[j];        ins.setValue(j, m_dwAttMinValArr[j]+dwRandval);      }    }    return ins;  }  /**   * Call this function to set the random seed for building the Random Object   * @param seed: the seed   */  private void setRandomSeed(long seed)  {    m_rand.setSeed(seed);  }  /**   * Call the function to set the MinVal and Range arrays.   * So each value of Attribute i is within the range of Min[i]~(Min[i]+Range[i])   * @param dataset: the dataset to be processed   */  private void setBaseAndRange(Instances dataset)  {

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -