⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 conjunctiverule.java

📁 MacroWeka扩展了著名数据挖掘工具weka
💻 JAVA
📖 第 1 页 / 共 4 页
字号:
	  else
	    if(att.isNumeric())
	      antd = new NumericAntd(att, uncoveredWtSq, uncoveredWtVl, uncoveredWts);
	    else
	      antd = new NominalAntd(att, uncoveredWtSq, uncoveredWtVl, uncoveredWts); 
		    
	  if(!used[index]){
	    /* Compute the best information gain for each attribute,
	       it's stored in the antecedent formed by this attribute.
	       This procedure returns the data covered by the antecedent*/
	    Instances[] coveredData = computeInfoGain(growData, defInfo, antd);	 
			
	    if(coveredData != null){
	      double infoGain = antd.getMaxInfoGain();			
	      boolean isUpdate = Utils.gr(infoGain, maxInfoGain);
			    
	      if(isUpdate){
		oneAntd=antd;
		coverData = coveredData[0]; 
		uncoverData = coveredData[1];  
		maxInfoGain = infoGain;		    
	      }
	    }
	  }
	}
		
	if(oneAntd == null) 		
	  break;	    
		
	//Numeric attributes can be used more than once
	if(!oneAntd.getAttr().isNumeric()){ 
	  used[oneAntd.getAttr().index()]=true;
	  numUnused--;
	}
		
	m_Antds.addElement(oneAntd);
	growData = coverData;// Grow data size is shrinking 	    
		
	for(int x=0; x < uncoverData.numInstances(); x++){
	  Instance datum = uncoverData.instance(x);
	  if(m_ClassAttribute.isNumeric()){
	    uncoveredWtSq += datum.weight() * datum.classValue() * datum.classValue();
	    uncoveredWtVl += datum.weight() * datum.classValue();
	    uncoveredWts += datum.weight();
	    classDstr[0][0] -= datum.weight() * datum.classValue();
	    classDstr[1][0] += datum.weight() * datum.classValue();
	  }
	  else{
	    classDstr[0][(int)datum.classValue()] -= datum.weight();
	    classDstr[1][(int)datum.classValue()] += datum.weight();
	  }
	}	       
		
	// Store class distribution of growing data
	tmp = new double[2][m_NumClasses];
	for(int y=0; y < m_NumClasses; y++){
	  if(m_ClassAttribute.isNominal()){	
	    tmp[0][y] = classDstr[0][y];
	    tmp[1][y] = classDstr[1][y];
	  }
	  else{
	    tmp[0][y] = classDstr[0][y]/(whole-uncoveredWts);
	    tmp[1][y] = classDstr[1][y]/uncoveredWts;
	  }
	}
	m_Targets.addElement(tmp);  
		
	defInfo = oneAntd.getInfo();
	int numAntdsThreshold = (m_NumAntds == -1) ? Integer.MAX_VALUE : m_NumAntds;

	if(Utils.eq(growData.sumOfWeights(), 0.0) || 
	   (numUnused == 0) ||
	   (m_Antds.size() >= numAntdsThreshold))
	  isContinue = false;
      }
    }
	
    m_Cnsqt = ((double[][])(m_Targets.lastElement()))[0];
    m_DefDstr = ((double[][])(m_Targets.lastElement()))[1];	
  }
    
  /** 
   * Compute the best information gain for the specified antecedent
   *  
   * @param data the data based on which the infoGain is computed
   * @param defInfo the default information of data
   * @param antd the specific antecedent
   * @return the data covered and not covered by the antecedent
   */
  private Instances[] computeInfoGain(Instances instances, double defInfo, Antd antd){	
    Instances data = new Instances(instances);
	
    /* Split the data into bags.
       The information gain of each bag is also calculated in this procedure */
    Instances[] splitData = antd.splitData(data, defInfo); 
    Instances[] coveredData = new Instances[2];
	
    /* Get the bag of data to be used for next antecedents */
    Instances tmp1 = new Instances(data, 0);
    Instances tmp2 = new Instances(data, 0);
	
    if(splitData == null)	    
      return null;
	
    for(int x=0; x < (splitData.length-1); x++){
      if(x == ((int)antd.getAttrValue()))
	tmp1 = splitData[x];
      else{		
	for(int y=0; y < splitData[x].numInstances(); y++)
	  tmp2.add(splitData[x].instance(y));		    
      }		
    }
	
    if(antd.getAttr().isNominal()){ // Nominal attributes
      if(((NominalAntd)antd).isIn()){ // Inclusive expression
	coveredData[0] = new Instances(tmp1);
	coveredData[1] = new Instances(tmp2);
      }
      else{                           // Exclusive expression
	coveredData[0] = new Instances(tmp2);
	coveredData[1] = new Instances(tmp1);
      }	
    }
    else{                           // Numeric attributes
      coveredData[0] = new Instances(tmp1);
      coveredData[1] = new Instances(tmp2);
    }
	
    /* Add data with missing value */
    for(int z=0; z<splitData[splitData.length-1].numInstances(); z++)
      coveredData[1].add(splitData[splitData.length-1].instance(z));
	
    return coveredData;
  }
    
  /**
   * Prune the rule using the pruning data.
   * The weighted average of accuracy rate/mean-squared error is 
   * used to prune the rule.
   *
   * @param pruneData the pruning data used to prune the rule
   */    
  private void prune(Instances pruneData){
    Instances data=new Instances(pruneData);
    Instances otherData = new Instances(data, 0);
    double total = data.sumOfWeights();
	
    /* The default accurate# and the the accuracy rate on pruning data */
    double defAccu;
    if(m_ClassAttribute.isNumeric())
      defAccu = meanSquaredError(pruneData,
				 ((double[][])m_Targets.firstElement())[0][0]);
    else{
      int predict = Utils.maxIndex(((double[][])m_Targets.firstElement())[0]);
      defAccu = computeAccu(pruneData, predict)/total;
    }
	
    int size=m_Antds.size();
    if(size == 0){
      m_Cnsqt = ((double[][])m_Targets.lastElement())[0];
      m_DefDstr = ((double[][])m_Targets.lastElement())[1];
      return; // Default rule before pruning
    }
	
    double[] worthValue = new double[size];
	
    /* Calculate accuracy parameters for all the antecedents in this rule */
    for(int x=0; x<size; x++){
      Antd antd=(Antd)m_Antds.elementAt(x);
      Attribute attr= antd.getAttr();
      Instances newData = new Instances(data);
      if(Utils.eq(newData.sumOfWeights(),0.0))
	break;
	    
      data = new Instances(newData, newData.numInstances()); // Make data empty
	    
      for(int y=0; y<newData.numInstances(); y++){
	Instance ins=newData.instance(y);
	if(antd.isCover(ins))              // Covered by this antecedent
	  data.add(ins);                 // Add to data for further 		
	else
	  otherData.add(ins);            // Not covered by this antecedent
      }
	    
      double covered, other;	    
      double[][] classes = 
	(double[][])m_Targets.elementAt(x+1); // m_Targets has one more element
      if(m_ClassAttribute.isNominal()){		
	int coverClass = Utils.maxIndex(classes[0]),
	  otherClass = Utils.maxIndex(classes[1]);
		
	covered = computeAccu(data, coverClass); 
	other = computeAccu(otherData, otherClass);		
      }
      else{
	double coverClass = classes[0][0],
	  otherClass = classes[1][0];
	covered = (data.sumOfWeights())*meanSquaredError(data, coverClass); 
	other = (otherData.sumOfWeights())*meanSquaredError(otherData, otherClass);
      }
	    
      worthValue[x] = (covered + other)/total;
    }
	
    /* Prune the antecedents according to the accuracy parameters */
    for(int z=(size-1); z > 0; z--){	
      // Treatment to avoid precision problems
      double valueDelta;
      if(m_ClassAttribute.isNominal()){
	if(Utils.sm(worthValue[z], 1.0))
	  valueDelta = (worthValue[z] - worthValue[z-1]) / worthValue[z];
	else
	  valueDelta = worthValue[z] - worthValue[z-1];
      }
      else{
	if(Utils.sm(worthValue[z], 1.0))
	  valueDelta = (worthValue[z-1] - worthValue[z]) / worthValue[z];
	else
	  valueDelta = (worthValue[z-1] - worthValue[z]);
      }
	    
      if(Utils.smOrEq(valueDelta, 0.0)){	
	m_Antds.removeElementAt(z);
	m_Targets.removeElementAt(z+1);
      }
      else  break;
    }	
	
    // Check whether this rule is a default rule
    if(m_Antds.size() == 1){
      double valueDelta;
      if(m_ClassAttribute.isNominal()){
	if(Utils.sm(worthValue[0], 1.0))
	  valueDelta = (worthValue[0] - defAccu) / worthValue[0];
	else
	  valueDelta = (worthValue[0] - defAccu);
      }
      else{
	if(Utils.sm(worthValue[0], 1.0))
	  valueDelta = (defAccu - worthValue[0]) / worthValue[0];
	else
	  valueDelta = (defAccu - worthValue[0]);
      }
	    
      if(Utils.smOrEq(valueDelta, 0.0)){
	m_Antds.removeAllElements();
	m_Targets.removeElementAt(1);
      }
    }
	
    m_Cnsqt = ((double[][])(m_Targets.lastElement()))[0];
    m_DefDstr = ((double[][])(m_Targets.lastElement()))[1];
  }
    
  /**
   * Private function to compute number of accurate instances
   * based on the specified predicted class
   * 
   * @param data the data in question
   * @param clas the predicted class
   * @return the default accuracy number
   */
  private double computeAccu(Instances data, int clas){ 
    double accu = 0;
    for(int i=0; i<data.numInstances(); i++){
      Instance inst = data.instance(i);
      if((int)inst.classValue() == clas)
	accu += inst.weight();
    }
    return accu;
  }
    

  /**
   * Private function to compute the squared error of
   * the specified data and the specified mean
   * 
   * @param data the data in question
   * @param mean the specified mean
   * @return the default mean-squared error
   */
  private double meanSquaredError(Instances data, double mean){ 
    if(Utils.eq(data.sumOfWeights(),0.0))
      return 0;
	
    double mSqErr=0, sum = data.sumOfWeights();
    for(int i=0; i < data.numInstances(); i++){
      Instance datum = data.instance(i);
      mSqErr += datum.weight()*
	(datum.classValue() - mean)*
	(datum.classValue() - mean);
    }	 
	
    return (mSqErr / sum);
  }
     
  /**
   * Prints this rule with the specified class label
   *
   * @param att the string standing for attribute in the consequent of this rule
   * @param cl the string standing for value in the consequent of this rule
   * @return a textual description of this rule with the specified class label
   */
  public String toString(String att, String cl) {
    StringBuffer text =  new StringBuffer();
    if(m_Antds.size() > 0){
      for(int j=0; j< (m_Antds.size()-1); j++)
	text.append("(" + ((Antd)(m_Antds.elementAt(j))).toString()+ ") and ");
      text.append("("+((Antd)(m_Antds.lastElement())).toString() + ")");
    }
    text.append(" => " + att + " = " + cl);
	
    return text.toString();
  }
    
  /**
   * Prints this rule
   *
   * @return a textual description of this rule
   */
  public String toString() {
    String title = 
      "\n\nSingle conjunctive rule learner:\n"+
      "--------------------------------\n", body = null;
    StringBuffer text =  new StringBuffer();
    if(m_ClassAttribute != null){
      if(m_ClassAttribute.isNominal()){
	body = toString(m_ClassAttribute.name(), m_ClassAttribute.value(Utils.maxIndex(m_Cnsqt)));
		
	text.append("\n\nClass distributions:\nCovered by the rule:\n");
	for(int k=0; k < m_Cnsqt.length; k++)
	  text.append(m_ClassAttribute.value(k)+ "\t");
	text.append('\n');
	for(int l=0; l < m_Cnsqt.length; l++)
	  text.append(Utils.doubleToString(m_Cnsqt[l], 6)+"\t");
		
	text.append("\n\nNot covered by the rule:\n");
	for(int k=0; k < m_DefDstr.length; k++)
	  text.append(m_ClassAttribute.value(k)+ "\t");
	text.append('\n');
	for(int l=0; l < m_DefDstr.length; l++)
	  text.append(Utils.doubleToString(m_DefDstr[l], 6)+"\t");	    
      }
      else
	body = toString(m_ClassAttribute.name(), Utils.doubleToString(m_Cnsqt[0], 6));
    }
    return (title + body + text.toString());
  }
    
  /**
   * Main method.
   *
   * @param args the options for the classifier
   */
  public static void main(String[] args) {	
    try {
      System.out.println(Evaluation.evaluateModel(new ConjunctiveRule(), args));
    } catch (Exception e) {
      e.printStackTrace();
      System.err.println(e.getMessage());
    }
  }
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -