📄 ridor.java
字号:
fstInfoGain = Utils.eq(fstAccuRate, 0) ? 0 : (fstAccu*(Utils.log2(fstAccuRate) - Utils.log2(defAcRt))); sndInfoGain = Utils.eq(sndAccuRate, 0) ? 0 : (sndAccu*(Utils.log2(sndAccuRate) - Utils.log2(defAcRt))); if(Utils.gr(fstInfoGain,sndInfoGain) || (Utils.eq(fstInfoGain,sndInfoGain)&&(Utils.grOrEq(fstAccuRate,sndAccuRate)))){ isFirst = true; infoGain = fstInfoGain; accRate = fstAccuRate; accurate = fstAccu; coverage = fstCover; } else{ isFirst = false; infoGain = sndInfoGain; accRate = sndAccuRate; accurate = sndAccu; coverage = sndCover; } boolean isUpdate = Utils.gr(infoGain, maxInfoGain); /* Check whether so far the max infoGain */ if(isUpdate){ splitPoint = (data.instance(split).value(att) + data.instance(prev).value(att))/2; value = ((isFirst) ? 0 : 1); accuRate = accRate; accu = accurate; cover = coverage; maxInfoGain = infoGain; finalSplit = split; } prev=split; } } /* Split the data */ Instances[] splitData = new Instances[2]; splitData[0] = new Instances(data, 0, finalSplit); splitData[1] = new Instances(data, finalSplit, total-finalSplit); return splitData; } /** * Whether the instance is covered by this antecedent * * @param inst the instance in question * @return the boolean value indicating whether the instance is covered * by this antecedent */ public boolean isCover(Instance inst){ boolean isCover=false; if(!inst.isMissing(att)){ if(Utils.eq(value, 0)){ if(Utils.smOrEq(inst.value(att), splitPoint)) isCover=true; } else if(Utils.gr(inst.value(att), splitPoint)) isCover=true; } return isCover; } /** * Prints this antecedent * * @return a textual description of this antecedent */ public String toString() { String symbol = Utils.eq(value, 0.0) ? " <= " : " > "; return (att.name() + symbol + Utils.doubleToString(splitPoint, 6)); } } /** * The antecedent with nominal attribute */ private class NominalAntd extends Antd{ /* The parameters of infoGain calculated for each attribute value */ private double[] accurate; private double[] coverage; private double[] infoGain; /* Constructor*/ public NominalAntd(Attribute a){ super(a); int bag = att.numValues(); accurate = new double[bag]; coverage = new double[bag]; infoGain = new double[bag]; } /** * Implements the splitData function. * This procedure is to split the data into bags according * to the nominal attribute value * The infoGain for each bag is also calculated. * * @param data the data to be split * @param defAcRt the default accuracy rate for data * @param cl the class label to be predicted * @return the array of data after split */ public Instances[] splitData(Instances data, double defAcRt, double cl){ int bag = att.numValues(); Instances[] splitData = new Instances[bag]; for(int x=0; x<bag; x++){ accurate[x] = coverage[x] = infoGain[x] = 0; splitData[x] = new Instances(data, data.numInstances()); } for(int x=0; x<data.numInstances(); x++){ Instance inst=data.instance(x); if(!inst.isMissing(att)){ int v = (int)inst.value(att); splitData[v].add(inst); coverage[v] += inst.weight(); if(Utils.eq(inst.classValue(), cl)) accurate[v] += inst.weight(); } } // Check if >=2 splits have more than the minimal data int count=0; for(int x=0; x<bag; x++){ double t = coverage[x]; if(Utils.grOrEq(t, m_MinNo)){ double p = accurate[x]; if(!Utils.eq(t, 0.0)) infoGain[x] = p *((Utils.log2(p/t)) - (Utils.log2(defAcRt))); ++count; } } if(count < 2) // Don't split return null; value = (double)Utils.maxIndex(infoGain); cover = coverage[(int)value]; accu = accurate[(int)value]; if(!Utils.eq(cover,0)) accuRate = accu / cover; else accuRate = 0; maxInfoGain = infoGain [(int)value]; return splitData; } /** * Whether the instance is covered by this antecedent * * @param inst the instance in question * @return the boolean value indicating whether the instance is covered * by this antecedent */ public boolean isCover(Instance inst){ boolean isCover=false; if(!inst.isMissing(att)){ if(Utils.eq(inst.value(att), value)) isCover=true; } return isCover; } /** * Prints this antecedent * * @return a textual description of this antecedent */ public String toString() { return (att.name() + " = " +att.value((int)value)); } } /** * Builds a ripple-down manner rule learner. * * @param data the training data * @exception Exception if classifier can't be built successfully */ public void buildClassifier(Instances instances) throws Exception { Instances data = new Instances(instances); if (data.checkForStringAttributes()) throw new UnsupportedAttributeTypeException("Cannot handle string attributes!"); if(Utils.eq(data.sumOfWeights(),0)) throw new Exception("No training data."); data.deleteWithMissingClass(); if(Utils.eq(data.sumOfWeights(),0)) throw new Exception("The class labels of all the training data are missing."); int numCl = data.numClasses(); m_Root = new Ridor_node(); m_Class = instances.classAttribute(); // The original class label if(!m_Class.isNominal()) throw new UnsupportedClassTypeException("Only nominal class, please."); int index = data.classIndex(); m_Cover = data.sumOfWeights(); /* Create a binary attribute */ FastVector binary_values = new FastVector(2); binary_values.addElement("otherClasses"); binary_values.addElement("defClass"); Attribute attr = new Attribute ("newClass", binary_values); data.insertAttributeAt(attr, index); data.setClassIndex(index); // The new class label /* Partition the data into bags according to their original class values */ Instances[] dataByClass = new Instances[numCl]; for(int i=0; i < numCl; i++) dataByClass[i] = new Instances(data, data.numInstances()); // Empty bags for(int i=0; i < data.numInstances(); i++){ // Partitioning Instance inst = data.instance(i); inst.setClassValue(0); // Set new class vaue to be 0 dataByClass[(int)inst.value(index+1)].add(inst); } for(int i=0; i < numCl; i++) dataByClass[i].deleteAttributeAt(index+1); // Delete original class m_Root.findRules(dataByClass, 0); } /** * Classify the test instance with the rule learner * * @param instance the instance to be classified * @return the classification */ public double classifyInstance(Instance datum){ return classify(m_Root, datum); } /** * Classify the test instance with one node of Ridor * * @param node the node of Ridor to classify the test instance * @param instance the instance to be classified * @return the classification */ private double classify(Ridor_node node, Instance datum){ double classValue = node.getDefClass(); RidorRule[] rules = node.getRules(); if(rules != null){ Ridor_node[] excepts = node.getExcepts(); for(int i=0; i < excepts.length; i++){ if(rules[i].isCover(datum)){ classValue = classify(excepts[i], datum); break; } } } return classValue; } /** * Returns an enumeration describing the available options * Valid options are: <p> * * -F number <br> * Set number of folds for reduced error pruning. One fold is * used as the pruning set. (Default: 3) <p> * * -S number <br> * Set number of shuffles for randomization. (Default: 10) <p> * * -A <br> * Set flag of whether use the error rate of all the data to select * the default class in each step. If not set, the learner will only use * the error rate in the pruning data <p> * * -M <br> * Set flag of whether use the majority class as the default class * in each step instead of choosing default class based on the error rate * (if the flag is not set) <p> * * -N number <br> * Set the minimal weights of instances within a split. * (Default: 2) <p> * * @return an enumeration of all the available options */ public Enumeration listOptions() { Vector newVector = new Vector(5); newVector.addElement(new Option("\tSet number of folds for IREP\n" + "\tOne fold is used as pruning set.\n" + "\t(default 3)","F", 1, "-F <number of folds>")); newVector.addElement(new Option("\tSet number of shuffles to randomize\n" + "\tthe data in order to get better rule.\n" + "\t(default 10)","S", 1, "-S <number of shuffles>")); newVector.addElement(new Option("\tSet flag of whether use the error rate \n"+ "\tof all the data to select the default class\n"+ "\tin each step. If not set, the learner will only use"+ "\tthe error rate in the pruning data","A", 0, "-A")); newVector.addElement(new Option("\t Set flag of whether use the majority class as\n"+ "\tthe default class in each step instead of \n"+ "\tchoosing default class based on the error rate\n"+ "\t(if the flag is not set)","M", 0, "-M")); newVector.addElement(new Option("\tSet the minimal weights of instances\n" + "\twithin a split.\n" + "\t(default 2.0)","N", 1, "-N <min. weights>")); return newVector.elements(); } /** * Parses a given list of options. * * @param options the list of options as an array of strings * @exception Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { String numFoldsString = Utils.getOption('F', options); if (numFoldsString.length() != 0) m_Folds = Integer.parseInt(numFoldsString); else m_Folds = 3; String numShuffleString = Utils.getOption('S', options); if (numShuffleString.length() != 0) m_Shuffle = Integer.parseInt(numShuffleString); else m_Shuffle = 1; String seedString = Utils.getOption('s', options); if (seedString.length() != 0) m_Seed = Integer.parseInt(seedString); else m_Seed = 1; String minNoString = Utils.getOption('N', options); if (minNoString.length() != 0) m_MinNo = Double.parseDouble(minNoString); else m_MinNo = 2.0; m_IsAllErr = Utils.getFlag('A', options); m_IsMajority = Utils.getFlag('M', options); } /** * Gets the current settings of the Classifier. * * @return an array of strings suitable for passing to setOptions */ public String [] getOptions() { String [] options = new String [8]; int current = 0; options[current++] = "-F"; options[current++] = "" + m_Folds; options[current++] = "-S"; options[current++] = "" + m_Shuffle; options[current++] = "-N"; options[current++] = "" + m_MinNo; if(m_IsAllErr) options[current++] = "-A"; if(m_IsMajority) options[current++] = "-M"; while (current < options.length) options[current++] = ""; return options; } /** Set and get members for parameters */ public void setFolds(int fold){ m_Folds = fold; } public int getFolds(){ return m_Folds; } public void setShuffle(int sh){ m_Shuffle = sh; } public int getShuffle(){ return m_Shuffle; } public void setSeed(int s){ m_Seed = s; } public int getSeed(){ return m_Seed; } public void setWholeDataErr(boolean a){ m_IsAllErr = a; } public boolean getWholeDataErr(){ return m_IsAllErr; } public void setMajorityClass(boolean m){ m_IsMajority = m; } public boolean getMajorityClass(){ return m_IsMajority; } public void setMinNo(double m){ m_MinNo = m; } public double getMinNo(){ return m_MinNo; } /** * Returns an enumeration of the additional measure names * @return an enumeration of the measure names */ public Enumeration enumerateMeasures() { Vector newVector = new Vector(1); newVector.addElement("measureNumRules"); return newVector.elements(); } /** * Returns the value of the named measure * @param measureName the name of the measure to query for its value * @return the value of the named measure * @exception IllegalArgumentException if the named measure is not supported */ public double getMeasure(String additionalMeasureName) { if (additionalMeasureName.compareTo("measureNumRules") == 0) return numRules(); else throw new IllegalArgumentException(additionalMeasureName+" not supported (Ripple down rule learner)"); } /** * Measure the number of rules in total in the model * * @return the number of rules */ private double numRules(){ int size = 0; if(m_Root != null) size = m_Root.size(); return (double)(size+1); // Add the default rule } /** * Prints the all the rules of the rule learner. * * @return a textual description of the classifier */ public String toString() { if (m_Root == null) return "RIpple DOwn Rule Learner(Ridor): No model built yet."; return ("RIpple DOwn Rule Learner(Ridor) rules\n"+ "--------------------------------------\n\n" + m_Root.toString() + "\nTotal number of rules (incl. the default rule): " + (int)numRules()); } /** * Main method. * * @param args the options for the classifier */ public static void main(String[] args) { try { System.out.println(Evaluation.evaluateModel(new Ridor(), args)); } catch (Exception e) { e.printStackTrace(); System.err.println(e.getMessage()); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -