📄 additiveregression.java
字号:
} /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String debugTipText() { return "Turn on debugging output"; } /** * Set whether debugging output is produced. * * @param d true if debugging output is to be produced */ public void setDebug(boolean d) { m_debug = d; } /** * Gets whether debugging has been turned on * * @return true if debugging has been turned on */ public boolean getDebug() { return m_debug; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String classifierTipText() { return "Classifier to use"; } /** * Sets the classifier * * @param classifier the classifier with all options set. */ public void setClassifier(Classifier classifier) { m_Classifier = classifier; } /** * Gets the classifier used. * * @return the classifier */ public Classifier getClassifier() { return m_Classifier; } /** * Gets the classifier specification string, which contains the class name of * the classifier and any options to the classifier * * @return the classifier string. */ protected String getClassifierSpec() { Classifier c = getClassifier(); if (c instanceof OptionHandler) { return c.getClass().getName() + " " + Utils.joinOptions(((OptionHandler)c).getOptions()); } return c.getClass().getName(); } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String maxModelsTipText() { return "Max models to generate. <= 0 indicates no maximum, ie. continue until " +"error reduction threshold is reached."; } /** * Set the maximum number of models to generate * @param maxM the maximum number of models */ public void setMaxModels(int maxM) { m_maxModels = maxM; } /** * Get the max number of models to generate * @return the max number of models to generate */ public int getMaxModels() { return m_maxModels; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String shrinkageTipText() { return "Shrinkage rate. Smaller values help prevent overfitting and " + "have a smoothing effect (but increase learning time). " +"Default = 1.0, ie. no shrinkage."; } /** * Set the shrinkage parameter * * @param l the shrinkage rate. */ public void setShrinkage(double l) { m_shrinkage = l; } /** * Get the shrinkage rate. * * @return the value of the learning rate */ public double getShrinkage() { return m_shrinkage; } /** * Build the classifier on the supplied data * * @param data the training data * @exception Exception if the classifier could not be built successfully */ public void buildClassifier(Instances data) throws Exception { m_additiveModels = new FastVector(); if (m_Classifier == null) { throw new Exception("No base classifiers have been set!"); } if (data.classAttribute().isNominal()) { throw new UnsupportedClassTypeException("Class must be numeric!"); } Instances newData = new Instances(data); newData.deleteWithMissingClass(); m_classIndex = newData.classIndex(); double sum = 0; double temp_sum = 0; // Add the model for the mean first ZeroR zr = new ZeroR(); zr.buildClassifier(newData); m_additiveModels.addElement(zr); newData = residualReplace(newData, zr); for (int i = 0; i < newData.numInstances(); i++) { sum += newData.instance(i).weight() * newData.instance(i).classValue() * newData.instance(i).classValue(); } if (m_debug) { System.err.println("Sum of squared residuals " +"(predicting the mean) : "+sum); } int modelCount = 0; do { temp_sum = sum; Classifier nextC = Classifier.makeCopies(m_Classifier, 1)[0]; nextC.buildClassifier(newData); m_additiveModels.addElement(nextC); newData = residualReplace(newData, nextC); sum = 0; for (int i = 0; i < newData.numInstances(); i++) { sum += newData.instance(i).weight() * newData.instance(i).classValue() * newData.instance(i).classValue(); } if (m_debug) { System.err.println("Sum of squared residuals : "+sum); } modelCount++; } while (((temp_sum - sum) > Utils.SMALL) && (m_maxModels > 0 ? (modelCount < m_maxModels) : true)); // remove last classifier m_additiveModels.removeElementAt(m_additiveModels.size()-1); } /** * Classify an instance. * * @param inst the instance to predict * @return a prediction for the instance * @exception Exception if an error occurs */ public double classifyInstance(Instance inst) throws Exception { double prediction = 0; for (int i = 0; i < m_additiveModels.size(); i++) { Classifier current = (Classifier)m_additiveModels.elementAt(i); prediction += (current.classifyInstance(inst) * getShrinkage()); } return prediction; } /** * Replace the class values of the instances from the current iteration * with residuals ater predicting with the supplied classifier. * * @param data the instances to predict * @param c the classifier to use * @return a new set of instances with class values replaced by residuals */ private Instances residualReplace(Instances data, Classifier c) { double pred,residual; Instances newInst = new Instances(data); for (int i = 0; i < newInst.numInstances(); i++) { try { pred = c.classifyInstance(newInst.instance(i)) * getShrinkage(); residual = newInst.instance(i).classValue() - pred; // System.err.println("Residual : "+residual); newInst.instance(i).setClassValue(residual); } catch (Exception ex) { // continue } } // System.err.print(newInst); return newInst; } /** * Returns an enumeration of the additional measure names * @return an enumeration of the measure names */ public Enumeration enumerateMeasures() { Vector newVector = new Vector(1); newVector.addElement("measureNumIterations"); return newVector.elements(); } /** * Returns the value of the named measure * @param measureName the name of the measure to query for its value * @return the value of the named measure * @exception IllegalArgumentException if the named measure is not supported */ public double getMeasure(String additionalMeasureName) { if (additionalMeasureName.compareTo("measureNumIterations") == 0) { return measureNumIterations(); } else { throw new IllegalArgumentException(additionalMeasureName + " not supported (AdditiveRegression)"); } } /** * return the number of iterations (base classifiers) completed * @return the number of iterations (same as number of base classifier * models) */ public double measureNumIterations() { return m_additiveModels.size(); } /** * Returns textual description of the classifier. * * @return a description of the classifier as a string */ public String toString() { StringBuffer text = new StringBuffer(); if (m_additiveModels.size() == 0) { return "Classifier hasn't been built yet!"; } text.append("Additive Regression\n\n"); text.append("Base classifier " + getClassifier().getClass().getName() + "\n\n"); text.append(""+m_additiveModels.size()+" models generated.\n"); return text.toString(); } /** * Main method for testing this class. * * @param argv should contain the following arguments: * -t training file [-T test file] [-c class index] */ public static void main(String [] argv) { try { System.out.println(Evaluation.evaluateModel(new AdditiveRegression(), argv)); } catch (Exception e) { System.err.println(e.getMessage()); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -