📄 ridor.java
字号:
m_ClassAttribute = instances.classAttribute(); if (!m_ClassAttribute.isNominal()) throw new UnsupportedClassTypeException(" Only nominal class, please."); if(instances.numClasses() != 2) throw new Exception(" Only 2 classes, please."); Instances data = new Instances(instances); if(Utils.eq(data.sumOfWeights(),0)) throw new Exception(" No training data."); data.deleteWithMissingClass(); if(Utils.eq(data.sumOfWeights(),0)) throw new Exception(" The class labels of all the training data are missing."); if(data.numInstances() < m_Folds) throw new Exception(" Not enough data for REP."); m_Antds = new FastVector(); /* Split data into Grow and Prune*/ data.stratify(m_Folds); Instances growData=data.trainCV(m_Folds, m_Folds-1); Instances pruneData=data.testCV(m_Folds, m_Folds-1); grow(growData); // Build this rule prune(pruneData); // Prune this rule } /** * Find all the instances in the dataset covered by this rule. * The instances not covered will also be deducted from the the original data * and returned by this procedure. * * @param insts the dataset to be covered by this rule. * @return the instances covered and not covered by this rule */ public Instances[] coveredByRule(Instances insts){ Instances[] data = new Instances[2]; data[0] = new Instances(insts, insts.numInstances()); data[1] = new Instances(insts, insts.numInstances()); for(int i=0; i<insts.numInstances(); i++){ Instance datum = insts.instance(i); if(isCover(datum)) data[0].add(datum); // Covered by this rule else data[1].add(datum); // Not covered by this rule } return data; } /** * Whether the instance covered by this rule * * @param inst the instance in question * @return the boolean value indicating whether the instance is covered by this rule */ public boolean isCover(Instance datum){ boolean isCover=true; for(int i=0; i<m_Antds.size(); i++){ Antd antd = (Antd)m_Antds.elementAt(i); if(!antd.isCover(datum)){ isCover = false; break; } } return isCover; } /** * Whether this rule has antecedents, i.e. whether it is a default rule * * @return the boolean value indicating whether the rule has antecedents */ public boolean hasAntds(){ if (m_Antds == null) return false; else return (m_Antds.size() > 0); } /** * Build one rule using the growing data * * @param data the growing data used to build the rule */ private void grow(Instances data){ Instances growData = new Instances(data); m_AccuG = computeDefAccu(growData); m_CoverG = growData.sumOfWeights(); /* Compute the default accurate rate of the growing data */ double defAcRt= m_AccuG / m_CoverG; /* Keep the record of which attributes have already been used*/ boolean[] used=new boolean [growData.numAttributes()]; for (int k=0; k<used.length; k++) used[k]=false; int numUnused=used.length; double maxInfoGain; boolean isContinue = true; // The stopping criterion of this rule while (isContinue){ maxInfoGain = 0; // We require that infoGain be positive /* Build a list of antecedents */ Antd oneAntd=null; Instances coverData = null; Enumeration enumAttr=growData.enumerateAttributes(); int index=-1; /* Build one condition based on all attributes not used yet*/ while (enumAttr.hasMoreElements()){ Attribute att= (Attribute)(enumAttr.nextElement()); index++; Antd antd =null; if(att.isNumeric()) antd = new NumericAntd(att); else antd = new NominalAntd(att); if(!used[index]){ /* Compute the best information gain for each attribute, it's stored in the antecedent formed by this attribute. This procedure returns the data covered by the antecedent*/ Instances coveredData = computeInfoGain(growData, defAcRt, antd); if(coveredData != null){ double infoGain = antd.getMaxInfoGain(); if(Utils.gr(infoGain, maxInfoGain)){ oneAntd=antd; coverData = coveredData; maxInfoGain = infoGain; } } } } if(oneAntd == null) return; //Numeric attributes can be used more than once if(!oneAntd.getAttr().isNumeric()){ used[oneAntd.getAttr().index()]=true; numUnused--; } m_Antds.addElement((Object)oneAntd); growData = coverData;// Grow data size is shrinking defAcRt = oneAntd.getAccuRate(); /* Stop if no more data, rule perfect, no more attributes */ if(Utils.eq(growData.sumOfWeights(), 0.0) || Utils.eq(defAcRt, 1.0) || (numUnused == 0)) isContinue = false; } } /** * Compute the best information gain for the specified antecedent * * @param data the data based on which the infoGain is computed * @param defAcRt the default accuracy rate of data * @param antd the specific antecedent * @return the data covered by the antecedent */ private Instances computeInfoGain(Instances instances, double defAcRt, Antd antd){ Instances data = new Instances(instances); /* Split the data into bags. The information gain of each bag is also calculated in this procedure */ Instances[] splitData = antd.splitData(data, defAcRt, m_Class); /* Get the bag of data to be used for next antecedents */ if(splitData != null) return splitData[(int)antd.getAttrValue()]; else return null; } /** * Prune the rule using the pruning data and update the worth parameters for this rule * The accuracy rate is used to prune the rule. * * @param pruneData the pruning data used to prune the rule */ private void prune(Instances pruneData){ Instances data=new Instances(pruneData); double total = data.sumOfWeights(); /* The default accurate# and the the accuracy rate on pruning data */ double defAccu=0, defAccuRate=0; int size=m_Antds.size(); if(size == 0) return; // Default rule before pruning double[] worthRt = new double[size]; double[] coverage = new double[size]; double[] worthValue = new double[size]; for(int w=0; w<size; w++){ worthRt[w]=coverage[w]=worthValue[w]=0.0; } /* Calculate accuracy parameters for all the antecedents in this rule */ for(int x=0; x<size; x++){ Antd antd=(Antd)m_Antds.elementAt(x); Attribute attr= antd.getAttr(); Instances newData = new Instances(data); data = new Instances(newData, newData.numInstances()); // Make data empty for(int y=0; y<newData.numInstances(); y++){ Instance ins=newData.instance(y); if(!ins.isMissing(attr)){ // Attribute not missing if(antd.isCover(ins)){ // Covered by this antecedent coverage[x] += ins.weight(); data.add(ins); // Add to data for further pruning if(Utils.eq(ins.classValue(), m_Class)) // Accurate prediction worthValue[x] += ins.weight(); } } } if(coverage[x] != 0) worthRt[x] = worthValue[x]/coverage[x]; } /* Prune the antecedents according to the accuracy parameters */ for(int z=(size-1); z > 0; z--) if(Utils.sm(worthRt[z], worthRt[z-1])) m_Antds.removeElementAt(z); else break; /* Check whether this rule is a default rule */ if(m_Antds.size() == 1){ defAccu = computeDefAccu(pruneData); defAccuRate = defAccu/total; // Compute def. accuracy if(Utils.sm(worthRt[0], defAccuRate)){ // Becomes a default rule m_Antds.removeAllElements(); } } /* Update the worth parameters of this rule*/ int antdsSize = m_Antds.size(); if(antdsSize != 0){ // Not a default rule m_Worth = worthValue[antdsSize-1]; // WorthValues of the last antecedent m_WorthRate = worthRt[antdsSize-1]; m_CoverP = coverage[antdsSize-1]; Antd last = (Antd)m_Antds.lastElement(); m_CoverG = last.getCover(); m_AccuG = last.getAccu(); } else{ // Default rule m_Worth = defAccu; // Default WorthValues m_WorthRate = defAccuRate; m_CoverP = total; } } /** * Private function to compute default number of accurate instances * in the specified data for m_Class * * @param data the data in question * @return the default accuracy number */ private double computeDefAccu(Instances data){ double defAccu=0; for(int i=0; i<data.numInstances(); i++){ Instance inst = data.instance(i); if(Utils.eq(inst.classValue(), m_Class)) defAccu += inst.weight(); } return defAccu; } /** The following are get functions after prune() has set the value of worthRate and worth*/ public double getWorthRate(){ return m_WorthRate; } public double getWorth(){ return m_Worth; } public double getCoverP(){ return m_CoverP; } public double getCoverG(){ return m_CoverG; } public double getAccuG(){ return m_AccuG; } /** * Prints this rule with the specified class label * * @param att the string standing for attribute in the consequent of this rule * @param cl the string standing for value in the consequent of this rule * @return a textual description of this rule with the specified class label */ public String toString(String att, String cl) { StringBuffer text = new StringBuffer(); if(m_Antds.size() > 0){ for(int j=0; j< (m_Antds.size()-1); j++) text.append("(" + ((Antd)(m_Antds.elementAt(j))).toString()+ ") and "); text.append("("+((Antd)(m_Antds.lastElement())).toString() + ")"); } text.append(" => " + att + " = " + cl); text.append(" ("+m_CoverG+"/"+(m_CoverG - m_AccuG)+") ["+ m_CoverP+"/"+(m_CoverP - m_Worth)+"]"); return text.toString(); } /** * Prints this rule * * @return a textual description of this rule */ public String toString() { return toString(m_ClassAttribute.name(), m_ClassAttribute.value((int)m_Class)); } } /** * The single antecedent in the rule, which is composed of an attribute and * the corresponding value. There are two inherited classes, namely NumericAntd * and NominalAntd in which the attributes are numeric and nominal respectively. */ private abstract class Antd{ /* The attribute of the antecedent */ protected Attribute att; /* The attribute value of the antecedent. For numeric attribute, value is either 0(1st bag) or 1(2nd bag) */ protected double value; /* The maximum infoGain achieved by this antecedent test */ protected double maxInfoGain; /* The accurate rate of this antecedent test on the growing data */ protected double accuRate; /* The coverage of this antecedent */ protected double cover; /* The accurate data for this antecedent */ protected double accu; /* Constructor*/ public Antd(Attribute a){ att=a; value=Double.NaN; maxInfoGain = 0; accuRate = Double.NaN; cover = Double.NaN; accu = Double.NaN; } /* The abstract members for inheritance */ public abstract Instances[] splitData(Instances data, double defAcRt, double cla); public abstract boolean isCover(Instance inst); public abstract String toString(); /* Get functions of this antecedent */ public Attribute getAttr(){ return att; } public double getAttrValue(){ return value; } public double getMaxInfoGain(){ return maxInfoGain; } public double getAccuRate(){ return accuRate; } public double getAccu(){ return accu; } public double getCover(){ return cover; } } /** * The antecedent with numeric attribute */ private class NumericAntd extends Antd{ /* The split point for this numeric antecedent */ private double splitPoint; /* Constructor*/ public NumericAntd(Attribute a){ super(a); splitPoint = Double.NaN; } /* Get split point of this numeric antecedent */ public double getSplitPoint(){ return splitPoint; } /** * Implements the splitData function. * This procedure is to split the data into two bags according * to the information gain of the numeric attribute value * The maximum infoGain is also calculated. * * @param insts the data to be split * @param defAcRt the default accuracy rate for data * @param cl the class label to be predicted * @return the array of data after split */ public Instances[] splitData(Instances insts, double defAcRt, double cl){ Instances data = new Instances(insts); data.sort(att); int total=data.numInstances();// Total number of instances without // missing value for att int split=1; // Current split position int prev=0; // Previous split position int finalSplit=split; // Final split position maxInfoGain = 0; value = 0; // Compute minimum number of Instances required in each split double minSplit = 0.1 * (data.sumOfWeights()) / 2.0; if (Utils.smOrEq(minSplit,m_MinNo)) minSplit = m_MinNo; else if (Utils.gr(minSplit,25)) minSplit = 25; double fstCover=0, sndCover=0, fstAccu=0, sndAccu=0; for(int x=0; x<data.numInstances(); x++){ Instance inst = data.instance(x); if(inst.isMissing(att)){ total = x; break; } sndCover += inst.weight(); if(Utils.eq(inst.classValue(), cl)) sndAccu += inst.weight(); } // Enough Instances with known values? if (Utils.sm(sndCover,(2*minSplit))) return null; if(total == 0) return null; // Data all missing for the attribute splitPoint = data.instance(total-1).value(att); for(; split < total; split++){ if(!Utils.eq(data.instance(split).value(att), data.instance(prev).value(att))){ // Can't split within same value for(int y=prev; y<split; y++){ Instance inst = data.instance(y); fstCover += inst.weight(); sndCover -= inst.weight(); if(Utils.eq(data.instance(y).classValue(), cl)){ fstAccu += inst.weight(); // First bag positive# ++ sndAccu -= inst.weight(); // Second bag positive# -- } } if(Utils.sm(fstCover, minSplit) || Utils.sm(sndCover, minSplit)){ prev=split; // Cannot split because either continue; // split has not enough data } double fstAccuRate = 0, sndAccuRate = 0; if(!Utils.eq(fstCover,0)) fstAccuRate = fstAccu/fstCover; if(!Utils.eq(sndCover,0)) sndAccuRate = sndAccu/sndCover; /* Which bag has higher information gain? */ boolean isFirst; double fstInfoGain, sndInfoGain; double accRate, infoGain, coverage, accurate;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -