📄 winnow.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * Winnow.java * Copyright (C) 2002 J. Lindgren * */package weka.classifiers.functions;import weka.filters.unsupervised.attribute.NominalToBinary;import weka.filters.unsupervised.attribute.ReplaceMissingValues;import weka.filters.Filter;import weka.classifiers.*;import weka.core.*;import java.util.*;/** * * Implements Winnow and Balanced Winnow algorithms by * N. Littlestone. For more information, see<p> * * N. Littlestone (1988). <i> Learning quickly when irrelevant * attributes are abound: A new linear threshold algorithm</i>. * Machine Learning 2, pp. 285-318.<p> * * and * * N. Littlestone (1989). <i> Mistake bounds and logarithmic * linear-threshold learning algorithms</i>. Technical report * UCSC-CRL-89-11, University of California, Santa Cruz.<p> * * Valid options are:<p> * * -L <br> * Use the baLanced variant (default: false)<p> * * -I num <br> * The number of iterations to be performed. (default 1)<p> * * -A double <br> * Promotion coefficient alpha. (default 2.0)<p> * * -B double <br> * Demotion coefficient beta. (default 0.5)<p> * * -W double <br> * Starting weights of the prediction coeffs. (default 2.0)<p> * * -H double <br> * Prediction threshold. (default -1.0 == number of attributes)<p> * * -S int <br> * Random seed to shuffle the input. (default 1), -1 == no shuffling<p> * * @author J. Lindgren (jtlindgr<at>cs.helsinki.fi) * @version $Revision: 1.1.1.1 $ */public class Winnow extends Classifier implements OptionHandler, UpdateableClassifier{ /** Use the balanced variant? **/ protected boolean m_Balanced; /** The number of iterations **/ protected int m_numIterations = 1; /** The promotion coefficient **/ protected double m_Alpha=2.0; /** The demotion coefficient **/ protected double m_Beta=0.5; /** Prediction threshold, <0 == numAttributes **/ protected double m_Threshold=-1.0; /** Random seed used for shuffling the dataset, -1 == disable **/ protected int m_Seed=1; /** Accumulated mistake count (for statistics) **/ protected int m_Mistakes; /** Starting weights for the prediction vector(s) **/ protected double m_defaultWeight=2.0; /** The weight vectors for prediction **/ private double[] m_predPosVector=null; private double[] m_predNegVector=null; /** The true threshold used for prediction **/ private double m_actualThreshold; /** The training instances */ private Instances m_Train = null; /** The filter used to make attributes numeric. */ private NominalToBinary m_NominalToBinary; /** The filter used to get rid of missing values. */ private ReplaceMissingValues m_ReplaceMissingValues; /** * Returns an enumeration describing the available options * * @return an enumeration of all the available options */ public Enumeration listOptions() { Vector newVector = new Vector(7); newVector.addElement(new Option("\tUse the baLanced version\n" + "\t(default false)", "L", 0, "-L")); newVector.addElement(new Option("\tThe number of iterations to be performed.\n" + "\t(default 1)", "I", 1, "-I <int>")); newVector.addElement(new Option("\tPromotion coefficient alpha.\n" + "\t(default 2.0)", "A", 1, "-A <double>")); newVector.addElement(new Option("\tDemotion coefficient beta.\n" + "\t(default 0.5)", "B", 1, "-B <double>")); newVector.addElement(new Option("\tPrediction threshold.\n" + "\t(default -1.0 == number of attributes)", "H", 1, "-H <double>")); newVector.addElement(new Option("\tStarting weights.\n" + "\t(default 2.0)", "W", 1, "-W <double>")); newVector.addElement(new Option("\tDefault random seed.\n" + "\t(default 1)", "S", 1, "-S <int>")); return newVector.elements(); } /** * Parses a given list of options.<p> * * @param options the list of options as an array of strings * @exception Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { m_Balanced = Utils.getFlag('L', options); String iterationsString = Utils.getOption('I', options); if (iterationsString.length() != 0) { m_numIterations = Integer.parseInt(iterationsString); } String alphaString = Utils.getOption('A', options); if (alphaString.length() != 0) { m_Alpha = (new Double(alphaString)).doubleValue(); } String betaString = Utils.getOption('B', options); if (betaString.length() != 0) { m_Beta = (new Double(betaString)).doubleValue(); } String tString = Utils.getOption('H', options); if (tString.length() != 0) { m_Threshold = (new Double(tString)).doubleValue(); } String wString = Utils.getOption('W', options); if (wString.length() != 0) { m_defaultWeight = (new Double(wString)).doubleValue(); } String rString = Utils.getOption('S', options); if (rString.length() != 0) { m_Seed = Integer.parseInt(rString); } } /** * Gets the current settings of the classifier. * * @return an array of strings suitable for passing to setOptions */ public String[] getOptions() { String[] options = new String [20]; int current = 0; if(m_Balanced) { options[current++] = "-L"; } options[current++] = "-I"; options[current++] = "" + m_numIterations; options[current++] = "-A"; options[current++] = "" + m_Alpha; options[current++] = "-B"; options[current++] = "" + m_Beta; options[current++] = "-H"; options[current++] = "" + m_Threshold; options[current++] = "-W"; options[current++] = "" + m_defaultWeight; options[current++] = "-S"; options[current++] = "" + m_Seed; while (current < options.length) { options[current++] = ""; } return options; } /** * Builds the classifier * * @exception Exception if something goes wrong during building */ public void buildClassifier(Instances insts) throws Exception { if (insts.checkForStringAttributes()) { throw new UnsupportedAttributeTypeException("Can't handle string attributes!"); } if (insts.numClasses() > 2) { throw new Exception("Can only handle two-class datasets!"); } if (insts.classAttribute().isNumeric()) { throw new UnsupportedClassTypeException("Can't handle a numeric class!"); } Enumeration enum = insts.enumerateAttributes(); while (enum.hasMoreElements()) { Attribute attr = (Attribute) enum.nextElement(); if (!attr.isNominal()) { throw new UnsupportedAttributeTypeException("Winnow: only nominal attributes, please."); } } // Filter data m_Train = new Instances(insts); m_Train.deleteWithMissingClass(); m_ReplaceMissingValues = new ReplaceMissingValues(); m_ReplaceMissingValues.setInputFormat(m_Train); m_Train = Filter.useFilter(m_Train, m_ReplaceMissingValues); m_NominalToBinary = new NominalToBinary(); m_NominalToBinary.setInputFormat(m_Train); m_Train = Filter.useFilter(m_Train, m_NominalToBinary); /** Randomize training data */ if(m_Seed!=-1) m_Train.randomize(new Random(m_Seed)); /** Make space to store weights */ m_predPosVector = new double[m_Train.numAttributes()]; if(m_Balanced) m_predNegVector = new double[m_Train.numAttributes()]; /** Initialize the weights to starting values **/ for(int i=0;i<m_Train.numAttributes();i++) m_predPosVector[i]=m_defaultWeight; if(m_Balanced) for(int i=0;i<m_Train.numAttributes();i++) m_predNegVector[i]=m_defaultWeight; /** Set actual prediction threshold **/ if(m_Threshold<0) m_actualThreshold = (double)m_Train.numAttributes()-1; else m_actualThreshold = m_Threshold; m_Mistakes=0; /** Compute the weight vectors **/ if(m_Balanced) for (int it = 0; it < m_numIterations; it++) { for (int i = 0; i < m_Train.numInstances(); i++) { actualUpdateClassifierBalanced(m_Train.instance(i)); } } else for (int it = 0; it < m_numIterations; it++) { for (int i = 0; i < m_Train.numInstances(); i++) { actualUpdateClassifier(m_Train.instance(i)); } } } /** * Updates the classifier with a new learning example * * @exception Exception if something goes wrong */ public void updateClassifier(Instance instance) throws Exception { m_ReplaceMissingValues.input(instance); m_ReplaceMissingValues.batchFinished(); Instance filtered = m_ReplaceMissingValues.output(); m_NominalToBinary.input(filtered); m_NominalToBinary.batchFinished(); filtered = m_NominalToBinary.output(); if(m_Balanced) actualUpdateClassifierBalanced(filtered); else actualUpdateClassifier(filtered); } /** * Actual update routine for prefiltered instances * * @exception Exception if something goes wrong */ private void actualUpdateClassifier(Instance inst) throws Exception { double posmultiplier; if (!inst.classIsMissing()) { double prediction = makePrediction(inst); if (prediction != inst.classValue()) {
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -