classifierpanel.java
来自「Weka」· Java 代码 · 共 1,867 行 · 第 1/5 页
JAVA
1,867 行
m_RunThread.start(); } } /** * generates a prediction row for an instance * * @param classifier the classifier to use for making the prediction * @param inst the instance to predict * @param instNum the index of the instance * @throws Exception if something goes wrong * @return the generated row */ protected String predictionText(Classifier classifier, Instance inst, int instNum) throws Exception { //> inst# actual predicted error probability distribution StringBuffer text = new StringBuffer(); // inst # text.append(Utils.padLeft("" + instNum, 6) + " "); if (inst.classAttribute().isNominal()) { // actual if (inst.classIsMissing()) text.append(Utils.padLeft("?", 10) + " "); else text.append(Utils.padLeft("" + ((int) inst.classValue()+1) + ":" + inst.stringValue(inst.classAttribute()), 10) + " "); // predicted double[] probdist = null; double pred; if (inst.classAttribute().isNominal()) { probdist = classifier.distributionForInstance(inst); pred = (double) Utils.maxIndex(probdist); if (probdist[(int) pred] <= 0.0) pred = Instance.missingValue(); } else { pred = classifier.classifyInstance(inst); } text.append(Utils.padLeft((Instance.isMissingValue(pred) ? "?" : (((int) pred+1) + ":" + inst.classAttribute().value((int) pred))), 10) + " "); // error if (pred == inst.classValue()) text.append(Utils.padLeft(" ", 6) + " "); else text.append(Utils.padLeft("+", 6) + " "); // prob dist if (inst.classAttribute().type() == Attribute.NOMINAL) { for (int i=0; i<probdist.length; i++) { if (i == (int) pred) text.append(" *"); else text.append(" "); text.append(Utils.doubleToString(probdist[i], 5, 3)); } } } else { // actual if (inst.classIsMissing()) text.append(Utils.padLeft("?", 10) + " "); else text.append(Utils.doubleToString(inst.classValue(), 10, 3) + " "); // predicted double pred = classifier.classifyInstance(inst); if (Instance.isMissingValue(pred)) text.append(Utils.padLeft("?", 10) + " "); else text.append(Utils.doubleToString(pred, 10, 3) + " "); // err if (!inst.classIsMissing() && !Instance.isMissingValue(pred)) text.append(Utils.doubleToString(pred - inst.classValue(), 10, 3)); } // additional Attributes if (m_OutputAdditionalAttributesRange != null) { text.append(" ("); boolean first = true; for (int i = 0; i < inst.numAttributes() - 1; i++) { if (m_OutputAdditionalAttributesRange.isInRange(i)) { if (!first) text.append(","); else first = false; text.append(inst.toString(i)); } } text.append(")"); } text.append("\n"); return text.toString(); } /** * Handles constructing a popup menu with visualization options. * @param name the name of the result history list entry clicked on by * the user * @param x the x coordinate for popping up the menu * @param y the y coordinate for popping up the menu */ protected void visualize(String name, int x, int y) { final String selectedName = name; JPopupMenu resultListMenu = new JPopupMenu(); JMenuItem visMainBuffer = new JMenuItem("View in main window"); if (selectedName != null) { visMainBuffer.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { m_History.setSingle(selectedName); } }); } else { visMainBuffer.setEnabled(false); } resultListMenu.add(visMainBuffer); JMenuItem visSepBuffer = new JMenuItem("View in separate window"); if (selectedName != null) { visSepBuffer.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { m_History.openFrame(selectedName); } }); } else { visSepBuffer.setEnabled(false); } resultListMenu.add(visSepBuffer); JMenuItem saveOutput = new JMenuItem("Save result buffer"); if (selectedName != null) { saveOutput.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { saveBuffer(selectedName); } }); } else { saveOutput.setEnabled(false); } resultListMenu.add(saveOutput); JMenuItem deleteOutput = new JMenuItem("Delete result buffer"); if (selectedName != null) { deleteOutput.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { m_History.removeResult(selectedName); } }); } else { deleteOutput.setEnabled(false); } resultListMenu.add(deleteOutput); resultListMenu.addSeparator(); JMenuItem loadModel = new JMenuItem("Load model"); loadModel.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { loadClassifier(); } }); resultListMenu.add(loadModel); FastVector o = null; if (selectedName != null) { o = (FastVector)m_History.getNamedObject(selectedName); } VisualizePanel temp_vp = null; String temp_grph = null; FastVector temp_preds = null; Attribute temp_classAtt = null; Classifier temp_classifier = null; Instances temp_trainHeader = null; if (o != null) { for (int i = 0; i < o.size(); i++) { Object temp = o.elementAt(i); if (temp instanceof Classifier) { temp_classifier = (Classifier)temp; } else if (temp instanceof Instances) { // training header temp_trainHeader = (Instances)temp; } else if (temp instanceof VisualizePanel) { // normal errors temp_vp = (VisualizePanel)temp; } else if (temp instanceof String) { // graphable output temp_grph = (String)temp; } else if (temp instanceof FastVector) { // predictions temp_preds = (FastVector)temp; } else if (temp instanceof Attribute) { // class attribute temp_classAtt = (Attribute)temp; } } } final VisualizePanel vp = temp_vp; final String grph = temp_grph; final FastVector preds = temp_preds; final Attribute classAtt = temp_classAtt; final Classifier classifier = temp_classifier; final Instances trainHeader = temp_trainHeader; JMenuItem saveModel = new JMenuItem("Save model"); if (classifier != null) { saveModel.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { saveClassifier(selectedName, classifier, trainHeader); } }); } else { saveModel.setEnabled(false); } resultListMenu.add(saveModel); JMenuItem reEvaluate = new JMenuItem("Re-evaluate model on current test set"); if (classifier != null && m_TestLoader != null) { reEvaluate.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { reevaluateModel(selectedName, classifier, trainHeader); } }); } else { reEvaluate.setEnabled(false); } resultListMenu.add(reEvaluate); resultListMenu.addSeparator(); JMenuItem visErrors = new JMenuItem("Visualize classifier errors"); if (vp != null) { visErrors.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { visualizeClassifierErrors(vp); } }); } else { visErrors.setEnabled(false); } resultListMenu.add(visErrors); JMenuItem visGrph = new JMenuItem("Visualize tree"); if (grph != null) { if(((Drawable)temp_classifier).graphType()==Drawable.TREE) { visGrph.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { String title; if (vp != null) title = vp.getName(); else title = selectedName; visualizeTree(grph, title); } }); } else if(((Drawable)temp_classifier).graphType()==Drawable.BayesNet) { visGrph.setText("Visualize graph"); visGrph.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { Thread th = new Thread() { public void run() { visualizeBayesNet(grph, selectedName); } }; th.start(); } }); } else visGrph.setEnabled(false); } else { visGrph.setEnabled(false); } resultListMenu.add(visGrph); JMenuItem visMargin = new JMenuItem("Visualize margin curve"); if (preds != null) { visMargin.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { try { MarginCurve tc = new MarginCurve(); Instances result = tc.getCurve(preds); VisualizePanel vmc = new VisualizePanel(); vmc.setName(result.relationName()); vmc.setLog(m_Log); PlotData2D tempd = new PlotData2D(result); tempd.setPlotName(result.relationName()); tempd.addInstanceNumberAttribute(); vmc.addPlot(tempd); visualizeClassifierErrors(vmc); } catch (Exception ex) { ex.printStackTrace(); } } }); } else { visMargin.setEnabled(false); } resultListMenu.add(visMargin); JMenu visThreshold = new JMenu("Visualize threshold curve"); if (preds != null && classAtt != null) { for (int i = 0; i < classAtt.numValues(); i++) { JMenuItem clv = new JMenuItem(classAtt.value(i)); final int classValue = i; clv.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { try { ThresholdCurve tc = new ThresholdCurve(); Instances result = tc.getCurve(preds, classValue); //VisualizePanel vmc = new VisualizePanel(); ThresholdVisualizePanel vmc = new ThresholdVisualizePanel(); vmc.setROCString("(Area under ROC = " + Utils.doubleToString(ThresholdCurve.getROCArea(result), 4) + ")"); vmc.setLog(m_Log); vmc.setName(result.relationName()+". (Class value "+ classAtt.value(classValue)+")"); PlotData2D tempd = new PlotData2D(result); tempd.setPlotName(result.relationName()); tempd.addInstanceNumberAttribute(); vmc.addPlot(tempd); visualizeClassifierErrors(vmc); } catch (Exception ex) { ex.printStackTrace(); } } }); visThreshold.add(clv); } } else { visThreshold.setEnabled(false); } resultListMenu.add(visThreshold); JMenu visCost = new JMenu("Visualize cost curve"); if (preds != null && classAtt != null) { for (int i = 0; i < classAtt.numValues(); i++) { JMenuItem clv = new JMenuItem(classAtt.value(i)); final int classValue = i; clv.addActionListener(new ActionListener() { public void actionPerformed(ActionEvent e) { try { CostCurve cc = new CostCurve(); Instances result = cc.getCurve(preds, classValue); VisualizePanel vmc = new VisualizePanel(); vmc.setLog(m_Log); vmc.setName(result.relationName()+". (Class value "+ classAtt.value(classValue)+")"); PlotData2D tempd = new PlotData2D(result); tempd.m_displayAllPoints = true; tempd.setPlotName(result.relationName()); boolean [] connectPoints = new boolean [result.numInstances()]; for (int jj = 1; jj < connectPoints.length; jj+=2) { connectPoints[jj] = true; } tempd.setConnectPoints(connectPoints); // tempd.addInstanceNumberAttribute(); vmc.addPlot(tempd); visualizeClassifierErrors(vmc); } catch (Exception ex) { ex.printStackTrace(); } } }); visCost.add(clv); } } else { visCost.setEnabled(false); } resultListMenu.add(visCost); JMenu visPlugins = new JMenu("Plugins"); Vector pluginsVector = GenericObjectEditor.getClassnames(VisualizePlugin.class.getName()); boolean availablePlugins = false; for (int i=0; i<pluginsVector.size(); i++) { String className = (String)(pluginsVector.elementAt(i)); try { VisualizePlugin plugin = (VisualizePlugin) Class.forName(className).newInstance(); if (plugin == n
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?