⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 naivebayescat.java

📁 Naive Bayes算法java代码
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
    * The Kullback Leibler distance is defined as:    * sum over all x  p(x) log(p(x)/q(x))    * We assume that no p(x) or q(x) are zero    * @param p - the first array    * @param q - the second array.    * @return - the computer distance between p and q.    */  public static double kullback_leibler_distance(double[] p, double[] q) {    if(p.length != q.length) {      Error.fatalErr("kullback_leibler_distance: p and q arrays have different " +	               "sizes");    }    // @@ The KL distance needs to be fixed so that p can be zero    // @@ (real-world prob).  In that case p*log(p/q) = 0, which    // @@ we need to special case.  Ronny    double sum = 0;    double sump = 0;    double sumq = 0;    for(int i=0; i<p.length; i++) {	 if(p[i] <= 0) {	    Error.fatalErr("p(" + i + ") <= zero");       }	 if(q[i] <= 0) {	    Error.fatalErr("q(" + i + ") <= zero");       }	 sum += (p[i] * MLJ.log_bin(p[i] / q[i]));	 sump += p[i];
	 sumq += q[i];
    }
    MLJ.verify_approx_equal(sump, (double)(1), "kullback_leibler_distance:: " +
			   "sum of p doesn't add to 1");
    MLJ.verify_approx_equal(sumq, (double)(1), "kullback_leibler_distance:: " +
			   "sum of q doesn't add to 1");
    if (sum < 0) {
      Error.err("kullback_leibler_distance: sum < 0: " + sum + endl);
    }
    return sum;
  }  /** Prints detail probability values with attribute names.    * Used when logging maximum detail in categorize().    */  private void log_prob(double[] prob, Schema instSchema) throws IOException {
    for (int i = 0; i <= prob.length; i++) {
      logOptions.get_log_stream().write("Prob=" + prob[i] +
                             " for label " + i + " (" +
	                       instSchema.category_to_label_string(i) + ')' + endl);
      logOptions.get_log_stream().write(endl);
    }// <--this might not go here

  }  /** Check state of object after training.  Checks
    *    1) BagCounter is ok, and
    *    2) the number of test cases is > 0, and
    *    3) that there are no variances = 0.    */  public void OK(int level) {
//e55    nominCounts.OK(level);
   
    if (nominCounts.label_counts().length < 2 ) { // 1 comes free for unkn
      Error.fatalErr("NaiveBayesCat.OK: BagCounter has less than 2 labels");
    }
    MLJ.verify_strictly_greater(trainWeight, 0, "NaiveBayesCat.OK: total training weight " +
                                "must be strictly greater than zero");
    if ( continNorm != null ) {
      int labelNumVals = nominCounts.label_counts().length - 1;
      for (int attrVal = 0; attrVal < numAttributes; attrVal++) {
        if ( nominCounts.value_counts()[attrVal] == null ) { // continuous //JWP if this doesn't work use extra variable!!!
          for (int labelVal = -1; labelVal < labelNumVals; labelVal++) {
            NBNorm nbn = continNorm[attrVal][labelVal];
//            NBNorm nbn = (*continNorm)(attrVal,labelVal); //what the heck is this???
            if ( !nbn.hasData ) {
	        Error.fatalErr("NaiveBayesCat.OK: Normal Distribution data " +
	          "missing for (label, attribute) (" + labelVal + ", " + attrVal + "). ");
            }
            if ( nbn.var<=0 ) {
              Error.fatalErr("NaiveBayesCat.OK: Varience must be > 0 for continuous attributes.  Varience = for " +
                 "(label, attribute) (" + labelVal + ", " + attrVal + "). ");
            }
          }
        }
      }
    }
  }  /** Rescale the probabilities so that the highest is 1
    * This is to avoid underflows.    * @param prob - an array of probabilities to be scaled.    * @report Possbile bug; function prob will not hold the values on return.    */  static void rescale_prob(double[] prob) {
    double maxVal = findMax(prob);
    // Performance shows this is time critical
    // Avoid division by zero.  Happens on shuttle-small.
    if (maxVal > 0) {
      // int numLabels = prob.length; moved inside for
      for (int labelVal = 0; labelVal < prob.length; labelVal++) {
        prob[labelVal] /= maxVal;
        MLJ.ASSERT( (prob[labelVal] >= 0), "NaiveBayesCat.rescale_prob()");
      }
    }
  }  /** Returns a category given an instance by checking all 
    * attributes in schema and returning category with highest
    * relative probability.
    * The relative probability is being estimated for each label.
    * The label with the highest values is the category returned.
    * The probability for a given label is
    * P(Nominal Attributes)*P(Continuous Attributes)
    * Since the probability is a product,  we can factor out any
    * constants that will be multiplied times every label, since
    * this will not change the ordering of labels.
    * P(Continuous Attribute Value X) is caculated using the normal
    * density: 
    * Normal(X) = 1/(sqrt(2*pi)*std-dev)*exp((-1/2)*(X-mean)^2/var)
    * This calculation can be stripped of the constant
    * (sqrt(2*pi)) without changing the outcome.
    * P(Nominal Attributes) is calculated as the percentage of a
    * label's training set that had the test instance's value
    * for a each attribute.
    * The majority label is returned if all are equal.
    * See this file's header for more information.    * @param instance - the instance to be scored.    */  public CatDist score(Instance instance) {
    int attrNum;
    int labelVal;
    Schema Schema = get_schema();
    int labelNumVals = get_schema().num_label_values();
    // starts at -1 for the unknown category
    double[] prob = new double[labelNumVals + 1];
    // Sanity check: the number of attributes and the number of labels of
    // the training set of the categorizer should correspond with the Schema
    // of the instance being categorized.

    //?OK();
    logOptions.LOG(3, "Instance to categorize: ");
    //IFLOG(3,instance.display_unlabelled(logOptions.get_log_stream())); 
    logOptions.LOG(3, endl);

    int trainLabelCount = nominCounts.label_counts().length - 1;
    if ( labelNumVals != trainLabelCount || numAttributes != get_schema().num_attr()) {
      Error.fatalErr("NaiveBayesCat.categorize: Schema of instance to be categorized does not match Schema of training set.");
    }
    init_class_prob(nominCounts, trainWeight, prob, useLaplace, useEvidenceProjection, evidenceFactor);
    logOptions.LOG(4, "Initial class probabilities" + endl);
    //IFLOG(4, this.log_prob(prob, schema));

    // compute kl distances here if needed.  Break constness to keep
    // categorize() a logically const function.
    // We compute these here instead of in the constructor because we don't
    // know if we're using unknown auto mode until we actually categorize.
    if(unknownIsValue == unknownAuto && unkIsVal == null) {
      (this).compute_kl_distances();
    }
    // loop through each attribute in instance:
    for (attrNum=0; attrNum < numAttributes; attrNum++) {
      AttrInfo ai = get_schema().attr_info(attrNum);
      if(!ai.can_cast_to_nominal() && unknownIsValue != unknownNo) {
	 Error.fatalErr("NaiveBayesCat.categorize: UNKNOWN_IS_VALUE is set and " +
	    ai.name() + " is a real value with unknowns.  UNKNOWN_IS_VALUE settings of " +
	    "yes and auto are not supported for undiscretized real values " +
	    "with unknowns.");
      }
	 
      // determine whether or not to treat unknowns as values for this
      // attribute.
      boolean useUnknowns = false;
      if(unknownIsValue == unknownYes) {
	  useUnknowns = true;
      }
      else if(unknownIsValue == unknownAuto) {
	  MLJ.ASSERT(unkIsVal[attrNum], "NaiveBayesCat.score(): unkIsVal["+attrNum+"]");
	  if(unkIsVal[attrNum]) {
	    useUnknowns = true;
        }
      }
      
      if (!useUnknowns && ai.is_unknown(instance.get_value(attrNum))) {
	  logOptions.LOG(4, "Skipping unknown value for attribute " + attrNum + endl);
      }
      else {
	 // continuous attr
	 if ( nominCounts.value_counts()[attrNum] == null ) {
	    logOptions.LOG(4, endl + "Continuous Attribute " + attrNum + endl);

	    MLJ.ASSERT( continNorm != null, "NaiveBayesCat.score(): continNorm");
	    if ( !ai.can_cast_to_real() ) {
	       Error.fatalErr("NaiveBayesCat.categorize: Schema of instance to be " +
		  "categorized does not match Schema of training set. " +
		  "Attribute Number " + attrNum + " is continuous in training " +
		  "set and nominal in instance schema.");
          }
	    double realVal = ai.get_real_val(instance.values[attrNum]);
	    for (labelVal = 0; labelVal < prob.length; labelVal++) {
	       NBNorm nbn = continNorm[attrNum][labelVal];
	       double distToMean = realVal - nbn.mean;
	       double stdDev = Math.sqrt(nbn.var);
	       double e2The = Math.exp( -1 * distToMean * distToMean / (2*nbn.var) );
	       prob[labelVal] *= e2The / stdDev;
		  
	       logOptions.LOG(5, " P(" + labelVal + "): times " + e2The / stdDev
		   + ", X = " + realVal + ",  Probability so far = " +
	         prob[labelVal] + endl);
	    }
	 }
	 else { // nominal attribute
	    logOptions.LOG(4, endl + "Nominal Attribute " + attrNum + endl);
	    if ( ! ai.can_cast_to_nominal() ) {
	       Error.fatalErr("NaiveBayesCat.categorize: Schema of instance to be " +
		  "categorized does not match Schema of training set. " +
		  "Attribute Number " + attrNum + " is nominal in training " +
		  "set and continuous in instance schema.");
          }
	    NominalAttrInfo nai = ai.cast_to_nominal();
	    int nomVal = nai.get_nominal_val(instance.get_value(attrNum));

	    double[] estProb = new double[prob.length];
          for (int i = 0; i < estProb.length; i++) {
            estProb[i] = 0.0;
          }	    

	    // The value should never be out of range of the BagCounter.
	    // Even in a non-fixed value set, this value should have
	    // been converted into an unknown during reading/assimilation.
	    if(nomVal > nominCounts.attr_counts()[attrNum].length-1 /*[nominCounts.attr_counts()[attrNum].length - 1]*/ ) {
	       Error.fatalErr("Value for attribute " + attrNum + " is out of " +
		  "range.  This indicates that this instance was not " +
		  "correctly assimilated into the training schema.");
          }
	    else {
	      // loop through each label val, updating cumulative vector
	      // The generate_cond_probability function provides several
	      // options for handling zero counts here.
	      for (labelVal = 0;labelVal < prob.length;labelVal++) {
		  // include unknowns in the label count: (langley didn't)
              double labelCount = nominCounts.label_count(labelVal);
              double labelMatch = nominCounts.val_count(labelVal, attrNum, nomVal);
              // If there are Nulls, we have to account for them
              //   in the laplace correction, or else the sum of
              //   the probabilities don't sum up to 1.  
              //   A failure will occur in t_DFInducer.c
              boolean hasNulls = (nominCounts.attr_count(attrNum, Globals.UNKNOWN_CATEGORY_VAL) > 0);
		  estProb[labelVal] = generate_cond_probability(labelMatch, labelCount, nai.num_values() + (hasNulls?1:0), numAttributes);
	      }

	      // The evidence projection algorithm produces unnormalized
	      // probabilities.  If we're using this algorithm, normalize
	      // here.
	      if(useEvidenceProjection) {
              double sum = sumArray(estProb);
              for(labelVal = 0; labelVal < prob.length; labelVal++) {
		    estProb[labelVal] /= sum;
              }
	      }
          }

	    // accumulate probabilities
	    for(int i=0; i < estProb.length; i++) {
	       prob[i] *= estProb[i];
	       logOptions.LOG(4, "P | L=" + i + " = " + estProb[i] + ".  Cumulative prob = " + prob[i] + endl);
	    }
	  }

	  // Since these are unscaled relative probabilities, we rescale them
	  //   to 0-1, so that we don't get underflows
	  this.rescale_prob(prob);
	  logOptions.LOG(4, "Relative probabilities after scaling: " + endl);
	  //IFLOG(4, this.log_prob(prob, schema));
      } // if unknown
    }    // place the probabilities into a CatDist
    CatDist retDist = new CatDist(get_schema(), prob, CatDist.none, 1.0);
    // set the tiebreaking ordering to favor label values with more
    // instances (higher label counts)
    //JWP retDist.set_tiebreaking_by_values(nominCounts.label_counts());
       return retDist;
  }  /** set m value for L'aplace correction.
    * @param m - the new m-estimate factor.    */  public void set_m_estimate_factor(double m) {
    if (m < 0) {
      Error.fatalErr("NaiveBayesCat.set_m_estimate_factor() : illegal m_estimate_" +
	 "factor value : " + m);
    }
    else {
      mEstimateFactor = m;
    }
  }  public void set_no_matches_factor(double nm) { noMatchesFactor = nm; }
  /** set_unknown_is_value sets the value of unknownIsValue. the variable unknownIsValue must be either
    * 1, 2, or 3. Any other value fails and gives an Error.
    * @param unk - the new value of unknownIsValue
    */
  public void set_unknown_is_value(int unk) {
    if (unk >= 1 && unk <= 3) {
      unknownIsValue = unk;
    }
    else {
      Error.err("NaiveBayesCat.set_unknown_is_value(): unknownIsValue cannot be " + unk);
    }
  }
  public void set_use_evidence_projection(boolean b) { useEvidenceProjection = b;}
  public void set_use_laplace(boolean lap) { useLaplace = lap; }
  /** sumArray() adds the values of all the ellements in the given array
    * @param d[] and array of doubles to add
    * @return the sum of the doubles
    */
  public static double sumArray(double [] d) {
    double result = 0.0;
    for( int i = 0; i < d.length; i++) {
      result += d[i];
    }
    return result;
  }  public boolean supports_backfit() { return false; }
  public double total_train_weight() { return trainWeight; }
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -