⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 additiveregression.java

📁 一个数据挖掘软件ALPHAMINERR的整个过程的JAVA版源代码
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
  public boolean getDebug() {
    return m_debug;
  }

  /**
   * Returns the tip text for this property
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
  public String classifierTipText() {
    return "Classifier to use";
  }

  /**
   * Sets the classifier
   *
   * @param classifier the classifier with all options set.
   */
  public void setClassifier(Classifier classifier) {

    m_Classifier = classifier;
  }

  /**
   * Gets the classifier used.
   *
   * @return the classifier
   */
  public Classifier getClassifier() {

    return m_Classifier;
  }
  
  /**
   * Gets the classifier specification string, which contains the class name of
   * the classifier and any options to the classifier
   *
   * @return the classifier string.
   */
  protected String getClassifierSpec() {
    
    Classifier c = getClassifier();
    if (c instanceof OptionHandler) {
      return c.getClass().getName() + " "
	+ Utils.joinOptions(((OptionHandler)c).getOptions());
    }
    return c.getClass().getName();
  }
  
  /**
   * Returns the tip text for this property
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
  public String maxModelsTipText() {
    return "Max models to generate. <= 0 indicates no maximum, ie. continue until "
      +"error reduction threshold is reached.";
  }

  /**
   * Set the maximum number of models to generate
   * @param maxM the maximum number of models
   */
  public void setMaxModels(int maxM) {
    m_maxModels = maxM;
  }

  /**
   * Get the max number of models to generate
   * @return the max number of models to generate
   */
  public int getMaxModels() {
    return m_maxModels;
  }

  /**
   * Returns the tip text for this property
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
  public String shrinkageTipText() {
    return "Shrinkage rate. Smaller values help prevent overfitting and "
      + "have a smoothing effect (but increase learning time). "
      +"Default = 1.0, ie. no shrinkage."; 
  }

  /**
   * Set the shrinkage parameter
   *
   * @param l the shrinkage rate.
   */
  public void setShrinkage(double l) {
    m_shrinkage = l;
  }

  /**
   * Get the shrinkage rate.
   *
   * @return the value of the learning rate
   */
  public double getShrinkage() {
    return m_shrinkage;
  }

  /**
   * Build the classifier on the supplied data
   *
   * @param data the training data
   * @exception Exception if the classifier could not be built successfully
   */
  public void buildClassifier(Instances data) throws Exception {
     m_additiveModels = new FastVector();

    if (m_Classifier == null) {
      throw new Exception("No base classifiers have been set!");
    }
    if (data.classAttribute().isNominal()) {
      throw new UnsupportedClassTypeException("Class must be numeric!");
    }
    Instances newData = new Instances(data);
    newData.deleteWithMissingClass();
    m_classIndex = newData.classIndex();

    double sum = 0;
    double temp_sum = 0;
    // Add the model for the mean first
    ZeroR zr = new ZeroR();
    zr.buildClassifier(newData);
    m_additiveModels.addElement(zr);
    newData = residualReplace(newData, zr, false);
    for (int i = 0; i < newData.numInstances(); i++) {
      sum += newData.instance(i).weight() *
	newData.instance(i).classValue() *
	newData.instance(i).classValue();
    }
    if (m_debug) {
      System.err.println("Sum of squared residuals "
			 +"(predicting the mean) : "+sum);
    }

    int modelCount = 0;
    do {
      temp_sum = sum;
      Classifier nextC = Classifier.makeCopies(m_Classifier, 1)[0];
      nextC.buildClassifier(newData);
      m_additiveModels.addElement(nextC);
      newData = residualReplace(newData, nextC, true);
      sum = 0;
      for (int i = 0; i < newData.numInstances(); i++) {
	sum += newData.instance(i).weight() *
	  newData.instance(i).classValue() *
	  newData.instance(i).classValue();
      }
      if (m_debug) {
	System.err.println("Sum of squared residuals : "+sum);
      }
      modelCount++;
    } while (((temp_sum - sum) > Utils.SMALL) && 
	     (m_maxModels > 0 ? (modelCount < m_maxModels) : true));
  }

  /**
   * Classify an instance.
   *
   * @param inst the instance to predict
   * @return a prediction for the instance
   * @exception Exception if an error occurs
   */
  public double classifyInstance(Instance inst) throws Exception {
    double prediction = 0;

    for (int i = 0; i < m_additiveModels.size(); i++) {
      Classifier current = (Classifier)m_additiveModels.elementAt(i);
      double toAdd = current.classifyInstance(inst);
      if (i > 0) {
	toAdd *= getShrinkage();
      }
      prediction += toAdd;
    }

    return prediction;
  }

  /**
   * Replace the class values of the instances from the current iteration
   * with residuals ater predicting with the supplied classifier.
   *
   * @param data the instances to predict
   * @param c the classifier to use
   * @param useShrinkage whether shrinkage is to be applied to the model's output
   * @return a new set of instances with class values replaced by residuals
   */
  private Instances residualReplace(Instances data, Classifier c, 
				    boolean useShrinkage) throws Exception {
    double pred,residual;
    Instances newInst = new Instances(data);

    for (int i = 0; i < newInst.numInstances(); i++) {
      pred = c.classifyInstance(newInst.instance(i));
      if (useShrinkage) {
	pred *= getShrinkage();
      }
      residual = newInst.instance(i).classValue() - pred;
      newInst.instance(i).setClassValue(residual);
    }
    //    System.err.print(newInst);
    return newInst;
  }

  /**
   * Returns an enumeration of the additional measure names
   * @return an enumeration of the measure names
   */
  public Enumeration emerateMeasures() {
    Vector newVector = new Vector(1);
    newVector.addElement("measureNumIterations");
    return newVector.elements();
  }

  /**
   * Returns the value of the named measure
   * @param measureName the name of the measure to query for its value
   * @return the value of the named measure
   * @exception IllegalArgumentException if the named measure is not supported
   */
  public double getMeasure(String additionalMeasureName) {
    if (additionalMeasureName.compareToIgnoreCase("measureNumIterations") == 0) {
      return measureNumIterations();
    } else {
      throw new IllegalArgumentException(additionalMeasureName 
			  + " not supported (AdditiveRegression)");
    }
  }

  /**
   * return the number of iterations (base classifiers) completed
   * @return the number of iterations (same as number of base classifier
   * models)
   */
  public double measureNumIterations() {
    return m_additiveModels.size();
  }

  /**
   * Returns textual description of the classifier.
   *
   * @return a description of the classifier as a string
   */
  public String toString() {
    StringBuffer text = new StringBuffer();

    if (m_additiveModels.size() == 0) {
      return "Classifier hasn't been built yet!";
    }

    text.append("Additive Regression\n\n");
    text.append("Base classifier " 
		+ getClassifier().getClass().getName()
		+ "\n\n");
    text.append(""+m_additiveModels.size()+" models generated.\n");

    for (int i = 0; i < m_additiveModels.size(); i++) {
      text.append("\nModel number " + i + "\n\n" +
		  m_additiveModels.elementAt(i) + "\n");
    }

    return text.toString();
  }

  /**
   * Main method for testing this class.
   *
   * @param argv should contain the following arguments:
   * -t training file [-T test file] [-c class index]
   */
  public static void main(String [] argv) {

    try {
      System.out.println(Evaluation.evaluateModel(new AdditiveRegression(),
						  argv));
    } catch (Exception e) {
      System.err.println(e.getMessage());
    }
  }
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -