📄 lwl.java
字号:
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* LWL.java
* Copyright (C) 1999, 2002, 2003 Len Trigg, Eibe Frank, Ashraf M. Kibriya
*
*/
package weka.classifiers.lazy;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.trees.DecisionStump;
import weka.classifiers.UpdateableClassifier;
import weka.classifiers.SingleClassifierEnhancer;
import java.io.*;
import java.util.*;
import weka.core.*;
/**
* Locally-weighted learning. Uses an instance-based algorithm to
* assign instance weights which are then used by a specified
* WeightedInstancesHandler. A good choice for classification is
* NaiveBayes. LinearRegression is suitable for regression problems.
* For more information, see<p>
*
* Eibe Frank, Mark Hall, and Bernhard Pfahringer (2003). Locally
* Weighted Naive Bayes. Working Paper 04/03, Department of Computer
* Science, University of Waikato.
*
* Atkeson, C., A. Moore, and S. Schaal (1996) <i>Locally weighted
* learning</i>
* <a href="ftp://ftp.cc.gatech.edu/pub/people/cga/air1.ps.gz">download
* postscript</a>. <p>
*
* Valid options are:<p>
*
* -D <br>
* Produce debugging output. <p>
*
* -N <br>
* Do not normalize numeric attributes' values in distance calculation.<p>
*
* -K num <br>
* Set the number of neighbours used for setting kernel bandwidth.
* (default all) <p>
*
* -U num <br>
* Set the weighting kernel shape to use. 0 = Linear, 1 = Epnechnikov,
* 2 = Tricube, 3 = Inverse, 4 = Gaussian and 5 = Constant.
* (default 0 = Linear) <p>
*
* -W classname <br>
* Specify the full class name of a base classifier (which needs
* to be a WeightedInstancesHandler).<p>
*
* @author Len Trigg (trigg@cs.waikato.ac.nz)
* @author Eibe Frank (eibe@cs.waikato.ac.nz)
* @author Ashraf M. Kibriya (amk14@waikato.ac.nz)
* @version $Revision: 1.1 $
*/
public class LWL extends SingleClassifierEnhancer
implements UpdateableClassifier, WeightedInstancesHandler {
/** The training instances used for classification. */
protected Instances m_Train;
/** The minimum values for numeric attributes. */
protected double [] m_Min;
/** The maximum values for numeric attributes. */
protected double [] m_Max;
/** True if numeric attributes' values should not be normalized in distance
calculation. */
protected boolean m_NoAttribNorm=false;
/** The number of neighbours used to select the kernel bandwidth */
protected int m_kNN = -1;
/** The weighting kernel method currently selected */
protected int m_WeightKernel = LINEAR;
/** True if m_kNN should be set to all instances */
protected boolean m_UseAllK = true;
/** The available kernel weighting methods */
protected static final int LINEAR = 0;
protected static final int EPANECHNIKOV = 1;
protected static final int TRICUBE = 2;
protected static final int INVERSE = 3;
protected static final int GAUSS = 4;
protected static final int CONSTANT = 5;
/**
* Returns a string describing classifier
* @return a description suitable for
* displaying in the explorer/experimenter gui
*/
public String globalInfo() {
return "Class for performing locally weighted learning. Can do "
+ "classification (e.g. using naive Bayes) or regression (e.g. using "
+ "linear regression). The base learner needs to implement "
+ "WeightedInstancesHandler. For more info, see\n\n"
+ "Eibe Frank, Mark Hall, and Bernhard Pfahringer (2003). \"Locally "
+ "Weighted Naive Bayes\". Conference on Uncertainty in AI.\n\n"
+ "Atkeson, C., A. Moore, and S. Schaal (1996) \"Locally weighted "
+ "learning\" AI Reviews.";
}
/**
* Constructor.
*/
public LWL() {
m_Classifier = new weka.classifiers.trees.DecisionStump();
}
/**
* String describing default classifier.
*/
protected String defaultClassifierString() {
return "weka.classifiers.trees.DecisionStump";
}
/**
* Returns an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
public Enumeration listOptions() {
Vector newVector = new Vector(3);
newVector.addElement(new Option("\tDo not normalize numeric attributes' "
+"values in distance calculation.\n"
+"\t(default DO normalization)",
"N", 0, "-N"));
newVector.addElement(new Option("\tSet the number of neighbours used to set"
+" the kernel bandwidth.\n"
+"\t(default all)",
"K", 1, "-K <number of neighbours>"));
newVector.addElement(new Option("\tSet the weighting kernel shape to use."
+" 0=Linear, 1=Epanechnikov,\n"
+"\t2=Tricube, 3=Inverse, 4=Gaussian.\n"
+"\t(default 0 = Linear)",
"U", 1,"-U <number of weighting method>"));
Enumeration enu = super.listOptions();
while (enu.hasMoreElements()) {
newVector.addElement(enu.nextElement());
}
return newVector.elements();
}
/**
* Parses a given list of options. Valid options are:<p>
*
* -D <br>
* Produce debugging output. <p>
*
* -N <br>
* Do not normalize numeric attributes' values in distance calculation.
* (default DO normalization)<p>
*
* -K num <br>
* Set the number of neighbours used for setting kernel bandwidth.
* (default all) <p>
*
* -U num <br>
* Set the weighting kernel shape to use. 0 = Linear, 1 = Epnechnikov,
* 2 = Tricube, 3 = Inverse, 4 = Gaussian and 5 = Constant.
* (default 0 = Linear) <p>
*
* -W classname <br>
* Specify the full class name of a base classifier (which needs
* to be a WeightedInstancesHandler).<p>
*
* @param options the list of options as an array of strings
* @exception Exception if an option is not supported
*/
public void setOptions(String[] options) throws Exception {
String knnString = Utils.getOption('K', options);
if (knnString.length() != 0) {
setKNN(Integer.parseInt(knnString));
} else {
setKNN(0);
}
String weightString = Utils.getOption('U', options);
if (weightString.length() != 0) {
setWeightingKernel(Integer.parseInt(weightString));
} else {
setWeightingKernel(LINEAR);
}
setDontNormalize(Utils.getFlag('N', options));
super.setOptions(options);
}
/**
* Gets the current settings of the classifier.
*
* @return an array of strings suitable for passing to setOptions
*/
public String [] getOptions() {
String [] superOptions = super.getOptions();
String [] options = new String [superOptions.length + 5];
int current = 0;
options[current++] = "-U"; options[current++] = "" + getWeightingKernel();
//if (!m_UseAllK) {
options[current++] = "-K"; options[current++] = "" + getKNN();
//}
if (getDontNormalize()) {
options[current++] = "-N";
}
else
options[current++] = "";
System.arraycopy(superOptions, 0, options, current,
superOptions.length);
return options;
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String KNNTipText() {
return "How many neighbours are used to determine the width of the "
+ "weighting function (<= 0 means all neighbours).";
}
/**
* Sets the number of neighbours used for kernel bandwidth setting.
* The bandwidth is taken as the distance to the kth neighbour.
*
* @param knn the number of neighbours included inside the kernel
* bandwidth, or 0 to specify using all neighbors.
*/
public void setKNN(int knn) {
m_kNN = knn;
if (knn <= 0) {
m_kNN = 0;
m_UseAllK = true;
} else {
m_UseAllK = false;
}
}
/**
* Gets the number of neighbours used for kernel bandwidth setting.
* The bandwidth is taken as the distance to the kth neighbour.
*
* @return the number of neighbours included inside the kernel
* bandwidth, or 0 for all neighbours
*/
public int getKNN() {
return m_kNN;
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String weightingKernelTipText() {
return "Determines weighting function. [0 = Linear, 1 = Epnechnikov,"+
"2 = Tricube, 3 = Inverse, 4 = Gaussian and 5 = Constant. "+
"(default 0 = Linear)].";
}
/**
* Sets the kernel weighting method to use. Must be one of LINEAR,
* EPANECHNIKOV, TRICUBE, INVERSE, GAUSS or CONSTANT, other values
* are ignored.
*
* @param kernel the new kernel method to use. Must be one of LINEAR,
* EPANECHNIKOV, TRICUBE, INVERSE, GAUSS or CONSTANT.
*/
public void setWeightingKernel(int kernel) {
if ((kernel != LINEAR)
&& (kernel != EPANECHNIKOV)
&& (kernel != TRICUBE)
&& (kernel != INVERSE)
&& (kernel != GAUSS)
&& (kernel != CONSTANT)) {
return;
}
m_WeightKernel = kernel;
}
/**
* Gets the kernel weighting method to use.
*
* @return the new kernel method to use. Will be one of LINEAR,
* EPANECHNIKOV, TRICUBE, INVERSE, GAUSS or CONSTANT.
*/
public int getWeightingKernel() {
return m_WeightKernel;
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String dontNormalizeTipText() {
return "Turns off normalization for attribute values in distance "+
"calculation.";
}
/**
* Gets whether if the numeric attribute values are not to be normalized for
* calculating the distances.
*
* @return true if normalization is not to be performed
*/
public boolean getDontNormalize() {
return m_NoAttribNorm;
}
/**
* Sets whether if the numeric attribute values are not to be normalized for
* calculating the distances between them.
*
* @param dontNormalize true if normalization is not to be performed
*/
public void setDontNormalize(boolean normalize) {
m_NoAttribNorm = normalize;
}
/**
* Gets an attributes minimum observed value
*
* @param index the index of the attribute
* @return the minimum observed value
*/
protected double getAttributeMin(int index) {
return m_Min[index];
}
/**
* Gets an attributes maximum observed value
*
* @param index the index of the attribute
* @return the maximum observed value
*/
protected double getAttributeMax(int index) {
return m_Max[index];
}
/**
* Generates the classifier.
*
* @param instances set of instances serving as training data
* @exception Exception if the classifier has not been generated successfully
*/
public void buildClassifier(Instances instances) throws Exception {
if (!(m_Classifier instanceof WeightedInstancesHandler)) {
throw new IllegalArgumentException("Classifier must be a "
+ "WeightedInstancesHandler!");
}
if (instances.classIndex() < 0) {
throw new Exception("No class attribute assigned to instances");
}
if (instances.checkForStringAttributes()) {
throw new UnsupportedAttributeTypeException("Cannot handle string "+
"attributes!");
}
// Throw away training instances with missing class
m_Train = new Instances(instances, 0, instances.numInstances());
m_Train.deleteWithMissingClass();
// Calculate the minimum and maximum values
m_Min = new double [m_Train.numAttributes()];
m_Max = new double [m_Train.numAttributes()];
for (int i = 0; i < m_Train.numAttributes(); i++) {
m_Min[i] = m_Max[i] = Double.NaN;
}
for (int i = 0; i < m_Train.numInstances(); i++) {
updateMinMax(m_Train.instance(i));
}
}
/**
* Adds the supplied instance to the training set
*
* @param instance the instance to add
* @exception Exception if instance could not be incorporated
* successfully
*/
public void updateClassifier(Instance instance) throws Exception {
if (m_Train.equalHeaders(instance.dataset()) == false) {
throw new Exception("Incompatible instance types");
}
if (!instance.classIsMissing()) {
updateMinMax(instance);
m_Train.add(instance);
}
}
/**
* Calculates the class membership probabilities for the given test instance.
*
* @param instance the instance to be classified
* @return preedicted class probability distribution
* @exception Exception if distribution can't be computed successfully
*/
public double[] distributionForInstance(Instance instance) throws Exception {
if (m_Train.numInstances() == 0) {
throw new Exception("No training instances!");
}
updateMinMax(instance);
//Get the distances to each training instance
double [] distance = new double [m_Train.numInstances()];
MyHeap h;
int k = distance.length-1; //sortKey.length - 1;
if (!m_UseAllK && (m_kNN < k)) {
k = m_kNN;
h = new MyHeap(k);
}
else
h = new MyHeap(distance.length);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -