aodesr.java

来自「Weka」· Java 代码 · 共 916 行 · 第 1/2 页

JAVA
916
字号
     // calculate probabilities for each possible class value    for(int classVal = 0; classVal < m_NumClasses; classVal++) {        probs[classVal] = 0;       double x = 0;       parentCount = 0;        countsForClass = m_CondiCounts[classVal];       // each attribute has a turn of being the parent       for(int parent = 0; parent < m_NumAttributes; parent++) {          if(attIndex[parent] == -1)             continue;  // skip class attribute or missing value          // determine correct index for the parent in m_CondiCounts matrix          pIndex = attIndex[parent];          // check that the att value has a frequency of m_Limit or greater	  if(m_Frequencies[pIndex] < m_Limit)              continue;                    // delete the generalization attributes.          if(SpecialGeneralArray[parent] != -1)             continue;          countsForClassParent = countsForClass[pIndex];          // block the parent from being its own child          attIndex[parent] = -1;          parentCount++;          double classparentfreq = countsForClassParent[pIndex];          // find the number of missing values for parent's attribute          double missing4ParentAtt =             m_Frequencies[m_StartAttIndex[parent] + m_NumAttValues[parent]];          // calculate the prior probability -- P(parent & classVal)           if (m_Laplace){             x = LaplaceEstimate(classparentfreq, m_SumInstances - missing4ParentAtt,                                     m_NumClasses * m_NumAttValues[parent]);          } else {                       x = MEstimate(classparentfreq, m_SumInstances - missing4ParentAtt,                                     m_NumClasses * m_NumAttValues[parent]);          }              // take into account the value of each attribute          for(int att = 0; att < m_NumAttributes; att++) {             if(attIndex[att] == -1) // skip class attribute or missing value                continue;             // delete the generalization attributes.             if(SpecialGeneralArray[att] != -1)                continue;                          double missingForParentandChildAtt =                       countsForClassParent[m_StartAttIndex[att] + m_NumAttValues[att]];             if (m_Laplace){                x *= LaplaceEstimate(countsForClassParent[attIndex[att]],                     classparentfreq - missingForParentandChildAtt, m_NumAttValues[att]);             } else {                x *= MEstimate(countsForClassParent[attIndex[att]],                     classparentfreq - missingForParentandChildAtt, m_NumAttValues[att]);             }          }          // add this probability to the overall probability          probs[classVal] += x;           // unblock the parent          attIndex[parent] = pIndex;       }        // check that at least one att was a parent       if(parentCount < 1) {          // do plain naive bayes conditional prob	  probs[classVal] = NBconditionalProb(instance, classVal);          //probs[classVal] = Double.NaN;       } else {           // divide by number of parent atts to get the mean          probs[classVal] /= (double)(parentCount);       }    }    Utils.normalize(probs);    return probs;  }  /**   * Calculates the probability of the specified class for the given test   * instance, using naive Bayes.   *   * @param instance the instance to be classified   * @param classVal the class for which to calculate the probability   * @return predicted class probability   * @throws Exception if there is a problem generating the prediction   */  public double NBconditionalProb(Instance instance, int classVal)                                                     throws Exception {    double prob;    int attIndex;    double [][] pointer;    // calculate the prior probability    if(m_Laplace) {       prob = LaplaceEstimate(m_ClassCounts[classVal],m_SumInstances,m_NumClasses);     } else {       prob = MEstimate(m_ClassCounts[classVal], m_SumInstances, m_NumClasses);    }    pointer = m_CondiCounts[classVal];        // consider effect of each att value    for(int att = 0; att < m_NumAttributes; att++) {       if(att == m_ClassIndex || instance.isMissing(att))          continue;              // determine correct index for att in m_CondiCounts       attIndex = m_StartAttIndex[att] + (int)instance.value(att);       if (m_Laplace){         prob *= LaplaceEstimate((double)pointer[attIndex][attIndex],                    (double)m_SumForCounts[classVal][att], m_NumAttValues[att]);       } else {           prob *= MEstimate((double)pointer[attIndex][attIndex],                    (double)m_SumForCounts[classVal][att], m_NumAttValues[att]);       }    }    return prob;  }  /**   * Returns the probability estimate, using m-estimate   *   * @param frequency frequency of value of interest   * @param total count of all values   * @param numValues number of different values   * @return the probability estimate   */  public double MEstimate(double frequency, double total,                          double numValues) {        return (frequency + m_MWeight / numValues) / (total + m_MWeight);  }     /**   * Returns the probability estimate, using laplace correction   *   * @param frequency frequency of value of interest   * @param total count of all values   * @param numValues number of different values   * @return the probability estimate   */  public double LaplaceEstimate(double frequency, double total,                                double numValues) {        return (frequency + 1.0) / (total + numValues);  }         /**   * Returns an enumeration describing the available options   *   * @return an enumeration of all the available options   */  public Enumeration listOptions() {    Vector newVector = new Vector(5);            newVector.addElement(       new Option("\tOutput debugging information\n",                  "D", 0,"-D"));    newVector.addElement(       new Option("\tImpose a critcal value for specialization-generalization relationship\n"                  + "\t(default is 50)", "C", 1,"-C"));    newVector.addElement(       new Option("\tImpose a frequency limit for superParents\n"                  + "\t(default is 1)", "F", 2,"-F"));    newVector.addElement(       new Option("\tUsing Laplace estimation\n"                  + "\t(default is m-esimation (m=1))",                  "L", 3,"-L"));    newVector.addElement(       new Option("\tWeight value for m-estimation\n"                  + "\t(default is 1.0)", "M", 4,"-M"));    return newVector.elements();  }  /**   * Parses a given list of options. <p/>   *   <!-- options-start -->   * Valid options are:<p/>   *   * <pre> -D   *  Output debugging information   * </pre>   *    * <pre> -F &lt;int&gt;   *  Impose a frequency limit for superParents   *  (default is 1)</pre>   *   * <pre> -L   *  Use Laplace estimation   *  (default is m-estimation)</pre>   *   * <pre> -M &lt;double&gt;   *  Specify the m value of m-estimation   *  (default is 1)</pre>   *   * <pre>-C &lt;int&gt;   *  Specify critical value for specialization-generalization.   *  (default is 50).   *  Larger values than the default of 50 substantially reduce   *  the risk of incorrectly inferring that one value subsumes   *  another, but also reduces the number of true subsumptions   *  that are detected.</pre>   *   <!-- options-end -->   *   * @param options the list of options as an array of strings   * @throws Exception if an option is not supported   */  public void setOptions(String[] options) throws Exception {    m_Debug = Utils.getFlag('D', options);    String Critical = Utils.getOption('C', options);    if(Critical.length() != 0)        m_Critical = Integer.parseInt(Critical);    else       m_Critical = 50;        String Freq = Utils.getOption('F', options);    if(Freq.length() != 0)        m_Limit = Integer.parseInt(Freq);    else       m_Limit = 1;        m_Laplace = Utils.getFlag('L', options);    String MWeight = Utils.getOption('M', options);     if(MWeight.length() != 0) {       if(m_Laplace)          throw new Exception("weight for m-estimate is pointless if using laplace estimation!");       m_MWeight = Double.parseDouble(MWeight);    } else       m_MWeight = 1.0;        Utils.checkForRemainingOptions(options);  }      /**   * Gets the current settings of the classifier.   *   * @return an array of strings suitable for passing to setOptions   */  public String [] getOptions() {            Vector result  = new Vector();    if (m_Debug)       result.add("-D");    result.add("-F");    result.add("" + m_Limit);    if (m_Laplace) {       result.add("-L");    } else {       result.add("-M");       result.add("" + m_MWeight);    }            result.add("-C");    result.add("" + m_Critical);    return (String[]) result.toArray(new String[result.size()]);  }   /**   * Returns the tip text for this property   * @return tip text for this property suitable for   * displaying in the explorer/experimenter gui   */  public String mestWeightTipText() {    return "Set the weight for m-estimate.";  }  /**   * Sets the weight for m-estimate   *   * @param w the weight   */  public void setMestWeight(double w) {    if (getUseLaplace()) {       System.out.println(          "Weight is only used in conjunction with m-estimate - ignored!");    } else {      if(w > 0)         m_MWeight = w;      else         System.out.println("M-Estimate Weight must be greater than 0!");    }  }  /**   * Gets the weight used in m-estimate   *   * @return the weight for m-estimation   */  public double getMestWeight() {    return m_MWeight;  }  /**   * Returns the tip text for this property   * @return tip text for this property suitable for   * displaying in the explorer/experimenter gui   */  public String useLaplaceTipText() {    return "Use Laplace correction instead of m-estimation.";  }  /**   * Gets if laplace correction is being used.   *   * @return Value of m_Laplace.   */  public boolean getUseLaplace() {    return m_Laplace;  }  /**   * Sets if laplace correction is to be used.   *   * @param value Value to assign to m_Laplace.   */  public void setUseLaplace(boolean value) {    m_Laplace = value;  }  /**   * Returns the tip text for this property   * @return tip text for this property suitable for   * displaying in the explorer/experimenter gui   */  public String frequencyLimitTipText() {    return "Attributes with a frequency in the train set below "           + "this value aren't used as parents.";  }  /**   * Sets the frequency limit   *   * @param f the frequency limit   */  public void setFrequencyLimit(int f) {    m_Limit = f;  }  /**   * Gets the frequency limit.   *   * @return the frequency limit   */  public int getFrequencyLimit() {    return m_Limit;  }  /**   * Returns the tip text for this property   * @return tip text for this property suitable for   * displaying in the explorer/experimenter gui   */  public String criticalValueTipText() {    return "Specify critical value for specialization-generalization "           + "relationship (default 50).";  }  /**   * Sets the critical value   *   * @param c the critical value   */  public void setCriticalValue(int c) {    m_Critical = c;  }  /**   * Gets the critical value.   *   * @return the critical value   */  public int getCriticalValue() {    return m_Critical;  }  /**   * Returns a description of the classifier.   *   * @return a description of the classifier as a string.   */  public String toString() {     StringBuffer text = new StringBuffer();            text.append("The AODEsr Classifier");    if (m_Instances == null) {       text.append(": No model built yet.");    } else {       try {          for (int i = 0; i < m_NumClasses; i++) {             // print to string, the prior probabilities of class values             text.append("\nClass " + m_Instances.classAttribute().value(i) +                       ": Prior probability = " + Utils.                          doubleToString(((m_ClassCounts[i] + 1)                       /(m_SumInstances + m_NumClasses)), 4, 2)+"\n\n");          }                          text.append("Dataset: " + m_Instances.relationName() + "\n"                      + "Instances: " + m_NumInstances + "\n"                      + "Attributes: " + m_NumAttributes + "\n"		      + "Frequency limit for superParents: " + m_Limit + "\n"                      + "Critical value for the specializtion-generalization "                      + "relationship: " + m_Critical + "\n");          if(m_Laplace) {            text.append("Using LapLace estimation.");          } else {              text.append("Using m-estimation, m = " + m_MWeight);           }       } catch (Exception ex) {          text.append(ex.getMessage());       }    }    return text.toString();  }          /**   * Main method for testing this class.   *   * @param argv the options   */  public static void main(String [] argv) {    runClassifier(new AODEsr(), argv);  }}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?