📄 c45rule_pane.java
字号:
* Call the function to set the MinVal and Range arrays. * So each value of Attribute i is within the range of Min[i]~(Min[i]+Range[i]) * @param dataset: the dataset to be processed */ private void setBaseAndRange(Instances dataset) { if(dataset!=null) { double min=Double.POSITIVE_INFINITY; double max=Double.NEGATIVE_INFINITY; double value; m_dwAttMinValArr = new double[dataset.numAttributes()]; m_dwAttValRangeArr = new double[dataset.numAttributes()]; for(int j = 0; j < dataset.numAttributes(); j++) { min = Double.POSITIVE_INFINITY; max = Double.NEGATIVE_INFINITY; for(int i = 0; i < dataset.numInstances(); i++) { if(!dataset.instance(i).isMissing(j)) { value = dataset.instance(i).value(j); if(value < min) min = value; if(value > max) max = value; } } m_dwAttMinValArr[j] = min; m_dwAttValRangeArr[j] = max - min; } } } //============================================================================== /** * The main function for testing and comparing Partial C4.5Rule and C4.5Rule-PANE * @param args: specify the dataset path following the pattern * " -t <dataset full path> " */ public static void main(String[] args) { Instances dataset; Random rand = new Random(System.currentTimeMillis()); try { BufferedReader reader = new BufferedReader(new FileReader(args[1])); dataset = new Instances(reader); dataset.setClassIndex(dataset.numAttributes()-1); dataset.randomize(rand); /** Build a PART to test the error rate via 10 fold cross-validation */ System.out.println("Pure C4.5Rule:"); Evaluation eval = new Evaluation(dataset); eval.crossValidateModel(new PART(),dataset,10); System.out.println(eval.toSummaryString()); /** Build a C4.5 Rule-PANE to test the error rate via 10 fold cross-validation */ System.out.println("\nC4.5Rule-PANE:"); eval = new Evaluation(dataset); C45Rule_PANE c45_rule = new C45Rule_PANE(); c45_rule.setPrintBaggingProcess(true); eval.crossValidateModel(c45_rule,dataset,10); System.out.println(eval.toSummaryString()); } catch(Exception e) { System.out.println(e.getMessage()); e.printStackTrace(); System.exit(1); } }}////////////////////////////////////////////////////////////////////////////////////////** * This class extends the weka.classifier.Bagging * It sets the base learner to Neural Network automatically * The default arguments for base learner is: * --The Number of Hidden unit: 10 * --The learning rate: 0.3 * --The momentum: 0.2 * --The maximum training epochs: 500 * --The validation threshold for avoiding overfitting: 5 * --The number of NNs for bagging: 10 * * All the argument can be set by calling the method 'setArgs' * * REQUIREMENT: The whole packages of weka. */class BaggingNN extends Bagging{ /** The flag whether to print the classifier-building process */ private boolean m_bPrintBuildingProcess = false; /** The number of the Hidden Units */ private int m_iHiddenUnits = 10; /** The Learning rate for Neural Networks */ private double m_dwLearningRate = .3; /** The momentum for the Neural Networks */ private double m_dwMomentum = .2; /** The training epochs */ private int m_iEpochs = 500; /** The validation threshold for training NN, avoiding overfitting */ private int m_iValidationThreshold = 5; /** The training set */ private Instances m_trainingSet; /** * The Constructor */ public BaggingNN() { setClassifier(new NeuralNetwork()); } /** * Call the function to set the arguments for bagging the NNs * @param NNs: the number of neural networks in ensemble via Baggging * @param hidUnit: the number of hidden units * @param learningrate: the learning rate * @param momentum: the momentum * @param epochs: the training epochs */ public void setArgs(int iNNs, int hidUnit, double learningrate, double momentum, int epochs, int threshold) { setNumIterations(iNNs); m_iHiddenUnits = hidUnit; m_dwLearningRate = learningrate; m_dwMomentum = momentum; m_iEpochs = epochs; m_iValidationThreshold = threshold; } /** * Call this function to set the flag for printing the classifier-building process * @param b: true for printing the process, false otherwise */ public void setPrintBaggingProcess(boolean b) { m_bPrintBuildingProcess = b; } /** * Call this function to get the number of hidden units * @return: the number of hidden units */ public int numHiddenUnits() { return m_iHiddenUnits; } /** * Call this function to get the momentum for training a neural network * @return: the momentum */ public double getMomentum() { return m_dwMomentum; } /** * Call this function to get the learning rate for training a neural network * @return: the learning rate */ public double getLearingRate() { return m_dwLearningRate; } /** * Call this function to get the validation threshold specified to avoid overfitting * @return: the validation threshold */ public int getValidationThreshold() { return m_iValidationThreshold; } /** * Call this function to get the training epochs for a neural network * @return: the training epochs */ public int getEpochs() { return m_iEpochs; } /** * Call the function to get the flag whether to print the neural network building * process in training or not * @return: the flag. True for printing, otherwise false. */ public boolean isPrintBuildingProcess() { return m_bPrintBuildingProcess; } /** * Call the function to build the neural network ensemble via Bagging * @param data: - the training set * @throws java.lang.Exception */ public void buildClassifier(Instances data) throws java.lang.Exception { /**@todo Override this weka.classifiers.Bagging method*/ m_trainingSet = data; /** Set arguments for the base learner */ ((NeuralNetwork)m_Classifier).setHiddenLayers(""+m_iHiddenUnits); ((NeuralNetwork)m_Classifier).setLearningRate(m_dwLearningRate); ((NeuralNetwork)m_Classifier).setMomentum(m_dwMomentum); ((NeuralNetwork)m_Classifier).setTrainingTime(m_iEpochs); ((NeuralNetwork)m_Classifier).setValidationThreshold(m_iValidationThreshold); /** copy the base learner */ m_Classifiers = Classifier.makeCopies(m_Classifier, m_NumIterations); int bagSize = data.numInstances() * m_BagSizePercent / 100; Random random = new Random(m_Seed); /** Train each base learners via bagging */ for(int j = 0; j < m_Classifiers.length; j++) { Instances temp = new Instances(data, bagSize); boolean[] bInBag = new boolean[data.numInstances()]; //keep an eye on instances been selected for(int i = 0; i < data.numInstances(); i++) bInBag[i] = false; int count = 0; //counter for the number of distinct instances in the bag /** generating the bagging trainingset */ while(temp.numInstances() < bagSize) { int i = (int) (random.nextDouble() * (double) data.numInstances()); temp.add(data.instance(i)); if(bInBag[i] == false) { count++; bInBag[i] = true; } } /** Since the NN training algorithm in weka utilize several instances located in the * front of trainingset as the validation set according to some percentage, we make * the instances that have not been in the bag as the validation set. So we have to * put them in the front of the training set, set the percentage, then add the * instances in the bag to the rear to form the training set feed to NN training * algorithm in weka. */ Instances bagData = new Instances(data, data.numInstances()-count); for(int i = 0; i < data.numInstances(); i++) if(bInBag[i] == false) bagData.add(data.instance(i)); for(int i = 0; i < temp.numInstances(); i++) bagData.add(temp.instance(i)); NeuralNetwork net = (NeuralNetwork)m_Classifiers[j]; net.setValidationSetSize((data.numInstances()-count)* 100 / bagData.numInstances()); if(m_bPrintBuildingProcess) System.out.println("Building Neural Network No."+(j+1)); net.buildClassifier(bagData); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -