⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 logitboost.java

📁 一个数据挖掘软件ALPHAMINERR的整个过程的JAVA版源代码
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
			  "not allowed.");
    }

    super.setOptions(options);
  }

  /**
   * Gets the current settings of the Classifier.
   *
   * @return an array of strings suitable for passing to setOptions
   */
  public String [] getOptions() {

    String [] superOptions = super.getOptions();
    String [] options = new String [superOptions.length + 10];

    int current = 0;
    if (getUseResampling()) {
      options[current++] = "-Q";
    } else {
      options[current++] = "-P"; 
      options[current++] = "" + getWeightThreshold();
    }
    options[current++] = "-F"; options[current++] = "" + getNumFolds();
    options[current++] = "-R"; options[current++] = "" + getNumRuns();
    options[current++] = "-L"; options[current++] = "" + getLikelihoodThreshold();
    options[current++] = "-H"; options[current++] = "" + getShrinkage();

    System.arraycopy(superOptions, 0, options, current, 
		     superOptions.length);
    current += superOptions.length;
    while (current < options.length) {
      options[current++] = "";
    }
    return options;
  }
  
  /**
   * Returns the tip text for this property
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
  public String shrinkageTipText() {
    return "Shrinkage parameter (use small value like 0.1 to reduce "
      + "overfitting).";
  }
			 
  /**
   * Get the value of Shrinkage.
   *
   * @return Value of Shrinkage.
   */
  public double getShrinkage() {
    
    return m_Shrinkage;
  }
  
  /**
   * Set the value of Shrinkage.
   *
   * @param newShrinkage Value to assign to Shrinkage.
   */
  public void setShrinkage(double newShrinkage) {
    
    m_Shrinkage = newShrinkage;
  }
  
  /**
   * Returns the tip text for this property
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
  public String likelihoodThresholdTipText() {
    return "Threshold on improvement in likelihood.";
  }
			 
  /**
   * Get the value of Precision.
   *
   * @return Value of Precision.
   */
  public double getLikelihoodThreshold() {
    
    return m_Precision;
  }
  
  /**
   * Set the value of Precision.
   *
   * @param newPrecision Value to assign to Precision.
   */
  public void setLikelihoodThreshold(double newPrecision) {
    
    m_Precision = newPrecision;
  }
  
  /**
   * Returns the tip text for this property
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
  public String numRunsTipText() {
    return "Number of runs for internal cross-validation.";
  }
  
  /**
   * Get the value of NumRuns.
   *
   * @return Value of NumRuns.
   */
  public int getNumRuns() {
    
    return m_NumRuns;
  }
  
  /**
   * Set the value of NumRuns.
   *
   * @param newNumRuns Value to assign to NumRuns.
   */
  public void setNumRuns(int newNumRuns) {
    
    m_NumRuns = newNumRuns;
  }
  
  /**
   * Returns the tip text for this property
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
  public String numFoldsTipText() {
    return "Number of folds for internal cross-validation (default 0 "
      + "means no cross-validation is performed).";
  }
  
  /**
   * Get the value of NumFolds.
   *
   * @return Value of NumFolds.
   */
  public int getNumFolds() {
    
    return m_NumFolds;
  }
  
  /**
   * Set the value of NumFolds.
   *
   * @param newNumFolds Value to assign to NumFolds.
   */
  public void setNumFolds(int newNumFolds) {
    
    m_NumFolds = newNumFolds;
  }
  
  /**
   * Returns the tip text for this property
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
  public String useResamplingTipText() {
    return "Whether resampling is used instead of reweighting.";
  }
  
  /**
   * Set resampling mode
   *
   * @param resampling true if resampling should be done
   */
  public void setUseResampling(boolean r) {
    
    m_UseResampling = r;
  }

  /**
   * Get whether resampling is turned on
   *
   * @return true if resampling output is on
   */
  public boolean getUseResampling() {
    
    return m_UseResampling;
  }
  
  /**
   * Returns the tip text for this property
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
  public String weightThresholdTipText() {
    return "Weight threshold for weight pruning (reduce to 90 "
      + "for speeding up learning process).";
  }

  /**
   * Set weight thresholding
   *
   * @param thresholding the percentage of weight mass used for training
   */
  public void setWeightThreshold(int threshold) {

    m_WeightThreshold = threshold;
  }

  /**
   * Get the degree of weight thresholding
   *
   * @return the percentage of weight mass used for training
   */
  public int getWeightThreshold() {

    return m_WeightThreshold;
  }

  /**
   * Builds the boosted classifier
   */
  public void buildClassifier(Instances data) throws Exception {

    m_RandomInstance = new Random(m_Seed);
    Instances boostData, trainData;
    int classIndex = data.classIndex();

    if (data.classAttribute().isNumeric()) {
      throw new UnsupportedClassTypeException("LogitBoost can't handle a numeric class!");
    }
    if (m_Classifier == null) {
      throw new Exception("A base classifier has not been specified!");
    }
    
    if (!(m_Classifier instanceof WeightedInstancesHandler) &&
	!m_UseResampling) {
      m_UseResampling = true;
    }
    if (data.checkForStringAttributes()) {
      throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");
    }
    if (m_Debug) {
      System.err.println("Creating copy of the training data");
    }

    m_NumClasses = data.numClasses();
    m_ClassAttribute = data.classAttribute();

    // Create a copy of the data 
    data = new Instances(data);
    data.deleteWithMissingClass();
    
    // Create the base classifiers
    if (m_Debug) {
      System.err.println("Creating base classifiers");
    }
    m_Classifiers = new Classifier [m_NumClasses][];
    for (int j = 0; j < m_NumClasses; j++) {
      m_Classifiers[j] = Classifier.makeCopies(m_Classifier,
					       getNumIterations());
    }

    // Do we want to select the appropriate number of iterations
    // using cross-validation?
    int bestNumIterations = getNumIterations();
    if (m_NumFolds > 1) {
      if (m_Debug) {
	System.err.println("Processing first fold.");
      }

      // Array for storing the results
      double[] results = new double[getNumIterations()];

      // Iterate throught the cv-runs
      for (int r = 0; r < m_NumRuns; r++) {

	// Stratify the data
	data.randomize(m_RandomInstance);
	data.stratify(m_NumFolds);
	
	// Perform the cross-validation
	for (int i = 0; i < m_NumFolds; i++) {
	  
	  // Get train and test folds
	  Instances train = data.trainCV(m_NumFolds, i, m_RandomInstance);
	  Instances test = data.testCV(m_NumFolds, i);
	  
	  // Make class numeric
	  Instances trainN = new Instances(train);
	  trainN.setClassIndex(-1);
	  trainN.deleteAttributeAt(classIndex);
	  trainN.insertAttributeAt(new Attribute("'pseudo class'"), classIndex);
	  trainN.setClassIndex(classIndex);
	  m_NumericClassData = new Instances(trainN, 0);
	  
	  // Get class values
	  int numInstances = train.numInstances();
	  double [][] trainFs = new double [numInstances][m_NumClasses];
	  double [][] trainYs = new double [numInstances][m_NumClasses];
	  for (int j = 0; j < m_NumClasses; j++) {
	    for (int k = 0; k < numInstances; k++) {
	      trainYs[k][j] = (train.instance(k).classValue() == j) ? 
		1.0 - m_Offset: 0.0 + (m_Offset / (double)m_NumClasses);
	    }
	  }
	  
	  // Perform iterations
	  double[][] probs = initialProbs(numInstances);
	  m_NumGenerated = 0;
	  double sumOfWeights = train.sumOfWeights();
	  for (int j = 0; j < getNumIterations(); j++) {
	    performIteration(trainYs, trainFs, probs, trainN, sumOfWeights);
	    Evaluation eval = new Evaluation(train);
	    eval.evaluateModel(this, test);
	    results[j] += eval.correct();
	  }
	}
      }
      
      // Find the number of iterations with the lowest error
      double bestResult = -Double.MAX_VALUE;
      for (int j = 0; j < getNumIterations(); j++) {
	if (results[j] > bestResult) {
	  bestResult = results[j];
	  bestNumIterations = j;
	}
      }
      if (m_Debug) {
	System.err.println("Best result for " + 
			   bestNumIterations + " iterations: " +
			   bestResult);
      }
    }

    // Build classifier on all the data
    int numInstances = data.numInstances();
    double [][] trainFs = new double [numInstances][m_NumClasses];
    double [][] trainYs = new double [numInstances][m_NumClasses];
    for (int j = 0; j < m_NumClasses; j++) {
      for (int i = 0, k = 0; i < numInstances; i++, k++) {
	trainYs[i][j] = (data.instance(k).classValue() == j) ? 
	  1.0 - m_Offset: 0.0 + (m_Offset / (double)m_NumClasses);
      }
    }
    
    // Make class numeric
    data.setClassIndex(-1);
    data.deleteAttributeAt(classIndex);

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -