classifierpanel.java

来自「Weka」· Java 代码 · 共 1,867 行 · 第 1/5 页

JAVA
1,867
字号
      m_RunThread.start();    }  }  /**   * generates a prediction row for an instance   *    * @param classifier the classifier to use for making the prediction   * @param inst the instance to predict   * @param instNum the index of the instance   * @throws Exception if something goes wrong   * @return the generated row   */  protected String predictionText(Classifier classifier, Instance inst, int instNum) throws Exception {    //> inst#   actual   predicted   error  probability distribution    StringBuffer text = new StringBuffer();    // inst #    text.append(Utils.padLeft("" + instNum, 6) + " ");    if (inst.classAttribute().isNominal()) {      // actual      if (inst.classIsMissing()) text.append(Utils.padLeft("?", 10) + " ");      else text.append(Utils.padLeft("" + ((int) inst.classValue()+1) + ":"				+ inst.stringValue(inst.classAttribute()), 10) + " ");      // predicted      double[] probdist = null;      double pred;      if (inst.classAttribute().isNominal()) {	probdist = classifier.distributionForInstance(inst);	pred = (double) Utils.maxIndex(probdist);	if (probdist[(int) pred] <= 0.0) pred = Instance.missingValue();      } else {	pred = classifier.classifyInstance(inst);      }      text.append(Utils.padLeft((Instance.isMissingValue(pred) ? "?" :				 (((int) pred+1) + ":"				 + inst.classAttribute().value((int) pred))), 10) + " ");      // error      if (pred == inst.classValue()) text.append(Utils.padLeft(" ", 6) + " ");      else text.append(Utils.padLeft("+", 6) + " ");      // prob dist      if (inst.classAttribute().type() == Attribute.NOMINAL) {	for (int i=0; i<probdist.length; i++) {	  if (i == (int) pred) text.append(" *");	  else text.append("  ");	  text.append(Utils.doubleToString(probdist[i], 5, 3));	}      }    } else {      // actual      if (inst.classIsMissing()) text.append(Utils.padLeft("?", 10) + " ");      else text.append(Utils.doubleToString(inst.classValue(), 10, 3) + " ");      // predicted      double pred = classifier.classifyInstance(inst);      if (Instance.isMissingValue(pred)) text.append(Utils.padLeft("?", 10) + " ");      else text.append(Utils.doubleToString(pred, 10, 3) + " ");      // err      if (!inst.classIsMissing() && !Instance.isMissingValue(pred))	text.append(Utils.doubleToString(pred - inst.classValue(), 10, 3));    }        // additional Attributes    if (m_OutputAdditionalAttributesRange != null) {      text.append(" (");      boolean first = true;      for (int i = 0; i < inst.numAttributes() - 1; i++) {	if (m_OutputAdditionalAttributesRange.isInRange(i)) {	  if (!first)	    text.append(",");	  else	    first = false;	  text.append(inst.toString(i));	}      }      text.append(")");    }        text.append("\n");    return text.toString();  }  /**   * Handles constructing a popup menu with visualization options.   * @param name the name of the result history list entry clicked on by   * the user   * @param x the x coordinate for popping up the menu   * @param y the y coordinate for popping up the menu   */  protected void visualize(String name, int x, int y) {    final String selectedName = name;    JPopupMenu resultListMenu = new JPopupMenu();        JMenuItem visMainBuffer = new JMenuItem("View in main window");    if (selectedName != null) {      visMainBuffer.addActionListener(new ActionListener() {	  public void actionPerformed(ActionEvent e) {	    m_History.setSingle(selectedName);	  }	});    } else {      visMainBuffer.setEnabled(false);    }    resultListMenu.add(visMainBuffer);        JMenuItem visSepBuffer = new JMenuItem("View in separate window");    if (selectedName != null) {      visSepBuffer.addActionListener(new ActionListener() {	public void actionPerformed(ActionEvent e) {	  m_History.openFrame(selectedName);	}      });    } else {      visSepBuffer.setEnabled(false);    }    resultListMenu.add(visSepBuffer);        JMenuItem saveOutput = new JMenuItem("Save result buffer");    if (selectedName != null) {      saveOutput.addActionListener(new ActionListener() {	  public void actionPerformed(ActionEvent e) {	    saveBuffer(selectedName);	  }	});    } else {      saveOutput.setEnabled(false);    }    resultListMenu.add(saveOutput);        JMenuItem deleteOutput = new JMenuItem("Delete result buffer");    if (selectedName != null) {      deleteOutput.addActionListener(new ActionListener() {	public void actionPerformed(ActionEvent e) {	  m_History.removeResult(selectedName);	}      });    } else {      deleteOutput.setEnabled(false);    }    resultListMenu.add(deleteOutput);    resultListMenu.addSeparator();        JMenuItem loadModel = new JMenuItem("Load model");    loadModel.addActionListener(new ActionListener() {	public void actionPerformed(ActionEvent e) {	  loadClassifier();	}      });    resultListMenu.add(loadModel);    FastVector o = null;    if (selectedName != null) {      o = (FastVector)m_History.getNamedObject(selectedName);    }    VisualizePanel temp_vp = null;    String temp_grph = null;    FastVector temp_preds = null;    Attribute temp_classAtt = null;    Classifier temp_classifier = null;    Instances temp_trainHeader = null;          if (o != null) {       for (int i = 0; i < o.size(); i++) {	Object temp = o.elementAt(i);	if (temp instanceof Classifier) {	  temp_classifier = (Classifier)temp;	} else if (temp instanceof Instances) { // training header	  temp_trainHeader = (Instances)temp;	} else if (temp instanceof VisualizePanel) { // normal errors	  temp_vp = (VisualizePanel)temp;	} else if (temp instanceof String) { // graphable output	  temp_grph = (String)temp;	} else if (temp instanceof FastVector) { // predictions	  temp_preds = (FastVector)temp;	} else if (temp instanceof Attribute) { // class attribute	  temp_classAtt = (Attribute)temp;	}      }    }    final VisualizePanel vp = temp_vp;    final String grph = temp_grph;    final FastVector preds = temp_preds;    final Attribute classAtt = temp_classAtt;    final Classifier classifier = temp_classifier;    final Instances trainHeader = temp_trainHeader;        JMenuItem saveModel = new JMenuItem("Save model");    if (classifier != null) {      saveModel.addActionListener(new ActionListener() {	  public void actionPerformed(ActionEvent e) {	    saveClassifier(selectedName, classifier, trainHeader);	  }	});    } else {      saveModel.setEnabled(false);    }    resultListMenu.add(saveModel);    JMenuItem reEvaluate =      new JMenuItem("Re-evaluate model on current test set");    if (classifier != null && m_TestLoader != null) {      reEvaluate.addActionListener(new ActionListener() {	  public void actionPerformed(ActionEvent e) {	    reevaluateModel(selectedName, classifier, trainHeader);	  }	});    } else {      reEvaluate.setEnabled(false);    }    resultListMenu.add(reEvaluate);        resultListMenu.addSeparator();        JMenuItem visErrors = new JMenuItem("Visualize classifier errors");    if (vp != null) {      visErrors.addActionListener(new ActionListener() {	  public void actionPerformed(ActionEvent e) {	    visualizeClassifierErrors(vp);	  }	});    } else {      visErrors.setEnabled(false);    }    resultListMenu.add(visErrors);    JMenuItem visGrph = new JMenuItem("Visualize tree");    if (grph != null) {	if(((Drawable)temp_classifier).graphType()==Drawable.TREE) {	    visGrph.addActionListener(new ActionListener() {		    public void actionPerformed(ActionEvent e) {			String title;			if (vp != null) title = vp.getName();			else title = selectedName;			visualizeTree(grph, title);		    }		});	}	else if(((Drawable)temp_classifier).graphType()==Drawable.BayesNet) {	    visGrph.setText("Visualize graph");	    visGrph.addActionListener(new ActionListener() {		    public void actionPerformed(ActionEvent e) {			Thread th = new Thread() {				public void run() {				visualizeBayesNet(grph, selectedName);				}			    };			th.start();		    }		});	}	else	    visGrph.setEnabled(false);    } else {      visGrph.setEnabled(false);    }    resultListMenu.add(visGrph);    JMenuItem visMargin = new JMenuItem("Visualize margin curve");    if (preds != null) {      visMargin.addActionListener(new ActionListener() {	  public void actionPerformed(ActionEvent e) {	    try {	      MarginCurve tc = new MarginCurve();	      Instances result = tc.getCurve(preds);	      VisualizePanel vmc = new VisualizePanel();	      vmc.setName(result.relationName());	      vmc.setLog(m_Log);	      PlotData2D tempd = new PlotData2D(result);	      tempd.setPlotName(result.relationName());	      tempd.addInstanceNumberAttribute();	      vmc.addPlot(tempd);	      visualizeClassifierErrors(vmc);	    } catch (Exception ex) {	      ex.printStackTrace();	    }	  }	});    } else {      visMargin.setEnabled(false);    }    resultListMenu.add(visMargin);    JMenu visThreshold = new JMenu("Visualize threshold curve");    if (preds != null && classAtt != null) {      for (int i = 0; i < classAtt.numValues(); i++) {	JMenuItem clv = new JMenuItem(classAtt.value(i));	final int classValue = i;	clv.addActionListener(new ActionListener() {	    public void actionPerformed(ActionEvent e) {	      try {		ThresholdCurve tc = new ThresholdCurve();		Instances result = tc.getCurve(preds, classValue);		//VisualizePanel vmc = new VisualizePanel();		ThresholdVisualizePanel vmc = new ThresholdVisualizePanel();		vmc.setROCString("(Area under ROC = " + 				 Utils.doubleToString(ThresholdCurve.getROCArea(result), 4) + ")");		vmc.setLog(m_Log);		vmc.setName(result.relationName()+". (Class value "+			    classAtt.value(classValue)+")");		PlotData2D tempd = new PlotData2D(result);		tempd.setPlotName(result.relationName());		tempd.addInstanceNumberAttribute();		vmc.addPlot(tempd);		visualizeClassifierErrors(vmc);	      } catch (Exception ex) {		ex.printStackTrace();	      }	      }	  });	  visThreshold.add(clv);      }    } else {      visThreshold.setEnabled(false);    }    resultListMenu.add(visThreshold);    JMenu visCost = new JMenu("Visualize cost curve");    if (preds != null && classAtt != null) {      for (int i = 0; i < classAtt.numValues(); i++) {	JMenuItem clv = new JMenuItem(classAtt.value(i));	final int classValue = i;	clv.addActionListener(new ActionListener() {	    public void actionPerformed(ActionEvent e) {	      try {		CostCurve cc = new CostCurve();		Instances result = cc.getCurve(preds, classValue);		VisualizePanel vmc = new VisualizePanel();		vmc.setLog(m_Log);		vmc.setName(result.relationName()+". (Class value "+			    classAtt.value(classValue)+")");		PlotData2D tempd = new PlotData2D(result);		tempd.m_displayAllPoints = true;		tempd.setPlotName(result.relationName());		boolean [] connectPoints = 		  new boolean [result.numInstances()];		for (int jj = 1; jj < connectPoints.length; jj+=2) {		  connectPoints[jj] = true;		}		tempd.setConnectPoints(connectPoints);		//		  tempd.addInstanceNumberAttribute();		vmc.addPlot(tempd);		visualizeClassifierErrors(vmc);	      } catch (Exception ex) {		ex.printStackTrace();	      }	    }	  });	visCost.add(clv);      }    } else {      visCost.setEnabled(false);    }    resultListMenu.add(visCost);        JMenu visPlugins = new JMenu("Plugins");    Vector pluginsVector = GenericObjectEditor.getClassnames(VisualizePlugin.class.getName());    boolean availablePlugins = false;    for (int i=0; i<pluginsVector.size(); i++) {      String className = (String)(pluginsVector.elementAt(i));      try {        VisualizePlugin plugin = (VisualizePlugin) Class.forName(className).newInstance();        if (plugin == n

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?