📄 conjunctiverule.java
字号:
else
if(att.isNumeric())
antd = new NumericAntd(att, uncoveredWtSq, uncoveredWtVl, uncoveredWts);
else
antd = new NominalAntd(att, uncoveredWtSq, uncoveredWtVl, uncoveredWts);
if(!used[index]){
/* Compute the best information gain for each attribute,
it's stored in the antecedent formed by this attribute.
This procedure returns the data covered by the antecedent*/
Instances[] coveredData = computeInfoGain(growData, defInfo, antd);
if(coveredData != null){
double infoGain = antd.getMaxInfoGain();
boolean isUpdate = Utils.gr(infoGain, maxInfoGain);
if(isUpdate){
oneAntd=antd;
coverData = coveredData[0];
uncoverData = coveredData[1];
maxInfoGain = infoGain;
}
}
}
}
if(oneAntd == null)
break;
//Numeric attributes can be used more than once
if(!oneAntd.getAttr().isNumeric()){
used[oneAntd.getAttr().index()]=true;
numUnused--;
}
m_Antds.addElement(oneAntd);
growData = coverData;// Grow data size is shrinking
for(int x=0; x < uncoverData.numInstances(); x++){
Instance datum = uncoverData.instance(x);
if(m_ClassAttribute.isNumeric()){
uncoveredWtSq += datum.weight() * datum.classValue() * datum.classValue();
uncoveredWtVl += datum.weight() * datum.classValue();
uncoveredWts += datum.weight();
classDstr[0][0] -= datum.weight() * datum.classValue();
classDstr[1][0] += datum.weight() * datum.classValue();
}
else{
classDstr[0][(int)datum.classValue()] -= datum.weight();
classDstr[1][(int)datum.classValue()] += datum.weight();
}
}
// Store class distribution of growing data
tmp = new double[2][m_NumClasses];
for(int y=0; y < m_NumClasses; y++){
if(m_ClassAttribute.isNominal()){
tmp[0][y] = classDstr[0][y];
tmp[1][y] = classDstr[1][y];
}
else{
tmp[0][y] = classDstr[0][y]/(whole-uncoveredWts);
tmp[1][y] = classDstr[1][y]/uncoveredWts;
}
}
m_Targets.addElement(tmp);
defInfo = oneAntd.getInfo();
int numAntdsThreshold = (m_NumAntds == -1) ? Integer.MAX_VALUE : m_NumAntds;
if(Utils.eq(growData.sumOfWeights(), 0.0) ||
(numUnused == 0) ||
(m_Antds.size() >= numAntdsThreshold))
isContinue = false;
}
}
m_Cnsqt = ((double[][])(m_Targets.lastElement()))[0];
m_DefDstr = ((double[][])(m_Targets.lastElement()))[1];
}
/**
* Compute the best information gain for the specified antecedent
*
* @param data the data based on which the infoGain is computed
* @param defInfo the default information of data
* @param antd the specific antecedent
* @return the data covered and not covered by the antecedent
*/
private Instances[] computeInfoGain(Instances instances, double defInfo, Antd antd){
Instances data = new Instances(instances);
/* Split the data into bags.
The information gain of each bag is also calculated in this procedure */
Instances[] splitData = antd.splitData(data, defInfo);
Instances[] coveredData = new Instances[2];
/* Get the bag of data to be used for next antecedents */
Instances tmp1 = new Instances(data, 0);
Instances tmp2 = new Instances(data, 0);
if(splitData == null)
return null;
for(int x=0; x < (splitData.length-1); x++){
if(x == ((int)antd.getAttrValue()))
tmp1 = splitData[x];
else{
for(int y=0; y < splitData[x].numInstances(); y++)
tmp2.add(splitData[x].instance(y));
}
}
if(antd.getAttr().isNominal()){ // Nominal attributes
if(((NominalAntd)antd).isIn()){ // Inclusive expression
coveredData[0] = new Instances(tmp1);
coveredData[1] = new Instances(tmp2);
}
else{ // Exclusive expression
coveredData[0] = new Instances(tmp2);
coveredData[1] = new Instances(tmp1);
}
}
else{ // Numeric attributes
coveredData[0] = new Instances(tmp1);
coveredData[1] = new Instances(tmp2);
}
/* Add data with missing value */
for(int z=0; z<splitData[splitData.length-1].numInstances(); z++)
coveredData[1].add(splitData[splitData.length-1].instance(z));
return coveredData;
}
/**
* Prune the rule using the pruning data.
* The weighted average of accuracy rate/mean-squared error is
* used to prune the rule.
*
* @param pruneData the pruning data used to prune the rule
*/
private void prune(Instances pruneData){
Instances data=new Instances(pruneData);
Instances otherData = new Instances(data, 0);
double total = data.sumOfWeights();
/* The default accurate# and the the accuracy rate on pruning data */
double defAccu;
if(m_ClassAttribute.isNumeric())
defAccu = meanSquaredError(pruneData,
((double[][])m_Targets.firstElement())[0][0]);
else{
int predict = Utils.maxIndex(((double[][])m_Targets.firstElement())[0]);
defAccu = computeAccu(pruneData, predict)/total;
}
int size=m_Antds.size();
if(size == 0){
m_Cnsqt = ((double[][])m_Targets.lastElement())[0];
m_DefDstr = ((double[][])m_Targets.lastElement())[1];
return; // Default rule before pruning
}
double[] worthValue = new double[size];
/* Calculate accuracy parameters for all the antecedents in this rule */
for(int x=0; x<size; x++){
Antd antd=(Antd)m_Antds.elementAt(x);
Attribute attr= antd.getAttr();
Instances newData = new Instances(data);
if(Utils.eq(newData.sumOfWeights(),0.0))
break;
data = new Instances(newData, newData.numInstances()); // Make data empty
for(int y=0; y<newData.numInstances(); y++){
Instance ins=newData.instance(y);
if(antd.isCover(ins)) // Covered by this antecedent
data.add(ins); // Add to data for further
else
otherData.add(ins); // Not covered by this antecedent
}
double covered, other;
double[][] classes =
(double[][])m_Targets.elementAt(x+1); // m_Targets has one more element
if(m_ClassAttribute.isNominal()){
int coverClass = Utils.maxIndex(classes[0]),
otherClass = Utils.maxIndex(classes[1]);
covered = computeAccu(data, coverClass);
other = computeAccu(otherData, otherClass);
}
else{
double coverClass = classes[0][0],
otherClass = classes[1][0];
covered = (data.sumOfWeights())*meanSquaredError(data, coverClass);
other = (otherData.sumOfWeights())*meanSquaredError(otherData, otherClass);
}
worthValue[x] = (covered + other)/total;
}
/* Prune the antecedents according to the accuracy parameters */
for(int z=(size-1); z > 0; z--){
// Treatment to avoid precision problems
double valueDelta;
if(m_ClassAttribute.isNominal()){
if(Utils.sm(worthValue[z], 1.0))
valueDelta = (worthValue[z] - worthValue[z-1]) / worthValue[z];
else
valueDelta = worthValue[z] - worthValue[z-1];
}
else{
if(Utils.sm(worthValue[z], 1.0))
valueDelta = (worthValue[z-1] - worthValue[z]) / worthValue[z];
else
valueDelta = (worthValue[z-1] - worthValue[z]);
}
if(Utils.smOrEq(valueDelta, 0.0)){
m_Antds.removeElementAt(z);
m_Targets.removeElementAt(z+1);
}
else break;
}
// Check whether this rule is a default rule
if(m_Antds.size() == 1){
double valueDelta;
if(m_ClassAttribute.isNominal()){
if(Utils.sm(worthValue[0], 1.0))
valueDelta = (worthValue[0] - defAccu) / worthValue[0];
else
valueDelta = (worthValue[0] - defAccu);
}
else{
if(Utils.sm(worthValue[0], 1.0))
valueDelta = (defAccu - worthValue[0]) / worthValue[0];
else
valueDelta = (defAccu - worthValue[0]);
}
if(Utils.smOrEq(valueDelta, 0.0)){
m_Antds.removeAllElements();
m_Targets.removeElementAt(1);
}
}
m_Cnsqt = ((double[][])(m_Targets.lastElement()))[0];
m_DefDstr = ((double[][])(m_Targets.lastElement()))[1];
}
/**
* Private function to compute number of accurate instances
* based on the specified predicted class
*
* @param data the data in question
* @param clas the predicted class
* @return the default accuracy number
*/
private double computeAccu(Instances data, int clas){
double accu = 0;
for(int i=0; i<data.numInstances(); i++){
Instance inst = data.instance(i);
if((int)inst.classValue() == clas)
accu += inst.weight();
}
return accu;
}
/**
* Private function to compute the squared error of
* the specified data and the specified mean
*
* @param data the data in question
* @param mean the specified mean
* @return the default mean-squared error
*/
private double meanSquaredError(Instances data, double mean){
if(Utils.eq(data.sumOfWeights(),0.0))
return 0;
double mSqErr=0, sum = data.sumOfWeights();
for(int i=0; i < data.numInstances(); i++){
Instance datum = data.instance(i);
mSqErr += datum.weight()*
(datum.classValue() - mean)*
(datum.classValue() - mean);
}
return (mSqErr / sum);
}
/**
* Prints this rule with the specified class label
*
* @param att the string standing for attribute in the consequent of this rule
* @param cl the string standing for value in the consequent of this rule
* @return a textual description of this rule with the specified class label
*/
public String toString(String att, String cl) {
StringBuffer text = new StringBuffer();
if(m_Antds.size() > 0){
for(int j=0; j< (m_Antds.size()-1); j++)
text.append("(" + ((Antd)(m_Antds.elementAt(j))).toString()+ ") and ");
text.append("("+((Antd)(m_Antds.lastElement())).toString() + ")");
}
text.append(" => " + att + " = " + cl);
return text.toString();
}
/**
* Prints this rule
*
* @return a textual description of this rule
*/
public String toString() {
String title =
"\n\nSingle conjunctive rule learner:\n"+
"--------------------------------\n", body = null;
StringBuffer text = new StringBuffer();
if(m_ClassAttribute != null){
if(m_ClassAttribute.isNominal()){
body = toString(m_ClassAttribute.name(), m_ClassAttribute.value(Utils.maxIndex(m_Cnsqt)));
text.append("\n\nClass distributions:\nCovered by the rule:\n");
for(int k=0; k < m_Cnsqt.length; k++)
text.append(m_ClassAttribute.value(k)+ "\t");
text.append('\n');
for(int l=0; l < m_Cnsqt.length; l++)
text.append(Utils.doubleToString(m_Cnsqt[l], 6)+"\t");
text.append("\n\nNot covered by the rule:\n");
for(int k=0; k < m_DefDstr.length; k++)
text.append(m_ClassAttribute.value(k)+ "\t");
text.append('\n');
for(int l=0; l < m_DefDstr.length; l++)
text.append(Utils.doubleToString(m_DefDstr[l], 6)+"\t");
}
else
body = toString(m_ClassAttribute.name(), Utils.doubleToString(m_Cnsqt[0], 6));
}
return (title + body + text.toString());
}
/**
* Main method.
*
* @param args the options for the classifier
*/
public static void main(String[] args) {
try {
System.out.println(Evaluation.evaluateModel(new ConjunctiveRule(), args));
} catch (Exception e) {
e.printStackTrace();
System.err.println(e.getMessage());
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -