📄 ridor.java
字号:
* @param total the total data size fed into the ruleset
* @return the weighted accuracy rate of this rule
*/
private double computeWeightedAcRt(double worthRt, double cover, double total){
return (worthRt * (cover/total));
}
/**
* Builds an array of data according to their true class label
* Each bag of data is filtered through the rule specified and
* is totally covered by this rule.
* Both the data covered and uncovered by the rule will be returned
* by the procedure.
*
* @param rule the rule covering the data
* @param dataByClass the array of data to be covered by the rule
* @return the arrays of data both covered and not covered by the rule
*/
private Instances[][] divide(RidorRule rule, Instances[] dataByClass){
int len = dataByClass.length;
Instances[][] dataBags = new Instances[2][len];
for(int i=0; i < len; i++){
Instances[] dvdData = rule.coveredByRule(dataByClass[i]);
dataBags[0][i] = dvdData[0]; // Covered by the rule
dataBags[1][i] = dvdData[1]; // Not covered by the rule
}
return dataBags;
}
/**
* The size of the certain node of Ridor, i.e. the
* number of rules generated within and below this node
*
* @return the size of this node
*/
public int size(){
int size = 0;
if(rules != null){
for(int i=0; i < rules.length; i++)
size += excepts[i].size(); // The children's size
size += rules.length; // This node's size
}
return size;
}
/**
* Prints the all the rules of one node of Ridor.
*
* @return a textual description of one node of Ridor
*/
public String toString(){
StringBuffer text = new StringBuffer();
if(level == 1)
text.append(m_Class.name() + " = " + m_Class.value((int)getDefClass())+
" ("+m_Cover+"/"+m_Err+")\n");
if(rules != null){
for(int i=0; i < rules.length; i++){
for(int j=0; j < level; j++)
text.append(" ");
String cl = m_Class.value((int)(excepts[i].getDefClass()));
text.append(" Except " +
rules[i].toString(m_Class.name(), cl)+
"\n" + excepts[i].toString());
}
}
return text.toString();
}
}
/**
* This class implements a single rule that predicts the 2-class distribution.
*
* A rule consists of antecedents "AND"ed together and the consequent (class value)
* for the classification. In this case, the consequent is the distribution of
* the available classes (always 2 classes) in the dataset.
* In this class, the Information Gain (p*[log(p/t) - log(P/T)]) is used to select
* an antecedent and Reduced Error Prunning (REP) is used to prune the rule.
*
*/
private class RidorRule implements WeightedInstancesHandler, Serializable {
/** The internal representation of the class label to be predicted*/
private double m_Class = -1;
/** The class attribute of the data*/
private Attribute m_ClassAttribute;
/** The vector of antecedents of this rule*/
protected FastVector m_Antds = null;
/** The worth rate of this rule, in this case, accuracy rate in the pruning data*/
private double m_WorthRate = 0;
/** The worth value of this rule, in this case, accurate # in pruning data*/
private double m_Worth = 0;
/** The sum of weights of the data covered by this rule in the pruning data */
private double m_CoverP = 0;
/** The accurate and covered data of this rule in the growing data */
private double m_CoverG = 0, m_AccuG = 0;
/** The access functions for parameters */
public void setPredictedClass(double cl){ m_Class = cl; }
public double getPredictedClass(){ return m_Class; }
/**
* Builds a single rule learner with REP dealing with 2 classes.
* This rule learner always tries to predict the class with label
* m_Class.
*
* @param instances the training data
* @exception Exception if classifier can't be built successfully
*/
public void buildClassifier(Instances instances) throws Exception {
m_ClassAttribute = instances.classAttribute();
if (!m_ClassAttribute.isNominal())
throw new UnsupportedClassTypeException(" Only nominal class, please.");
if(instances.numClasses() != 2)
throw new Exception(" Only 2 classes, please.");
Instances data = new Instances(instances);
if(Utils.eq(data.sumOfWeights(),0))
throw new Exception(" No training data.");
data.deleteWithMissingClass();
if(Utils.eq(data.sumOfWeights(),0))
throw new Exception(" The class labels of all the training data are missing.");
if(data.numInstances() < m_Folds)
throw new Exception(" Not enough data for REP.");
m_Antds = new FastVector();
/* Split data into Grow and Prune*/
m_Random = new Random(m_Seed);
data.randomize(m_Random);
data.stratify(m_Folds);
Instances growData=data.trainCV(m_Folds, m_Folds-1, m_Random);
Instances pruneData=data.testCV(m_Folds, m_Folds-1);
grow(growData); // Build this rule
prune(pruneData); // Prune this rule
}
/**
* Find all the instances in the dataset covered by this rule.
* The instances not covered will also be deducted from the the original data
* and returned by this procedure.
*
* @param insts the dataset to be covered by this rule.
* @return the instances covered and not covered by this rule
*/
public Instances[] coveredByRule(Instances insts){
Instances[] data = new Instances[2];
data[0] = new Instances(insts, insts.numInstances());
data[1] = new Instances(insts, insts.numInstances());
for(int i=0; i<insts.numInstances(); i++){
Instance datum = insts.instance(i);
if(isCover(datum))
data[0].add(datum); // Covered by this rule
else
data[1].add(datum); // Not covered by this rule
}
return data;
}
/**
* Whether the instance covered by this rule
*
* @param inst the instance in question
* @return the boolean value indicating whether the instance is covered by this rule
*/
public boolean isCover(Instance datum){
boolean isCover=true;
for(int i=0; i<m_Antds.size(); i++){
Antd antd = (Antd)m_Antds.elementAt(i);
if(!antd.isCover(datum)){
isCover = false;
break;
}
}
return isCover;
}
/**
* Whether this rule has antecedents, i.e. whether it is a default rule
*
* @return the boolean value indicating whether the rule has antecedents
*/
public boolean hasAntds(){
if (m_Antds == null)
return false;
else
return (m_Antds.size() > 0);
}
/**
* Build one rule using the growing data
*
* @param data the growing data used to build the rule
*/
private void grow(Instances data){
Instances growData = new Instances(data);
m_AccuG = computeDefAccu(growData);
m_CoverG = growData.sumOfWeights();
/* Compute the default accurate rate of the growing data */
double defAcRt= m_AccuG / m_CoverG;
/* Keep the record of which attributes have already been used*/
boolean[] used=new boolean [growData.numAttributes()];
for (int k=0; k<used.length; k++)
used[k]=false;
int numUnused=used.length;
double maxInfoGain;
boolean isContinue = true; // The stopping criterion of this rule
while (isContinue){
maxInfoGain = 0; // We require that infoGain be positive
/* Build a list of antecedents */
Antd oneAntd=null;
Instances coverData = null;
Enumeration emAttr=growData.emerateAttributes();
int index=-1;
/* Build one condition based on all attributes not used yet*/
while (emAttr.hasMoreElements()){
Attribute att= (Attribute)(emAttr.nextElement());
index++;
Antd antd =null;
if(att.isNumeric())
antd = new NumericAntd(att);
else
antd = new NominalAntd(att);
if(!used[index]){
/* Compute the best information gain for each attribute,
it's stored in the antecedent formed by this attribute.
This procedure returns the data covered by the antecedent*/
Instances coveredData = computeInfoGain(growData, defAcRt, antd);
if(coveredData != null){
double infoGain = antd.getMaxInfoGain();
if(Utils.gr(infoGain, maxInfoGain)){
oneAntd=antd;
coverData = coveredData;
maxInfoGain = infoGain;
}
}
}
}
if(oneAntd == null) return;
//Numeric attributes can be used more than once
if(!oneAntd.getAttr().isNumeric()){
used[oneAntd.getAttr().index()]=true;
numUnused--;
}
m_Antds.addElement((Object)oneAntd);
growData = coverData;// Grow data size is shrinking
defAcRt = oneAntd.getAccuRate();
/* Stop if no more data, rule perfect, no more attributes */
if(Utils.eq(growData.sumOfWeights(), 0.0) || Utils.eq(defAcRt, 1.0) || (numUnused == 0))
isContinue = false;
}
}
/**
* Compute the best information gain for the specified antecedent
*
* @param data the data based on which the infoGain is computed
* @param defAcRt the default accuracy rate of data
* @param antd the specific antecedent
* @return the data covered by the antecedent
*/
private Instances computeInfoGain(Instances instances, double defAcRt, Antd antd){
Instances data = new Instances(instances);
/* Split the data into bags.
The information gain of each bag is also calculated in this procedure */
Instances[] splitData = antd.splitData(data, defAcRt, m_Class);
/* Get the bag of data to be used for next antecedents */
if(splitData != null)
return splitData[(int)antd.getAttrValue()];
else return null;
}
/**
* Prune the rule using the pruning data and update the worth parameters for this rule
* The accuracy rate is used to prune the rule.
*
* @param pruneData the pruning data used to prune the rule
*/
private void prune(Instances pruneData){
Instances data=new Instances(pruneData);
double total = data.sumOfWeights();
/* The default accurate# and the the accuracy rate on pruning data */
double defAccu=0, defAccuRate=0;
int size=m_Antds.size();
if(size == 0) return; // Default rule before pruning
double[] worthRt = new double[size];
double[] coverage = new double[size];
double[] worthValue = new double[size];
for(int w=0; w<size; w++){
worthRt[w]=coverage[w]=worthValue[w]=0.0;
}
/* Calculate accuracy parameters for all the antecedents in this rule */
for(int x=0; x<size; x++){
Antd antd=(Antd)m_Antds.elementAt(x);
Attribute attr= antd.getAttr();
Instances newData = new Instances(data);
data = new Instances(newData, newData.numInstances()); // Make data empty
for(int y=0; y<newData.numInstances(); y++){
Instance ins=newData.instance(y);
if(!ins.isMissing(attr)){ // Attribute not missing
if(antd.isCover(ins)){ // Covered by this antecedent
coverage[x] += ins.weight();
data.add(ins); // Add to data for further pruning
if(Utils.eq(ins.classValue(), m_Class)) // Accurate prediction
worthValue[x] += ins.weight();
}
}
}
if(coverage[x] != 0)
worthRt[x] = worthValue[x]/coverage[x];
}
/* Prune the antecedents according to the accuracy parameters */
for(int z=(size-1); z > 0; z--)
if(Utils.sm(worthRt[z], worthRt[z-1]))
m_Antds.removeElementAt(z);
else break;
/* Check whether this rule is a default rule */
if(m_Antds.size() == 1){
defAccu = computeDefAccu(pruneData);
defAccuRate = defAccu/total; // Compute def. accuracy
if(Utils.sm(worthRt[0], defAccuRate)){ // Becomes a default rule
m_Antds.removeAllElements();
}
}
/* Update the worth parameters of this rule*/
int antdsSize = m_Antds.size();
if(antdsSize != 0){ // Not a default rule
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -