📄 additiveregression.java
字号:
public boolean getDebug() {
return m_debug;
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String classifierTipText() {
return "Classifier to use";
}
/**
* Sets the classifier
*
* @param classifier the classifier with all options set.
*/
public void setClassifier(Classifier classifier) {
m_Classifier = classifier;
}
/**
* Gets the classifier used.
*
* @return the classifier
*/
public Classifier getClassifier() {
return m_Classifier;
}
/**
* Gets the classifier specification string, which contains the class name of
* the classifier and any options to the classifier
*
* @return the classifier string.
*/
protected String getClassifierSpec() {
Classifier c = getClassifier();
if (c instanceof OptionHandler) {
return c.getClass().getName() + " "
+ Utils.joinOptions(((OptionHandler)c).getOptions());
}
return c.getClass().getName();
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String maxModelsTipText() {
return "Max models to generate. <= 0 indicates no maximum, ie. continue until "
+"error reduction threshold is reached.";
}
/**
* Set the maximum number of models to generate
* @param maxM the maximum number of models
*/
public void setMaxModels(int maxM) {
m_maxModels = maxM;
}
/**
* Get the max number of models to generate
* @return the max number of models to generate
*/
public int getMaxModels() {
return m_maxModels;
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String shrinkageTipText() {
return "Shrinkage rate. Smaller values help prevent overfitting and "
+ "have a smoothing effect (but increase learning time). "
+"Default = 1.0, ie. no shrinkage.";
}
/**
* Set the shrinkage parameter
*
* @param l the shrinkage rate.
*/
public void setShrinkage(double l) {
m_shrinkage = l;
}
/**
* Get the shrinkage rate.
*
* @return the value of the learning rate
*/
public double getShrinkage() {
return m_shrinkage;
}
/**
* Build the classifier on the supplied data
*
* @param data the training data
* @exception Exception if the classifier could not be built successfully
*/
public void buildClassifier(Instances data) throws Exception {
m_additiveModels = new FastVector();
if (m_Classifier == null) {
throw new Exception("No base classifiers have been set!");
}
if (data.classAttribute().isNominal()) {
throw new UnsupportedClassTypeException("Class must be numeric!");
}
Instances newData = new Instances(data);
newData.deleteWithMissingClass();
m_classIndex = newData.classIndex();
double sum = 0;
double temp_sum = 0;
// Add the model for the mean first
ZeroR zr = new ZeroR();
zr.buildClassifier(newData);
m_additiveModels.addElement(zr);
newData = residualReplace(newData, zr, false);
for (int i = 0; i < newData.numInstances(); i++) {
sum += newData.instance(i).weight() *
newData.instance(i).classValue() *
newData.instance(i).classValue();
}
if (m_debug) {
System.err.println("Sum of squared residuals "
+"(predicting the mean) : "+sum);
}
int modelCount = 0;
do {
temp_sum = sum;
Classifier nextC = Classifier.makeCopies(m_Classifier, 1)[0];
nextC.buildClassifier(newData);
m_additiveModels.addElement(nextC);
newData = residualReplace(newData, nextC, true);
sum = 0;
for (int i = 0; i < newData.numInstances(); i++) {
sum += newData.instance(i).weight() *
newData.instance(i).classValue() *
newData.instance(i).classValue();
}
if (m_debug) {
System.err.println("Sum of squared residuals : "+sum);
}
modelCount++;
} while (((temp_sum - sum) > Utils.SMALL) &&
(m_maxModels > 0 ? (modelCount < m_maxModels) : true));
}
/**
* Classify an instance.
*
* @param inst the instance to predict
* @return a prediction for the instance
* @exception Exception if an error occurs
*/
public double classifyInstance(Instance inst) throws Exception {
double prediction = 0;
for (int i = 0; i < m_additiveModels.size(); i++) {
Classifier current = (Classifier)m_additiveModels.elementAt(i);
double toAdd = current.classifyInstance(inst);
if (i > 0) {
toAdd *= getShrinkage();
}
prediction += toAdd;
}
return prediction;
}
/**
* Replace the class values of the instances from the current iteration
* with residuals ater predicting with the supplied classifier.
*
* @param data the instances to predict
* @param c the classifier to use
* @param useShrinkage whether shrinkage is to be applied to the model's output
* @return a new set of instances with class values replaced by residuals
*/
private Instances residualReplace(Instances data, Classifier c,
boolean useShrinkage) throws Exception {
double pred,residual;
Instances newInst = new Instances(data);
for (int i = 0; i < newInst.numInstances(); i++) {
pred = c.classifyInstance(newInst.instance(i));
if (useShrinkage) {
pred *= getShrinkage();
}
residual = newInst.instance(i).classValue() - pred;
newInst.instance(i).setClassValue(residual);
}
// System.err.print(newInst);
return newInst;
}
/**
* Returns an enumeration of the additional measure names
* @return an enumeration of the measure names
*/
public Enumeration emerateMeasures() {
Vector newVector = new Vector(1);
newVector.addElement("measureNumIterations");
return newVector.elements();
}
/**
* Returns the value of the named measure
* @param measureName the name of the measure to query for its value
* @return the value of the named measure
* @exception IllegalArgumentException if the named measure is not supported
*/
public double getMeasure(String additionalMeasureName) {
if (additionalMeasureName.compareToIgnoreCase("measureNumIterations") == 0) {
return measureNumIterations();
} else {
throw new IllegalArgumentException(additionalMeasureName
+ " not supported (AdditiveRegression)");
}
}
/**
* return the number of iterations (base classifiers) completed
* @return the number of iterations (same as number of base classifier
* models)
*/
public double measureNumIterations() {
return m_additiveModels.size();
}
/**
* Returns textual description of the classifier.
*
* @return a description of the classifier as a string
*/
public String toString() {
StringBuffer text = new StringBuffer();
if (m_additiveModels.size() == 0) {
return "Classifier hasn't been built yet!";
}
text.append("Additive Regression\n\n");
text.append("Base classifier "
+ getClassifier().getClass().getName()
+ "\n\n");
text.append(""+m_additiveModels.size()+" models generated.\n");
for (int i = 0; i < m_additiveModels.size(); i++) {
text.append("\nModel number " + i + "\n\n" +
m_additiveModels.elementAt(i) + "\n");
}
return text.toString();
}
/**
* Main method for testing this class.
*
* @param argv should contain the following arguments:
* -t training file [-T test file] [-c class index]
*/
public static void main(String [] argv) {
try {
System.out.println(Evaluation.evaluateModel(new AdditiveRegression(),
argv));
} catch (Exception e) {
System.err.println(e.getMessage());
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -