⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 rbfnetwork.java

📁 一个数据挖掘软件ALPHAMINERR的整个过程的JAVA版源代码
💻 JAVA
字号:
/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */


/*
 *    RBFNetwork.java
 *    Copyright (C) 2004 Mark Hall
 *
 */
package weka.classifiers.functions;

import java.util.Enumeration;
import java.util.Vector;

import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.clusterers.MakeDensityBasedClusterer;
import weka.clusterers.SimpleKMeans;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.ClusterMembership;

/**
 * Class that implements a radial basis function network. 
 * It uses the K-Means clustering algorithm to provide the basis
 * functions and learns either a logistic regression (discrete
 * class problems) or linear regression (numeric class problems)
 * on top of that. <p>
 *
 * Valid options are:<p>
 *
 * -B num <br>
 * Set the number of clusters (basis functions) to use.<p>
 *
 * -R ridge <br>
 * Set the ridge parameter for the logistic regression or linear regression.<p>
 * 
 * -M num <br>
 * Set the maximum number of iterations for logistic regression.
 * (default -1, until convergence)<p>
 *
 * -S seed <br>
 * Set the random seed used by K-means when generating clusters. 
 * (default 1). <p>
 *
 * @author Mark Hall
 * @version $Revision$
 */
public class RBFNetwork extends Classifier implements OptionHandler {

  /**
	 * 
	 */
	private static final long serialVersionUID = -7841548138113026179L;

/** The logistic regression for classification problems */
  private Logistic m_logistic;

  /** The linear regression for numeric problems */
  private LinearRegression m_linear;

  /** The filter for producing the meta data */
  private ClusterMembership m_basisFilter;

  /** The number of clusters (basis functions to generate) */
  private int m_numClusters = 2;

  /** The ridge parameter for the logistic regression. */
  protected double m_ridge = 1e-8;

  /** The maximum number of iterations for logistic regression. */
  private int m_maxIts = -1;

  /** The seed to pass on to K-means */
  private int m_clusteringSeed = 1;

  /**
   * Returns a string describing this classifier
   * @return a description of the classifier suitable for
   * displaying in the explorer/experimenter gui
   */
  public String globalInfo() {
    return "Class that implements a radial basis function network. "
      +"It uses the K-Means clustering algorithm to provide the basis "
      +"functions and learns either a logistic regression (discrete "
      +"class problems) or linear regression (numeric class problems) "
      +"on top of that.";
  }

  /**
   * Builds the classifier
   *
   * @param instances the training data
   * @exception Exception if the classifier could not be built successfully
   */
  public void buildClassifier(Instances instances) throws Exception {
    SimpleKMeans sk = new SimpleKMeans();
    sk.setNumClusters(m_numClusters);
    sk.setSeed(m_clusteringSeed);
    MakeDensityBasedClusterer dc = new MakeDensityBasedClusterer();
    dc.setClusterer(sk);
    m_basisFilter = new ClusterMembership();
    m_basisFilter.setDensityBasedClusterer(dc);
    m_basisFilter.setInputFormat(instances);
    Instances transformed = Filter.useFilter(instances, m_basisFilter);

    if (instances.classAttribute().isNominal()) {
      m_linear = null;
      m_logistic = new Logistic();
      m_logistic.setRidge(m_ridge);
      m_logistic.setMaxIts(m_maxIts);
      m_logistic.buildClassifier(transformed);
    } else {
      m_logistic = null;
      m_linear = new LinearRegression();
      m_linear.setRidge(m_ridge);
      m_linear.buildClassifier(transformed);
    }
  }

  /**
   * Computes the distribution for a given instance
   *
   * @param instance the instance for which distribution is computed
   * @return the distribution
   * @exception Exception if the distribution can't be computed successfully
   */
  public double [] distributionForInstance(Instance instance) 
    throws Exception {
    m_basisFilter.input(instance);
    Instance transformed = m_basisFilter.output();
    
    return ((instance.classAttribute().isNominal()
	     ? m_logistic.distributionForInstance(transformed)
	     : m_linear.distributionForInstance(transformed)));
  }
  
  /**
   * Returns a description of this classifier as a String
   *
   * @return a description of this classifier
   */
  public String toString() {
    StringBuffer sb = new StringBuffer();
    sb.append("Radial basis function network\n");
    sb.append((m_linear == null) 
	      ? "(Logistic regression "
	      : "(Linear regression ");
    sb.append("applied to K-means clusters as basis functions):\n\n");
    sb.append((m_linear == null)
	      ? m_logistic.toString()
	      : m_linear.toString());
    return sb.toString();
  }

  /**
   * Returns the tip text for this property
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
  public String maxItsTipText() {
    return "Maximum number of iterations for the logistic regression to perform. "
      +"Only applied to discrete class problems.";
  }

  /**
   * Get the value of MaxIts.
   *
   * @return Value of MaxIts.
   */
  public int getMaxIts() {
	
    return m_maxIts;
  }
    
  /**
   * Set the value of MaxIts.
   *
   * @param newMaxIts Value to assign to MaxIts.
   */
  public void setMaxIts(int newMaxIts) {
	
    m_maxIts = newMaxIts;
  }    

  /**
   * Returns the tip text for this property
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
  public String ridgeTipText() {
    return "Set the Ridge value for the logistic or linear regression.";
  }

  /**
   * Sets the ridge value for logistic or linear regression.
   *
   * @param ridge the ridge
   */
  public void setRidge(double ridge) {
    m_ridge = ridge;
  }
    
  /**
   * Gets the ridge value.
   *
   * @return the ridge
   */
  public double getRidge() {
    return m_ridge;
  }

  /**
   * Returns the tip text for this property
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
  public String numClustersTipText() {
    return "The number of clusters for K-Means to generate.";
  }

  /**
   * Set the number of clusters for K-means to generate.
   *
   * @param numClusters the number of clusters to generate.
   */
  public void setNumClusters(int numClusters) {
    if (numClusters > 0) {
      m_numClusters = numClusters;
    }
  }

  /**
   * Return the number of clusters to generate.
   *
   * @return the number of clusters to generate.
   */
  public int getNumClusters() {
    return m_numClusters;
  }

  /**
   * Returns the tip text for this property
   * @return tip text for this property suitable for
   * displaying in the explorer/experimenter gui
   */
  public String clusteringSeedTipText() {
    return "The random seed to pass on to K-means.";
  }
  
  /**
   * Set the random seed to be passed on to K-means.
   *
   * @param seed a seed value.
   */
  public void setClusteringSeed(int seed) {
    m_clusteringSeed = seed;
  }

  /**
   * Get the random seed used by K-means.
   *
   * @return the seed value.
   */
  public int getClusteringSeed() {
    return m_clusteringSeed;
  }

  /**
   * Returns an enumeration describing the available options
   *
   * @return an enumeration of all the available options
   */
  public Enumeration<Option> listOptions() {
    Vector<Option> newVector = new Vector<Option>(4);

    newVector.addElement(new Option("\tSet the number of clusters (basis functions) "
				    +"to generate. (default = 2).",
				    "B", 1, "-B <number>"));
    newVector.addElement(new Option("\tSet the random seed to be used by K-means. "
				    +"(default = 1).",
				    "S", 1, "-S <seed>"));
    newVector.addElement(new Option("\tSet the ridge value for the logistic or "
				    +"linear regression.",
				    "R", 1, "-R <ridge>"));
    newVector.addElement(new Option("\tSet the maximum number of iterations "
				    +"for the logistic regression."
				    + " (default -1, until convergence).",
				    "M", 1, "-M <number>"));
    return newVector.elements();
  }

  /**
   * Parses a given list of options. Valid options are:<p>
   *
   * -B num <br>
   * Set the number of clusters (basis functions) to use.<p>
   *
   * -R ridge <br>
   * Set the ridge parameter for the logistic regression or linear regression.<p>
   * 
   * -M num <br>
   * Set the maximum number of iterations for logistic regression.
   * (default -1, until convergence)<p>
   *
   * -S seed <br>
   * Set the random seed used by K-means when generating clusters. 
   * (default 1). <p>
   *
   * @param options the list of options as an array of strings
   * @exception Exception if an option is not supported
   */
  public void setOptions(String[] options) throws Exception {
    setDebug(Utils.getFlag('D', options));

    String ridgeString = Utils.getOption('R', options);
    if (ridgeString.length() != 0) {
      m_ridge = Double.parseDouble(ridgeString);
    } else {
      m_ridge = 1.0e-8;
    }
	
    String maxItsString = Utils.getOption('M', options);
    if (maxItsString.length() != 0) {
      m_maxIts = Integer.parseInt(maxItsString);
    } else {
      m_maxIts = -1;
    }

    String numClustersString = Utils.getOption('B', options);
    if (numClustersString.length() != 0) {
      setNumClusters(Integer.parseInt(numClustersString));
    }

    String seedString = Utils.getOption('S', options);
    if (seedString.length() != 0) {
      setClusteringSeed(Integer.parseInt(seedString));
    }
    Utils.checkForRemainingOptions(options);
  }

  /**
   * Gets the current settings of the classifier.
   *
   * @return an array of strings suitable for passing to setOptions
   */
  public String [] getOptions() {
	
    String [] options = new String [8];
    int current = 0;
    
    options[current++] = "-B";
    options[current++] = "" + m_numClusters;
    options[current++] = "-S";
    options[current++] = "" + m_clusteringSeed;
    options[current++] = "-R";
    options[current++] = ""+m_ridge;	
    options[current++] = "-M";
    options[current++] = ""+m_maxIts;

    while (current < options.length) 
      options[current++] = "";
    return options;
  }

  /**
   * Main method for testing this class.
   *
   * @param argv should contain the command line arguments to the
   * scheme (see Evaluation)
   */
  public static void main(String [] argv) {
    try {
      System.out.println(Evaluation.evaluateModel(new RBFNetwork(), argv));
    } catch (Exception e) {
      e.printStackTrace();
      System.err.println(e.getMessage());
    }
  }
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -