additiveregression.java

来自「Weka」· Java 代码 · 共 527 行 · 第 1/2 页

JAVA
527
字号
    String [] options = new String [superOptions.length + 2];    int current = 0;    options[current++] = "-S"; options[current++] = "" + getShrinkage();    System.arraycopy(superOptions, 0, options, current, 		     superOptions.length);    current += superOptions.length;    while (current < options.length) {      options[current++] = "";    }    return options;  }  /**   * Returns the tip text for this property   * @return tip text for this property suitable for   * displaying in the explorer/experimenter gui   */  public String shrinkageTipText() {    return "Shrinkage rate. Smaller values help prevent overfitting and "      + "have a smoothing effect (but increase learning time). "      +"Default = 1.0, ie. no shrinkage.";   }  /**   * Set the shrinkage parameter   *   * @param l the shrinkage rate.   */  public void setShrinkage(double l) {    m_shrinkage = l;  }  /**   * Get the shrinkage rate.   *   * @return the value of the learning rate   */  public double getShrinkage() {    return m_shrinkage;  }  /**   * Returns default capabilities of the classifier.   *   * @return      the capabilities of this classifier   */  public Capabilities getCapabilities() {    Capabilities result = super.getCapabilities();    // class    result.disableAllClasses();    result.disableAllClassDependencies();    result.enable(Capability.NUMERIC_CLASS);    result.enable(Capability.DATE_CLASS);        return result;  }  /**   * Build the classifier on the supplied data   *   * @param data the training data   * @throws Exception if the classifier could not be built successfully   */  public void buildClassifier(Instances data) throws Exception {    super.buildClassifier(data);    // can classifier handle the data?    getCapabilities().testWithFail(data);    // remove instances with missing class    Instances newData = new Instances(data);    newData.deleteWithMissingClass();    double sum = 0;    double temp_sum = 0;    // Add the model for the mean first    m_zeroR = new ZeroR();    m_zeroR.buildClassifier(newData);        // only class? -> use only ZeroR model    if (newData.numAttributes() == 1) {      System.err.println(	  "Cannot build model (only class attribute present in data!), "	  + "using ZeroR model instead!");      m_SuitableData = false;      return;    }    else {      m_SuitableData = true;    }        newData = residualReplace(newData, m_zeroR, false);    for (int i = 0; i < newData.numInstances(); i++) {      sum += newData.instance(i).weight() *	newData.instance(i).classValue() * newData.instance(i).classValue();    }    if (m_Debug) {      System.err.println("Sum of squared residuals "			 +"(predicting the mean) : " + sum);    }    m_NumIterationsPerformed = 0;    do {      temp_sum = sum;      // Build the classifier      m_Classifiers[m_NumIterationsPerformed].buildClassifier(newData);      newData = residualReplace(newData, m_Classifiers[m_NumIterationsPerformed], true);      sum = 0;      for (int i = 0; i < newData.numInstances(); i++) {	sum += newData.instance(i).weight() *	  newData.instance(i).classValue() * newData.instance(i).classValue();      }      if (m_Debug) {	System.err.println("Sum of squared residuals : "+sum);      }      m_NumIterationsPerformed++;    } while (((temp_sum - sum) > Utils.SMALL) && 	     (m_NumIterationsPerformed < m_Classifiers.length));  }  /**   * Classify an instance.   *   * @param inst the instance to predict   * @return a prediction for the instance   * @throws Exception if an error occurs   */  public double classifyInstance(Instance inst) throws Exception {    double prediction = m_zeroR.classifyInstance(inst);    // default model?    if (!m_SuitableData) {      return prediction;    }        for (int i = 0; i < m_NumIterationsPerformed; i++) {      double toAdd = m_Classifiers[i].classifyInstance(inst);      toAdd *= getShrinkage();      prediction += toAdd;    }    return prediction;  }  /**   * Replace the class values of the instances from the current iteration   * with residuals ater predicting with the supplied classifier.   *   * @param data the instances to predict   * @param c the classifier to use   * @param useShrinkage whether shrinkage is to be applied to the model's output   * @return a new set of instances with class values replaced by residuals   * @throws Exception if something goes wrong   */  private Instances residualReplace(Instances data, Classifier c, 				    boolean useShrinkage) throws Exception {    double pred,residual;    Instances newInst = new Instances(data);    for (int i = 0; i < newInst.numInstances(); i++) {      pred = c.classifyInstance(newInst.instance(i));      if (useShrinkage) {	pred *= getShrinkage();      }      residual = newInst.instance(i).classValue() - pred;      newInst.instance(i).setClassValue(residual);    }    //    System.err.print(newInst);    return newInst;  }  /**   * Returns an enumeration of the additional measure names   * @return an enumeration of the measure names   */  public Enumeration enumerateMeasures() {    Vector newVector = new Vector(1);    newVector.addElement("measureNumIterations");    return newVector.elements();  }  /**   * Returns the value of the named measure   * @param additionalMeasureName the name of the measure to query for its value   * @return the value of the named measure   * @throws IllegalArgumentException if the named measure is not supported   */  public double getMeasure(String additionalMeasureName) {    if (additionalMeasureName.compareToIgnoreCase("measureNumIterations") == 0) {      return measureNumIterations();    } else {      throw new IllegalArgumentException(additionalMeasureName 			  + " not supported (AdditiveRegression)");    }  }  /**   * return the number of iterations (base classifiers) completed   * @return the number of iterations (same as number of base classifier   * models)   */  public double measureNumIterations() {    return m_NumIterationsPerformed;  }  /**   * Returns textual description of the classifier.   *   * @return a description of the classifier as a string   */  public String toString() {    StringBuffer text = new StringBuffer();    // only ZeroR model?    if (!m_SuitableData) {      StringBuffer buf = new StringBuffer();      buf.append(this.getClass().getName().replaceAll(".*\\.", "") + "\n");      buf.append(this.getClass().getName().replaceAll(".*\\.", "").replaceAll(".", "=") + "\n\n");      buf.append("Warning: No model could be built, hence ZeroR model is used:\n\n");      buf.append(m_zeroR.toString());      return buf.toString();    }        if (m_NumIterations == 0) {      return "Classifier hasn't been built yet!";    }    text.append("Additive Regression\n\n");    text.append("ZeroR model\n\n" + m_zeroR + "\n\n");    text.append("Base classifier " 		+ getClassifier().getClass().getName()		+ "\n\n");    text.append("" + m_NumIterationsPerformed + " models generated.\n");    for (int i = 0; i < m_NumIterationsPerformed; i++) {      text.append("\nModel number " + i + "\n\n" +		  m_Classifiers[i] + "\n");    }    return text.toString();  }  /**   * Main method for testing this class.   *   * @param argv should contain the following arguments:   * -t training file [-T test file] [-c class index]   */  public static void main(String [] argv) {    runClassifier(new AdditiveRegression(), argv);  }}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?