📄 classifierpanel.java
字号:
}); } else { visTree.setEnabled(false); } resultListMenu.add(visTree); JMenuItem visMargin = new JMenuItem("Visualize margin curve"); if (preds != null) { visMargin.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { try { MarginCurve tc = new MarginCurve(); Instances result = tc.getCurve(preds); VisualizePanel vmc = new VisualizePanel(); vmc.setName(result.relationName()); vmc.setLog(m_Log); PlotData2D tempd = new PlotData2D(result); tempd.setPlotName(result.relationName()); tempd.addInstanceNumberAttribute(); vmc.addPlot(tempd); visualizeClassifierErrors(vmc); } catch (Exception ex) { ex.printStackTrace(); } } }); } else { visMargin.setEnabled(false); } resultListMenu.add(visMargin); JMenu visThreshold = new JMenu("Visualize threshold curve"); if (preds != null && classAtt != null) { for (int i = 0; i < classAtt.numValues(); i++) { JMenuItem clv = new JMenuItem(classAtt.value(i)); final int classValue = i; clv.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { try { ThresholdCurve tc = new ThresholdCurve(); Instances result = tc.getCurve(preds, classValue); VisualizePanel vmc = new VisualizePanel(); vmc.setLog(m_Log); vmc.setName(result.relationName()+". Class value "+ classAtt.value(classValue)+")"); PlotData2D tempd = new PlotData2D(result); tempd.setPlotName(result.relationName()); tempd.addInstanceNumberAttribute(); vmc.addPlot(tempd); visualizeClassifierErrors(vmc); } catch (Exception ex) { ex.printStackTrace(); } } }); visThreshold.add(clv); } } else { visThreshold.setEnabled(false); } resultListMenu.add(visThreshold); JMenu visCost = new JMenu("Visualize cost curve"); if (preds != null && classAtt != null) { for (int i = 0; i < classAtt.numValues(); i++) { JMenuItem clv = new JMenuItem(classAtt.value(i)); final int classValue = i; clv.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { try { CostCurve cc = new CostCurve(); Instances result = cc.getCurve(preds, classValue); VisualizePanel vmc = new VisualizePanel(); vmc.setLog(m_Log); vmc.setName(result.relationName()+". Class value "+ classAtt.value(classValue)+")"); PlotData2D tempd = new PlotData2D(result); tempd.m_displayAllPoints = true; tempd.setPlotName(result.relationName()); boolean [] connectPoints = new boolean [result.numInstances()]; for (int jj = 1; jj < connectPoints.length; jj+=2) { connectPoints[jj] = true; } tempd.setConnectPoints(connectPoints); // tempd.addInstanceNumberAttribute(); vmc.addPlot(tempd); visualizeClassifierErrors(vmc); } catch (Exception ex) { ex.printStackTrace(); } } }); visCost.add(clv); } } else { visCost.setEnabled(false); } resultListMenu.add(visCost); resultListMenu.show(m_History.getList(), x, y); } /** * Pops up a TreeVisualizer for the classifier from the currently * selected item in the results list * @param dottyString the description of the tree in dotty format * @param treeName the title to assign to the display */ protected void visualizeTree(String dottyString, String treeName) { final javax.swing.JFrame jf = new javax.swing.JFrame("Weka Classifier Tree Visualizer: "+treeName); jf.setSize(500,400); jf.getContentPane().setLayout(new BorderLayout()); TreeVisualizer tv = new TreeVisualizer(null, dottyString, new PlaceNode2()); jf.getContentPane().add(tv, BorderLayout.CENTER); jf.addWindowListener(new java.awt.event.WindowAdapter() { public void windowClosing(java.awt.event.WindowEvent e) { jf.dispose(); } }); jf.setVisible(true); tv.fitToScreen(); } /** * Pops up a VisualizePanel for visualizing the data and errors for * the classifier from the currently selected item in the results list * @param sp the VisualizePanel to pop up. */ protected void visualizeClassifierErrors(VisualizePanel sp) { if (sp != null) { String plotName = sp.getName(); final javax.swing.JFrame jf = new javax.swing.JFrame("Weka Classifier Visualize: "+plotName); jf.setSize(500,400); jf.getContentPane().setLayout(new BorderLayout()); jf.getContentPane().add(sp, BorderLayout.CENTER); jf.addWindowListener(new java.awt.event.WindowAdapter() { public void windowClosing(java.awt.event.WindowEvent e) { jf.dispose(); } }); jf.setVisible(true); } } /** * Save the currently selected classifier output to a file. * @param name the name of the buffer to save */ protected void saveBuffer(String name) { StringBuffer sb = m_History.getNamedBuffer(name); if (sb != null) { if (m_SaveOut.save(sb)) { m_Log.logMessage("Save successful."); } } } /** * Stops the currently running classifier (if any). */ protected void stopClassifier() { if (m_RunThread != null) { m_RunThread.interrupt(); // This is deprecated (and theoretically the interrupt should do). m_RunThread.stop(); } } /** * Saves the currently selected classifier */ protected void saveClassifier(String name, Classifier classifier, Instances trainHeader) { File sFile = null; boolean saveOK = true; int returnVal = m_FileChooser.showSaveDialog(this); if (returnVal == JFileChooser.APPROVE_OPTION) { sFile = m_FileChooser.getSelectedFile(); m_Log.statusMessage("Saving model to file..."); try { OutputStream os = new FileOutputStream(sFile); if (sFile.getName().endsWith(".gz")) { os = new GZIPOutputStream(os); } ObjectOutputStream objectOutputStream = new ObjectOutputStream(os); objectOutputStream.writeObject(classifier); if (trainHeader != null) objectOutputStream.writeObject(trainHeader); objectOutputStream.flush(); objectOutputStream.close(); } catch (Exception e) { JOptionPane.showMessageDialog(null, e, "Save Failed", JOptionPane.ERROR_MESSAGE); saveOK = false; } if (saveOK) m_Log.logMessage("Saved model (" + name + ") to file '" + sFile.getName() + "'"); m_Log.statusMessage("OK"); } } /** * Loads a classifier */ protected void loadClassifier() { int returnVal = m_FileChooser.showOpenDialog(this); if (returnVal == JFileChooser.APPROVE_OPTION) { File selected = m_FileChooser.getSelectedFile(); Classifier classifier = null; Instances trainHeader = null; m_Log.statusMessage("Loading model from file..."); try { InputStream is = new FileInputStream(selected); if (selected.getName().endsWith(".gz")) { is = new GZIPInputStream(is); } ObjectInputStream objectInputStream = new ObjectInputStream(is); classifier = (Classifier) objectInputStream.readObject(); try { // see if we can load the header trainHeader = (Instances) objectInputStream.readObject(); } catch (Exception e) {} // don't fuss if we can't objectInputStream.close(); } catch (Exception e) { JOptionPane.showMessageDialog(null, e, "Load Failed", JOptionPane.ERROR_MESSAGE); } m_Log.statusMessage("OK"); if (classifier != null) { m_Log.logMessage("Loaded model from file '" + selected.getName()+ "'"); String name = (new SimpleDateFormat("HH:mm:ss - ")).format(new Date()); String cname = classifier.getClass().getName(); if (cname.startsWith("weka.classifiers.")) cname = cname.substring("weka.classifiers.".length()); name += cname + " from file '" + selected.getName() + "'"; StringBuffer outBuff = new StringBuffer(); outBuff.append("=== Model information ===\n\n"); outBuff.append("Filename: " + selected.getName() + "\n"); outBuff.append("Scheme: " + cname); if (classifier instanceof OptionHandler) { String [] o = ((OptionHandler) classifier).getOptions(); outBuff.append(" " + Utils.joinOptions(o)); } outBuff.append("\n"); if (trainHeader != null) { outBuff.append("Relation: " + trainHeader.relationName() + '\n'); outBuff.append("Attributes: " + trainHeader.numAttributes() + '\n'); if (trainHeader.numAttributes() < 100) { for (int i = 0; i < trainHeader.numAttributes(); i++) { outBuff.append(" " + trainHeader.attribute(i).name() + '\n'); } } else { outBuff.append(" [list of attributes omitted]\n"); } } else { outBuff.append("\nTraining data unknown\n"); } outBuff.append("\n=== Classifier model ===\n\n"); outBuff.append(classifier.toString() + "\n"); m_History.addResult(name, outBuff); m_History.setSingle(name); FastVector vv = new FastVector(); vv.addElement(classifier); if (trainHeader != null) vv.addElement(trainHeader); // allow visualization of graphable classifiers String grph = null; if (classifier instanceof Drawable) { try { grph = ((Drawable)classifier).graph(); } catch (Exception ex) { } } if (grph != null) vv.addElement(grph); m_History.addObject(name, vv); } } } /** * Re-evaluates the named classifier with the current test set. Unpredictable * things will happen if the data set is not compatible with the classifier. * * @param name the name of the classifier entry * @param classifier the classifier to evaluate */ protected void reevaluateModel(String name, Classifier classifier, Instances trainHeader) { StringBuffer outBuff = m_History.getNamedBuffer(name); Instances userTest = null; // additional vis info (either shape type or point size) FastVector plotShape = new FastVector(); FastVector plotSize = new FastVector(); Instances predInstances = null; // will hold the prediction objects if the class is nominal FastVector predictions = null; CostMatrix costMatrix = null; if (m_EvalWRTCostsBut.isSelected()) { costMatrix = new CostMatrix((CostMatrix) m_CostMatrixEditor .getValue()); } boolean outputConfusion = m_OutputConfusionBut.isSelected(); boolean outputPerClass = m_OutputPerClassBut.isSelected(); boolean outputSummary = true; boolean outputEntropy = m_OutputEntropyBut.isSelected(); boolean saveVis = m_StorePredictionsBut.isSelected(); String grph = null; try { if (m_TestInstances != null) { userTest = new Instances(m_TestInstancesCopy); } // Check the test instance compatibility if (userTest == null) { throw new Exception("No user test set has been opened"); } if (trainHeader != null) { if (trainHeader.classIndex() > userTest.numAttributes()-1) throw new Exception("Train and test set are not compatible"); userTest.setClassIndex(trainHeader.classIndex()); if (!trainHeader.equalHeaders(userTest)) { throw new Exception("Train and test set are not compatible"); } } else { userTest.setClassIndex(userTest.numAttributes()-1); } m_Log.statusMessage("Evaluating on test data..."); m_Log.logMessage("Re-evaluating classifier (" + name + ") on test set"); Evaluation eval = new Evaluation(userTest, costMatrix); // set up the structure of the plottable instances for // visualization predInstances = setUpVisualizableInstances(userTest); predInstances.setClassIndex(userTest.classIndex()+1); if (userTest.classAttribute().isNominal() && classifier instanceof DistributionClassifier) { predictions = new FastVector(); } for (int jj=0;jj<userTest.numInstances();jj++) { processClassifierPrediction(userTest.instance(jj), classifier, eval, predictions, predInstances, plotShape, plotSize); if ((jj % 100) == 0) { m_Log.statusMessage("Evaluating on test data. Processed " +jj+" instances..."); } } outBuff.append("\n=== Re-evaluation on test set ===\n\n"); outBuff.append("User supplied test set\n"); outBuff.append("Relation: " + userTest.relationName() + '\n'); outBuff.append("Instances: " + userTest.numInstances() + '\n'); outBuff.append("Attributes: " + userTest.numAttributes() + "\n\n"); if (trainHeader == null) outBuff.append("NOTE - if test set is not compatible then results are " + "unpredictable\n\n");
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -