⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 logitboost.java

📁 为了下东西 随便发了个 datamining 的源代码
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
    data.insertAttributeAt(new Attribute("'pseudo class'"), classIndex);
    data.setClassIndex(classIndex);
    m_NumericClassData = new Instances(data, 0);
	
    // Perform iterations
    double[][] probs = initialProbs(numInstances);
    double logLikelihood = logLikelihood(trainYs, probs);
    m_NumGenerated = 0;
    if (m_Debug) {
      System.err.println("Avg. log-likelihood: " + logLikelihood);
    }
    double sumOfWeights = data.sumOfWeights();
    for (int j = 0; j < bestNumIterations; j++) {
      double previousLoglikelihood = logLikelihood;
      performIteration(trainYs, trainFs, probs, data, sumOfWeights);
      logLikelihood = logLikelihood(trainYs, probs);
      if (m_Debug) {
	System.err.println("Avg. log-likelihood: " + logLikelihood);
      }
      if (Math.abs(previousLoglikelihood - logLikelihood) < m_Precision) {
	return;
      }
    }
  }

  /**
   * Gets the intial class probabilities.
   */
  private double[][] initialProbs(int numInstances) {

    double[][] probs = new double[numInstances][m_NumClasses];
    for (int i = 0; i < numInstances; i++) {
      for (int j = 0 ; j < m_NumClasses; j++) {
	probs[i][j] = 1.0 / m_NumClasses;
      }
    }
    return probs;
  }

  /**
   * Computes loglikelihood given class values
   * and estimated probablities.
   */
  private double logLikelihood(double[][] trainYs, double[][] probs) {

    double logLikelihood = 0;
    for (int i = 0; i < trainYs.length; i++) {
      for (int j = 0; j < m_NumClasses; j++) {
	if (trainYs[i][j] == 1.0 - m_Offset) {
	  logLikelihood -= Math.log(probs[i][j]);
	}
      }
    }
    return logLikelihood / (double)trainYs.length;
  }

  /**
   * Performs one boosting iteration.
   */
  private void performIteration(double[][] trainYs,
				double[][] trainFs,
				double[][] probs,
				Instances data,
				double origSumOfWeights) throws Exception {

    if (m_Debug) {
      System.err.println("Training classifier " + (m_NumGenerated + 1));
    }

    // Build the new models
    for (int j = 0; j < m_NumClasses; j++) {
      if (m_Debug) {
	System.err.println("\t...for class " + (j + 1)
			   + " (" + m_ClassAttribute.name() 
			   + "=" + m_ClassAttribute.value(j) + ")");
      }
    
      // Make copy because we want to save the weights
      Instances boostData = new Instances(data);
      
      // Set instance pseudoclass and weights
      for (int i = 0; i < probs.length; i++) {

	// Compute response and weight
	double p = probs[i][j];
	double z, actual = trainYs[i][j];
	if (actual == 1 - m_Offset) {
	  z = 1.0 / p;
	  if (z > Z_MAX) { // threshold
	    z = Z_MAX;
	  }
	} else {
	  z = -1.0 / (1.0 - p);
	  if (z < -Z_MAX) { // threshold
	    z = -Z_MAX;
	  }
	}
	double w = (actual - p) / z;

	// Set values for instance
	Instance current = boostData.instance(i);
	current.setValue(boostData.classIndex(), z);
	current.setWeight(current.weight() * w);
      }
      
      // Scale the weights (helps with some base learners)
      double sumOfWeights = boostData.sumOfWeights();
      double scalingFactor = (double)origSumOfWeights / sumOfWeights;
      for (int i = 0; i < probs.length; i++) {
	Instance current = boostData.instance(i);
	current.setWeight(current.weight() * scalingFactor);
      }

      // Select instances to train the classifier on
      Instances trainData = boostData;
      if (m_WeightThreshold < 100) {
	trainData = selectWeightQuantile(boostData, 
					 (double)m_WeightThreshold / 100);
      } else {
	if (m_UseResampling) {
	  double[] weights = new double[boostData.numInstances()];
	  for (int kk = 0; kk < weights.length; kk++) {
	    weights[kk] = boostData.instance(kk).weight();
	  }
	  trainData = boostData.resampleWithWeights(m_RandomInstance, 
						    weights);
	}
      }
      
      // Build the classifier
      m_Classifiers[j][m_NumGenerated].buildClassifier(trainData);
    }      
    
    // Evaluate / increment trainFs from the classifier
    for (int i = 0; i < trainFs.length; i++) {
      double [] pred = new double [m_NumClasses];
      double predSum = 0;
      for (int j = 0; j < m_NumClasses; j++) {
	pred[j] = m_Shrinkage * m_Classifiers[j][m_NumGenerated]
	  .classifyInstance(data.instance(i));
	predSum += pred[j];
      }
      predSum /= m_NumClasses;
      for (int j = 0; j < m_NumClasses; j++) {
	trainFs[i][j] += (pred[j] - predSum) * (m_NumClasses - 1) 
	  / m_NumClasses;
      }
    }
    m_NumGenerated++;
    
    // Compute the current probability estimates
    for (int i = 0; i < trainYs.length; i++) {
      probs[i] = probs(trainFs[i]);
    }
  }

  /**
   * Returns the array of classifiers that have been built.
   */
  public Classifier[][] classifiers() {

    Classifier[][] classifiers = 
      new Classifier[m_NumClasses][m_NumGenerated];
    for (int j = 0; j < m_NumClasses; j++) {
      for (int i = 0; i < m_NumGenerated; i++) {
	classifiers[j][i] = m_Classifiers[j][i];
      }
    }
    return classifiers;
  }

  /**
   * Computes probabilities from F scores
   */
  private double[] probs(double[] Fs) {

    double maxF = -Double.MAX_VALUE;
    for (int i = 0; i < Fs.length; i++) {
      if (Fs[i] > maxF) {
	maxF = Fs[i];
      }
    }
    double sum = 0;
    double[] probs = new double[Fs.length];
    for (int i = 0; i < Fs.length; i++) {
      probs[i] = Math.exp(Fs[i] - maxF);
      sum += probs[i];
    }
    Utils.normalize(probs, sum);
    return probs;
  }
    
  /**
   * Calculates the class membership probabilities for the given test instance.
   *
   * @param instance the instance to be classified
   * @return predicted class probability distribution
   * @exception Exception if instance could not be classified
   * successfully
   */
  public double [] distributionForInstance(Instance instance) 
    throws Exception {

    instance = (Instance)instance.copy();
    instance.setDataset(m_NumericClassData);
    double [] pred = new double [m_NumClasses];
    double [] Fs = new double [m_NumClasses]; 
    for (int i = 0; i < m_NumGenerated; i++) {
      double predSum = 0;
      for (int j = 0; j < m_NumClasses; j++) {
	pred[j] = m_Classifiers[j][i].classifyInstance(instance);
	predSum += pred[j];
      }
      predSum /= m_NumClasses;
      for (int j = 0; j < m_NumClasses; j++) {
	Fs[j] += (pred[j] - predSum) * (m_NumClasses - 1) 
	  / m_NumClasses;
      }
    }

    return probs(Fs);
  }

  /**
   * Returns the boosted model as Java source code.
   *
   * @return the tree as Java source code
   * @exception Exception if something goes wrong
   */
  public String toSource(String className) throws Exception {

    if (m_NumGenerated == 0) {
      throw new Exception("No model built yet");
    }
    if (!(m_Classifiers[0][0] instanceof Sourcable)) {
      throw new Exception("Base learner " + m_Classifier.getClass().getName()
			  + " is not Sourcable");
    }

    StringBuffer text = new StringBuffer("class ");
    text.append(className).append(" {\n\n");
    text.append("  private static double RtoP(double []R, int j) {\n"+
		"    double Rcenter = 0;\n"+
		"    for (int i = 0; i < R.length; i++) {\n"+
		"      Rcenter += R[i];\n"+
		"    }\n"+
		"    Rcenter /= R.length;\n"+
		"    double Rsum = 0;\n"+
		"    for (int i = 0; i < R.length; i++) {\n"+
		"      Rsum += Math.exp(R[i] - Rcenter);\n"+
		"    }\n"+
		"    return Math.exp(R[j]) / Rsum;\n"+
		"  }\n\n");

    text.append("  public static double classify(Object [] i) {\n" +
                "    double [] d = distribution(i);\n" +
                "    double maxV = d[0];\n" +
		"    int maxI = 0;\n"+
		"    for (int j = 1; j < " + m_NumClasses + "; j++) {\n"+
		"      if (d[j] > maxV) { maxV = d[j]; maxI = j; }\n"+
		"    }\n    return (double) maxI;\n  }\n\n");

    text.append("  public static double [] distribution(Object [] i) {\n");
    text.append("    double [] Fs = new double [" + m_NumClasses + "];\n");
    text.append("    double [] Fi = new double [" + m_NumClasses + "];\n");
    text.append("    double Fsum;\n");
    for (int i = 0; i < m_NumGenerated; i++) {
      text.append("    Fsum = 0;\n");
      for (int j = 0; j < m_NumClasses; j++) {
	text.append("    Fi[" + j + "] = " + className + '_' +j + '_' + i 
		    + ".classify(i); Fsum += Fi[" + j + "];\n");
      }
      text.append("    Fsum /= " + m_NumClasses + ";\n");
      text.append("    for (int j = 0; j < " + m_NumClasses + "; j++) {");
      text.append(" Fs[j] += (Fi[j] - Fsum) * "
		  + (m_NumClasses - 1) + " / " + m_NumClasses + "; }\n");
    }
    
    text.append("    double [] dist = new double [" + m_NumClasses + "];\n" +
		"    for (int j = 0; j < " + m_NumClasses + "; j++) {\n"+
		"      dist[j] = RtoP(Fs, j);\n"+
		"    }\n    return dist;\n");
    text.append("  }\n}\n");

    for (int i = 0; i < m_Classifiers.length; i++) {
      for (int j = 0; j < m_Classifiers[i].length; j++) {
	text.append(((Sourcable)m_Classifiers[i][j])
		    .toSource(className + '_' + i + '_' + j));
      }
    }
    return text.toString();
  }

  /**
   * Returns description of the boosted classifier.
   *
   * @return description of the boosted classifier as a string
   */
  public String toString() {
    
    StringBuffer text = new StringBuffer();
    
    if (m_NumGenerated == 0) {
      text.append("LogitBoost: No model built yet.");
      //      text.append(m_Classifiers[0].toString()+"\n");
    } else {
      text.append("LogitBoost: Base classifiers and their weights: \n");
      for (int i = 0; i < m_NumGenerated; i++) {
	text.append("\nIteration "+(i+1));
	for (int j = 0; j < m_NumClasses; j++) {
	  text.append("\n\tClass " + (j + 1) 
		      + " (" + m_ClassAttribute.name() 
		      + "=" + m_ClassAttribute.value(j) + ")\n\n"
		      + m_Classifiers[j][i].toString() + "\n");
	}
      }
      text.append("Number of performed iterations: " +
		    m_NumGenerated + "\n");
    }
    
    return text.toString();
  }

  /**
   * Main method for testing this class.
   *
   * @param argv the options
   */
  public static void main(String [] argv) {

    try {
      System.out.println(Evaluation.evaluateModel(new LogitBoost(), argv));
    } catch (Exception e) {
      e.printStackTrace();
      System.err.println(e.getMessage());
    }
  }
}


  

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -