classifierpanel.java

来自「Weka」· Java 代码 · 共 1,867 行 · 第 1/5 页

JAVA
1,867
字号
   * Sets the Logger to receive informational messages   *   * @param newLog the Logger that will now get info messages   */  public void setLog(Logger newLog) {    m_Log = newLog;  }  /**   * Tells the panel to use a new set of instances.   *   * @param inst a set of Instances   */  public void setInstances(Instances inst) {    m_Instances = inst;    String [] attribNames = new String [m_Instances.numAttributes()];    for (int i = 0; i < attribNames.length; i++) {      String type = "";      switch (m_Instances.attribute(i).type()) {      case Attribute.NOMINAL:	type = "(Nom) ";	break;      case Attribute.NUMERIC:	type = "(Num) ";	break;      case Attribute.STRING:	type = "(Str) ";	break;      case Attribute.DATE:	type = "(Dat) ";	break;      case Attribute.RELATIONAL:	type = "(Rel) ";	break;      default:	type = "(???) ";      }      attribNames[i] = type + m_Instances.attribute(i).name();    }    m_ClassCombo.setModel(new DefaultComboBoxModel(attribNames));    if (attribNames.length > 0) {      if (inst.classIndex() == -1)	m_ClassCombo.setSelectedIndex(attribNames.length - 1);      else	m_ClassCombo.setSelectedIndex(inst.classIndex());      m_ClassCombo.setEnabled(true);      m_StartBut.setEnabled(m_RunThread == null);      m_StopBut.setEnabled(m_RunThread != null);    } else {      m_StartBut.setEnabled(false);      m_StopBut.setEnabled(false);    }  }  /**   * Sets the user test set. Information about the current test set   * is displayed in an InstanceSummaryPanel and the user is given the   * ability to load another set from a file or url.   *   */  protected void setTestSet() {    if (m_SetTestFrame == null) {      final SetInstancesPanel sp = new SetInstancesPanel();      if (m_TestLoader != null) {        try {          if (m_TestLoader.getStructure() != null)            sp.setInstances(m_TestLoader.getStructure());        } catch (Exception ex) {          ex.printStackTrace();        }      }      sp.addPropertyChangeListener(new PropertyChangeListener() {	public void propertyChange(PropertyChangeEvent e) {	  m_TestLoader = sp.getLoader();	}      });      // Add propertychangelistener to update m_TestLoader whenever      // it changes in the settestframe      m_SetTestFrame = new JFrame("Test Instances");      sp.setParentFrame(m_SetTestFrame);   // enable Close-Button      m_SetTestFrame.getContentPane().setLayout(new BorderLayout());      m_SetTestFrame.getContentPane().add(sp, BorderLayout.CENTER);      m_SetTestFrame.pack();    }    m_SetTestFrame.setVisible(true);  }  /**   * Process a classifier's prediction for an instance and update a   * set of plotting instances and additional plotting info. plotInfo   * for nominal class datasets holds shape types (actual data points have   * automatic shape type assignment; classifier error data points have   * box shape type). For numeric class datasets, the actual data points   * are stored in plotInstances and plotInfo stores the error (which is   * later converted to shape size values)   * @param toPredict the actual data point   * @param classifier the classifier   * @param eval the evaluation object to use for evaluating the classifier on   * the instance to predict   * @param plotInstances a set of plottable instances   * @param plotShape additional plotting information (shape)   * @param plotSize additional plotting information (size)   */  public static void processClassifierPrediction(Instance toPredict,                                           Classifier classifier,					   Evaluation eval,					   Instances plotInstances,					   FastVector plotShape,					   FastVector plotSize) {    try {      double pred = eval.evaluateModelOnceAndRecordPrediction(classifier, 							      toPredict);      if (plotInstances != null) {        double [] values = new double[plotInstances.numAttributes()];        for (int i = 0; i < plotInstances.numAttributes(); i++) {          if (i < toPredict.classIndex()) {            values[i] = toPredict.value(i);          } else if (i == toPredict.classIndex()) {            values[i] = pred;            values[i+1] = toPredict.value(i);            /* // if the class value of the instances to predict is missing then            // set it to the predicted value            if (toPredict.isMissing(i)) {	    values[i+1] = pred;	    } */            i++;          } else {            values[i] = toPredict.value(i-1);          }        }        plotInstances.add(new Instance(1.0, values));        if (toPredict.classAttribute().isNominal()) {          if (toPredict.isMissing(toPredict.classIndex())               || Instance.isMissingValue(pred)) {            plotShape.addElement(new Integer(Plot2D.MISSING_SHAPE));          } else if (pred != toPredict.classValue()) {            // set to default error point shape            plotShape.addElement(new Integer(Plot2D.ERROR_SHAPE));          } else {            // otherwise set to constant (automatically assigned) point shape            plotShape.addElement(new Integer(Plot2D.CONST_AUTOMATIC_SHAPE));          }          plotSize.addElement(new Integer(Plot2D.DEFAULT_SHAPE_SIZE));        } else {          // store the error (to be converted to a point size later)          Double errd = null;          if (!toPredict.isMissing(toPredict.classIndex()) &&               !Instance.isMissingValue(pred)) {            errd = new Double(pred - toPredict.classValue());            plotShape.addElement(new Integer(Plot2D.CONST_AUTOMATIC_SHAPE));          } else {            // missing shape if actual class not present or prediction is missing            plotShape.addElement(new Integer(Plot2D.MISSING_SHAPE));          }          plotSize.addElement(errd);        }      }    } catch (Exception ex) {      ex.printStackTrace();    }  }  /**   * Post processes numeric class errors into shape sizes for plotting   * in the visualize panel   * @param plotSize a FastVector of numeric class errors   */  private void postProcessPlotInfo(FastVector plotSize) {    int maxpSize = 20;    double maxErr = Double.NEGATIVE_INFINITY;    double minErr = Double.POSITIVE_INFINITY;    double err;        for (int i = 0; i < plotSize.size(); i++) {      Double errd = (Double)plotSize.elementAt(i);      if (errd != null) {	err = Math.abs(errd.doubleValue());        if (err < minErr) {	  minErr = err;	}	if (err > maxErr) {	  maxErr = err;	}      }    }        for (int i = 0; i < plotSize.size(); i++) {      Double errd = (Double)plotSize.elementAt(i);      if (errd != null) {	err = Math.abs(errd.doubleValue());	if (maxErr - minErr > 0) {	  double temp = (((err - minErr) / (maxErr - minErr)) 			 * maxpSize);	  plotSize.setElementAt(new Integer((int)temp), i);	} else {	  plotSize.setElementAt(new Integer(1), i);	}      } else {	plotSize.setElementAt(new Integer(1), i);      }    }  }  /**   * Sets up the structure for the visualizable instances. This dataset   * contains the original attributes plus the classifier's predictions   * for the class as an attribute called "predicted+WhateverTheClassIsCalled".   * @param trainInstances the instances that the classifier is trained on   * @return a new set of instances containing one more attribute (predicted   * class) than the trainInstances   */  public static Instances setUpVisualizableInstances(Instances trainInstances) {    FastVector hv = new FastVector();    Attribute predictedClass;    Attribute classAt = trainInstances.attribute(trainInstances.classIndex());    if (classAt.isNominal()) {      FastVector attVals = new FastVector();      for (int i = 0; i < classAt.numValues(); i++) {	attVals.addElement(classAt.value(i));      }      predictedClass = new Attribute("predicted"+classAt.name(), attVals);    } else {      predictedClass = new Attribute("predicted"+classAt.name());    }    for (int i = 0; i < trainInstances.numAttributes(); i++) {      if (i == trainInstances.classIndex()) {	hv.addElement(predictedClass);      }      hv.addElement(trainInstances.attribute(i).copy());    }    return new Instances(trainInstances.relationName()+"_predicted", hv, 			 trainInstances.numInstances());  }  /**   * outputs the header for the predictions on the data   *    * @param outBuff	the buffer to add the output to   * @param inst	the data header   * @param title	the title to print   */  protected void printPredictionsHeader(StringBuffer outBuff, Instances inst, String title) {    outBuff.append("=== Predictions on " + title + " ===\n\n");    outBuff.append(" inst#,    actual, predicted, error");    if (inst.classAttribute().isNominal()) {      outBuff.append(", probability distribution");    }    if (m_OutputAdditionalAttributesRange != null) {      outBuff.append(" (");      boolean first = true;      for (int i = 0; i < inst.numAttributes() - 1; i++) {	if (m_OutputAdditionalAttributesRange.isInRange(i)) {	  if (!first)	    outBuff.append(",");	  else	    first = false;	  outBuff.append(inst.attribute(i).name());	}      }      outBuff.append(")");    }    outBuff.append("\n");  }    /**   * Starts running the currently configured classifier with the current   * settings. This is run in a separate thread, and will only start if   * there is no classifier already running. The classifier output is sent   * to the results history panel.   */  protected void startClassifier() {    if (m_RunThread == null) {      synchronized (this) {	m_StartBut.setEnabled(false);	m_StopBut.setEnabled(true);      }      m_RunThread = new Thread() {	public void run() {	  // Copy the current state of things	  m_Log.statusMessage("Setting up...");	  CostMatrix costMatrix = null;	  Instances inst = new Instances(m_Instances);	  DataSource source = null;          Instances userTestStructure = null;	  // additional vis info (either shape type or point size)	  FastVector plotShape = new FastVector();	  FastVector plotSize = new FastVector();	  Instances predInstances = null;	 	  // for timing	  long trainTimeStart = 0, trainTimeElapsed = 0;          try {            if (m_TestLoader != null && m_TestLoader.getStructure() != null) {              m_TestLoader.reset();              source = new DataSource(m_TestLoader);              userTestStructure = source.getStructure();            }          } catch (Exception ex) {            ex.printStackTrace();          }	  if (m_EvalWRTCostsBut.isSelected()) {	    costMatrix = new CostMatrix((CostMatrix) m_CostMatrixEditor					.getValue());	  }	  boolean outputModel = m_OutputModelBut.isSelected();	  boolean outputConfusion = m_OutputConfusionBut.isSelected();	  boolean outputPerClass = m_OutputPerClassBut.isSelected();	  boolean outputSummary = true;          boolean outputEntropy = m_OutputEntropyBut.isSelected();	  boolean saveVis = m_StorePredictionsBut.isSelected();	  boolean outputPredictionsText = m_OutputPredictionsTextBut.isSelected();	  if (m_OutputAdditionalAttributesText.getText().equals("")) {	    m_OutputAdditionalAttributesRange = null;	  }	  else {	    m_OutputAdditionalAttributesRange = new Range(m_OutputAdditionalAttributesText.getText());	    m_OutputAdditionalAttributesRange.setUpper(inst.numAttributes() - 1);	  }	  String grph = null;	  int testMode = 0;	  int numFolds = 10, percent = 66;	  int classIndex = m_ClassCombo.getSelectedIndex();	  Classifier classifier = (Classifier) m_ClassifierEditor.getValue();	  Classifier template = null;	  try {	    template = Classifier.makeCopy(classifier);	  } catch (Exception ex) {	    m_Log.logMessage("Problem copying classifier: " + ex.getMessage());	  }	  Classifier fullClassifier = null;	  StringBuffer outBuff = new StringBuffer();	  String name = (new SimpleDateFormat("HH:mm:ss - "))	  .format(new Date());	  String cname = classifier.getClass().getName();	  if (cname.startsWith("weka.classifiers.")) {	    name += cname.substring("weka.classifiers.".length());	  } else {	    name += cname;	  }          String cmd = m_ClassifierEditor.getValue().getClass().getName();          if (m_ClassifierEditor.getValue() instanceof OptionHandler)            cmd += " " + Utils.joinOptions(((OptionHandler) m_ClassifierEditor.getValue()).getOptions());	  Evaluation eval = null;	  try {	    if (m_CVBut.isSelected()) {	      testMode = 1;	      numFolds = Integer.parseInt(m_CVText.getText());	      if (numFolds <= 1) {		throw new Exception("Number of folds must be greater than 1");	      }	    } else if (m_PercentBut.isSelected()) {	      testMode = 2;	      percent = Integer.parseInt(m_PercentText.getText());	      if ((percent <= 0) || (percent >= 100)) {		throw new Exception("Percentage must be between 0 and 100");	      }	    } else if (m_TrainBut.isSelected()) {	      testMode = 3;	    } else if (m_TestSplitBut.isSelected()) {	      testMode = 4;	      // Check the test instance compatibility	      if (source == null) {

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?