predictionappender.java

来自「Weka」· Java 代码 · 共 857 行 · 第 1/2 页

JAVA
857
字号
/* *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* *    PredictionAppender.java *    Copyright (C) 2003 University of Waikato, Hamilton, New Zealand * */package weka.gui.beans;import weka.clusterers.DensityBasedClusterer;import weka.core.Instance;import weka.core.Instances;import java.awt.BorderLayout;import java.beans.EventSetDescriptor;import java.io.Serializable;import java.util.Enumeration;import java.util.Vector;import javax.swing.JPanel;/** * Bean that can can accept batch or incremental classifier events * and produce dataset or instance events which contain instances with * predictions appended. * * @author <a href="mailto:mhall@cs.waikato.ac.nz">Mark Hall</a> * @version $Revision: 1.15 $ */public class PredictionAppender  extends JPanel  implements DataSource, TrainingSetProducer, TestSetProducer, Visible, BeanCommon,	     EventConstraints, BatchClassifierListener,	     IncrementalClassifierListener, BatchClustererListener, Serializable {  /** for serialization */  private static final long serialVersionUID = -2987740065058976673L;  /**   * Objects listenening for dataset events   */  protected Vector m_dataSourceListeners = new Vector();  /**   * Objects listening for instances events   */  protected Vector m_instanceListeners = new Vector();    /**   * Objects listening for training set events   */  protected Vector m_trainingSetListeners = new Vector();;    /**   * Objects listening for test set events   */  protected Vector m_testSetListeners = new Vector();  /**   * Non null if this object is a target for any events.   */  protected Object m_listenee = null;  /**   * Format of instances to be produced.   */  protected Instances m_format;  protected BeanVisual m_visual =     new BeanVisual("PredictionAppender", 		   BeanVisual.ICON_PATH+"PredictionAppender.gif",		   BeanVisual.ICON_PATH+"PredictionAppender_animated.gif");  /**   * Append classifier's predicted probabilities (if the class is discrete   * and the classifier is a distribution classifier)   */  protected boolean m_appendProbabilities;  protected transient weka.gui.Logger m_logger;  /**   * Global description of this bean   *   * @return a <code>String</code> value   */  public String globalInfo() {    return "Accepts batch or incremental classifier events and "      +"produces a new data set with classifier predictions appended.";  }  /**   * Creates a new <code>PredictionAppender</code> instance.   */  public PredictionAppender() {    setLayout(new BorderLayout());    add(m_visual, BorderLayout.CENTER);  }  /**   * Return a tip text suitable for displaying in a GUI   *   * @return a <code>String</code> value   */  public String appendPredictedProbabilitiesTipText() {    return "append probabilities rather than labels for discrete class "      +"predictions";  }  /**   * Return true if predicted probabilities are to be appended rather   * than class value   *   * @return a <code>boolean</code> value   */  public boolean getAppendPredictedProbabilities() {    return m_appendProbabilities;  }  /**   * Set whether to append predicted probabilities rather than   * class value (for discrete class data sets)   *   * @param ap a <code>boolean</code> value   */  public void setAppendPredictedProbabilities(boolean ap) {    m_appendProbabilities = ap;  }  /**   * Add a training set listener   *   * @param tsl a <code>TrainingSetListener</code> value   */  public void addTrainingSetListener(TrainingSetListener tsl) {    // TODO Auto-generated method stub    m_trainingSetListeners.addElement(tsl);    // pass on any format that we might have determined so far    if (m_format != null) {      TrainingSetEvent e = new TrainingSetEvent(this, m_format);      tsl.acceptTrainingSet(e);    }  }  /**   * Remove a training set listener   *   * @param tsl a <code>TrainingSetListener</code> value   */  public void removeTrainingSetListener(TrainingSetListener tsl) {       m_trainingSetListeners.removeElement(tsl);  }  /**   * Add a test set listener   *   * @param tsl a <code>TestSetListener</code> value   */  public void addTestSetListener(TestSetListener tsl) {    m_testSetListeners.addElement(tsl);//  pass on any format that we might have determined so far    if (m_format != null) {      TestSetEvent e = new TestSetEvent(this, m_format);      tsl.acceptTestSet(e);    }  }  /**   * Remove a test set listener   *   * @param tsl a <code>TestSetListener</code> value   */  public void removeTestSetListener(TestSetListener tsl) {    m_testSetListeners.removeElement(tsl);  }    /**   * Add a datasource listener   *   * @param dsl a <code>DataSourceListener</code> value   */  public synchronized void addDataSourceListener(DataSourceListener dsl) {    m_dataSourceListeners.addElement(dsl);    // pass on any format that we might have determined so far    if (m_format != null) {      DataSetEvent e = new DataSetEvent(this, m_format);      dsl.acceptDataSet(e);    }  }    /**   * Remove a datasource listener   *   * @param dsl a <code>DataSourceListener</code> value   */  public synchronized void removeDataSourceListener(DataSourceListener dsl) {    m_dataSourceListeners.remove(dsl);  }  /**   * Add an instance listener   *   * @param dsl a <code>InstanceListener</code> value   */  public synchronized void addInstanceListener(InstanceListener dsl) {    m_instanceListeners.addElement(dsl);    // pass on any format that we might have determined so far    if (m_format != null) {      InstanceEvent e = new InstanceEvent(this, m_format);      dsl.acceptInstance(e);    }  }    /**   * Remove an instance listener   *   * @param dsl a <code>InstanceListener</code> value   */  public synchronized void removeInstanceListener(InstanceListener dsl) {    m_instanceListeners.remove(dsl);  }  /**   * Set the visual for this data source   *   * @param newVisual a <code>BeanVisual</code> value   */  public void setVisual(BeanVisual newVisual) {    m_visual = newVisual;  }  /**   * Get the visual being used by this data source.   *   */  public BeanVisual getVisual() {    return m_visual;  }  /**   * Use the default images for a data source   *   */  public void useDefaultVisual() {    m_visual.loadIcons(BeanVisual.ICON_PATH+"PredictionAppender.gif",		       BeanVisual.ICON_PATH+"PredictionAppender_animated.gif");  }  protected InstanceEvent m_instanceEvent;  protected double [] m_instanceVals;    /**   * Accept and process an incremental classifier event   *   * @param e an <code>IncrementalClassifierEvent</code> value   */  public void acceptClassifier(IncrementalClassifierEvent e) {    weka.classifiers.Classifier classifier = e.getClassifier();    Instance currentI = e.getCurrentInstance();    int status = e.getStatus();    int oldNumAtts = 0;    if (status == IncrementalClassifierEvent.NEW_BATCH) {      oldNumAtts = e.getStructure().numAttributes();    } else {      oldNumAtts = currentI.dataset().numAttributes();    }    if (status == IncrementalClassifierEvent.NEW_BATCH) {      m_instanceEvent = new InstanceEvent(this, null, 0);      // create new header structure      Instances oldStructure = new Instances(e.getStructure(), 0);      //String relationNameModifier = oldStructure.relationName()	//+"_with predictions";      String relationNameModifier = "_with predictions";	//+"_with predictions";       if (!m_appendProbabilities 	   || oldStructure.classAttribute().isNumeric()) {	 try {	   m_format = makeDataSetClass(oldStructure, classifier,						     relationNameModifier);	   m_instanceVals = new double [m_format.numAttributes()];	 } catch (Exception ex) {	   ex.printStackTrace();	   return;	 }       } else if (m_appendProbabilities) {	 try {	   m_format = 	     makeDataSetProbabilities(oldStructure, classifier,				      relationNameModifier);	   m_instanceVals = new double [m_format.numAttributes()];	 } catch (Exception ex) {	   ex.printStackTrace();	   return;	 }       }       // Pass on the structure       m_instanceEvent.setStructure(m_format);       notifyInstanceAvailable(m_instanceEvent);       return;    }    Instance newInst;    try {      // process the actual instance      for (int i = 0; i < oldNumAtts; i++) {	m_instanceVals[i] = currentI.value(i);      }      if (!m_appendProbabilities 	  || currentI.dataset().classAttribute().isNumeric()) {	double predClass = 	  classifier.classifyInstance(currentI);	m_instanceVals[m_instanceVals.length - 1] = predClass;      } else if (m_appendProbabilities) {	double [] preds = classifier.distributionForInstance(currentI);	for (int i = oldNumAtts; i < m_instanceVals.length; i++) {	  m_instanceVals[i] = preds[i-oldNumAtts];	}            }          } catch (Exception ex) {      ex.printStackTrace();      return;    } finally {      newInst = new Instance(currentI.weight(), m_instanceVals);      newInst.setDataset(m_format);      m_instanceEvent.setInstance(newInst);      m_instanceEvent.setStatus(status);      // notify listeners      notifyInstanceAvailable(m_instanceEvent);    }    if (status == IncrementalClassifierEvent.BATCH_FINISHED) {      // clean up      //      m_incrementalStructure = null;      m_instanceVals = null;      m_instanceEvent = null;    }  }  /**   * Accept and process a batch classifier event   *   * @param e a <code>BatchClassifierEvent</code> value   */  public void acceptClassifier(BatchClassifierEvent e) {    if (m_dataSourceListeners.size() > 0 	|| m_trainingSetListeners.size() > 0	|| m_testSetListeners.size() > 0) {      Instances testSet = e.getTestSet().getDataSet();      Instances trainSet = e.getTrainSet().getDataSet();      int setNum = e.getSetNumber();      int maxNum = e.getMaxSetNumber();      weka.classifiers.Classifier classifier = e.getClassifier();      String relationNameModifier = "_set_"+e.getSetNumber()+"_of_"	+e.getMaxSetNumber();      if (!m_appendProbabilities || testSet.classAttribute().isNumeric()) {	try {	  Instances newTestSetInstances = makeDataSetClass(testSet, classifier,						    relationNameModifier);	  Instances newTrainingSetInstances = makeDataSetClass(trainSet, classifier,		    relationNameModifier);	  	  if (m_trainingSetListeners.size() > 0) {	    TrainingSetEvent tse = new TrainingSetEvent(this,		new Instances(newTrainingSetInstances, 0));	    tse.m_setNumber = setNum;	    tse.m_maxSetNumber = maxNum;	    notifyTrainingSetAvailable(tse);	    // fill in predicted values            for (int i = 0; i < trainSet.numInstances(); i++) {              double predClass =         	classifier.classifyInstance(trainSet.instance(i));              newTrainingSetInstances.instance(i).setValue(newTrainingSetInstances.numAttributes()-1,        	  predClass);            }            tse = new TrainingSetEvent(this,        	newTrainingSetInstances);            tse.m_setNumber = setNum;            tse.m_maxSetNumber = maxNum;            notifyTrainingSetAvailable(tse);	  }	  	  if (m_testSetListeners.size() > 0) {	    TestSetEvent tse = new TestSetEvent(this,		new Instances(newTestSetInstances, 0));	    tse.m_setNumber = setNum;	    tse.m_maxSetNumber = maxNum;	    notifyTestSetAvailable(tse);	  }	  if (m_dataSourceListeners.size() > 0) {	    notifyDataSetAvailable(new DataSetEvent(this, new Instances(newTestSetInstances,0)));	  }          if (e.getTestSet().isStructureOnly()) {	    m_format = newTestSetInstances;	  }          if (m_dataSourceListeners.size() > 0 || m_testSetListeners.size() > 0) {            // fill in predicted values            for (int i = 0; i < testSet.numInstances(); i++) {              double predClass =         	classifier.classifyInstance(testSet.instance(i));              newTestSetInstances.instance(i).setValue(newTestSetInstances.numAttributes()-1,        	  predClass);            }          }	  // notify listeners          if (m_testSetListeners.size() > 0) {            TestSetEvent tse = new TestSetEvent(this, newTestSetInstances);            tse.m_setNumber = setNum;            tse.m_maxSetNumber = maxNum;            notifyTestSetAvailable(tse);          }          if (m_dataSourceListeners.size() > 0) {            notifyDataSetAvailable(new DataSetEvent(this, newTestSetInstances));            

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?