📄 conjunctiverule.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * ConjunctiveRule.java * Copyright (C) 2001 Xin Xu * */package weka.classifiers.rules;import java.io.*;import java.util.*;import weka.core.*;import weka.classifiers.*;/** * This class implements a single conjunctive rule learner that can predict * for numeric and nominal class labels.<p> * * A rule consists of antecedents "AND"ed together and the consequent (class value) * for the classification/regression. In this case, the consequent is the * distribution of the available classes (or numeric value) in the dataset. * If the test instance is not covered by this rule, then it's predicted * using the default class distributions/value of the data not covered by the * rule in the training data. <br> * This learner selects an antecedent by computing the Information Gain of each * antecendent and prunes the generated rule using Reduced Error Prunning (REP). <p> * * For classification, the Information of one antecedent is the weighted average of * the entropies of both the data covered and not covered by the rule. <br> * * For regression, the Information is the weighted average of the mean-squared errors * of both the data covered and not covered by the rule. <p> * * In pruning, weighted average of accuracy rate of the pruning data is used * for classification while the weighted average of the mean-squared errors * of the pruning data is used for regression. <p> * * @author: Xin XU (xx5@cs.waikato.ac.nz) * @version $Revision: 1.1.1.1 $ */public class ConjunctiveRule extends DistributionClassifier implements OptionHandler, WeightedInstancesHandler{ /** The number of folds to split data into Grow and Prune for REP*/ private int m_Folds = 3; /** The class attribute of the data*/ private Attribute m_ClassAttribute; /** The vector of antecedents of this rule*/ protected FastVector m_Antds = null; /** The default rule distribution of the data not covered*/ protected double[] m_DefDstr = null; /** The consequent of this rule */ protected double[] m_Cnsqt = null; /** Number of classes in the training data */ private int m_NumClasses = 0; /** The seed to perform randomization */ private long m_Seed = 1; /** The Random object used for randomization */ private Random m_Random = null; /** Whether randomize the data */ private boolean m_IsRandomized = true; /** The predicted classes recorded for each antecedent in the growing data */ private FastVector m_Targets; /** Whether to use exlusive expressions for nominal attributes */ private boolean m_IsExclude = false; /** The minimal number of instance weights within a split*/ private double m_MinNo = 2.0; /** The number of antecedents in pre-pruning */ private int m_NumAntds = -1; /** * The single antecedent in the rule, which is composed of an attribute and * the corresponding value. There are two inherited classes, namely NumericAntd * and NominalAntd in which the attributes are numeric and nominal respectively. */ private abstract class Antd{ /** The attribute of the antecedent */ protected Attribute att; /** The attribute value of the antecedent. For numeric attribute, value is either 0(1st bag) or 1(2nd bag) */ protected double value; /** The maximum infoGain achieved by this antecedent test */ protected double maxInfoGain; /** The information of this antecedent test on the growing data */ protected double inform; /** The parameter related to the meanSquaredError of the data not covered by the previous antecedents when the class is numeric */ protected double uncoverWtSq, uncoverWtVl, uncoverSum; /** The parameters related to the data not covered by the previous antecedents when the class is nominal */ protected double[] uncover; /** Constructor for nominal class */ public Antd(Attribute a, double[] unc){ att=a; value=Double.NaN; maxInfoGain = 0; inform = Double.NaN; uncover = unc; } /* Constructor for numeric class */ public Antd(Attribute a, double uncoveredWtSq, double uncoveredWtVl, double uncoveredWts){ att=a; value=Double.NaN; maxInfoGain = 0; inform = Double.NaN; uncoverWtSq = uncoveredWtSq; uncoverWtVl = uncoveredWtVl; uncoverSum = uncoveredWts; } /* The abstract members for inheritance */ public abstract Instances[] splitData(Instances data, double defInfo); public abstract boolean isCover(Instance inst); public abstract String toString(); /* Get functions of this antecedent */ public Attribute getAttr(){ return att; } public double getAttrValue(){ return value; } public double getMaxInfoGain(){ return maxInfoGain; } public double getInfo(){ return inform;} /** * Function used to calculate the weighted mean squared error, * i.e., sum[x-avg(x)]^2 based on the given elements of the formula: * meanSquaredError = sum(Wi*Xi^2) - (sum(WiXi))^2/sum(Wi) * * @param weightedSq sum(Wi*Xi^2) * @param weightedValue sum(WiXi) * @param sum sum of weights * @return the weighted mean-squared error */ protected double wtMeanSqErr(double weightedSq, double weightedValue, double sum){ if(Utils.smOrEq(sum, 1.0E-6)) return 0; return (weightedSq - (weightedValue * weightedValue) / sum); } /** * Function used to calculate the entropy of given vector of values * entropy = (1/sum)*{-sigma[i=1..P](Xi*log2(Xi)) + sum*log2(sum)} * where P is the length of the vector * * @param value the given vector of values * @param sum the sum of the given values. It's provided just for efficiency. * @return the entropy */ protected double entropy(double[] value, double sum){ if(Utils.smOrEq(sum, 1.0E-6)) return 0; double entropy = 0; for(int i=0; i < value.length; i++){ if(!Utils.eq(value[i],0)) entropy -= value[i] * Utils.log2(value[i]); } entropy += sum * Utils.log2(sum); entropy /= sum; return entropy; } } /** * The antecedent with numeric attribute */ private class NumericAntd extends Antd{ /* The split point for this numeric antecedent */ private double splitPoint; /* Constructor for nominal class */ public NumericAntd(Attribute a, double[] unc){ super(a, unc); splitPoint = Double.NaN; } /* Constructor for numeric class */ public NumericAntd(Attribute a, double sq, double vl, double wts){ super(a, sq, vl, wts); splitPoint = Double.NaN; } /* Get split point of this numeric antecedent */ public double getSplitPoint(){ return splitPoint; } /** * Implements the splitData function. * This procedure is to split the data into two bags according * to the information gain of the numeric attribute value * the data with missing values are stored in the last split. * The maximum infoGain is also calculated. * * @param insts the data to be split * @param defInfo the default information for data * @return the array of data after split */ public Instances[] splitData(Instances insts, double defInfo){ Instances data = new Instances(insts); data.sort(att); int total=data.numInstances();// Total number of instances without // missing value for att maxInfoGain = 0; value = 0; // Compute minimum number of Instances required in each split double minSplit; if(m_ClassAttribute.isNominal()){ minSplit = 0.1 * (data.sumOfWeights()) / ((double)m_ClassAttribute.numValues()); if (Utils.smOrEq(minSplit,m_MinNo)) minSplit = m_MinNo; else if (Utils.gr(minSplit,25)) minSplit = 25; } else minSplit = m_MinNo; double[] fst=null, snd=null, missing=null; if(m_ClassAttribute.isNominal()){ fst = new double[m_NumClasses]; snd = new double[m_NumClasses]; missing = new double[m_NumClasses]; for(int v=0; v < m_NumClasses; v++) fst[v]=snd[v]=missing[v]=0.0; } double fstCover=0, sndCover=0, fstWtSq=0, sndWtSq=0, fstWtVl=0, sndWtVl=0; int split=1; // Current split position int prev=0; // Previous split position int finalSplit=split; // Final split position for(int x=0; x<data.numInstances(); x++){ Instance inst = data.instance(x); if(inst.isMissing(att)){ total = x; break; } sndCover += inst.weight(); if(m_ClassAttribute.isNominal()) // Nominal class snd[(int)inst.classValue()] += inst.weight(); else{ // Numeric class sndWtSq += inst.weight() * inst.classValue() * inst.classValue(); sndWtVl += inst.weight() * inst.classValue(); } } // Enough Instances with known values? if (Utils.sm(sndCover,(2*minSplit))) return null; double msingWtSq=0, msingWtVl=0; Instances missingData = new Instances(data, 0); for(int y=total; y < data.numInstances(); y++){ Instance inst = data.instance(y); missingData.add(inst); if(m_ClassAttribute.isNominal()) missing[(int)inst.classValue()] += inst.weight(); else{ msingWtSq += inst.weight() * inst.classValue() * inst.classValue(); msingWtVl += inst.weight() * inst.classValue(); } } if(total == 0) return null; // Data all missing for the attribute splitPoint = data.instance(total-1).value(att); for(; split < total; split++){ if(!Utils.eq(data.instance(split).value(att), // Can't split data.instance(prev).value(att))){// within same value // Move the split point for(int y=prev; y<split; y++){ Instance inst = data.instance(y); fstCover += inst.weight(); sndCover -= inst.weight(); if(m_ClassAttribute.isNominal()){ // Nominal class fst[(int)inst.classValue()] += inst.weight(); snd[(int)inst.classValue()] -= inst.weight(); } else{ // Numeric class fstWtSq += inst.weight() * inst.classValue() * inst.classValue(); fstWtVl += inst.weight() * inst.classValue(); sndWtSq -= inst.weight() * inst.classValue() * inst.classValue(); sndWtVl -= inst.weight() * inst.classValue(); } } if(Utils.sm(fstCover, minSplit) || Utils.sm(sndCover, minSplit)){ prev=split; // Cannot split because either continue; // split has not enough data } double fstEntp = 0, sndEntp = 0; if(m_ClassAttribute.isNominal()){ fstEntp = entropy(fst, fstCover); sndEntp = entropy(snd, sndCover); } else{ fstEntp = wtMeanSqErr(fstWtSq, fstWtVl, fstCover)/fstCover; sndEntp = wtMeanSqErr(sndWtSq, sndWtVl, sndCover)/sndCover; } /* Which bag has higher information gain? */ boolean isFirst; double fstInfoGain, sndInfoGain; double info, infoGain, fstInfo, sndInfo; if(m_ClassAttribute.isNominal()){ double sum = data.sumOfWeights(); double otherCover, whole = sum + Utils.sum(uncover), otherEntropy; double[] other = null; // InfoGain of first bag other = new double[m_NumClasses]; for(int z=0; z < m_NumClasses; z++) other[z] = uncover[z] + snd[z] + missing[z]; otherCover = whole - fstCover; otherEntropy = entropy(other, otherCover); // Weighted average fstInfo = (fstEntp*fstCover + otherEntropy*otherCover)/whole; fstInfoGain = defInfo - fstInfo; // InfoGain of second bag other = new double[m_NumClasses]; for(int z=0; z < m_NumClasses; z++) other[z] = uncover[z] + fst[z] + missing[z]; otherCover = whole - sndCover; otherEntropy = entropy(other, otherCover); // Weighted average sndInfo = (sndEntp*sndCover + otherEntropy*otherCover)/whole; sndInfoGain = defInfo - sndInfo; } else{ double sum = data.sumOfWeights(); double otherWtSq = (sndWtSq + msingWtSq + uncoverWtSq), otherWtVl = (sndWtVl + msingWtVl + uncoverWtVl), otherCover = (sum - fstCover + uncoverSum); fstInfo = Utils.eq(fstCover, 0) ? 0 : (fstEntp * fstCover); fstInfo += wtMeanSqErr(otherWtSq, otherWtVl, otherCover); fstInfoGain = defInfo - fstInfo; otherWtSq = (fstWtSq + msingWtSq + uncoverWtSq); otherWtVl = (fstWtVl + msingWtVl + uncoverWtVl); otherCover = sum - sndCover + uncoverSum; sndInfo = Utils.eq(sndCover, 0) ? 0 : (sndEntp * sndCover); sndInfo += wtMeanSqErr(otherWtSq, otherWtVl, otherCover); sndInfoGain = defInfo - sndInfo; } if(Utils.gr(fstInfoGain,sndInfoGain) || (Utils.eq(fstInfoGain,sndInfoGain)&&(Utils.sm(fstEntp,sndEntp)))){ isFirst = true; infoGain = fstInfoGain; info = fstInfo; } else{ isFirst = false; infoGain = sndInfoGain; info = sndInfo; } boolean isUpdate = Utils.gr(infoGain, maxInfoGain); /* Check whether so far the max infoGain */ if(isUpdate){ splitPoint = ((data.instance(split).value(att)) + (data.instance(prev).value(att)))/2.0; value = ((isFirst) ? 0 : 1); inform = info; maxInfoGain = infoGain; finalSplit = split; } prev=split; } } /* Split the data */ Instances[] splitData = new Instances[3]; splitData[0] = new Instances(data, 0, finalSplit); splitData[1] = new Instances(data, finalSplit, total-finalSplit); splitData[2] = new Instances(missingData); return splitData; } /** * Whether the instance is covered by this antecedent * * @param inst the instance in question * @return the boolean value indicating whether the instance is covered * by this antecedent */ public boolean isCover(Instance inst){ boolean isCover=false; if(!inst.isMissing(att)){ if(Utils.eq(value, 0)){ if(Utils.smOrEq(inst.value(att), splitPoint)) isCover=true; } else if(Utils.gr(inst.value(att), splitPoint)) isCover=true; } return isCover; } /** * Prints this antecedent * * @return a textual description of this antecedent */ public String toString() { String symbol = Utils.eq(value, 0.0) ? " <= " : " > "; return (att.name() + symbol + Utils.doubleToString(splitPoint, 6)); } } /** * The antecedent with nominal attribute */ class NominalAntd extends Antd{ /* The parameters of infoGain calculated for each attribute value */ private double[][] stats; private double[] coverage;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -