📄 leastmedsq.java
字号:
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* LeastMedSq.java
*
* Copyright (C) 2001 Tony Voyle
*/
package weka.classifiers.functions;
import weka.classifiers.functions.LinearRegression;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.filters.supervised.attribute.NominalToBinary;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;
import weka.filters.unsupervised.instance.RemoveRange;
import weka.filters.Filter;
import weka.core.*;
import java.io.*;
import java.util.*;
/**
* Implements a least median sqaured linear regression utilising the
* existing weka LinearRegression class to form predictions.
* The basis of the algorithm is Robust regression and outlier detection
* Peter J. Rousseeuw, Annick M. Leroy. c1987
*
* @author Tony Voyle (tv6@waikato.ac.nz)
* @version $Revision: 1.1 $
*/
public class LeastMedSq extends Classifier implements OptionHandler {
private double[] m_Residuals;
private double[] m_weight;
private double m_SSR;
private double m_scalefactor;
private double m_bestMedian = Double.POSITIVE_INFINITY;
private LinearRegression m_currentRegression;
private LinearRegression m_bestRegression;
private LinearRegression m_ls;
private Instances m_Data;
private Instances m_RLSData;
private Instances m_SubSample;
private ReplaceMissingValues m_MissingFilter;
private NominalToBinary m_TransformFilter;
private RemoveRange m_SplitFilter;
private int m_samplesize = 4;
private int m_samples;
private boolean m_israndom = false;
private boolean m_debug = false;
private Random m_random;
private long m_randomseed = 0;
/**
* Returns a string describing this classifier
* @return a description of the classifier suitable for
* displaying in the explorer/experimenter gui
*/
public String globalInfo() {
return "Implements a least median sqaured linear regression utilising the "
+"existing weka LinearRegression class to form predictions. "
+"Least squared regression functions are generated from random subsamples of "
+"the data. The least squared regression with the lowest meadian squared error "
+"is chosen as the final model.\n\n"
+"The basis of the algorithm is \n\nRobust regression and outlier detection "
+"Peter J. Rousseeuw, Annick M. Leroy. c1987";
}
/**
* Build lms regression
*
* @param data training data
* @exception Exception if an error occurs
*/
public void buildClassifier(Instances data)throws Exception{
data = new Instances(data);
data.deleteWithMissingClass();
if (!data.classAttribute().isNumeric())
throw new UnsupportedClassTypeException("Class attribute has to be numeric for regression!");
if (data.numInstances() == 0)
throw new Exception("No instances in training file!");
if (data.checkForStringAttributes())
throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");
cleanUpData(data);
getSamples();
findBestRegression();
buildRLSRegression();
} // buildClassifier
/**
* Classify a given instance using the best generated
* LinearRegression Classifier.
*
* @param instance instance to be classified
* @return class value
* @exception Exception if an error occurs
*/
public double classifyInstance(Instance instance)throws Exception{
Instance transformedInstance = instance;
m_TransformFilter.input(transformedInstance);
transformedInstance = m_TransformFilter.output();
m_MissingFilter.input(transformedInstance);
transformedInstance = m_MissingFilter.output();
return m_ls.classifyInstance(transformedInstance);
} // classifyInstance
/**
* Cleans up data
*
* @param data data to be cleaned up
* @exception Exception if an error occurs
*/
private void cleanUpData(Instances data)throws Exception{
m_Data = data;
m_TransformFilter = new NominalToBinary();
m_TransformFilter.setInputFormat(m_Data);
m_Data = Filter.useFilter(m_Data, m_TransformFilter);
m_MissingFilter = new ReplaceMissingValues();
m_MissingFilter.setInputFormat(m_Data);
m_Data = Filter.useFilter(m_Data, m_MissingFilter);
m_Data.deleteWithMissingClass();
}
/**
* Gets the number of samples to use.
*
*/
private void getSamples()throws Exception{
int stuf[] = new int[] {500,50,22,17,15,14};
int x = m_samplesize * 500;
if ( m_samplesize < 7){
if ( m_Data.numInstances() < stuf[m_samplesize - 1])
m_samples = combinations(m_Data.numInstances(), m_samplesize);
else
m_samples = m_samplesize * 500;
} else m_samples = 3000;
if (m_debug){
System.out.println("m_samplesize: " + m_samplesize);
System.out.println("m_samples: " + m_samples);
System.out.println("m_randomseed: " + m_randomseed);
}
}
/**
* Set up the random number generator
*
*/
private void setRandom(){
m_random = new Random(getRandomSeed());
}
/**
* Finds the best regression generated from m_samples
* random samples from the training data
*
* @exception Exception if an error occurs
*/
private void findBestRegression()throws Exception{
setRandom();
m_bestMedian = Double.POSITIVE_INFINITY;
if (m_debug) {
System.out.println("Starting:");
}
for(int s = 0, r = 0; s < m_samples; s++, r++){
if (m_debug) {
if(s%(m_samples/100)==0)
System.out.print("*");
}
genRegression();
getMedian();
}
if (m_debug) {
System.out.println("");
}
m_currentRegression = m_bestRegression;
}
/**
* Generates a LinearRegression classifier from
* the current m_SubSample
*
* @exception Exception if an error occurs
*/
private void genRegression()throws Exception{
m_currentRegression = new LinearRegression();
m_currentRegression.setOptions(new String[]{"-S", "1"});
selectSubSample(m_Data);
m_currentRegression.buildClassifier(m_SubSample);
}
/**
* Finds residuals (squared) for the current
* regression.
*
* @exception Exception if an error occurs
*/
private void findResiduals()throws Exception{
m_SSR = 0;
m_Residuals = new double [m_Data.numInstances()];
for(int i = 0; i < m_Data.numInstances(); i++){
m_Residuals[i] = m_currentRegression.classifyInstance(m_Data.instance(i));
m_Residuals[i] -= m_Data.instance(i).value(m_Data.classAttribute());
m_Residuals[i] *= m_Residuals[i];
m_SSR += m_Residuals[i];
}
}
/**
* finds the median residual squared for the
* current regression
*
* @exception Exception if an error occurs
*/
private void getMedian()throws Exception{
findResiduals();
int p = m_Residuals.length;
select(m_Residuals, 0, p - 1, p / 2);
if(m_Residuals[p / 2] < m_bestMedian){
m_bestMedian = m_Residuals[p / 2];
m_bestRegression = m_currentRegression;
}
}
/**
* Returns a string representing the best
* LinearRegression classifier found.
*
* @return String representing the regression
*/
public String toString(){
if( m_ls == null){
return "model has not been built";
}
return m_ls.toString();
}
/**
* Builds a weight function removing instances with an
* abnormally high scaled residual
*
*/
private void buildWeight()throws Exception{
findResiduals();
m_scalefactor = 1.4826 * ( 1 + 5 / (m_Data.numInstances()
- m_Data.numAttributes()))
* Math.sqrt(m_bestMedian);
m_weight = new double[m_Residuals.length];
for (int i = 0; i < m_Residuals.length; i++)
m_weight[i] = ((Math.sqrt(m_Residuals[i])/m_scalefactor < 2.5)?1.0:0.0);
}
/**
* Builds a new LinearRegression without the 'bad' data
* found by buildWeight
*
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -