⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 logistic.java

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 JAVA
字号:
/* *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* *    Logistic.java *    Copyright (C) 2002 Eibe Frank * */package weka.classifiers.functions;import weka.classifiers.meta.LogitBoost;import weka.classifiers.functions.LinearRegression;import weka.classifiers.Evaluation;import weka.classifiers.DistributionClassifier;import weka.classifiers.Classifier;import weka.core.UnsupportedClassTypeException;import weka.core.Instances;import weka.core.Instance;import weka.core.OptionHandler;import weka.core.WeightedInstancesHandler;import weka.core.SelectedTag;import weka.core.Utils;import weka.core.Attribute;import weka.core.Option;import weka.core.UnsupportedAttributeTypeException;import weka.filters.unsupervised.attribute.NominalToBinary;import weka.filters.unsupervised.attribute.ReplaceMissingValues;import weka.filters.unsupervised.attribute.Remove;import weka.filters.Filter;import java.util.Enumeration;import java.util.Vector;/** * Implements linear logistic regression using LogitBoost and * LinearRegression.<p> * * Missing values are replaced using ReplaceMissingValues, and * nominal attributes are transformed into numeric attributes using * NominalToBinary.<p> * * -P precision <br> * Set the precision of stopping criterion based on average loglikelihood. * (default 1.0e-13) <p> * * -R ridge <br> * Set the ridge parameter for the linear regression models. * (default 1.0e-8)<p> * * -M num <br> * Set the maximum number of iterations. * (default 200)<p> * * @author Eibe Frank (eibe@cs.waikato.ac.nz) * @version $Revision: 1.1.1.1 $  */public class Logistic extends DistributionClassifier   implements OptionHandler, WeightedInstancesHandler {  /* The coefficients */  private double[][] m_Coefficients = null;  /* The index of the class */  private int m_ClassIndex = -1;  /* An attribute filter */  private Remove m_AttFilter = null;  /* The header info */  private Instances m_Header = null;      /** The filter used to make attributes numeric. */  private NominalToBinary m_NominalToBinary = null;    /** The filter used to get rid of missing values. */  private ReplaceMissingValues m_ReplaceMissingValues = null;      /** The ridge parameter. */  private double m_Ridge = 1e-8;  /** The precision parameter */     private double m_Precision = 1.0e-13;    /** The maximum number of iterations. */  private int m_MaxIts = 200;      /**   * Returns an enumeration describing the available options.   *   * @return an enumeration of all the available options.   */  public Enumeration listOptions() {        Vector newVector = new Vector(3);    newVector.addElement(new Option("\tSet the precision of stopping criterion based on\n" + 				    "\tchange in average loglikelihood (default 1.0e-13).",				    "P", 1, "-P <precision>"));    newVector.addElement(new Option("\tSet the ridge for the linear regression models (default 1.0e-8).",				    "R", 1, "-R <ridge>"));    newVector.addElement(new Option("\tSet the maximum number of iterations (default 200).",				    "M", 1, "-M <number>"));    return newVector.elements();  }    /**   * Parses a given list of options. Valid options are:<p>   *   * -P precision <br>   * Set the precision of stopping criterion based on average loglikelihood.   * (default 1.0e-13) <p>   *   * -R ridge <br>   * Set the ridge parameter for the linear regression models.   * (default 1.0e-8)<p>   *   * -M num <br>   * Set the maximum number of iterations.   * (default 200)<p>   *   * @param options the list of options as an array of strings   * @exception Exception if an option is not supported   */  public void setOptions(String[] options) throws Exception {        String precisionString = Utils.getOption('P', options);    if (precisionString.length() != 0)       m_Precision = Double.parseDouble(precisionString);    else       m_Precision = 1.0e-13;          String ridgeString = Utils.getOption('R', options);    if (ridgeString.length() != 0)       m_Ridge = Double.parseDouble(ridgeString);    else       m_Ridge = 1.0e-8;          String maxItsString = Utils.getOption('M', options);    if (maxItsString.length() != 0)       m_MaxIts = Integer.parseInt(maxItsString);    else       m_MaxIts = 200;  }    /**   * Gets the current settings of the classifier.   *   * @return an array of strings suitable for passing to setOptions   */  public String [] getOptions() {        String [] options = new String [6];    int current = 0;        options[current++] = "-P";    options[current++] = ""+m_Precision;    options[current++] = "-R";    options[current++] = ""+m_Ridge;    options[current++] = "-M";    options[current++] = ""+m_MaxIts;        while (current < options.length)       options[current++] = "";    return options;  }    /**   * Builds the model.   */  public void buildClassifier(Instances data) throws Exception {    if (data.classAttribute().type() != Attribute.NOMINAL) {      throw new UnsupportedClassTypeException("Class attribute must be nominal.");    }    if (data.checkForStringAttributes()) {      throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");    }    data = new Instances(data);    data.deleteWithMissingClass();    m_ReplaceMissingValues = new ReplaceMissingValues();    m_ReplaceMissingValues.setInputFormat(data);    data = Filter.useFilter(data, m_ReplaceMissingValues);    m_NominalToBinary = new NominalToBinary();    m_NominalToBinary.setInputFormat(data);    data = Filter.useFilter(data, m_NominalToBinary);    // Find attributes that should be deleted because of    // zero variance    int[] indices = new int[data.numAttributes() - 1];    int numDeleted = 0;    for (int j = 0; j < data.numAttributes(); j++) {      if (j != data.classIndex()) {        double var = data.variance(j);	if (var == 0) {	  indices[numDeleted++] = j;	}      }    }    int[] temp = new int[numDeleted];    System.arraycopy(indices, 0, temp, 0, numDeleted);    indices = temp;    // Remove useless attributes    m_AttFilter = new Remove();    m_AttFilter.setAttributeIndicesArray(indices);    m_AttFilter.setInvertSelection(false);    m_AttFilter.setInputFormat(data);    data = Filter.useFilter(data, m_AttFilter);    // Set class index    m_ClassIndex = data.classIndex();    // Standardize data    double[][] values =       new double[data.numInstances()][data.numAttributes()];    double[] means = new double[data.numAttributes()];    double[] stdDevs = new double[data.numAttributes()];    for (int j = 0; j < data.numAttributes(); j++) {      if (j != data.classIndex()) {	means[j] = data.meanOrMode(j);	stdDevs[j] = Math.sqrt(data.variance(j));	for (int i = 0; i < data.numInstances(); i++) {	  values[i][j] = (data.instance(i).value(j) - means[j]) / 	    stdDevs[j];	}      } else {	for (int i = 0; i < data.numInstances(); i++) {	  values[i][j] = data.instance(i).value(j);	}      }    }    Instances newData = new Instances(data, data.numInstances());    for (int i = 0; i < data.numInstances(); i++) {      newData.add(new Instance(data.instance(i).weight(), values[i]));    }    // Use LogitBoost to build model    LogitBoost boostedModel = new LogitBoost();    boostedModel.setLikelihoodThreshold(m_Precision);    boostedModel.setMaxIterations(m_MaxIts);    LinearRegression lr = new LinearRegression();    lr.setEliminateColinearAttributes(false);    lr.setAttributeSelectionMethod(new SelectedTag(LinearRegression.						   SELECTION_NONE,						   LinearRegression.						   TAGS_SELECTION));    lr.turnChecksOff();    lr.setRidge(m_Ridge);    boostedModel.setClassifier(lr);    boostedModel.buildClassifier(newData);    // Extract coefficients    Classifier[][] models = boostedModel.classifiers();    m_Coefficients = new double[newData.numClasses()]      [newData.numAttributes() + 1];    for (int j = 0; j < newData.numClasses(); j++) {      for (int i = 0; i < models[j].length; i++) {	double[] locCoefficients = 	  ((LinearRegression)models[j][i]).coefficients();	for (int k = 0; k <= newData.numAttributes(); k++) {	  if (k != newData.classIndex()) {	    m_Coefficients[j][k] += locCoefficients[k];	  }	}      }    }	       // Convert coefficients into original scale    for(int j = 0; j < data.numClasses(); j++){      for(int i = 0; i < data.numAttributes(); i++) {	if ((i != newData.classIndex()) &&	    (stdDevs[i] > 0)) {	  m_Coefficients[j][i] /= stdDevs[i];	  m_Coefficients[j][data.numAttributes()] -= 	    m_Coefficients[j][i] * means[i];	}      }    }    m_Header = new Instances(data, 0);  }  /**   * Classifies an instance.   */  public double[] distributionForInstance(Instance inst)     throws Exception {    // Filter instance    m_ReplaceMissingValues.input(inst);    inst = m_ReplaceMissingValues.output();    m_NominalToBinary.input(inst);    inst = m_NominalToBinary.output();    m_AttFilter.input(inst);    m_AttFilter.batchFinished();    inst = m_AttFilter.output();    // Compute prediction    double[] preds = new double[m_Coefficients.length];    for (int j = 0; j < inst.numClasses(); j++) {      for (int i = 0; i < inst.numAttributes(); i++) {	if (i != inst.classIndex()) {	  preds[j] += inst.value(i) * m_Coefficients[j][i];	}      }      preds[j] += m_Coefficients[j][inst.numAttributes()];    }    return probs(preds);  }  /**   * Computes probabilities from F scores   */  private double[] probs(double[] Fs) {    double maxF = -Double.MAX_VALUE;    for (int i = 0; i < Fs.length; i++) {      if (Fs[i] > maxF) {	maxF = Fs[i];      }    }    double sum = 0;    double[] probs = new double[Fs.length];    for (int i = 0; i < Fs.length; i++) {      probs[i] = Math.exp(Fs[i] - maxF);      sum += probs[i];    }    Utils.normalize(probs, sum);    return probs;  }  /**   * Prints the model.   */  public String toString() {    if (m_Coefficients == null) {      return "No model has been built yet!";    }     StringBuffer text = new StringBuffer();    for (int j = 0; j < m_Coefficients.length; j++) {      text.append("\nModel for class: " + 		  m_Header.classAttribute().value(j) + "\n\n");      for (int i = 0; i < m_Coefficients[j].length; i++) {	if (i != m_ClassIndex) {	  if (i > 0) {	    text.append(" + ");	  } else {	    text.append("   ");	  }	  text.append(Utils.doubleToString(m_Coefficients[j][i], 12, 4));	  if (i < m_Coefficients[j].length - 1) {	    text.append(" * " 			+ m_Header.attribute(i).name() + "\n");	  }	}      }      text.append("\n");    }    return text.toString();  }    /**   * Get the value of MaxIts.   *   * @return Value of MaxIts.   */  public int getMaxIts() {        return m_MaxIts;  }    /**   * Set the value of MaxIts.   *   * @param newMaxIts Value to assign to MaxIts.   */  public void setMaxIts(int newMaxIts) {        m_MaxIts = newMaxIts;  }    /**   * Sets the precision of stopping criterion in Newton method.   *   * @param precision the precision   */  public void setPrecision(double precision) {    m_Precision = precision;  }      /**   * Gets the precision of stopping criterion in Newton method.   *   * @return the precision   */  public double getPrecision() {    return m_Precision;  }  /**   * Sets the ridge parameter.   *   * @param ridge the ridge   */  public void setRidge(double ridge) {    m_Ridge = ridge;  }      /**   * Gets the ridge parameter.   *   * @return the ridge   */  public double getRidge() {    return m_Ridge;  }  /**   * Main method for testing this class.   */  public static void main(String[] argv) {    try {      System.out.println(Evaluation.evaluateModel(new Logistic(), argv));    } catch (Exception e) {      e.printStackTrace();      System.err.println(e.getMessage());    }  }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -