📄 classifierpanel.java
字号:
} catch (Exception ex) { ex.printStackTrace(); } } sp.addPropertyChangeListener(new PropertyChangeListener() { public void propertyChange(PropertyChangeEvent e) { m_TestLoader = sp.getLoader(); } }); // Add propertychangelistener to update m_TestLoader whenever // it changes in the settestframe m_SetTestFrame = new JFrame("Test Instances"); sp.setParentFrame(m_SetTestFrame); // enable Close-Button m_SetTestFrame.getContentPane().setLayout(new BorderLayout()); m_SetTestFrame.getContentPane().add(sp, BorderLayout.CENTER); m_SetTestFrame.pack(); } m_SetTestFrame.setVisible(true); } /** * Process a classifier's prediction for an instance and update a * set of plotting instances and additional plotting info. plotInfo * for nominal class datasets holds shape types (actual data points have * automatic shape type assignment; classifier error data points have * box shape type). For numeric class datasets, the actual data points * are stored in plotInstances and plotInfo stores the error (which is * later converted to shape size values) * @param toPredict the actual data point * @param classifier the classifier * @param eval the evaluation object to use for evaluating the classifier on * the instance to predict * @param plotInstances a set of plottable instances * @param plotShape additional plotting information (shape) * @param plotSize additional plotting information (size) */ public static void processClassifierPrediction(Instance toPredict, Classifier classifier, Evaluation eval, Instances plotInstances, FastVector plotShape, FastVector plotSize) { try { double pred = eval.evaluateModelOnceAndRecordPrediction(classifier, toPredict); double [] values = new double[plotInstances.numAttributes()]; for (int i = 0; i < plotInstances.numAttributes(); i++) { if (i < toPredict.classIndex()) { values[i] = toPredict.value(i); } else if (i == toPredict.classIndex()) { values[i] = pred; values[i+1] = toPredict.value(i); /* // if the class value of the instances to predict is missing then // set it to the predicted value if (toPredict.isMissing(i)) { values[i+1] = pred; } */ i++; } else { values[i] = toPredict.value(i-1); } } plotInstances.add(new Instance(1.0, values)); if (toPredict.classAttribute().isNominal()) { if (toPredict.isMissing(toPredict.classIndex()) || Instance.isMissingValue(pred)) { plotShape.addElement(new Integer(Plot2D.MISSING_SHAPE)); } else if (pred != toPredict.classValue()) { // set to default error point shape plotShape.addElement(new Integer(Plot2D.ERROR_SHAPE)); } else { // otherwise set to constant (automatically assigned) point shape plotShape.addElement(new Integer(Plot2D.CONST_AUTOMATIC_SHAPE)); } plotSize.addElement(new Integer(Plot2D.DEFAULT_SHAPE_SIZE)); } else { // store the error (to be converted to a point size later) Double errd = null; if (!toPredict.isMissing(toPredict.classIndex()) && !Instance.isMissingValue(pred)) { errd = new Double(pred - toPredict.classValue()); plotShape.addElement(new Integer(Plot2D.CONST_AUTOMATIC_SHAPE)); } else { // missing shape if actual class not present or prediction is missing plotShape.addElement(new Integer(Plot2D.MISSING_SHAPE)); } plotSize.addElement(errd); } } catch (Exception ex) { ex.printStackTrace(); } } /** * Post processes numeric class errors into shape sizes for plotting * in the visualize panel * @param plotSize a FastVector of numeric class errors */ private void postProcessPlotInfo(FastVector plotSize) { int maxpSize = 20; double maxErr = Double.NEGATIVE_INFINITY; double minErr = Double.POSITIVE_INFINITY; double err; for (int i = 0; i < plotSize.size(); i++) { Double errd = (Double)plotSize.elementAt(i); if (errd != null) { err = Math.abs(errd.doubleValue()); if (err < minErr) { minErr = err; } if (err > maxErr) { maxErr = err; } } } for (int i = 0; i < plotSize.size(); i++) { Double errd = (Double)plotSize.elementAt(i); if (errd != null) { err = Math.abs(errd.doubleValue()); if (maxErr - minErr > 0) { double temp = (((err - minErr) / (maxErr - minErr)) * maxpSize); plotSize.setElementAt(new Integer((int)temp), i); } else { plotSize.setElementAt(new Integer(1), i); } } else { plotSize.setElementAt(new Integer(1), i); } } } /** * Sets up the structure for the visualizable instances. This dataset * contains the original attributes plus the classifier's predictions * for the class as an attribute called "predicted+WhateverTheClassIsCalled". * @param trainInstances the instances that the classifier is trained on * @return a new set of instances containing one more attribute (predicted * class) than the trainInstances */ public static Instances setUpVisualizableInstances(Instances trainInstances) { FastVector hv = new FastVector(); Attribute predictedClass; Attribute classAt = trainInstances.attribute(trainInstances.classIndex()); if (classAt.isNominal()) { FastVector attVals = new FastVector(); for (int i = 0; i < classAt.numValues(); i++) { attVals.addElement(classAt.value(i)); } predictedClass = new Attribute("predicted"+classAt.name(), attVals); } else { predictedClass = new Attribute("predicted"+classAt.name()); } for (int i = 0; i < trainInstances.numAttributes(); i++) { if (i == trainInstances.classIndex()) { hv.addElement(predictedClass); } hv.addElement(trainInstances.attribute(i).copy()); } return new Instances(trainInstances.relationName()+"_predicted", hv, trainInstances.numInstances()); } /** * Starts running the currently configured classifier with the current * settings. This is run in a separate thread, and will only start if * there is no classifier already running. The classifier output is sent * to the results history panel. */ protected void startClassifier() { if (m_RunThread == null) { synchronized (this) { m_StartBut.setEnabled(false); m_StopBut.setEnabled(true); } m_RunThread = new Thread() { public void run() { // Copy the current state of things m_Log.statusMessage("Setting up..."); CostMatrix costMatrix = null; Instances inst = new Instances(m_Instances); DataSource source = null; Instances userTestStructure = null; // additional vis info (either shape type or point size) FastVector plotShape = new FastVector(); FastVector plotSize = new FastVector(); Instances predInstances = null; // for timing long trainTimeStart = 0, trainTimeElapsed = 0; try { if (m_TestLoader != null && m_TestLoader.getStructure() != null) { m_TestLoader.reset(); source = new DataSource(m_TestLoader); userTestStructure = source.getStructure(); } } catch (Exception ex) { ex.printStackTrace(); } if (m_EvalWRTCostsBut.isSelected()) { costMatrix = new CostMatrix((CostMatrix) m_CostMatrixEditor .getValue()); } boolean outputModel = m_OutputModelBut.isSelected(); boolean outputConfusion = m_OutputConfusionBut.isSelected(); boolean outputPerClass = m_OutputPerClassBut.isSelected(); boolean outputSummary = true; boolean outputEntropy = m_OutputEntropyBut.isSelected(); boolean saveVis = m_StorePredictionsBut.isSelected(); boolean outputPredictionsText = m_OutputPredictionsTextBut.isSelected(); String grph = null; int testMode = 0; int numFolds = 10, percent = 66; int classIndex = m_ClassCombo.getSelectedIndex(); Classifier classifier = (Classifier) m_ClassifierEditor.getValue(); Classifier template = null; try { template = Classifier.makeCopy(classifier); } catch (Exception ex) { m_Log.logMessage("Problem copying classifier: " + ex.getMessage()); } Classifier fullClassifier = null; StringBuffer outBuff = new StringBuffer(); String name = (new SimpleDateFormat("HH:mm:ss - ")) .format(new Date()); String cname = classifier.getClass().getName(); if (cname.startsWith("weka.classifiers.")) { name += cname.substring("weka.classifiers.".length()); } else { name += cname; } String cmd = m_ClassifierEditor.getValue().getClass().getName(); if (m_ClassifierEditor.getValue() instanceof OptionHandler) cmd += " " + Utils.joinOptions(((OptionHandler) m_ClassifierEditor.getValue()).getOptions()); Evaluation eval = null; try { if (m_CVBut.isSelected()) { testMode = 1; numFolds = Integer.parseInt(m_CVText.getText()); if (numFolds <= 1) { throw new Exception("Number of folds must be greater than 1"); } } else if (m_PercentBut.isSelected()) { testMode = 2; percent = Integer.parseInt(m_PercentText.getText()); if ((percent <= 0) || (percent >= 100)) { throw new Exception("Percentage must be between 0 and 100"); } } else if (m_TrainBut.isSelected()) { testMode = 3; } else if (m_TestSplitBut.isSelected()) { testMode = 4; // Check the test instance compatibility if (source == null) { throw new Exception("No user test set has been specified"); } if (!inst.equalHeaders(userTestStructure)) { throw new Exception("Train and test set are not compatible"); } userTestStructure.setClassIndex(classIndex); } else { throw new Exception("Unknown test mode"); } inst.setClassIndex(classIndex); // set up the structure of the plottable instances for // visualization predInstances = setUpVisualizableInstances(inst); predInstances.setClassIndex(inst.classIndex()+1); // Output some header information m_Log.logMessage("Started " + cname); m_Log.logMessage("Command: " + cmd); if (m_Log instanceof TaskLogger) { ((TaskLogger)m_Log).taskStarted(); } outBuff.append("=== Run information ===\n\n"); outBuff.append("Scheme: " + cname); if (classifier instanceof OptionHandler) { String [] o = ((OptionHandler) classifier).getOptions(); outBuff.append(" " + Utils.joinOptions(o)); } outBuff.append("\n"); outBuff.append("Relation: " + inst.relationName() + '\n'); outBuff.append("Instances: " + inst.numInstances() + '\n'); outBuff.append("Attributes: " + inst.numAttributes() + '\n'); if (inst.numAttributes() < 100) { for (int i = 0; i < inst.numAttributes(); i++) { outBuff.append(" " + inst.attribute(i).name() + '\n'); } } else { outBuff.append(" [list of attributes omitted]\n"); } outBuff.append("Test mode: "); switch (testMode) { case 3: // Test on training outBuff.append("evaluate on training data\n"); break; case 1: // CV mode outBuff.append("" + numFolds + "-fold cross-validation\n"); break; case 2: // Percent split outBuff.append("split " + percent + "% train, remainder test\n"); break; case 4: // Test on user split if (source.isIncremental()) outBuff.append("user supplied test set: " + " size unknown (reading incrementally)\n"); else outBuff.append("user supplied test set: " + source.getDataSet().numInstances() + " instances\n"); break; } if (costMatrix != null) { outBuff.append("Evaluation cost matrix:\n") .append(costMatrix.toString()).append("\n"); } outBuff.append("\n"); m_History.addResult(name, outBuff); m_History.setSingle(name); // Build the model and output it. if (outputModel || (testMode == 3) || (testMode == 4)) { m_Log.statusMessage("Building model on training data..."); trainTimeStart = System.currentTimeMillis(); classifier.buildClassifier(inst); trainTimeElapsed = System.currentTimeMillis() - trainTimeStart; } if (outputModel) { outBuff.append("=== Classifier model (full training set) ===\n\n"); outBuff.append(classifier.toString() + "\n"); outBuff.append("\nTime taken to build model: " + Utils.doubleToString(trainTimeElapsed / 1000.0,2) + " seconds\n\n"); m_History.updateResult(name); if (classifier instanceof Drawable) { grph = null; try { grph = ((Drawable)classifier).graph(); } catch (Exception ex) { } } // copy full model for output SerializedObject so = new SerializedObject(classifier); fullClassifier = (Classifier) so.getObject(); } switch (testMode) { case 3: // Test on training m_Log.statusMessage("Evaluating on training data..."); eval = new Evaluation(inst, costMatrix); if (outputPredictionsText) { outBuff.append("=== Predictions on training set ===\n\n"); outBuff.append(" inst#, actual, predicted, error"); if (inst.classAttribute().isNominal()) { outBuff.append(", probability distribution"); } outBuff.append("\n"); } for (int jj=0;jj<inst.numInstances();jj++) { processClassifierPrediction(inst.instance(jj), classifier, eval, predInstances, plotShape, plotSize); if (outputPredictionsText) { outBuff.append(predictionText(classifier, inst.instance(jj), jj+1));
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -