📄 nec45.java
字号:
package nec45;/** * Description: Use NeC4.5 to generate classification trees. * * Reference: Z.-H. Zhou and Y. Jiang. NeC4.5: neural ensemble based C4.5. IEEE * Transactions on Knowledge and Data Engineering, 2004, 16(6): 770-773. * * ATTN: This package is free for academic usage. You can run it at your own risk. * For other purposes, please contact Prof. Zhi-Hua Zhou (zhouzh@nju.edu.cn). * * Requirement: To use this package, the whole WEKA environment must be available. * Refer: I.H. Witten and E. Frank. Data Mining: Practical Machine Learning * Tools and Techniques with Java Implementations. Morgan Kaufmann, * San Francisco, CA, 2000. * * Data format: Both the input and output formats are the same as those used by WEKA. * * ATTN2: This package was developed by Mr. Ming Li (liming@ai.nju.edu.cn). There * is a ReadMe file provided for roughly explaining the codes. But for any * problem concerning the code, please feel free to contact with Mr. Li. * Please note that since the C4.5 decision tree routine used by this * package is its reimplementation in WEKA, the performance of the package * would be slightly different from that reported in the paper. * */import java.io.*;import java.util.*;import weka.core.*;import weka.classifiers.*;import weka.classifiers.neural.*;import weka.classifiers.j48.*;/** The default extra data ratio is 1, which can be modified along as well as * other arguments by calling 'setArgs'. */public class NeC45 extends Classifier{ /** The flag whether to print the classifier-building process */ private boolean m_bPrintBuildingProcess = false; /** The C4.5 classifier trained on the preprocessed data set */ private J48 m_decisionTree; /** The object of neural network ensemble via bagging */ private BaggingNN m_baggingNN = new BaggingNN(); /** The original training set */ private Instances m_originalDataSet; /** The Dataset generated by N*, the neural network ensemble, to train a C4.5 decision tree */ private Instances m_trainingSet; /** The ratio for extra data, the default value is 1 */ private int m_iMiu = 1; /** The minimal value of each attribute */ private double[] m_dwAttMinValArr; /**The value range of each attribute */ private double[] m_dwAttValRangeArr; /** The random object */ Random m_rand = new Random(0); //========================================================================= /** * The Constructors */ public NeC45() { } //========================================================================= /** * The function to change the arguments of C4.5 * @param miu: the extra data ratio * @param NNs: the number of neural networks in ensemble via Baggging * @param hidUnit: the number of hidden units * @param learningrate: the learning rate * @param momentum: the momentum * @param threshold: the threshold on validation set for training NN to avoid overfitting */ public void setArgs(int miu, int NNs, int hidUnit, double learningrate, double momentum, int epoch, int threshold) { m_iMiu = miu; m_baggingNN.setArgs(NNs,hidUnit, learningrate, momentum,epoch, threshold); } /** * Call this function to get the extra data ratio * @return: the extra data ratio, miu */ public int getExtraDataRate() { return m_iMiu; } /** * Call this function to get the number of hidden units * @return: the number of hidden units */ public int numHiddenUnits() { return m_baggingNN.numHiddenUnits(); } /** * Call this function to get the momentum for training a neural network * @return: the momentum */ public double getMomentum() { return m_baggingNN.getMomentum(); } /** * Call this function to get the learning rate for training a neural network * @return: the learning rate */ public double getLearingRate() { return m_baggingNN.getLearingRate(); } /** * Call this function to get the validation threshold specified to avoid overfitting * @return: the validation threshold */ public int getValidationThreshold() { return m_baggingNN.getValidationThreshold(); } /** * Call this function to get the training epochs for a neural network * @return: the training epochs */ public int getEpochs() { return m_baggingNN.getEpochs(); } /** * Call the function to get the flag whether to print the neural network building * process in training or not * @return: The printing flag. true for printing, otherwise false. */ public boolean isPrintBuildingProcess() { return m_bPrintBuildingProcess; } /** * Call this function to build a NeC45 Classifier * @param i: the original training set * @throws Exception: some exception */ public void buildClassifier(Instances i) throws Exception { m_originalDataSet = i; setBaseAndRange(i); m_trainingSet = new Instances(i); /* train a neural network ensemble N* from S via Bagging */ if(m_bPrintBuildingProcess) System.out.println("train a neural network ensemble N* from S via Bagging"); m_baggingNN.buildClassifier(i); /* process the original training set with the trained ensemble */ if(m_bPrintBuildingProcess) System.out.println("process the original training set with the trained ensemble "); int tsOriginalSize = m_trainingSet.numInstances(); for(int instanceId = 0; instanceId < tsOriginalSize; instanceId++) { Instance ins = m_trainingSet.instance(instanceId); double category = m_baggingNN.classifyInstance(ins); ins.setClassValue(category); } /* generate extra training data from the trained ensemble */ if(m_bPrintBuildingProcess && m_iMiu!=0) System.out.println("generate extra training data from the trained ensemble "); for(int instanceId = 0; instanceId < tsOriginalSize*m_iMiu; instanceId++) { Instance ins = randomGenerateInstance(); //generate a random feature vector double category = m_baggingNN.classifyInstance(ins); ins.setClassValue(category); m_trainingSet.add(ins); } /* grow a C4.5 decision tree from the new training set */ if(m_bPrintBuildingProcess) System.out.println("grow a C4.5 decision tree from the new training set "); m_decisionTree = new J48(); m_decisionTree.setUnpruned(true); m_decisionTree.buildClassifier(m_trainingSet); } /** * Call this function to set the flag for printing the classifier-building process * @param b: true for printing the process, false otherwise */ public void setPrintBaggingProcess(boolean b) { m_bPrintBuildingProcess = b; } /** * Call this function to classify an instance with the built NeC4.5 model * @param ins: the instance to be classified * @return: the class value of the instance * @throws Exception: if the NeC4.5 model has not been trained before using this method */ public double classifyInstance(Instance ins) throws Exception { if(m_decisionTree == null) throw new Exception("NO Classification Modal has been trained!"); return m_decisionTree.classifyInstance(ins); } /** * Returns a string describing the model. * @return: a string describing the model. */ public String toString() { return m_decisionTree.toString(); } /** * Call this function to get the measure of the decision tree size * @return:the size of the tree */ public double measureSize() { return m_decisionTree.measureTreeSize(); } /** * Call this function the generate the feature vector randomly * Note: each attribute value is within the range of original dataset * @return: randomly generated feature vector. */ private Instance randomGenerateInstance() { Instance ins = new Instance(m_trainingSet.numAttributes()); ins.setDataset(m_trainingSet); for(int j = 0; j < m_trainingSet.numAttributes(); j++) { if(this.m_originalDataSet.attribute(j).isNominal()) { int ra = Math.abs(m_rand.nextInt()); int iRandval = ra % (int)(m_dwAttValRangeArr[j]+1); ins.setValue(j, m_dwAttMinValArr[j]+iRandval); } else { double dwRandval = m_rand.nextDouble() * m_dwAttValRangeArr[j]; ins.setValue(j, m_dwAttMinValArr[j]+dwRandval); } } return ins; } /** * Call this function to set the random seed for building the Random Object * @param seed: the seed */ private void setRandomSeed(long seed) { m_rand.setSeed(seed); } /** * Call the function to set the MinVal and Range arrays. * So each value of Attribute i is within the range of Min[i]~(Min[i]+Range[i]) * @param dataset: the dataset to be processed */ private void setBaseAndRange(Instances dataset) {
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -