⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 lwr.java

📁 :<<数据挖掘--实用机器学习技术及java实现>>一书的配套源程序
💻 JAVA
字号:
/* *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* *    LWR.java *    Copyright (C) 1999 Len Trigg * */package weka.classifiers;import java.io.*;import java.util.*;import weka.core.*;/** * Locally-weighted regression. Uses an instance-based algorithm to assign * instance weights which are then used by a linear regression model.  For * more information, see<p> * * Atkeson, C., A. Moore, and S. Schaal (1996) <i>Locally weighted * learning</i> * <a href="ftp://ftp.cc.gatech.edu/pub/people/cga/air1.ps.gz">download  * postscript</a>. <p> * * Valid options are:<p> * * -D <br> * Produce debugging output. <p> * * -K num <br> * Set the number of neighbours used for setting kernel bandwidth. * (default all) <p> * * -W num <br> * Set the weighting kernel shape to use. 1 = Inverse, 2 = Gaussian. * (default 0 = Linear) <p> * * @author Len Trigg (trigg@cs.waikato.ac.nz) * @version $Revision: 1.9 $ */public class LWR extends Classifier   implements OptionHandler, UpdateableClassifier,   WeightedInstancesHandler {    /** The training instances used for classification. */  protected Instances m_Train;  /** The minimum values for numeric attributes. */  protected double [] m_Min;  /** The maximum values for numeric attributes. */  protected double [] m_Max;      /** True if debugging output should be printed */  protected boolean m_Debug;  /** The number of neighbours used to select the kernel bandwidth */  protected int m_kNN = 5;  /** The weighting kernel method currently selected */  protected int m_WeightKernel = LINEAR;  /** True if m_kNN should be set to all instances */  protected boolean m_UseAllK = true;  /* The available kernel weighting methods */  protected static final int LINEAR  = 0;  protected static final int INVERSE = 1;  protected static final int GAUSS   = 2;  /**   * Returns an enumeration describing the available options   *   * @return an enumeration of all the available options   */  public Enumeration listOptions() {        Vector newVector = new Vector(2);    newVector.addElement(new Option("\tProduce debugging output.\n"				    + "\t(default no debugging output)",				    "D", 0, "-D"));    newVector.addElement(new Option("\tSet the number of neighbors used to set"				    + " the kernel bandwidth.\n"				    + "\t(default all)",				    "K", 1, "-K <number of neighbours>"));    newVector.addElement(new Option("\tSet the weighting kernel shape to use."				    + " 1 = Inverse, 2 = Gaussian.\n"				    + "\t(default 0 = Linear)",				    "W", 1,"-W <number of weighting method>"));    return newVector.elements();  }  /**   * Parses a given list of options. Valid options are:<p>   *   * -D <br>   * Produce debugging output. <p>   *   * -K num <br>   * Set the number of neighbours used for setting kernel bandwidth.   * (default all) <p>   *   * -W num <br>   * Set the weighting kernel shape to use. 1 = Inverse, 2 = Gaussian.   * (default 0 = Linear) <p>   *   * @param options the list of options as an array of strings   * @exception Exception if an option is not supported   */  public void setOptions(String[] options) throws Exception {    String knnString = Utils.getOption('K', options);    if (knnString.length() != 0) {      setKNN(Integer.parseInt(knnString));    } else {      setKNN(0);    }    String weightString = Utils.getOption('W', options);    if (weightString.length() != 0) {      setWeightingKernel(Integer.parseInt(weightString));    } else {      setWeightingKernel(LINEAR);    }    setDebug(Utils.getFlag('D', options));  }  /**   * Gets the current settings of the classifier.   *   * @return an array of strings suitable for passing to setOptions   */  public String [] getOptions() {    String [] options = new String [5];    int current = 0;    if (getDebug()) {      options[current++] = "-D";    }    options[current++] = "-W"; options[current++] = "" + getWeightingKernel();    if (!m_UseAllK) {      options[current++] = "-K"; options[current++] = "" + getKNN();    }    while (current < options.length) {      options[current++] = "";    }    return options;  }  /**   * Sets whether debugging output should be produced   *   * @param debug true if debugging output should be printed   */  public void setDebug(boolean debug) {    m_Debug = debug;  }  /**   * SGts whether debugging output should be produced   *   * @return true if debugging output should be printed   */  public boolean getDebug() {    return m_Debug;  }  /**   * Sets the number of neighbours used for kernel bandwidth setting.   * The bandwidth is taken as the distance to the kth neighbour.   *   * @param knn the number of neighbours included inside the kernel   * bandwidth, or 0 to specify using all neighbors.   */  public void setKNN(int knn) {    m_kNN = knn;    if (knn <= 0) {      m_kNN = 0;      m_UseAllK = true;    } else {      m_UseAllK = false;    }  }  /**   * Gets the number of neighbours used for kernel bandwidth setting.   * The bandwidth is taken as the distance to the kth neighbour.   *   * @return the number of neighbours included inside the kernel   * bandwidth, or 0 for all neighbours   */  public int getKNN() {    return m_kNN;  }  /**   * Sets the kernel weighting method to use. Must be one of LINEAR,   * INVERSE, or GAUSS, other values are ignored.   *   * @param kernel the new kernel method to use. Must be one of LINEAR,   * INVERSE, or GAUSS   */  public void setWeightingKernel(int kernel) {    if ((kernel != LINEAR)	&& (kernel != INVERSE)	&& (kernel != GAUSS)) {      return;    }    m_WeightKernel = kernel;  }  /**   * Gets the kernel weighting method to use.   *   * @return the new kernel method to use. Will be one of LINEAR,   * INVERSE, or GAUSS   */  public int getWeightingKernel() {    return m_WeightKernel;  }  /**   * Gets an attributes minimum observed value   *   * @param index the index of the attribute   * @return the minimum observed value   */  protected double getAttributeMin(int index) {    return m_Min[index];  }  /**   * Gets an attributes maximum observed value   *   * @param index the index of the attribute   * @return the maximum observed value   */  protected double getAttributeMax(int index) {    return m_Max[index];  }  /**   * Generates the classifier.   *   * @param instances set of instances serving as training data    * @exception Exception if the classifier has not been generated successfully   */  public void buildClassifier(Instances instances) throws Exception {    if (instances.classIndex() < 0) {      throw new Exception("No class attribute assigned to instances");    }    if (instances.classAttribute().type() != Attribute.NUMERIC) {      throw new Exception("Class attribute must be numeric");    }    if (instances.checkForStringAttributes()) {      throw new Exception("Can't handle string attributes!");    }    // Throw away training instances with missing class    m_Train = new Instances(instances, 0, instances.numInstances());    m_Train.deleteWithMissingClass();    // Calculate the minimum and maximum values    m_Min = new double [m_Train.numAttributes()];    m_Max = new double [m_Train.numAttributes()];    for (int i = 0; i < m_Train.numAttributes(); i++) {      m_Min[i] = m_Max[i] = Double.NaN;    }    for (int i = 0; i < m_Train.numInstances(); i++) {      updateMinMax(m_Train.instance(i));    }  }  /**   * Adds the supplied instance to the training set   *   * @param instance the instance to add   * @exception Exception if instance could not be incorporated   * successfully   */  public void updateClassifier(Instance instance) throws Exception {    if (m_Train.equalHeaders(instance.dataset()) == false) {      throw new Exception("Incompatible instance types");    }    if (!instance.classIsMissing()) {      updateMinMax(instance);      m_Train.add(instance);    }  }    /**   * Predicts the class value for the given test instance.   *   * @param instance the instance to be classified   * @return the predicted class value   * @exception Exception if an error occurred during the prediction   */  public double classifyInstance(Instance instance) throws Exception {    if (m_Train.numInstances() == 0) {      throw new Exception("No training instances!");    }    updateMinMax(instance);    // Get the distances to each training instance    double [] distance = new double [m_Train.numInstances()];    for (int i = 0; i < m_Train.numInstances(); i++) {      distance[i] = distance(instance, m_Train.instance(i));    }    int [] sortKey = Utils.sort(distance);    if (m_Debug) {      System.out.println("Instance Distances");      for (int i = 0; i < distance.length; i++) {	System.out.println("" + distance[sortKey[i]]);      }    }    // Determine the bandwidth    int k = sortKey.length - 1;    if (!m_UseAllK && (m_kNN < k)) {      k = m_kNN;    }    double bandwidth = distance[sortKey[k]];    if (bandwidth == distance[sortKey[0]]) {      for (int i = k; i < sortKey.length; i++) {	if (distance[sortKey[i]] > bandwidth) {	  bandwidth = distance[sortKey[i]];	  break;	}      }      if (bandwidth == distance[sortKey[0]]) {	bandwidth *= 10;  // Include them all      }    }    // Rescale the distances by the bandwidth    for (int i = 0; i < distance.length; i++) {      distance[i] = distance[i] / bandwidth;    }    // Pass the distances through a weighting kernel    for (int i = 0; i < distance.length; i++) {      switch (m_WeightKernel) {      case LINEAR:	distance[i] = Math.max(1.0001 - distance[i], 0);	break;      case INVERSE:	distance[i] = 1.0 / (1.0 + distance[i]);	break;      case GAUSS:	distance[i] = Math.exp(-distance[i] * distance[i]);	break;      }    }    if (m_Debug) {      System.out.println("Instance Weights");      for (int i = 0; i < distance.length; i++) {	System.out.println("" + distance[i]);      }    }    // Set the weights on a copy of the training data    Instances weightedTrain = new Instances(m_Train, 0);    for (int i = 0; i < distance.length; i++) {      double weight = distance[sortKey[i]];      if (weight < 1e-20) {	break;      }      Instance newInst = (Instance) m_Train.instance(sortKey[i]).copy();      newInst.setWeight(newInst.weight() * weight);      weightedTrain.add(newInst);    }    if (m_Debug) {      System.out.println("Kept " + weightedTrain.numInstances() + " out of "			 + m_Train.numInstances() + " instances");    }        // Create a weighted linear regression    LinearRegression weightedRegression = new LinearRegression();    weightedRegression.buildClassifier(weightedTrain);    if (m_Debug) {      System.out.println("Classifying test instance: " + instance);      System.out.println("Built regression model:\n" 			 + weightedRegression.toString());    }    // Return the linear regression's prediction    return weightedRegression.classifyInstance(instance);  }   /**   * Returns a description of this classifier.   *   * @return a description of this classifier as a string.   */  public String toString() {    if (m_Train == null) {      return "Locally weighted regression: No model built yet.";    }    String result = "Locally weighted regression\n"      + "===========================\n";    switch (m_WeightKernel) {    case LINEAR:      result += "Using linear weighting kernels\n";      break;    case INVERSE:      result += "Using inverse-distance weighting kernels\n";      break;    case GAUSS:      result += "Using gaussian weighting kernels\n";      break;    }    result += "Using " + (m_UseAllK ? "all" : "" + m_kNN) + " neighbours";    return result;  }  /**   * Calculates the distance between two instances   *   * @param test the first instance   * @param train the second instance   * @return the distance between the two given instances, between 0 and 1   */            private double distance(Instance first, Instance second) {      double diff, distance = 0;    int numAttribsUsed = 0;    for(int i = 0; i < m_Train.numAttributes(); i++) {       if (i == m_Train.classIndex()) {	continue;      }      switch (m_Train.attribute(i).type()) {      case Attribute.NOMINAL:	// If attribute is nominal	numAttribsUsed++;	if (first.isMissing(i) || second.isMissing(i) ||	    ((int)first.value(i) != (int)second.value(i))) {	  diff = 1;	} else {	  diff = 0;	}	break;      case Attribute.NUMERIC:	// If attribute is numeric	numAttribsUsed++;		if (first.isMissing(i) || second.isMissing(i)) {	  if (first.isMissing(i) && second.isMissing(i)) {	    diff = 1;	  } else {	    if (second.isMissing(i)) {	      diff = norm(first.value(i),i);	    } else {	      diff = norm(second.value(i),i);	    }	    if (diff < 0.5) {	      diff = 1.0-diff;	    }	  }	} else {	  diff = norm(first.value(i),i) - norm(second.value(i),i);	}	break;      default:	diff = 0;	break;      }      distance += diff * diff;    }    return Math.sqrt(distance / numAttribsUsed);  }  /**   * Normalizes a given value of a numeric attribute.   *   * @param x the value to be normalized   * @param i the attribute's index   */  private double norm(double x,int i) {    if (Double.isNaN(m_Min[i]) || Utils.eq(m_Max[i], m_Min[i])) {      return 0;    } else {      return (x - m_Min[i]) / (m_Max[i] - m_Min[i]);    }  }                        /**   * Updates the minimum and maximum values for all the attributes   * based on a new instance.   *   * @param instance the new instance   */  private void updateMinMax(Instance instance) {      for (int j = 0; j < m_Train.numAttributes(); j++) {      if (!instance.isMissing(j)) {	if (Double.isNaN(m_Min[j])) {	  m_Min[j] = instance.value(j);	  m_Max[j] = instance.value(j);	} else if (instance.value(j) < m_Min[j]) {	  m_Min[j] = instance.value(j);	} else if (instance.value(j) > m_Max[j]) {	  m_Max[j] = instance.value(j);	}      }    }  }  /**   * Main method for testing this class.   *   * @param argv the options   */  public static void main(String [] argv) {    try {      System.out.println(Evaluation.evaluateModel(	    new LWR(), argv));    } catch (Exception e) {      System.err.println(e.getMessage());    }  }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -