⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 logitboost.java

📁 :<<数据挖掘--实用机器学习技术及java实现>>一书的配套源程序
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
   *   * @param newClassifier the Classifier to use.   */  public void setClassifier(Classifier newClassifier) {    m_Classifier = newClassifier;  }  /**   * Get the classifier used as the classifier   *   * @return the classifier used as the classifier   */  public Classifier getClassifier() {    return m_Classifier;  }  /**   * Set the maximum number of boost iterations   *   * @param maxIterations the maximum number of boost iterations   */  public void setMaxIterations(int maxIterations) {    m_MaxIterations = maxIterations;  }  /**   * Get the maximum number of boost iterations   *   * @return the maximum number of boost iterations   */  public int getMaxIterations() {    return m_MaxIterations;  }  /**   * Set weight thresholding   *   * @param thresholding the percentage of weight mass used for training   */  public void setWeightThreshold(int threshold) {    m_WeightThreshold = threshold;  }  /**   * Get the degree of weight thresholding   *   * @return the percentage of weight mass used for training   */  public int getWeightThreshold() {    return m_WeightThreshold;  }  /**   * Set debugging mode   *   * @param debug true if debug output should be printed   */  public void setDebug(boolean debug) {    m_Debug = debug;  }  /**   * Get whether debugging is turned on   *   * @return true if debugging output is on   */  public boolean getDebug() {    return m_Debug;  }  /**   * Boosting method. Boosts any classifier that can handle weighted   * instances.   *   * @param data the training data to be used for generating the   * boosted classifier.   * @exception Exception if the classifier could not be built successfully   */  public void buildClassifier(Instances data) throws Exception {    Random randomInstance = new Random(m_Seed);    Instances boostData, trainData;    int classIndex = data.classIndex();    if (data.classAttribute().isNumeric()) {      throw new Exception("LogitBoost can't handle a numeric class!");    }    if (m_Classifier == null) {      throw new Exception("A base classifier has not been specified!");    }        if (!(m_Classifier instanceof WeightedInstancesHandler) &&	!m_UseResampling) {      m_UseResampling = true;    }    if (data.checkForStringAttributes()) {      throw new Exception("Can't handle string attributes!");    }    if (m_Debug) {      System.err.println("Creating copy of the training data");    }    m_NumClasses = data.numClasses();    m_ClassAttribute = data.classAttribute();    // Create a copy of the data with the class transformed into numeric    boostData = new Instances(data);    boostData.deleteWithMissingClass();    int numInstances = boostData.numInstances();    // Temporarily unset the class index    boostData.setClassIndex(-1);    boostData.deleteAttributeAt(classIndex);    boostData.insertAttributeAt(new Attribute("'pseudo class'"), classIndex);    boostData.setClassIndex(classIndex);    m_NumericClassData = new Instances(boostData, 0);    double [][] trainFs = new double [numInstances][m_NumClasses];    double [][] trainYs = new double [numInstances][m_NumClasses];    for (int j = 0; j < m_NumClasses; j++) {      for (int i = 0, k = 0; i < numInstances; i++, k++) {	while (data.instance(k).classIsMissing()) k++;	trainYs[i][j] = (data.instance(k).classValue() == j) ? 1 : 0;      }    }    if (m_Debug) {      System.err.println("Creating base classifiers");    }    // Create the base classifiers    m_Classifiers = new Classifier [m_NumClasses][];    for (int j = 0; j < m_NumClasses; j++) {      m_Classifiers[j] = Classifier.makeCopies(m_Classifier,					       getMaxIterations());    }    // Do boostrap iterations    for (m_NumIterations = 0; m_NumIterations < getMaxIterations(); 	 m_NumIterations++) {      if (m_Debug) {	System.err.println("Training classifier " + (m_NumIterations + 1));      }            for (int j = 0; j < m_NumClasses; j++) {	if (m_Debug) {	  System.err.println("\t...for class " + (j + 1)			     + " (" + m_ClassAttribute.name() 			     + "=" + m_ClassAttribute.value(j) + ")");	}	// Set instance pseudoclass and weights	for (int i = 0; i < numInstances; i++) {	  double p = RtoP(trainFs[i], j);	  Instance current = boostData.instance(i);	  double z, actual = trainYs[i][j];	  if (actual == 1) {	    z = 1.0 / p;	    if (z > Z_MAX) { // threshold	      z = Z_MAX;	    }	  } else if (actual == 0) {	    z = -1.0 / (1.0 - p);	    if (z < -Z_MAX) { // threshold	      z = -Z_MAX;	    }	  } else {	    z = (actual - p) / (p * (1 - p));	  }	  double w = Math.max(p * (1 - p), VERY_SMALL);	  current.setValue(classIndex, z);	  current.setWeight(numInstances * w);	}	// Select instances to train the classifier on	if (m_WeightThreshold < 100) {	  trainData = selectWeightQuantile(boostData, 					   (double)m_WeightThreshold/100);	} else {	  trainData = new Instances(boostData,0,numInstances);	  if (m_UseResampling) {	     double[] weights = new double[boostData.numInstances()];	     for (int kk = 0; kk < weights.length; kk++) {	       weights[kk] = boostData.instance(kk).weight();	     }	     trainData = boostData.resampleWithWeights(randomInstance, 						       weights);	  }	}      	// Build the classifier	m_Classifiers[j][m_NumIterations].buildClassifier(trainData);      }            // Evaluate / increment trainFs from the classifier      for (int i = 0; i < numInstances; i++) {	double [] pred = new double [m_NumClasses];	double predSum = 0;	for (int j = 0; j < m_NumClasses; j++) {	  pred[j] = m_Classifiers[j][m_NumIterations]	    .classifyInstance(boostData.instance(i));	  predSum += pred[j];	}	predSum /= m_NumClasses;	for (int j = 0; j < m_NumClasses; j++) {	  trainFs[i][j] += (pred[j] - predSum) * (m_NumClasses-1) 	    / m_NumClasses;	}      }    }  }  /**   * Calculates the class membership probabilities for the given test instance.   *   * @param instance the instance to be classified   * @return predicted class probability distribution   * @exception Exception if instance could not be classified   * successfully   */  public double [] distributionForInstance(Instance instance)     throws Exception {    instance = (Instance)instance.copy();    instance.setDataset(m_NumericClassData);    double [] Fs = new double [m_NumClasses];     for (int i = 0; i < m_NumIterations; i++) {      double [] Fi = new double [m_NumClasses];      double Fsum = 0;      for (int j = 0; j < m_NumClasses; j++) {	Fi[j] = m_Classifiers[j][i].classifyInstance(instance);	Fsum += Fi[j];      }      Fsum /= m_NumClasses;      for (int j = 0; j < m_NumClasses; j++) {	Fs[j] += (Fi[j] - Fsum) * (m_NumClasses - 1) / m_NumClasses;      }    }    double [] distribution = new double [m_NumClasses];    for (int j = 0; j < m_NumClasses; j++) {      distribution[j] = RtoP(Fs, j);    }    Utils.normalize(distribution);    return distribution;  }  /**   * Returns the boosted model as Java source code.   *   * @return the tree as Java source code   * @exception Exception if something goes wrong   */  public String toSource(String className) throws Exception {    if (m_NumIterations == 0) {      throw new Exception("No model built yet");    }    if (!(m_Classifiers[0][0] instanceof Sourcable)) {      throw new Exception("Base learner " + m_Classifier.getClass().getName()			  + " is not Sourcable");    }    StringBuffer text = new StringBuffer("class ");    text.append(className).append(" {\n\n");    text.append("  private static double RtoP(double []R, int j) {\n"+		"    double Rcenter = 0;\n"+		"    for (int i = 0; i < R.length; i++) {\n"+		"      Rcenter += R[i];\n"+		"    }\n"+		"    Rcenter /= R.length;\n"+		"    double Rsum = 0;\n"+		"    for (int i = 0; i < R.length; i++) {\n"+		"      Rsum += Math.exp(R[i] - Rcenter);\n"+		"    }\n"+		"    return Math.exp(R[j]) / Rsum;\n"+		"  }\n\n");    text.append("  public static double classify(Object [] i) {\n" +                "    double [] d = distribution(i);\n" +                "    double maxV = d[0];\n" +		"    int maxI = 0;\n"+		"    for (int j = 1; j < " + m_NumClasses + "; j++) {\n"+		"      if (d[j] > maxV) { maxV = d[j]; maxI = j; }\n"+		"    }\n    return (double) maxI;\n  }\n\n");    text.append("  public static double [] distribution(Object [] i) {\n");    text.append("    double [] Fs = new double [" + m_NumClasses + "];\n");    text.append("    double [] Fi = new double [" + m_NumClasses + "];\n");    text.append("    double Fsum;\n");    for (int i = 0; i < m_NumIterations; i++) {      text.append("    Fsum = 0;\n");      for (int j = 0; j < m_NumClasses; j++) {	text.append("    Fi[" + j + "] = " + className + '_' +j + '_' + i 		    + ".classify(i); Fsum += Fi[" + j + "];\n");      }      text.append("    Fsum /= " + m_NumClasses + ";\n");      text.append("    for (int j = 0; j < " + m_NumClasses + "; j++) {");      text.append(" Fs[j] += (Fi[j] - Fsum) * "		  + (m_NumClasses - 1) + " / " + m_NumClasses + "; }\n");    }        text.append("    double [] dist = new double [" + m_NumClasses + "];\n" +		"    for (int j = 0; j < " + m_NumClasses + "; j++) {\n"+		"      dist[j] = RtoP(Fs, j);\n"+		"    }\n    return dist;\n");    text.append("  }\n}\n");    for (int i = 0; i < m_Classifiers.length; i++) {      for (int j = 0; j < m_Classifiers[i].length; j++) {	text.append(((Sourcable)m_Classifiers[i][j])		    .toSource(className + '_' + i + '_' + j));      }    }    return text.toString();  }  /**   * Returns description of the boosted classifier.   *   * @return description of the boosted classifier as a string   */  public String toString() {        StringBuffer text = new StringBuffer();        if (m_NumIterations == 0) {      text.append("LogitBoost: No model built yet.");      //      text.append(m_Classifiers[0].toString()+"\n");    } else {      text.append("LogitBoost: Base classifiers and their weights: \n");      for (int i = 0; i < m_NumIterations; i++) {	text.append("\nIteration "+(i+1));	for (int j = 0; j < m_NumClasses; j++) {	  text.append("\n\tClass " + (j + 1) 		      + " (" + m_ClassAttribute.name() 		      + "=" + m_ClassAttribute.value(j) + ")\n\n"		      + m_Classifiers[j][i].toString() + "\n");	}      }      text.append("Number of performed iterations: " +		    m_NumIterations + "\n");    }        return text.toString();  }  /**   * Main method for testing this class.   *   * @param argv the options   */  public static void main(String [] argv) {    try {      System.out.println(Evaluation.evaluateModel(new LogitBoost(), argv));    } catch (Exception e) {      System.err.println(e.getMessage());    }  }}  

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -