📄 conjunctiverule.java
字号:
return isCover; } /** * Whether this rule has antecedents, i.e. whether it is a default rule * * @return the boolean value indicating whether the rule has antecedents */ public boolean hasAntds(){ if (m_Antds == null) return false; else return (m_Antds.size() > 0); } /** * Build one rule using the growing data * * @param data the growing data used to build the rule */ private void grow(Instances data){ Instances growData = new Instances(data); double defInfo; double whole = data.sumOfWeights(); if(m_NumAntds != 0){ /* Class distribution for data both covered and not covered by one antecedent */ double[][] classDstr = new double[2][m_NumClasses]; /* Compute the default information of the growing data */ for(int j=0; j < m_NumClasses; j++){ classDstr[0][j] = 0; classDstr[1][j] = 0; } if(m_ClassAttribute.isNominal()){ for(int i=0; i < growData.numInstances(); i++){ Instance datum = growData.instance(i); classDstr[0][(int)datum.classValue()] += datum.weight(); } defInfo = ContingencyTables.entropy(classDstr[0]); } else{ for(int i=0; i < growData.numInstances(); i++){ Instance datum = growData.instance(i); classDstr[0][0] += datum.weight() * datum.classValue(); } // No need to be divided by the denomitor because // it's always the same double defMean = (classDstr[0][0] / whole); defInfo = meanSquaredError(growData, defMean) * growData.sumOfWeights(); } // Store the default class distribution double[][] tmp = new double[2][m_NumClasses]; for(int y=0; y < m_NumClasses; y++){ if(m_ClassAttribute.isNominal()){ tmp[0][y] = classDstr[0][y]; tmp[1][y] = classDstr[1][y]; } else{ tmp[0][y] = classDstr[0][y]/whole; tmp[1][y] = classDstr[1][y]; } } m_Targets.addElement(tmp); /* Keep the record of which attributes have already been used*/ boolean[] used=new boolean[growData.numAttributes()]; for (int k=0; k<used.length; k++) used[k]=false; int numUnused=used.length; double maxInfoGain, uncoveredWtSq=0, uncoveredWtVl=0, uncoveredWts=0; boolean isContinue = true; // The stopping criterion of this rule while (isContinue){ maxInfoGain = 0; // We require that infoGain be positive /* Build a list of antecedents */ Antd oneAntd=null; Instances coverData = null, uncoverData = null; Enumeration enumAttr=growData.enumerateAttributes(); int index=-1; /* Build one condition based on all attributes not used yet*/ while (enumAttr.hasMoreElements()){ Attribute att= (Attribute)(enumAttr.nextElement()); index++; Antd antd =null; if(m_ClassAttribute.isNominal()){ if(att.isNumeric()) antd = new NumericAntd(att, classDstr[1]); else antd = new NominalAntd(att, classDstr[1]); } else if(att.isNumeric()) antd = new NumericAntd(att, uncoveredWtSq, uncoveredWtVl, uncoveredWts); else antd = new NominalAntd(att, uncoveredWtSq, uncoveredWtVl, uncoveredWts); if(!used[index]){ /* Compute the best information gain for each attribute, it's stored in the antecedent formed by this attribute. This procedure returns the data covered by the antecedent*/ Instances[] coveredData = computeInfoGain(growData, defInfo, antd); if(coveredData != null){ double infoGain = antd.getMaxInfoGain(); boolean isUpdate = Utils.gr(infoGain, maxInfoGain); if(isUpdate){ oneAntd=antd; coverData = coveredData[0]; uncoverData = coveredData[1]; maxInfoGain = infoGain; } } } } if(oneAntd == null) break; //Numeric attributes can be used more than once if(!oneAntd.getAttr().isNumeric()){ used[oneAntd.getAttr().index()]=true; numUnused--; } m_Antds.addElement(oneAntd); growData = coverData;// Grow data size is shrinking for(int x=0; x < uncoverData.numInstances(); x++){ Instance datum = uncoverData.instance(x); if(m_ClassAttribute.isNumeric()){ uncoveredWtSq += datum.weight() * datum.classValue() * datum.classValue(); uncoveredWtVl += datum.weight() * datum.classValue(); uncoveredWts += datum.weight(); classDstr[0][0] -= datum.weight() * datum.classValue(); classDstr[1][0] += datum.weight() * datum.classValue(); } else{ classDstr[0][(int)datum.classValue()] -= datum.weight(); classDstr[1][(int)datum.classValue()] += datum.weight(); } } // Store class distribution of growing data tmp = new double[2][m_NumClasses]; for(int y=0; y < m_NumClasses; y++){ if(m_ClassAttribute.isNominal()){ tmp[0][y] = classDstr[0][y]; tmp[1][y] = classDstr[1][y]; } else{ tmp[0][y] = classDstr[0][y]/(whole-uncoveredWts); tmp[1][y] = classDstr[1][y]/uncoveredWts; } } m_Targets.addElement(tmp); defInfo = oneAntd.getInfo(); int numAntdsThreshold = (m_NumAntds == -1) ? Integer.MAX_VALUE : m_NumAntds; if(Utils.eq(growData.sumOfWeights(), 0.0) || (numUnused == 0) || (m_Antds.size() >= numAntdsThreshold)) isContinue = false; } } m_Cnsqt = ((double[][])(m_Targets.lastElement()))[0]; m_DefDstr = ((double[][])(m_Targets.lastElement()))[1]; } /** * Compute the best information gain for the specified antecedent * * @param data the data based on which the infoGain is computed * @param defInfo the default information of data * @param antd the specific antecedent * @return the data covered and not covered by the antecedent */ private Instances[] computeInfoGain(Instances instances, double defInfo, Antd antd){ Instances data = new Instances(instances); /* Split the data into bags. The information gain of each bag is also calculated in this procedure */ Instances[] splitData = antd.splitData(data, defInfo); Instances[] coveredData = new Instances[2]; /* Get the bag of data to be used for next antecedents */ Instances tmp1 = new Instances(data, 0); Instances tmp2 = new Instances(data, 0); if(splitData == null) return null; for(int x=0; x < (splitData.length-1); x++){ if(x == ((int)antd.getAttrValue())) tmp1 = splitData[x]; else{ for(int y=0; y < splitData[x].numInstances(); y++) tmp2.add(splitData[x].instance(y)); } } if(antd.getAttr().isNominal()){ // Nominal attributes if(((NominalAntd)antd).isIn()){ // Inclusive expression coveredData[0] = new Instances(tmp1); coveredData[1] = new Instances(tmp2); } else{ // Exclusive expression coveredData[0] = new Instances(tmp2); coveredData[1] = new Instances(tmp1); } } else{ // Numeric attributes coveredData[0] = new Instances(tmp1); coveredData[1] = new Instances(tmp2); } /* Add data with missing value */ for(int z=0; z<splitData[splitData.length-1].numInstances(); z++) coveredData[1].add(splitData[splitData.length-1].instance(z)); return coveredData; } /** * Prune the rule using the pruning data. * The weighted average of accuracy rate/mean-squared error is * used to prune the rule. * * @param pruneData the pruning data used to prune the rule */ private void prune(Instances pruneData){ Instances data=new Instances(pruneData); Instances otherData = new Instances(data, 0); double total = data.sumOfWeights(); /* The default accurate# and the the accuracy rate on pruning data */ double defAccu; if(m_ClassAttribute.isNumeric()) defAccu = meanSquaredError(pruneData, ((double[][])m_Targets.firstElement())[0][0]); else{ int predict = Utils.maxIndex(((double[][])m_Targets.firstElement())[0]); defAccu = computeAccu(pruneData, predict)/total; } int size=m_Antds.size(); if(size == 0){ m_Cnsqt = ((double[][])m_Targets.lastElement())[0]; m_DefDstr = ((double[][])m_Targets.lastElement())[1]; return; // Default rule before pruning } double[] worthValue = new double[size]; /* Calculate accuracy parameters for all the antecedents in this rule */ for(int x=0; x<size; x++){ Antd antd=(Antd)m_Antds.elementAt(x); Attribute attr= antd.getAttr(); Instances newData = new Instances(data); if(Utils.eq(newData.sumOfWeights(),0.0)) break; data = new Instances(newData, newData.numInstances()); // Make data empty for(int y=0; y<newData.numInstances(); y++){ Instance ins=newData.instance(y); if(antd.isCover(ins)) // Covered by this antecedent data.add(ins); // Add to data for further else otherData.add(ins); // Not covered by this antecedent } double covered, other; double[][] classes = (double[][])m_Targets.elementAt(x+1); // m_Targets has one more element if(m_ClassAttribute.isNominal()){ int coverClass = Utils.maxIndex(classes[0]), otherClass = Utils.maxIndex(classes[1]); covered = computeAccu(data, coverClass); other = computeAccu(otherData, otherClass); } else{ double coverClass = classes[0][0], otherClass = classes[1][0]; covered = (data.sumOfWeights())*meanSquaredError(data, coverClass); other = (otherData.sumOfWeights())*meanSquaredError(otherData, otherClass); } worthValue[x] = (covered + other)/total; } /* Prune the antecedents according to the accuracy parameters */ for(int z=(size-1); z > 0; z--){ // Treatment to avoid precision problems double valueDelta; if(m_ClassAttribute.isNominal()){ if(Utils.sm(worthValue[z], 1.0)) valueDelta = (worthValue[z] - worthValue[z-1]) / worthValue[z]; else valueDelta = worthValue[z] - worthValue[z-1]; } else{ if(Utils.sm(worthValue[z], 1.0)) valueDelta = (worthValue[z-1] - worthValue[z]) / worthValue[z]; else valueDelta = (worthValue[z-1] - worthValue[z]); } if(Utils.smOrEq(valueDelta, 0.0)){ m_Antds.removeElementAt(z); m_Targets.removeElementAt(z+1); } else break; } // Check whether this rule is a default rule if(m_Antds.size() == 1){ double valueDelta; if(m_ClassAttribute.isNominal()){ if(Utils.sm(worthValue[0], 1.0)) valueDelta = (worthValue[0] - defAccu) / worthValue[0]; else valueDelta = (worthValue[0] - defAccu); } else{ if(Utils.sm(worthValue[0], 1.0)) valueDelta = (defAccu - worthValue[0]) / worthValue[0]; else valueDelta = (defAccu - worthValue[0]); } if(Utils.smOrEq(valueDelta, 0.0)){ m_Antds.removeAllElements(); m_Targets.removeElementAt(1); } } m_Cnsqt = ((double[][])(m_Targets.lastElement()))[0]; m_DefDstr = ((double[][])(m_Targets.lastElement()))[1]; } /** * Private function to compute number of accurate instances * based on the specified predicted class * * @param data the data in question * @param clas the predicted class * @return the default accuracy number */ private double computeAccu(Instances data, int clas){ double accu = 0; for(int i=0; i<data.numInstances(); i++){ Instance inst = data.instance(i); if((int)inst.classValue() == clas) accu += inst.weight(); } return accu; } /** * Private function to compute the squared error of * the specified data and the specified mean * * @param data the data in question * @param mean the specified mean * @return the default mean-squared error */ private double meanSquaredError(Instances data, double mean){ if(Utils.eq(data.sumOfWeights(),0.0)) return 0; double mSqErr=0, sum = data.sumOfWeights(); for(int i=0; i < data.numInstances(); i++){ Instance datum = data.instance(i); mSqErr += datum.weight()* (datum.classValue() - mean)* (datum.classValue() - mean); } return (mSqErr / sum); } /** * Prints this rule with the specified class label * * @param att the string standing for attribute in the consequent of this rule * @param cl the string standing for value in the consequent of this rule * @return a textual description of this rule with the specified class label */ public String toString(String att, String cl) { StringBuffer text = new StringBuffer(); if(m_Antds.size() > 0){ for(int j=0; j< (m_Antds.size()-1); j++) text.append("(" + ((Antd)(m_Antds.elementAt(j))).toString()+ ") and "); text.append("("+((Antd)(m_Antds.lastElement())).toString() + ")"); } text.append(" => " + att + " = " + cl); return text.toString(); } /** * Prints this rule * * @return a textual description of this rule */ public String toString() { String title = "\n\nSingle conjunctive rule learner:\n"+ "--------------------------------\n", body = null; StringBuffer text = new StringBuffer(); if(m_ClassAttribute != null){ if(m_ClassAttribute.isNominal()){ body = toString(m_ClassAttribute.name(), m_ClassAttribute.value(Utils.maxIndex(m_Cnsqt))); text.append("\n\nClass distributions:\nCovered by the rule:\n"); for(int k=0; k < m_Cnsqt.length; k++) text.append(m_ClassAttribute.value(k)+ "\t"); text.append('\n'); for(int l=0; l < m_Cnsqt.length; l++) text.append(Utils.doubleToString(m_Cnsqt[l], 6)+"\t"); text.append("\n\nNot covered by the rule:\n"); for(int k=0; k < m_DefDstr.length; k++) text.append(m_ClassAttribute.value(k)+ "\t"); text.append('\n'); for(int l=0; l < m_DefDstr.length; l++) text.append(Utils.doubleToString(m_DefDstr[l], 6)+"\t"); } else body = toString(m_ClassAttribute.name(), Utils.doubleToString(m_Cnsqt[0], 6)); } return (title + body + text.toString()); } /** * Main method. * * @param args the options for the classifier */ public static void main(String[] args) { try { System.out.println(Evaluation.evaluateModel(new ConjunctiveRule(), args)); } catch (Exception e) { e.printStackTrace(); System.err.println(e.getMessage()); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -