additiveregression.java
来自「Weka」· Java 代码 · 共 527 行 · 第 1/2 页
JAVA
527 行
String [] options = new String [superOptions.length + 2]; int current = 0; options[current++] = "-S"; options[current++] = "" + getShrinkage(); System.arraycopy(superOptions, 0, options, current, superOptions.length); current += superOptions.length; while (current < options.length) { options[current++] = ""; } return options; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String shrinkageTipText() { return "Shrinkage rate. Smaller values help prevent overfitting and " + "have a smoothing effect (but increase learning time). " +"Default = 1.0, ie. no shrinkage."; } /** * Set the shrinkage parameter * * @param l the shrinkage rate. */ public void setShrinkage(double l) { m_shrinkage = l; } /** * Get the shrinkage rate. * * @return the value of the learning rate */ public double getShrinkage() { return m_shrinkage; } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); // class result.disableAllClasses(); result.disableAllClassDependencies(); result.enable(Capability.NUMERIC_CLASS); result.enable(Capability.DATE_CLASS); return result; } /** * Build the classifier on the supplied data * * @param data the training data * @throws Exception if the classifier could not be built successfully */ public void buildClassifier(Instances data) throws Exception { super.buildClassifier(data); // can classifier handle the data? getCapabilities().testWithFail(data); // remove instances with missing class Instances newData = new Instances(data); newData.deleteWithMissingClass(); double sum = 0; double temp_sum = 0; // Add the model for the mean first m_zeroR = new ZeroR(); m_zeroR.buildClassifier(newData); // only class? -> use only ZeroR model if (newData.numAttributes() == 1) { System.err.println( "Cannot build model (only class attribute present in data!), " + "using ZeroR model instead!"); m_SuitableData = false; return; } else { m_SuitableData = true; } newData = residualReplace(newData, m_zeroR, false); for (int i = 0; i < newData.numInstances(); i++) { sum += newData.instance(i).weight() * newData.instance(i).classValue() * newData.instance(i).classValue(); } if (m_Debug) { System.err.println("Sum of squared residuals " +"(predicting the mean) : " + sum); } m_NumIterationsPerformed = 0; do { temp_sum = sum; // Build the classifier m_Classifiers[m_NumIterationsPerformed].buildClassifier(newData); newData = residualReplace(newData, m_Classifiers[m_NumIterationsPerformed], true); sum = 0; for (int i = 0; i < newData.numInstances(); i++) { sum += newData.instance(i).weight() * newData.instance(i).classValue() * newData.instance(i).classValue(); } if (m_Debug) { System.err.println("Sum of squared residuals : "+sum); } m_NumIterationsPerformed++; } while (((temp_sum - sum) > Utils.SMALL) && (m_NumIterationsPerformed < m_Classifiers.length)); } /** * Classify an instance. * * @param inst the instance to predict * @return a prediction for the instance * @throws Exception if an error occurs */ public double classifyInstance(Instance inst) throws Exception { double prediction = m_zeroR.classifyInstance(inst); // default model? if (!m_SuitableData) { return prediction; } for (int i = 0; i < m_NumIterationsPerformed; i++) { double toAdd = m_Classifiers[i].classifyInstance(inst); toAdd *= getShrinkage(); prediction += toAdd; } return prediction; } /** * Replace the class values of the instances from the current iteration * with residuals ater predicting with the supplied classifier. * * @param data the instances to predict * @param c the classifier to use * @param useShrinkage whether shrinkage is to be applied to the model's output * @return a new set of instances with class values replaced by residuals * @throws Exception if something goes wrong */ private Instances residualReplace(Instances data, Classifier c, boolean useShrinkage) throws Exception { double pred,residual; Instances newInst = new Instances(data); for (int i = 0; i < newInst.numInstances(); i++) { pred = c.classifyInstance(newInst.instance(i)); if (useShrinkage) { pred *= getShrinkage(); } residual = newInst.instance(i).classValue() - pred; newInst.instance(i).setClassValue(residual); } // System.err.print(newInst); return newInst; } /** * Returns an enumeration of the additional measure names * @return an enumeration of the measure names */ public Enumeration enumerateMeasures() { Vector newVector = new Vector(1); newVector.addElement("measureNumIterations"); return newVector.elements(); } /** * Returns the value of the named measure * @param additionalMeasureName the name of the measure to query for its value * @return the value of the named measure * @throws IllegalArgumentException if the named measure is not supported */ public double getMeasure(String additionalMeasureName) { if (additionalMeasureName.compareToIgnoreCase("measureNumIterations") == 0) { return measureNumIterations(); } else { throw new IllegalArgumentException(additionalMeasureName + " not supported (AdditiveRegression)"); } } /** * return the number of iterations (base classifiers) completed * @return the number of iterations (same as number of base classifier * models) */ public double measureNumIterations() { return m_NumIterationsPerformed; } /** * Returns textual description of the classifier. * * @return a description of the classifier as a string */ public String toString() { StringBuffer text = new StringBuffer(); // only ZeroR model? if (!m_SuitableData) { StringBuffer buf = new StringBuffer(); buf.append(this.getClass().getName().replaceAll(".*\\.", "") + "\n"); buf.append(this.getClass().getName().replaceAll(".*\\.", "").replaceAll(".", "=") + "\n\n"); buf.append("Warning: No model could be built, hence ZeroR model is used:\n\n"); buf.append(m_zeroR.toString()); return buf.toString(); } if (m_NumIterations == 0) { return "Classifier hasn't been built yet!"; } text.append("Additive Regression\n\n"); text.append("ZeroR model\n\n" + m_zeroR + "\n\n"); text.append("Base classifier " + getClassifier().getClass().getName() + "\n\n"); text.append("" + m_NumIterationsPerformed + " models generated.\n"); for (int i = 0; i < m_NumIterationsPerformed; i++) { text.append("\nModel number " + i + "\n\n" + m_Classifiers[i] + "\n"); } return text.toString(); } /** * Main method for testing this class. * * @param argv should contain the following arguments: * -t training file [-T test file] [-c class index] */ public static void main(String [] argv) { runClassifier(new AdditiveRegression(), argv); }}
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?