📄 classifier.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * Classifier.java * Copyright (C) 2002 Mark Hall * */package weka.gui.beans;import java.util.Vector;import java.util.Enumeration;import java.util.Hashtable;import javax.swing.JPanel;import javax.swing.JLabel;import javax.swing.JTextField;import java.awt.BorderLayout;import java.awt.event.MouseAdapter;import java.awt.event.MouseEvent;import java.awt.event.InputEvent;import java.awt.*;import java.io.Serializable;import java.io.Reader;import java.io.BufferedReader;import java.io.FileReader;import java.io.File;import javax.swing.ImageIcon;import javax.swing.SwingConstants;import java.beans.EventSetDescriptor;import weka.core.Instance;import weka.core.Instances;import weka.classifiers.*;import weka.classifiers.rules.ZeroR;import weka.gui.Logger;/** * Bean that wraps around weka.classifiers * * @author <a href="mailto:mhall@cs.waikato.ac.nz">Mark Hall</a> * @version $Revision: 1.1.1.1 $ * @since 1.0 * @see JPanel * @see BeanCommon * @see Visible * @see WekaWrapper * @see Serializable * @see UserRequestAcceptor * @see TrainingSetListener * @see TestSetListener */public class Classifier extends JPanel implements BeanCommon, Visible, WekaWrapper, EventConstraints, Serializable, UserRequestAcceptor, TrainingSetListener, TestSetListener, InstanceListener { protected BeanVisual m_visual = new BeanVisual("Classifier", BeanVisual.ICON_PATH+"DefaultClassifier.gif", BeanVisual.ICON_PATH+"DefaultClassifier_animated.gif"); private static int IDLE = 0; private static int BUILDING_MODEL = 1; private static int CLASSIFYING = 2; private int m_state = IDLE; private Thread m_buildThread = null; /** * Objects talking to us */ private Hashtable m_listenees = new Hashtable(); /** * Objects listening for batch classifier events */ private Vector m_batchClassifierListeners = new Vector(); /** * Objects listening for incremental classifier events */ private Vector m_incrementalClassifierListeners = new Vector(); /** * Objects listening for graph events */ private Vector m_graphListeners = new Vector(); /** * Objects listening for text events */ private Vector m_textListeners = new Vector(); /** * Holds training instances for batch training. Not transient because * header is retained for validating any instance events that this * classifier might be asked to predict in the future. */ private Instances m_trainingSet; private transient Instances m_testingSet; private weka.classifiers.Classifier m_Classifier = new ZeroR(); private IncrementalClassifierEvent m_ie = new IncrementalClassifierEvent(this); /** * If the classifier is an incremental classifier, should we * update it (ie train it on incoming instances). This makes it * possible incrementally test on a separate stream of instances * without updating the classifier, or mix batch training/testing * with incremental training/testing */ private boolean m_updateIncrementalClassifier = true; private transient Logger m_log = null; /** * Event to handle when processing incremental updates */ private InstanceEvent m_incrementalEvent; private Double m_dummy = new Double(0.0); /** * Creates a new <code>Classifier</code> instance. */ public Classifier() { setLayout(new BorderLayout()); add(m_visual, BorderLayout.CENTER); setClassifier(m_Classifier); } /** * Set the classifier for this wrapper * * @param c a <code>weka.classifiers.Classifier</code> value */ public void setClassifier(weka.classifiers.Classifier c) { boolean loadImages = true; if (c.getClass().getName(). compareTo(m_Classifier.getClass().getName()) == 0) { loadImages = false; } else { // classifier has changed so any batch training status is now // invalid m_trainingSet = null; } m_Classifier = c; String classifierName = c.getClass().toString(); classifierName = classifierName.substring(classifierName. lastIndexOf('.')+1, classifierName.length()); if (loadImages) { if (!m_visual.loadIcons(BeanVisual.ICON_PATH+classifierName+".gif", BeanVisual.ICON_PATH+classifierName+"_animated.gif")) { useDefaultVisual(); } } m_visual.setText(classifierName); if (!(m_Classifier instanceof weka.classifiers.UpdateableClassifier) && (m_listenees.containsKey("instance"))) { if (m_log != null) { m_log.logMessage("WARNING : "+m_Classifier.getClass().getName() +" is not an incremental classifier (Classifier)"); } } } /** * Get the classifier currently set for this wrapper * * @return a <code>weka.classifiers.Classifier</code> value */ public weka.classifiers.Classifier getClassifier() { return m_Classifier; } /** * Sets the algorithm (classifier) for this bean * * @param algorithm an <code>Object</code> value * @exception IllegalArgumentException if an error occurs */ public void setWrappedAlgorithm(Object algorithm) { if (!(algorithm instanceof weka.classifiers.Classifier)) { throw new IllegalArgumentException(algorithm.getClass()+" : incorrect " +"type of algorithm (Classifier)"); } setClassifier((weka.classifiers.Classifier)algorithm); } /** * Returns the wrapped classifier * * @return an <code>Object</code> value */ public Object getWrappedAlgorithm() { return getClassifier(); } public boolean getUpdateIncrementalClassifier() { return m_updateIncrementalClassifier; } public void setUpdateIncrementalClassifier(boolean update) { m_updateIncrementalClassifier = update; }// public void acceptDataSet(DataSetEvent e) {// // will wrap up data in a TrainingSetEvent and call acceptTrainingSet// // then will do same for TestSetEvent// acceptTrainingSet(new TrainingSetEvent(e.getSource(), e.getDataSet()));// } /** * Accepts an instance for incremental processing. * * @param e an <code>InstanceEvent</code> value */ public void acceptInstance(InstanceEvent e) { /* if (m_buildThread == null) { System.err.println("Starting handler "); startIncrementalHandler(); } */ // if (m_Classifier instanceof weka.classifiers.UpdateableClassifier) { /* synchronized(m_dummy) { m_state = BUILDING_MODEL; m_incrementalEvent = e; m_dummy.notifyAll(); } try { // if (m_state == BUILDING_MODEL && m_buildThread != null) { block(true); // } } catch (Exception ex) { return; } */ m_incrementalEvent = e; handleIncrementalEvent(); // } } /** * Handles initializing and updating an incremental classifier */ private void handleIncrementalEvent() { if (m_buildThread != null) { String messg = "Classifier is currently batch training!"; if (m_log != null) { m_log.logMessage(messg); } else { System.err.println(messg); } return; } if (m_incrementalEvent.getStatus() == InstanceEvent.FORMAT_AVAILABLE) { Instances dataset = m_incrementalEvent.getInstance().dataset(); // default to the last column if no class is set if (dataset.classIndex() < 0) { // System.err.println("Classifier : setting class index..."); dataset.setClassIndex(dataset.numAttributes()-1); } try { // initialize classifier if m_trainingSet is null // otherwise assume that classifier has been pre-trained in batch // mode, *if* headers match if (m_trainingSet == null || (!dataset.equalHeaders(m_trainingSet))) { if (!(m_Classifier instanceof weka.classifiers.UpdateableClassifier)) { if (m_log != null) { String msg = (m_trainingSet == null) ? "ERROR : "+m_Classifier.getClass().getName() +" has not been batch " +"trained; can't process instance events." : "ERROR : instance event's structure is different from " +"the data that " + "was used to batch train this classifier; can't continue."; m_log.logMessage(msg); } return; } if (m_trainingSet != null && (!dataset.equalHeaders(m_trainingSet))) { if (m_log != null) { m_log.logMessage("Warning : structure of instance events differ " +"from data used in batch training this " +"classifier. Resetting classifier..."); } m_trainingSet = null; } if (m_trainingSet == null) { // initialize the classifier if it hasn't been trained yet m_Classifier.buildClassifier(dataset); m_trainingSet = new Instances(dataset, 0); } } } catch (Exception ex) { ex.printStackTrace(); } } try { // test on this instance int status = IncrementalClassifierEvent.WITHIN_BATCH; if (m_incrementalEvent.getStatus() == InstanceEvent.FORMAT_AVAILABLE) { status = IncrementalClassifierEvent.NEW_BATCH; } else if (m_incrementalEvent.getStatus() == InstanceEvent.BATCH_FINISHED) { status = IncrementalClassifierEvent.BATCH_FINISHED; } m_ie.setStatus(status); m_ie.setClassifier(m_Classifier); m_ie.setCurrentInstance(m_incrementalEvent.getInstance()); notifyIncrementalClassifierListeners(m_ie); // now update on this instance (if class is not missing and classifier // is updateable and user has specified that classifier is to be // updated) if (m_Classifier instanceof weka.classifiers.UpdateableClassifier && m_updateIncrementalClassifier == true && !(m_incrementalEvent.getInstance(). isMissing(m_incrementalEvent.getInstance(). dataset().classIndex()))) { ((weka.classifiers.UpdateableClassifier)m_Classifier). updateClassifier(m_incrementalEvent.getInstance()); } if (m_incrementalEvent.getStatus() == InstanceEvent.BATCH_FINISHED) { if (m_textListeners.size() > 0) { String modelString = m_Classifier.toString(); String titleString = m_Classifier.getClass().getName(); titleString = titleString. substring(titleString.lastIndexOf('.') + 1, titleString.length()); titleString = "( "+m_trainingSet.relationName() + ") " + titleString + " model"; TextEvent nt = new TextEvent(this, modelString, titleString); notifyTextListeners(nt); } } } catch (Exception ex) { if (m_log != null) { m_log.logMessage(ex.toString()); } ex.printStackTrace(); } } /** * Unused at present */ private void startIncrementalHandler() { if (m_buildThread == null) { m_buildThread = new Thread() { public void run() { while (true) { synchronized(m_dummy) { try { m_dummy.wait(); } catch (InterruptedException ex) { // m_buildThread = null; // System.err.println("Here"); return; } } Classifier.this.handleIncrementalEvent(); m_state = IDLE; block(false); } } }; m_buildThread.setPriority(Thread.MIN_PRIORITY); m_buildThread.start(); // give thread a chance to start try { Thread.sleep(500); } catch (InterruptedException ex) { } } } /** * Accepts a training set and builds batch classifier * * @param e a <code>TrainingSetEvent</code> value */ public void acceptTrainingSet(final TrainingSetEvent e) { if (m_buildThread == null) { try { if (m_state == IDLE) { synchronized (this) { m_state = BUILDING_MODEL; } m_trainingSet = e.getTrainingSet(); final String oldText = m_visual.getText(); m_buildThread = new Thread() { public void run() { try { if (m_trainingSet != null) { if (m_trainingSet.classIndex() < 0) { // assume last column is the class m_trainingSet.setClassIndex(m_trainingSet.numAttributes()-1); if (m_log != null) { m_log.logMessage("Classifier : assuming last " +"column is the class"); } } m_visual.setAnimated(); m_visual.setText("Building model..."); if (m_log != null) { m_log.statusMessage("Classifier : building model..."); } buildClassifier(); if (m_Classifier instanceof weka.core.Drawable && m_graphListeners.size() > 0) { String grphString = ((weka.core.Drawable)m_Classifier).graph(); String grphTitle = m_Classifier.getClass().getName(); grphTitle = grphTitle.substring(grphTitle. lastIndexOf('.')+1, grphTitle.length()); grphTitle = "Set " + e.getSetNumber() + " (" +e.getTrainingSet().relationName() + ") " +grphTitle; GraphEvent ge = new GraphEvent(Classifier.this, grphString, grphTitle); notifyGraphListeners(ge); } if (m_textListeners.size() > 0) { String modelString = m_Classifier.toString(); String titleString = m_Classifier.getClass().getName(); titleString = titleString. substring(titleString.lastIndexOf('.') + 1, titleString.length()); titleString = "Set "+e.getSetNumber() + " (" + m_trainingSet.relationName() + ") " + titleString + " model"; TextEvent nt = new TextEvent(Classifier.this, modelString, titleString); notifyTextListeners(nt); } } } catch (Exception ex) { ex.printStackTrace(); } finally { m_visual.setText(oldText); m_visual.setStatic(); m_state = IDLE; if (isInterrupted()) { // prevent any classifier events from being fired m_trainingSet = null;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -