predictionappender.java
来自「Weka」· Java 代码 · 共 857 行 · 第 1/2 页
JAVA
857 行
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * PredictionAppender.java * Copyright (C) 2003 University of Waikato, Hamilton, New Zealand * */package weka.gui.beans;import weka.clusterers.DensityBasedClusterer;import weka.core.Instance;import weka.core.Instances;import java.awt.BorderLayout;import java.beans.EventSetDescriptor;import java.io.Serializable;import java.util.Enumeration;import java.util.Vector;import javax.swing.JPanel;/** * Bean that can can accept batch or incremental classifier events * and produce dataset or instance events which contain instances with * predictions appended. * * @author <a href="mailto:mhall@cs.waikato.ac.nz">Mark Hall</a> * @version $Revision: 1.15 $ */public class PredictionAppender extends JPanel implements DataSource, TrainingSetProducer, TestSetProducer, Visible, BeanCommon, EventConstraints, BatchClassifierListener, IncrementalClassifierListener, BatchClustererListener, Serializable { /** for serialization */ private static final long serialVersionUID = -2987740065058976673L; /** * Objects listenening for dataset events */ protected Vector m_dataSourceListeners = new Vector(); /** * Objects listening for instances events */ protected Vector m_instanceListeners = new Vector(); /** * Objects listening for training set events */ protected Vector m_trainingSetListeners = new Vector();; /** * Objects listening for test set events */ protected Vector m_testSetListeners = new Vector(); /** * Non null if this object is a target for any events. */ protected Object m_listenee = null; /** * Format of instances to be produced. */ protected Instances m_format; protected BeanVisual m_visual = new BeanVisual("PredictionAppender", BeanVisual.ICON_PATH+"PredictionAppender.gif", BeanVisual.ICON_PATH+"PredictionAppender_animated.gif"); /** * Append classifier's predicted probabilities (if the class is discrete * and the classifier is a distribution classifier) */ protected boolean m_appendProbabilities; protected transient weka.gui.Logger m_logger; /** * Global description of this bean * * @return a <code>String</code> value */ public String globalInfo() { return "Accepts batch or incremental classifier events and " +"produces a new data set with classifier predictions appended."; } /** * Creates a new <code>PredictionAppender</code> instance. */ public PredictionAppender() { setLayout(new BorderLayout()); add(m_visual, BorderLayout.CENTER); } /** * Return a tip text suitable for displaying in a GUI * * @return a <code>String</code> value */ public String appendPredictedProbabilitiesTipText() { return "append probabilities rather than labels for discrete class " +"predictions"; } /** * Return true if predicted probabilities are to be appended rather * than class value * * @return a <code>boolean</code> value */ public boolean getAppendPredictedProbabilities() { return m_appendProbabilities; } /** * Set whether to append predicted probabilities rather than * class value (for discrete class data sets) * * @param ap a <code>boolean</code> value */ public void setAppendPredictedProbabilities(boolean ap) { m_appendProbabilities = ap; } /** * Add a training set listener * * @param tsl a <code>TrainingSetListener</code> value */ public void addTrainingSetListener(TrainingSetListener tsl) { // TODO Auto-generated method stub m_trainingSetListeners.addElement(tsl); // pass on any format that we might have determined so far if (m_format != null) { TrainingSetEvent e = new TrainingSetEvent(this, m_format); tsl.acceptTrainingSet(e); } } /** * Remove a training set listener * * @param tsl a <code>TrainingSetListener</code> value */ public void removeTrainingSetListener(TrainingSetListener tsl) { m_trainingSetListeners.removeElement(tsl); } /** * Add a test set listener * * @param tsl a <code>TestSetListener</code> value */ public void addTestSetListener(TestSetListener tsl) { m_testSetListeners.addElement(tsl);// pass on any format that we might have determined so far if (m_format != null) { TestSetEvent e = new TestSetEvent(this, m_format); tsl.acceptTestSet(e); } } /** * Remove a test set listener * * @param tsl a <code>TestSetListener</code> value */ public void removeTestSetListener(TestSetListener tsl) { m_testSetListeners.removeElement(tsl); } /** * Add a datasource listener * * @param dsl a <code>DataSourceListener</code> value */ public synchronized void addDataSourceListener(DataSourceListener dsl) { m_dataSourceListeners.addElement(dsl); // pass on any format that we might have determined so far if (m_format != null) { DataSetEvent e = new DataSetEvent(this, m_format); dsl.acceptDataSet(e); } } /** * Remove a datasource listener * * @param dsl a <code>DataSourceListener</code> value */ public synchronized void removeDataSourceListener(DataSourceListener dsl) { m_dataSourceListeners.remove(dsl); } /** * Add an instance listener * * @param dsl a <code>InstanceListener</code> value */ public synchronized void addInstanceListener(InstanceListener dsl) { m_instanceListeners.addElement(dsl); // pass on any format that we might have determined so far if (m_format != null) { InstanceEvent e = new InstanceEvent(this, m_format); dsl.acceptInstance(e); } } /** * Remove an instance listener * * @param dsl a <code>InstanceListener</code> value */ public synchronized void removeInstanceListener(InstanceListener dsl) { m_instanceListeners.remove(dsl); } /** * Set the visual for this data source * * @param newVisual a <code>BeanVisual</code> value */ public void setVisual(BeanVisual newVisual) { m_visual = newVisual; } /** * Get the visual being used by this data source. * */ public BeanVisual getVisual() { return m_visual; } /** * Use the default images for a data source * */ public void useDefaultVisual() { m_visual.loadIcons(BeanVisual.ICON_PATH+"PredictionAppender.gif", BeanVisual.ICON_PATH+"PredictionAppender_animated.gif"); } protected InstanceEvent m_instanceEvent; protected double [] m_instanceVals; /** * Accept and process an incremental classifier event * * @param e an <code>IncrementalClassifierEvent</code> value */ public void acceptClassifier(IncrementalClassifierEvent e) { weka.classifiers.Classifier classifier = e.getClassifier(); Instance currentI = e.getCurrentInstance(); int status = e.getStatus(); int oldNumAtts = 0; if (status == IncrementalClassifierEvent.NEW_BATCH) { oldNumAtts = e.getStructure().numAttributes(); } else { oldNumAtts = currentI.dataset().numAttributes(); } if (status == IncrementalClassifierEvent.NEW_BATCH) { m_instanceEvent = new InstanceEvent(this, null, 0); // create new header structure Instances oldStructure = new Instances(e.getStructure(), 0); //String relationNameModifier = oldStructure.relationName() //+"_with predictions"; String relationNameModifier = "_with predictions"; //+"_with predictions"; if (!m_appendProbabilities || oldStructure.classAttribute().isNumeric()) { try { m_format = makeDataSetClass(oldStructure, classifier, relationNameModifier); m_instanceVals = new double [m_format.numAttributes()]; } catch (Exception ex) { ex.printStackTrace(); return; } } else if (m_appendProbabilities) { try { m_format = makeDataSetProbabilities(oldStructure, classifier, relationNameModifier); m_instanceVals = new double [m_format.numAttributes()]; } catch (Exception ex) { ex.printStackTrace(); return; } } // Pass on the structure m_instanceEvent.setStructure(m_format); notifyInstanceAvailable(m_instanceEvent); return; } Instance newInst; try { // process the actual instance for (int i = 0; i < oldNumAtts; i++) { m_instanceVals[i] = currentI.value(i); } if (!m_appendProbabilities || currentI.dataset().classAttribute().isNumeric()) { double predClass = classifier.classifyInstance(currentI); m_instanceVals[m_instanceVals.length - 1] = predClass; } else if (m_appendProbabilities) { double [] preds = classifier.distributionForInstance(currentI); for (int i = oldNumAtts; i < m_instanceVals.length; i++) { m_instanceVals[i] = preds[i-oldNumAtts]; } } } catch (Exception ex) { ex.printStackTrace(); return; } finally { newInst = new Instance(currentI.weight(), m_instanceVals); newInst.setDataset(m_format); m_instanceEvent.setInstance(newInst); m_instanceEvent.setStatus(status); // notify listeners notifyInstanceAvailable(m_instanceEvent); } if (status == IncrementalClassifierEvent.BATCH_FINISHED) { // clean up // m_incrementalStructure = null; m_instanceVals = null; m_instanceEvent = null; } } /** * Accept and process a batch classifier event * * @param e a <code>BatchClassifierEvent</code> value */ public void acceptClassifier(BatchClassifierEvent e) { if (m_dataSourceListeners.size() > 0 || m_trainingSetListeners.size() > 0 || m_testSetListeners.size() > 0) { Instances testSet = e.getTestSet().getDataSet(); Instances trainSet = e.getTrainSet().getDataSet(); int setNum = e.getSetNumber(); int maxNum = e.getMaxSetNumber(); weka.classifiers.Classifier classifier = e.getClassifier(); String relationNameModifier = "_set_"+e.getSetNumber()+"_of_" +e.getMaxSetNumber(); if (!m_appendProbabilities || testSet.classAttribute().isNumeric()) { try { Instances newTestSetInstances = makeDataSetClass(testSet, classifier, relationNameModifier); Instances newTrainingSetInstances = makeDataSetClass(trainSet, classifier, relationNameModifier); if (m_trainingSetListeners.size() > 0) { TrainingSetEvent tse = new TrainingSetEvent(this, new Instances(newTrainingSetInstances, 0)); tse.m_setNumber = setNum; tse.m_maxSetNumber = maxNum; notifyTrainingSetAvailable(tse); // fill in predicted values for (int i = 0; i < trainSet.numInstances(); i++) { double predClass = classifier.classifyInstance(trainSet.instance(i)); newTrainingSetInstances.instance(i).setValue(newTrainingSetInstances.numAttributes()-1, predClass); } tse = new TrainingSetEvent(this, newTrainingSetInstances); tse.m_setNumber = setNum; tse.m_maxSetNumber = maxNum; notifyTrainingSetAvailable(tse); } if (m_testSetListeners.size() > 0) { TestSetEvent tse = new TestSetEvent(this, new Instances(newTestSetInstances, 0)); tse.m_setNumber = setNum; tse.m_maxSetNumber = maxNum; notifyTestSetAvailable(tse); } if (m_dataSourceListeners.size() > 0) { notifyDataSetAvailable(new DataSetEvent(this, new Instances(newTestSetInstances,0))); } if (e.getTestSet().isStructureOnly()) { m_format = newTestSetInstances; } if (m_dataSourceListeners.size() > 0 || m_testSetListeners.size() > 0) { // fill in predicted values for (int i = 0; i < testSet.numInstances(); i++) { double predClass = classifier.classifyInstance(testSet.instance(i)); newTestSetInstances.instance(i).setValue(newTestSetInstances.numAttributes()-1, predClass); } } // notify listeners if (m_testSetListeners.size() > 0) { TestSetEvent tse = new TestSetEvent(this, newTestSetInstances); tse.m_setNumber = setNum; tse.m_maxSetNumber = maxNum; notifyTestSetAvailable(tse); } if (m_dataSourceListeners.size() > 0) { notifyDataSetAvailable(new DataSetEvent(this, newTestSetInstances));
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?