📄 rbfnetwork.java
字号:
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* RBFNetwork.java
* Copyright (C) 2004 Mark Hall
*
*/
package weka.classifiers.functions;
import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.clusterers.MakeDensityBasedClusterer;
import weka.clusterers.SimpleKMeans;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.ClusterMembership;
/**
* Class that implements a radial basis function network.
* It uses the K-Means clustering algorithm to provide the basis
* functions and learns either a logistic regression (discrete
* class problems) or linear regression (numeric class problems)
* on top of that. <p>
*
* Valid options are:<p>
*
* -B num <br>
* Set the number of clusters (basis functions) to use.<p>
*
* -R ridge <br>
* Set the ridge parameter for the logistic regression or linear regression.<p>
*
* -M num <br>
* Set the maximum number of iterations for logistic regression.
* (default -1, until convergence)<p>
*
* -S seed <br>
* Set the random seed used by K-means when generating clusters.
* (default 1). <p>
*
* @author Mark Hall
* @version $Revision$
*/
public class RBFNetwork extends Classifier implements OptionHandler {
/**
*
*/
private static final long serialVersionUID = -7841548138113026179L;
/** The logistic regression for classification problems */
private Logistic m_logistic;
/** The linear regression for numeric problems */
private LinearRegression m_linear;
/** The filter for producing the meta data */
private ClusterMembership m_basisFilter;
/** The number of clusters (basis functions to generate) */
private int m_numClusters = 2;
/** The ridge parameter for the logistic regression. */
protected double m_ridge = 1e-8;
/** The maximum number of iterations for logistic regression. */
private int m_maxIts = -1;
/** The seed to pass on to K-means */
private int m_clusteringSeed = 1;
/**
* Returns a string describing this classifier
* @return a description of the classifier suitable for
* displaying in the explorer/experimenter gui
*/
public String globalInfo() {
return "Class that implements a radial basis function network. "
+"It uses the K-Means clustering algorithm to provide the basis "
+"functions and learns either a logistic regression (discrete "
+"class problems) or linear regression (numeric class problems) "
+"on top of that.";
}
/**
* Builds the classifier
*
* @param instances the training data
* @exception Exception if the classifier could not be built successfully
*/
public void buildClassifier(Instances instances) throws Exception {
SimpleKMeans sk = new SimpleKMeans();
sk.setNumClusters(m_numClusters);
sk.setSeed(m_clusteringSeed);
MakeDensityBasedClusterer dc = new MakeDensityBasedClusterer();
dc.setClusterer(sk);
m_basisFilter = new ClusterMembership();
m_basisFilter.setDensityBasedClusterer(dc);
m_basisFilter.setInputFormat(instances);
Instances transformed = Filter.useFilter(instances, m_basisFilter);
if (instances.classAttribute().isNominal()) {
m_linear = null;
m_logistic = new Logistic();
m_logistic.setRidge(m_ridge);
m_logistic.setMaxIts(m_maxIts);
m_logistic.buildClassifier(transformed);
} else {
m_logistic = null;
m_linear = new LinearRegression();
m_linear.setRidge(m_ridge);
m_linear.buildClassifier(transformed);
}
}
/**
* Computes the distribution for a given instance
*
* @param instance the instance for which distribution is computed
* @return the distribution
* @exception Exception if the distribution can't be computed successfully
*/
public double [] distributionForInstance(Instance instance)
throws Exception {
m_basisFilter.input(instance);
Instance transformed = m_basisFilter.output();
return ((instance.classAttribute().isNominal()
? m_logistic.distributionForInstance(transformed)
: m_linear.distributionForInstance(transformed)));
}
/**
* Returns a description of this classifier as a String
*
* @return a description of this classifier
*/
public String toString() {
StringBuffer sb = new StringBuffer();
sb.append("Radial basis function network\n");
sb.append((m_linear == null)
? "(Logistic regression "
: "(Linear regression ");
sb.append("applied to K-means clusters as basis functions):\n\n");
sb.append((m_linear == null)
? m_logistic.toString()
: m_linear.toString());
return sb.toString();
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String maxItsTipText() {
return "Maximum number of iterations for the logistic regression to perform. "
+"Only applied to discrete class problems.";
}
/**
* Get the value of MaxIts.
*
* @return Value of MaxIts.
*/
public int getMaxIts() {
return m_maxIts;
}
/**
* Set the value of MaxIts.
*
* @param newMaxIts Value to assign to MaxIts.
*/
public void setMaxIts(int newMaxIts) {
m_maxIts = newMaxIts;
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String ridgeTipText() {
return "Set the Ridge value for the logistic or linear regression.";
}
/**
* Sets the ridge value for logistic or linear regression.
*
* @param ridge the ridge
*/
public void setRidge(double ridge) {
m_ridge = ridge;
}
/**
* Gets the ridge value.
*
* @return the ridge
*/
public double getRidge() {
return m_ridge;
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String numClustersTipText() {
return "The number of clusters for K-Means to generate.";
}
/**
* Set the number of clusters for K-means to generate.
*
* @param numClusters the number of clusters to generate.
*/
public void setNumClusters(int numClusters) {
if (numClusters > 0) {
m_numClusters = numClusters;
}
}
/**
* Return the number of clusters to generate.
*
* @return the number of clusters to generate.
*/
public int getNumClusters() {
return m_numClusters;
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String clusteringSeedTipText() {
return "The random seed to pass on to K-means.";
}
/**
* Set the random seed to be passed on to K-means.
*
* @param seed a seed value.
*/
public void setClusteringSeed(int seed) {
m_clusteringSeed = seed;
}
/**
* Get the random seed used by K-means.
*
* @return the seed value.
*/
public int getClusteringSeed() {
return m_clusteringSeed;
}
/**
* Returns an enumeration describing the available options
*
* @return an enumeration of all the available options
*/
public Enumeration<Option> listOptions() {
Vector<Option> newVector = new Vector<Option>(4);
newVector.addElement(new Option("\tSet the number of clusters (basis functions) "
+"to generate. (default = 2).",
"B", 1, "-B <number>"));
newVector.addElement(new Option("\tSet the random seed to be used by K-means. "
+"(default = 1).",
"S", 1, "-S <seed>"));
newVector.addElement(new Option("\tSet the ridge value for the logistic or "
+"linear regression.",
"R", 1, "-R <ridge>"));
newVector.addElement(new Option("\tSet the maximum number of iterations "
+"for the logistic regression."
+ " (default -1, until convergence).",
"M", 1, "-M <number>"));
return newVector.elements();
}
/**
* Parses a given list of options. Valid options are:<p>
*
* -B num <br>
* Set the number of clusters (basis functions) to use.<p>
*
* -R ridge <br>
* Set the ridge parameter for the logistic regression or linear regression.<p>
*
* -M num <br>
* Set the maximum number of iterations for logistic regression.
* (default -1, until convergence)<p>
*
* -S seed <br>
* Set the random seed used by K-means when generating clusters.
* (default 1). <p>
*
* @param options the list of options as an array of strings
* @exception Exception if an option is not supported
*/
public void setOptions(String[] options) throws Exception {
setDebug(Utils.getFlag('D', options));
String ridgeString = Utils.getOption('R', options);
if (ridgeString.length() != 0) {
m_ridge = Double.parseDouble(ridgeString);
} else {
m_ridge = 1.0e-8;
}
String maxItsString = Utils.getOption('M', options);
if (maxItsString.length() != 0) {
m_maxIts = Integer.parseInt(maxItsString);
} else {
m_maxIts = -1;
}
String numClustersString = Utils.getOption('B', options);
if (numClustersString.length() != 0) {
setNumClusters(Integer.parseInt(numClustersString));
}
String seedString = Utils.getOption('S', options);
if (seedString.length() != 0) {
setClusteringSeed(Integer.parseInt(seedString));
}
Utils.checkForRemainingOptions(options);
}
/**
* Gets the current settings of the classifier.
*
* @return an array of strings suitable for passing to setOptions
*/
public String [] getOptions() {
String [] options = new String [8];
int current = 0;
options[current++] = "-B";
options[current++] = "" + m_numClusters;
options[current++] = "-S";
options[current++] = "" + m_clusteringSeed;
options[current++] = "-R";
options[current++] = ""+m_ridge;
options[current++] = "-M";
options[current++] = ""+m_maxIts;
while (current < options.length)
options[current++] = "";
return options;
}
/**
* Main method for testing this class.
*
* @param argv should contain the command line arguments to the
* scheme (see Evaluation)
*/
public static void main(String [] argv) {
try {
System.out.println(Evaluation.evaluateModel(new RBFNetwork(), argv));
} catch (Exception e) {
e.printStackTrace();
System.err.println(e.getMessage());
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -