⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 leastmedsq.java

📁 MacroWeka扩展了著名数据挖掘工具weka
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    LeastMedSq.java
 *
 *    Copyright (C) 2001 Tony Voyle
 */

package weka.classifiers.functions;

import weka.classifiers.functions.LinearRegression;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.filters.supervised.attribute.NominalToBinary;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;
import weka.filters.unsupervised.instance.RemoveRange;
import weka.filters.Filter;
import weka.core.*;
import java.io.*;
import java.util.*;


/**
 * Implements a least median sqaured linear regression utilising the
 * existing weka LinearRegression class to form predictions.
 * The basis of the algorithm is Robust regression and outlier detection
 * Peter J. Rousseeuw, Annick M. Leroy. c1987
 *
 * @author Tony Voyle (tv6@waikato.ac.nz)
 * @version $Revision: 1.1 $
 */
public class LeastMedSq extends Classifier implements OptionHandler {
  
  private double[] m_Residuals;
  
  private double[] m_weight;
  
  private double m_SSR;
  
  private double m_scalefactor;
  
  private double m_bestMedian = Double.POSITIVE_INFINITY;
  
  private LinearRegression m_currentRegression;
  
  private LinearRegression m_bestRegression;
  
  private LinearRegression m_ls;

  private Instances m_Data;

  private Instances m_RLSData;

  private Instances m_SubSample;

  private ReplaceMissingValues m_MissingFilter;

  private NominalToBinary m_TransformFilter;

  private RemoveRange m_SplitFilter;

  private int m_samplesize = 4;

  private int m_samples;

  private boolean m_israndom = false;

  private boolean m_debug = false;

  private Random m_random;

  private long m_randomseed = 0;

  /**
   * Returns a string describing this classifier
   * @return a description of the classifier suitable for
   * displaying in the explorer/experimenter gui
   */
  public String globalInfo() {
    return "Implements a least median sqaured linear regression utilising the "
      +"existing weka LinearRegression class to form predictions. "
      +"Least squared regression functions are generated from random subsamples of "
      +"the data. The least squared regression with the lowest meadian squared error "
      +"is chosen as the final model.\n\n"
      +"The basis of the algorithm is \n\nRobust regression and outlier detection "
      +"Peter J. Rousseeuw, Annick M. Leroy. c1987";
  }

  /**
   * Build lms regression
   *
   * @param data training data
   * @exception Exception if an error occurs
   */
  public void buildClassifier(Instances data)throws Exception{

    data = new Instances(data);
    data.deleteWithMissingClass();

    if (!data.classAttribute().isNumeric())
      throw new UnsupportedClassTypeException("Class attribute has to be numeric for regression!");
    if (data.numInstances() == 0)
      throw new Exception("No instances in training file!");
    if (data.checkForStringAttributes())
      throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");

    cleanUpData(data);

    getSamples();

    findBestRegression();

    buildRLSRegression();

  } // buildClassifier

  /**
   * Classify a given instance using the best generated
   * LinearRegression Classifier.
   *
   * @param instance instance to be classified
   * @return class value
   * @exception Exception if an error occurs
   */
  public double classifyInstance(Instance instance)throws Exception{

    Instance transformedInstance = instance;
    m_TransformFilter.input(transformedInstance);
    transformedInstance = m_TransformFilter.output();
    m_MissingFilter.input(transformedInstance);
    transformedInstance = m_MissingFilter.output();

    return m_ls.classifyInstance(transformedInstance);
  } // classifyInstance

  /**
   * Cleans up data
   *
   * @param data data to be cleaned up
   * @exception Exception if an error occurs
   */
  private void cleanUpData(Instances data)throws Exception{

    m_Data = data;
    m_TransformFilter = new NominalToBinary();
    m_TransformFilter.setInputFormat(m_Data);
    m_Data = Filter.useFilter(m_Data, m_TransformFilter);
    m_MissingFilter = new ReplaceMissingValues();
    m_MissingFilter.setInputFormat(m_Data);
    m_Data = Filter.useFilter(m_Data, m_MissingFilter);
    m_Data.deleteWithMissingClass();
  }

  /**
   * Gets the number of samples to use.
   *
   */
  private void getSamples()throws Exception{

    int stuf[] = new int[] {500,50,22,17,15,14};
    int x = m_samplesize * 500;
    if ( m_samplesize < 7){
      if ( m_Data.numInstances() < stuf[m_samplesize - 1])
	m_samples = combinations(m_Data.numInstances(), m_samplesize);
      else
	m_samples = m_samplesize * 500;

    } else m_samples = 3000;
    if (m_debug){
      System.out.println("m_samplesize: " + m_samplesize);
      System.out.println("m_samples: " + m_samples);
      System.out.println("m_randomseed: " + m_randomseed);
    }

  }

  /**
   * Set up the random number generator
   *
   */
  private void setRandom(){

    m_random = new Random(getRandomSeed());
  }

  /**
   * Finds the best regression generated from m_samples
   * random samples from the training data
   *
   * @exception Exception if an error occurs
   */
  private void findBestRegression()throws Exception{

    setRandom();
    m_bestMedian = Double.POSITIVE_INFINITY;
    if (m_debug) {
      System.out.println("Starting:");
    }
    for(int s = 0, r = 0; s < m_samples; s++, r++){
      if (m_debug) {
	if(s%(m_samples/100)==0)
	  System.out.print("*");
      }
      genRegression();
      getMedian();
    }
    if (m_debug) {
      System.out.println("");
    }
    m_currentRegression = m_bestRegression;
  }

  /**
   * Generates a LinearRegression classifier from
   * the current m_SubSample
   *
   * @exception Exception if an error occurs
   */
  private void genRegression()throws Exception{

    m_currentRegression = new LinearRegression();
    m_currentRegression.setOptions(new String[]{"-S", "1"});
    selectSubSample(m_Data);
    m_currentRegression.buildClassifier(m_SubSample);
  }

  /**
   * Finds residuals (squared) for the current
   * regression.
   *
   * @exception Exception if an error occurs
   */
  private void findResiduals()throws Exception{

    m_SSR = 0;
    m_Residuals = new double [m_Data.numInstances()];
    for(int i = 0; i < m_Data.numInstances(); i++){
      m_Residuals[i] = m_currentRegression.classifyInstance(m_Data.instance(i));
      m_Residuals[i] -= m_Data.instance(i).value(m_Data.classAttribute());
      m_Residuals[i] *= m_Residuals[i];
      m_SSR += m_Residuals[i];
    }
  }

  /**
   * finds the median residual squared for the
   * current regression
   *
   * @exception Exception if an error occurs
   */
  private void getMedian()throws Exception{

    findResiduals();
    int p = m_Residuals.length;
    select(m_Residuals, 0, p - 1, p / 2);
    if(m_Residuals[p / 2] < m_bestMedian){
      m_bestMedian = m_Residuals[p / 2];
      m_bestRegression = m_currentRegression;
    }
  }

  /**
   * Returns a string representing the best
   * LinearRegression classifier found.
   *
   * @return String representing the regression
   */
  public String toString(){

    if( m_ls == null){
      return "model has not been built";
    }
    return m_ls.toString();
  }

  /**
   * Builds a weight function removing instances with an
   * abnormally high scaled residual
   *
   */
  private void buildWeight()throws Exception{

    findResiduals();
    m_scalefactor = 1.4826 * ( 1 + 5 / (m_Data.numInstances()
					- m_Data.numAttributes()))
      * Math.sqrt(m_bestMedian);
    m_weight = new double[m_Residuals.length];
    for (int i = 0; i < m_Residuals.length; i++)
      m_weight[i] = ((Math.sqrt(m_Residuals[i])/m_scalefactor < 2.5)?1.0:0.0);
  }

  /**
   * Builds a new LinearRegression without the 'bad' data
   * found by buildWeight
   *

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -