⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 c45rule_pane.java

📁 决策树 C45Rule-PANE算法 解决了决策的问题,是从QUILAN算法修改而成
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
   * Call the function to set the MinVal and Range arrays.   * So each value of Attribute i is within the range of Min[i]~(Min[i]+Range[i])   * @param dataset: the dataset to be processed   */  private void setBaseAndRange(Instances dataset)  {    if(dataset!=null)    {      double min=Double.POSITIVE_INFINITY;      double max=Double.NEGATIVE_INFINITY;      double value;      m_dwAttMinValArr = new double[dataset.numAttributes()];      m_dwAttValRangeArr = new double[dataset.numAttributes()];      for(int j = 0; j < dataset.numAttributes(); j++)      {        min = Double.POSITIVE_INFINITY;        max = Double.NEGATIVE_INFINITY;        for(int i = 0; i < dataset.numInstances(); i++)        {          if(!dataset.instance(i).isMissing(j))          {             value = dataset.instance(i).value(j);             if(value < min)               min = value;             if(value > max)               max = value;          }        }        m_dwAttMinValArr[j] = min;        m_dwAttValRangeArr[j] = max - min;      }    }  }  //==============================================================================  /**   * The main function for testing and comparing Partial C4.5Rule and C4.5Rule-PANE   * @param args:  specify the dataset path following the pattern   *               " -t <dataset full path> "   */  public static void main(String[] args)  {    Instances dataset;    Random rand = new Random(System.currentTimeMillis());    try    {      BufferedReader reader = new BufferedReader(new  FileReader(args[1]));      dataset = new Instances(reader);      dataset.setClassIndex(dataset.numAttributes()-1);      dataset.randomize(rand);      /** Build a PART to test the error rate via 10 fold cross-validation */      System.out.println("Pure C4.5Rule:");      Evaluation eval = new Evaluation(dataset);      eval.crossValidateModel(new PART(),dataset,10);      System.out.println(eval.toSummaryString());      /** Build a C4.5 Rule-PANE to test the error rate via 10 fold cross-validation */      System.out.println("\nC4.5Rule-PANE:");      eval = new Evaluation(dataset);      C45Rule_PANE c45_rule = new C45Rule_PANE();      c45_rule.setPrintBaggingProcess(true);      eval.crossValidateModel(c45_rule,dataset,10);      System.out.println(eval.toSummaryString());    }    catch(Exception e)    {      System.out.println(e.getMessage());      e.printStackTrace();      System.exit(1);    }  }}////////////////////////////////////////////////////////////////////////////////////////** * This class extends the weka.classifier.Bagging * It sets the base learner to Neural Network automatically * The default arguments for base learner is: * --The Number of Hidden unit: 10 * --The learning rate:  0.3 * --The momentum:  0.2 * --The maximum training epochs: 500 * --The validation threshold for avoiding overfitting: 5 * --The number of NNs for bagging: 10 * * All the argument can be set by calling the method 'setArgs' * * REQUIREMENT: The whole packages of weka. */class BaggingNN extends Bagging{  /** The flag whether to print the classifier-building process */  private boolean m_bPrintBuildingProcess = false;  /** The number of the Hidden Units */  private int m_iHiddenUnits = 10;  /** The Learning rate for Neural Networks */  private double m_dwLearningRate = .3;  /** The momentum for the Neural Networks */  private double m_dwMomentum = .2;  /** The training epochs */  private int m_iEpochs = 500;  /** The validation threshold for training NN, avoiding overfitting */  private int m_iValidationThreshold = 5;  /** The training set */  private Instances m_trainingSet;  /**   *  The Constructor   */  public BaggingNN()  {    setClassifier(new NeuralNetwork());  }  /**   * Call the function to set the arguments for bagging the NNs   * @param NNs:  the number of neural networks in ensemble via Baggging   * @param hidUnit: the number of hidden units   * @param learningrate: the learning rate   * @param momentum: the momentum   * @param epochs: the training epochs   */  public void setArgs(int iNNs, int hidUnit,                      double learningrate, double momentum,                      int epochs, int threshold)  {    setNumIterations(iNNs);    m_iHiddenUnits = hidUnit;    m_dwLearningRate = learningrate;    m_dwMomentum = momentum;    m_iEpochs = epochs;    m_iValidationThreshold = threshold;  }  /**   * Call this function to set the flag for printing the classifier-building process   * @param b: true for printing the process, false otherwise   */  public void setPrintBaggingProcess(boolean b)  {    m_bPrintBuildingProcess = b;  }  /**   * Call this function to get the number of hidden units   * @return: the number of hidden units   */  public int numHiddenUnits()  {    return m_iHiddenUnits;  }  /**   * Call this function to get the momentum for training a neural network   * @return: the momentum   */  public double getMomentum()  {    return m_dwMomentum;  }  /**   * Call this function to get the learning rate for training a neural network   * @return: the learning rate   */  public double getLearingRate()  {    return m_dwLearningRate;  }  /**   * Call this function to get the validation threshold specified to avoid overfitting   * @return: the validation threshold   */  public int getValidationThreshold()  {    return m_iValidationThreshold;  }  /**   * Call this function to get the training epochs for a neural network   * @return: the training epochs   */  public int getEpochs()  {    return m_iEpochs;  }  /**   * Call the function to get the flag whether to print the neural network building   * process in training or not   * @return: the flag. True for printing, otherwise false.   */  public boolean isPrintBuildingProcess()  {    return m_bPrintBuildingProcess;  }  /**   * Call the function to build the neural network ensemble via Bagging   * @param data: - the training set   * @throws java.lang.Exception   */  public void buildClassifier(Instances data) throws java.lang.Exception  {    /**@todo Override this weka.classifiers.Bagging method*/    m_trainingSet = data;    /** Set arguments for the base learner */    ((NeuralNetwork)m_Classifier).setHiddenLayers(""+m_iHiddenUnits);    ((NeuralNetwork)m_Classifier).setLearningRate(m_dwLearningRate);    ((NeuralNetwork)m_Classifier).setMomentum(m_dwMomentum);    ((NeuralNetwork)m_Classifier).setTrainingTime(m_iEpochs);    ((NeuralNetwork)m_Classifier).setValidationThreshold(m_iValidationThreshold);    /** copy the base learner */    m_Classifiers = Classifier.makeCopies(m_Classifier, m_NumIterations);    int bagSize = data.numInstances() * m_BagSizePercent / 100;    Random random = new Random(m_Seed);    /** Train each base learners via bagging */    for(int j = 0; j < m_Classifiers.length; j++)    {      Instances temp = new Instances(data, bagSize);      boolean[] bInBag = new boolean[data.numInstances()];  //keep an eye on instances been selected      for(int i = 0; i < data.numInstances(); i++)        bInBag[i] = false;      int count = 0;        //counter for the number of distinct instances in the bag      /** generating the bagging trainingset */      while(temp.numInstances() < bagSize)      {        int i = (int) (random.nextDouble() *                   (double) data.numInstances());        temp.add(data.instance(i));        if(bInBag[i] == false)        {          count++;          bInBag[i] = true;        }      }      /** Since the NN training algorithm in weka utilize several instances located in the       * front of trainingset as the validation set according to some percentage, we make       * the instances that have not been in the bag as the validation set. So we have to       * put them in the front of the training set, set the percentage, then add the       * instances in the bag to the rear to form the training set feed to NN training       * algorithm in weka.       */      Instances bagData = new Instances(data, data.numInstances()-count);      for(int i = 0; i < data.numInstances(); i++)        if(bInBag[i] == false)          bagData.add(data.instance(i));      for(int i = 0; i < temp.numInstances(); i++)        bagData.add(temp.instance(i));      NeuralNetwork net = (NeuralNetwork)m_Classifiers[j];      net.setValidationSetSize((data.numInstances()-count)* 100 / bagData.numInstances());      if(m_bPrintBuildingProcess)        System.out.println("Building Neural Network No."+(j+1));      net.buildClassifier(bagData);    }  }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -