⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 logitboost.java

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
  /**   * Gets the current settings of the Classifier.   *   * @return an array of strings suitable for passing to setOptions   */  public String [] getOptions() {    String [] classifierOptions = new String [0];    if ((m_Classifier != null) && 	(m_Classifier instanceof OptionHandler)) {      classifierOptions = ((OptionHandler)m_Classifier).getOptions();    }    String [] options = new String [classifierOptions.length + 17];    int current = 0;    if (getDebug()) {      options[current++] = "-D";    }        if (getUseResampling()) {      options[current++] = "-Q";    } else {      options[current++] = "-P";       options[current++] = "" + getWeightThreshold();    }    if (getSeed() != 1) {      options[current++] = "-S"; options[current++] = "" + getSeed();    }    options[current++] = "-I"; options[current++] = "" + getMaxIterations();    options[current++] = "-F"; options[current++] = "" + getNumFolds();    options[current++] = "-R"; options[current++] = "" + getNumRuns();    options[current++] = "-L"; options[current++] = "" + getLikelihoodThreshold();    options[current++] = "-H"; options[current++] = "" + getShrinkage();    if (getClassifier() != null) {      options[current++] = "-W";      options[current++] = getClassifier().getClass().getName();    }    options[current++] = "--";    System.arraycopy(classifierOptions, 0, options, current, 		     classifierOptions.length);    current += classifierOptions.length;    while (current < options.length) {      options[current++] = "";    }    return options;  }			   /**   * Get the value of Shrinkage.   *   * @return Value of Shrinkage.   */  public double getShrinkage() {        return m_Shrinkage;  }    /**   * Set the value of Shrinkage.   *   * @param newShrinkage Value to assign to Shrinkage.   */  public void setShrinkage(double newShrinkage) {        m_Shrinkage = newShrinkage;  }			   /**   * Get the value of Precision.   *   * @return Value of Precision.   */  public double getLikelihoodThreshold() {        return m_Precision;  }    /**   * Set the value of Precision.   *   * @param newPrecision Value to assign to Precision.   */  public void setLikelihoodThreshold(double newPrecision) {        m_Precision = newPrecision;  }    /**   * Get the value of NumRuns.   *   * @return Value of NumRuns.   */  public int getNumRuns() {        return m_NumRuns;  }    /**   * Set the value of NumRuns.   *   * @param newNumRuns Value to assign to NumRuns.   */  public void setNumRuns(int newNumRuns) {        m_NumRuns = newNumRuns;  }    /**   * Get the value of NumFolds.   *   * @return Value of NumFolds.   */  public int getNumFolds() {        return m_NumFolds;  }    /**   * Set the value of NumFolds.   *   * @param newNumFolds Value to assign to NumFolds.   */  public void setNumFolds(int newNumFolds) {        m_NumFolds = newNumFolds;  }    /**   * Set resampling mode   *   * @param resampling true if resampling should be done   */  public void setUseResampling(boolean r) {        m_UseResampling = r;  }  /**   * Get whether resampling is turned on   *   * @return true if resampling output is on   */  public boolean getUseResampling() {        return m_UseResampling;  }  /**   * Set seed for resampling.   *   * @param seed the seed for resampling   */  public void setSeed(int seed) {    m_Seed = seed;  }  /**   * Get seed for resampling.   *   * @return the seed for resampling   */  public int getSeed() {    return m_Seed;  }  /**   * Set the classifier for boosting. The learner should be able to   * handle numeric class attributes.   *   * @param newClassifier the Classifier to use.   */  public void setClassifier(Classifier newClassifier) {    m_Classifier = newClassifier;  }  /**   * Get the classifier used as the classifier   *   * @return the classifier used as the classifier   */  public Classifier getClassifier() {    return m_Classifier;  }  /**   * Set the maximum number of boost iterations   *   * @param maxIterations the maximum number of boost iterations   */  public void setMaxIterations(int maxIterations) {    m_MaxIterations = maxIterations;  }  /**   * Get the maximum number of boost iterations   *   * @return the maximum number of boost iterations   */  public int getMaxIterations() {    return m_MaxIterations;  }  /**   * Set weight thresholding   *   * @param thresholding the percentage of weight mass used for training   */  public void setWeightThreshold(int threshold) {    m_WeightThreshold = threshold;  }  /**   * Get the degree of weight thresholding   *   * @return the percentage of weight mass used for training   */  public int getWeightThreshold() {    return m_WeightThreshold;  }  /**   * Set debugging mode   *   * @param debug true if debug output should be printed   */  public void setDebug(boolean debug) {    m_Debug = debug;  }  /**   * Get whether debugging is turned on   *   * @return true if debugging output is on   */  public boolean getDebug() {    return m_Debug;  }  /**   * Builds the boosted classifier   */  public void buildClassifier(Instances data) throws Exception {    m_RandomInstance = new Random(m_Seed);    Instances boostData, trainData;    int classIndex = data.classIndex();    if (data.classAttribute().isNumeric()) {      throw new UnsupportedClassTypeException("LogitBoost can't handle a numeric class!");    }    if (m_Classifier == null) {      throw new Exception("A base classifier has not been specified!");    }        if (!(m_Classifier instanceof WeightedInstancesHandler) &&	!m_UseResampling) {      m_UseResampling = true;    }    if (data.checkForStringAttributes()) {      throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");    }    if (m_Debug) {      System.err.println("Creating copy of the training data");    }    m_NumClasses = data.numClasses();    m_ClassAttribute = data.classAttribute();    // Create a copy of the data     data = new Instances(data);    data.deleteWithMissingClass();        // Create the base classifiers    if (m_Debug) {      System.err.println("Creating base classifiers");    }    m_Classifiers = new Classifier [m_NumClasses][];    for (int j = 0; j < m_NumClasses; j++) {      m_Classifiers[j] = Classifier.makeCopies(m_Classifier,					       getMaxIterations());    }    // Do we want to select the appropriate number of iterations    // using cross-validation?    int bestNumIterations = getMaxIterations();    if (m_NumFolds > 1) {      if (m_Debug) {	System.err.println("Processing first fold.");      }      // Array for storing the results      double[] results = new double[getMaxIterations()];      // Iterate throught the cv-runs      for (int r = 0; r < m_NumRuns; r++) {	// Stratify the data	data.randomize(m_RandomInstance);	data.stratify(m_NumFolds);		// Perform the cross-validation	for (int i = 0; i < m_NumFolds; i++) {	  	  // Get train and test folds	  Instances train = data.trainCV(m_NumFolds, i);	  Instances test = data.testCV(m_NumFolds, i);	  	  // Make class numeric	  Instances trainN = new Instances(train);	  trainN.setClassIndex(-1);	  trainN.deleteAttributeAt(classIndex);	  trainN.insertAttributeAt(new Attribute("'pseudo class'"), classIndex);	  trainN.setClassIndex(classIndex);	  m_NumericClassData = new Instances(trainN, 0);	  	  // Get class values	  int numInstances = train.numInstances();	  double [][] trainFs = new double [numInstances][m_NumClasses];	  double [][] trainYs = new double [numInstances][m_NumClasses];	  for (int j = 0; j < m_NumClasses; j++) {	    for (int k = 0; k < numInstances; k++) {	      trainYs[k][j] = (train.instance(k).classValue() == j) ? 		1.0 - m_Offset: 0.0 + (m_Offset / (double)m_NumClasses);	    }	  }	  	  // Perform iterations	  double[][] probs = initialProbs(numInstances);	  m_NumIterations = 0;	  double sumOfWeights = train.sumOfWeights();	  for (int j = 0; j < getMaxIterations(); j++) {	    performIteration(trainYs, trainFs, probs, trainN, sumOfWeights);	    Evaluation eval = new Evaluation(train);	    eval.evaluateModel(this, test);	    results[j] += eval.correct();	  }	}      }            // Find the number of iterations with the lowest error      double bestResult = -Double.MAX_VALUE;      for (int j = 0; j < getMaxIterations(); j++) {	if (results[j] > bestResult) {	  bestResult = results[j];	  bestNumIterations = j;	}      }      if (m_Debug) {

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -