classifierpanel.java
来自「Weka」· Java 代码 · 共 1,867 行 · 第 1/5 页
JAVA
1,867 行
* Sets the Logger to receive informational messages * * @param newLog the Logger that will now get info messages */ public void setLog(Logger newLog) { m_Log = newLog; } /** * Tells the panel to use a new set of instances. * * @param inst a set of Instances */ public void setInstances(Instances inst) { m_Instances = inst; String [] attribNames = new String [m_Instances.numAttributes()]; for (int i = 0; i < attribNames.length; i++) { String type = ""; switch (m_Instances.attribute(i).type()) { case Attribute.NOMINAL: type = "(Nom) "; break; case Attribute.NUMERIC: type = "(Num) "; break; case Attribute.STRING: type = "(Str) "; break; case Attribute.DATE: type = "(Dat) "; break; case Attribute.RELATIONAL: type = "(Rel) "; break; default: type = "(???) "; } attribNames[i] = type + m_Instances.attribute(i).name(); } m_ClassCombo.setModel(new DefaultComboBoxModel(attribNames)); if (attribNames.length > 0) { if (inst.classIndex() == -1) m_ClassCombo.setSelectedIndex(attribNames.length - 1); else m_ClassCombo.setSelectedIndex(inst.classIndex()); m_ClassCombo.setEnabled(true); m_StartBut.setEnabled(m_RunThread == null); m_StopBut.setEnabled(m_RunThread != null); } else { m_StartBut.setEnabled(false); m_StopBut.setEnabled(false); } } /** * Sets the user test set. Information about the current test set * is displayed in an InstanceSummaryPanel and the user is given the * ability to load another set from a file or url. * */ protected void setTestSet() { if (m_SetTestFrame == null) { final SetInstancesPanel sp = new SetInstancesPanel(); if (m_TestLoader != null) { try { if (m_TestLoader.getStructure() != null) sp.setInstances(m_TestLoader.getStructure()); } catch (Exception ex) { ex.printStackTrace(); } } sp.addPropertyChangeListener(new PropertyChangeListener() { public void propertyChange(PropertyChangeEvent e) { m_TestLoader = sp.getLoader(); } }); // Add propertychangelistener to update m_TestLoader whenever // it changes in the settestframe m_SetTestFrame = new JFrame("Test Instances"); sp.setParentFrame(m_SetTestFrame); // enable Close-Button m_SetTestFrame.getContentPane().setLayout(new BorderLayout()); m_SetTestFrame.getContentPane().add(sp, BorderLayout.CENTER); m_SetTestFrame.pack(); } m_SetTestFrame.setVisible(true); } /** * Process a classifier's prediction for an instance and update a * set of plotting instances and additional plotting info. plotInfo * for nominal class datasets holds shape types (actual data points have * automatic shape type assignment; classifier error data points have * box shape type). For numeric class datasets, the actual data points * are stored in plotInstances and plotInfo stores the error (which is * later converted to shape size values) * @param toPredict the actual data point * @param classifier the classifier * @param eval the evaluation object to use for evaluating the classifier on * the instance to predict * @param plotInstances a set of plottable instances * @param plotShape additional plotting information (shape) * @param plotSize additional plotting information (size) */ public static void processClassifierPrediction(Instance toPredict, Classifier classifier, Evaluation eval, Instances plotInstances, FastVector plotShape, FastVector plotSize) { try { double pred = eval.evaluateModelOnceAndRecordPrediction(classifier, toPredict); if (plotInstances != null) { double [] values = new double[plotInstances.numAttributes()]; for (int i = 0; i < plotInstances.numAttributes(); i++) { if (i < toPredict.classIndex()) { values[i] = toPredict.value(i); } else if (i == toPredict.classIndex()) { values[i] = pred; values[i+1] = toPredict.value(i); /* // if the class value of the instances to predict is missing then // set it to the predicted value if (toPredict.isMissing(i)) { values[i+1] = pred; } */ i++; } else { values[i] = toPredict.value(i-1); } } plotInstances.add(new Instance(1.0, values)); if (toPredict.classAttribute().isNominal()) { if (toPredict.isMissing(toPredict.classIndex()) || Instance.isMissingValue(pred)) { plotShape.addElement(new Integer(Plot2D.MISSING_SHAPE)); } else if (pred != toPredict.classValue()) { // set to default error point shape plotShape.addElement(new Integer(Plot2D.ERROR_SHAPE)); } else { // otherwise set to constant (automatically assigned) point shape plotShape.addElement(new Integer(Plot2D.CONST_AUTOMATIC_SHAPE)); } plotSize.addElement(new Integer(Plot2D.DEFAULT_SHAPE_SIZE)); } else { // store the error (to be converted to a point size later) Double errd = null; if (!toPredict.isMissing(toPredict.classIndex()) && !Instance.isMissingValue(pred)) { errd = new Double(pred - toPredict.classValue()); plotShape.addElement(new Integer(Plot2D.CONST_AUTOMATIC_SHAPE)); } else { // missing shape if actual class not present or prediction is missing plotShape.addElement(new Integer(Plot2D.MISSING_SHAPE)); } plotSize.addElement(errd); } } } catch (Exception ex) { ex.printStackTrace(); } } /** * Post processes numeric class errors into shape sizes for plotting * in the visualize panel * @param plotSize a FastVector of numeric class errors */ private void postProcessPlotInfo(FastVector plotSize) { int maxpSize = 20; double maxErr = Double.NEGATIVE_INFINITY; double minErr = Double.POSITIVE_INFINITY; double err; for (int i = 0; i < plotSize.size(); i++) { Double errd = (Double)plotSize.elementAt(i); if (errd != null) { err = Math.abs(errd.doubleValue()); if (err < minErr) { minErr = err; } if (err > maxErr) { maxErr = err; } } } for (int i = 0; i < plotSize.size(); i++) { Double errd = (Double)plotSize.elementAt(i); if (errd != null) { err = Math.abs(errd.doubleValue()); if (maxErr - minErr > 0) { double temp = (((err - minErr) / (maxErr - minErr)) * maxpSize); plotSize.setElementAt(new Integer((int)temp), i); } else { plotSize.setElementAt(new Integer(1), i); } } else { plotSize.setElementAt(new Integer(1), i); } } } /** * Sets up the structure for the visualizable instances. This dataset * contains the original attributes plus the classifier's predictions * for the class as an attribute called "predicted+WhateverTheClassIsCalled". * @param trainInstances the instances that the classifier is trained on * @return a new set of instances containing one more attribute (predicted * class) than the trainInstances */ public static Instances setUpVisualizableInstances(Instances trainInstances) { FastVector hv = new FastVector(); Attribute predictedClass; Attribute classAt = trainInstances.attribute(trainInstances.classIndex()); if (classAt.isNominal()) { FastVector attVals = new FastVector(); for (int i = 0; i < classAt.numValues(); i++) { attVals.addElement(classAt.value(i)); } predictedClass = new Attribute("predicted"+classAt.name(), attVals); } else { predictedClass = new Attribute("predicted"+classAt.name()); } for (int i = 0; i < trainInstances.numAttributes(); i++) { if (i == trainInstances.classIndex()) { hv.addElement(predictedClass); } hv.addElement(trainInstances.attribute(i).copy()); } return new Instances(trainInstances.relationName()+"_predicted", hv, trainInstances.numInstances()); } /** * outputs the header for the predictions on the data * * @param outBuff the buffer to add the output to * @param inst the data header * @param title the title to print */ protected void printPredictionsHeader(StringBuffer outBuff, Instances inst, String title) { outBuff.append("=== Predictions on " + title + " ===\n\n"); outBuff.append(" inst#, actual, predicted, error"); if (inst.classAttribute().isNominal()) { outBuff.append(", probability distribution"); } if (m_OutputAdditionalAttributesRange != null) { outBuff.append(" ("); boolean first = true; for (int i = 0; i < inst.numAttributes() - 1; i++) { if (m_OutputAdditionalAttributesRange.isInRange(i)) { if (!first) outBuff.append(","); else first = false; outBuff.append(inst.attribute(i).name()); } } outBuff.append(")"); } outBuff.append("\n"); } /** * Starts running the currently configured classifier with the current * settings. This is run in a separate thread, and will only start if * there is no classifier already running. The classifier output is sent * to the results history panel. */ protected void startClassifier() { if (m_RunThread == null) { synchronized (this) { m_StartBut.setEnabled(false); m_StopBut.setEnabled(true); } m_RunThread = new Thread() { public void run() { // Copy the current state of things m_Log.statusMessage("Setting up..."); CostMatrix costMatrix = null; Instances inst = new Instances(m_Instances); DataSource source = null; Instances userTestStructure = null; // additional vis info (either shape type or point size) FastVector plotShape = new FastVector(); FastVector plotSize = new FastVector(); Instances predInstances = null; // for timing long trainTimeStart = 0, trainTimeElapsed = 0; try { if (m_TestLoader != null && m_TestLoader.getStructure() != null) { m_TestLoader.reset(); source = new DataSource(m_TestLoader); userTestStructure = source.getStructure(); } } catch (Exception ex) { ex.printStackTrace(); } if (m_EvalWRTCostsBut.isSelected()) { costMatrix = new CostMatrix((CostMatrix) m_CostMatrixEditor .getValue()); } boolean outputModel = m_OutputModelBut.isSelected(); boolean outputConfusion = m_OutputConfusionBut.isSelected(); boolean outputPerClass = m_OutputPerClassBut.isSelected(); boolean outputSummary = true; boolean outputEntropy = m_OutputEntropyBut.isSelected(); boolean saveVis = m_StorePredictionsBut.isSelected(); boolean outputPredictionsText = m_OutputPredictionsTextBut.isSelected(); if (m_OutputAdditionalAttributesText.getText().equals("")) { m_OutputAdditionalAttributesRange = null; } else { m_OutputAdditionalAttributesRange = new Range(m_OutputAdditionalAttributesText.getText()); m_OutputAdditionalAttributesRange.setUpper(inst.numAttributes() - 1); } String grph = null; int testMode = 0; int numFolds = 10, percent = 66; int classIndex = m_ClassCombo.getSelectedIndex(); Classifier classifier = (Classifier) m_ClassifierEditor.getValue(); Classifier template = null; try { template = Classifier.makeCopy(classifier); } catch (Exception ex) { m_Log.logMessage("Problem copying classifier: " + ex.getMessage()); } Classifier fullClassifier = null; StringBuffer outBuff = new StringBuffer(); String name = (new SimpleDateFormat("HH:mm:ss - ")) .format(new Date()); String cname = classifier.getClass().getName(); if (cname.startsWith("weka.classifiers.")) { name += cname.substring("weka.classifiers.".length()); } else { name += cname; } String cmd = m_ClassifierEditor.getValue().getClass().getName(); if (m_ClassifierEditor.getValue() instanceof OptionHandler) cmd += " " + Utils.joinOptions(((OptionHandler) m_ClassifierEditor.getValue()).getOptions()); Evaluation eval = null; try { if (m_CVBut.isSelected()) { testMode = 1; numFolds = Integer.parseInt(m_CVText.getText()); if (numFolds <= 1) { throw new Exception("Number of folds must be greater than 1"); } } else if (m_PercentBut.isSelected()) { testMode = 2; percent = Integer.parseInt(m_PercentText.getText()); if ((percent <= 0) || (percent >= 100)) { throw new Exception("Percentage must be between 0 and 100"); } } else if (m_TrainBut.isSelected()) { testMode = 3; } else if (m_TestSplitBut.isSelected()) { testMode = 4; // Check the test instance compatibility if (source == null) {
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?