classifierpanel.java
来自「Weka」· Java 代码 · 共 1,867 行 · 第 1/5 页
JAVA
1,867 行
throw new Exception("No user test set has been specified"); } if (!inst.equalHeaders(userTestStructure)) { throw new Exception("Train and test set are not compatible"); } userTestStructure.setClassIndex(classIndex); } else { throw new Exception("Unknown test mode"); } inst.setClassIndex(classIndex); // set up the structure of the plottable instances for // visualization if (saveVis) { predInstances = setUpVisualizableInstances(inst); predInstances.setClassIndex(inst.classIndex()+1); } // Output some header information m_Log.logMessage("Started " + cname); m_Log.logMessage("Command: " + cmd); if (m_Log instanceof TaskLogger) { ((TaskLogger)m_Log).taskStarted(); } outBuff.append("=== Run information ===\n\n"); outBuff.append("Scheme: " + cname); if (classifier instanceof OptionHandler) { String [] o = ((OptionHandler) classifier).getOptions(); outBuff.append(" " + Utils.joinOptions(o)); } outBuff.append("\n"); outBuff.append("Relation: " + inst.relationName() + '\n'); outBuff.append("Instances: " + inst.numInstances() + '\n'); outBuff.append("Attributes: " + inst.numAttributes() + '\n'); if (inst.numAttributes() < 100) { for (int i = 0; i < inst.numAttributes(); i++) { outBuff.append(" " + inst.attribute(i).name() + '\n'); } } else { outBuff.append(" [list of attributes omitted]\n"); } outBuff.append("Test mode: "); switch (testMode) { case 3: // Test on training outBuff.append("evaluate on training data\n"); break; case 1: // CV mode outBuff.append("" + numFolds + "-fold cross-validation\n"); break; case 2: // Percent split outBuff.append("split " + percent + "% train, remainder test\n"); break; case 4: // Test on user split if (source.isIncremental()) outBuff.append("user supplied test set: " + " size unknown (reading incrementally)\n"); else outBuff.append("user supplied test set: " + source.getDataSet().numInstances() + " instances\n"); break; } if (costMatrix != null) { outBuff.append("Evaluation cost matrix:\n") .append(costMatrix.toString()).append("\n"); } outBuff.append("\n"); m_History.addResult(name, outBuff); m_History.setSingle(name); // Build the model and output it. if (outputModel || (testMode == 3) || (testMode == 4)) { m_Log.statusMessage("Building model on training data..."); trainTimeStart = System.currentTimeMillis(); classifier.buildClassifier(inst); trainTimeElapsed = System.currentTimeMillis() - trainTimeStart; } if (outputModel) { outBuff.append("=== Classifier model (full training set) ===\n\n"); outBuff.append(classifier.toString() + "\n"); outBuff.append("\nTime taken to build model: " + Utils.doubleToString(trainTimeElapsed / 1000.0,2) + " seconds\n\n"); m_History.updateResult(name); if (classifier instanceof Drawable) { grph = null; try { grph = ((Drawable)classifier).graph(); } catch (Exception ex) { } } // copy full model for output SerializedObject so = new SerializedObject(classifier); fullClassifier = (Classifier) so.getObject(); } switch (testMode) { case 3: // Test on training m_Log.statusMessage("Evaluating on training data..."); eval = new Evaluation(inst, costMatrix); if (outputPredictionsText) { printPredictionsHeader(outBuff, inst, "training set"); } for (int jj=0;jj<inst.numInstances();jj++) { processClassifierPrediction(inst.instance(jj), classifier, eval, predInstances, plotShape, plotSize); if (outputPredictionsText) { outBuff.append(predictionText(classifier, inst.instance(jj), jj+1)); } if ((jj % 100) == 0) { m_Log.statusMessage("Evaluating on training data. Processed " +jj+" instances..."); } } if (outputPredictionsText) { outBuff.append("\n"); } outBuff.append("=== Evaluation on training set ===\n"); break; case 1: // CV mode m_Log.statusMessage("Randomizing instances..."); int rnd = 1; try { rnd = Integer.parseInt(m_RandomSeedText.getText().trim()); // System.err.println("Using random seed "+rnd); } catch (Exception ex) { m_Log.logMessage("Trouble parsing random seed value"); rnd = 1; } Random random = new Random(rnd); inst.randomize(random); if (inst.attribute(classIndex).isNominal()) { m_Log.statusMessage("Stratifying instances..."); inst.stratify(numFolds); } eval = new Evaluation(inst, costMatrix); if (outputPredictionsText) { printPredictionsHeader(outBuff, inst, "test data"); } // Make some splits and do a CV for (int fold = 0; fold < numFolds; fold++) { m_Log.statusMessage("Creating splits for fold " + (fold + 1) + "..."); Instances train = inst.trainCV(numFolds, fold, random); eval.setPriors(train); m_Log.statusMessage("Building model for fold " + (fold + 1) + "..."); Classifier current = null; try { current = Classifier.makeCopy(template); } catch (Exception ex) { m_Log.logMessage("Problem copying classifier: " + ex.getMessage()); } current.buildClassifier(train); Instances test = inst.testCV(numFolds, fold); m_Log.statusMessage("Evaluating model for fold " + (fold + 1) + "..."); for (int jj=0;jj<test.numInstances();jj++) { processClassifierPrediction(test.instance(jj), current, eval, predInstances, plotShape, plotSize); if (outputPredictionsText) { outBuff.append(predictionText(current, test.instance(jj), jj+1)); } } } if (outputPredictionsText) { outBuff.append("\n"); } if (inst.attribute(classIndex).isNominal()) { outBuff.append("=== Stratified cross-validation ===\n"); } else { outBuff.append("=== Cross-validation ===\n"); } break; case 2: // Percent split if (!m_PreserveOrderBut.isSelected()) { m_Log.statusMessage("Randomizing instances..."); try { rnd = Integer.parseInt(m_RandomSeedText.getText().trim()); } catch (Exception ex) { m_Log.logMessage("Trouble parsing random seed value"); rnd = 1; } inst.randomize(new Random(rnd)); } int trainSize = inst.numInstances() * percent / 100; int testSize = inst.numInstances() - trainSize; Instances train = new Instances(inst, 0, trainSize); Instances test = new Instances(inst, trainSize, testSize); m_Log.statusMessage("Building model on training split..."); Classifier current = null; try { current = Classifier.makeCopy(template); } catch (Exception ex) { m_Log.logMessage("Problem copying classifier: " + ex.getMessage()); } current.buildClassifier(train); eval = new Evaluation(train, costMatrix); m_Log.statusMessage("Evaluating on test split..."); if (outputPredictionsText) { printPredictionsHeader(outBuff, inst, "test split"); } for (int jj=0;jj<test.numInstances();jj++) { processClassifierPrediction(test.instance(jj), current, eval, predInstances, plotShape, plotSize); if (outputPredictionsText) { outBuff.append(predictionText(current, test.instance(jj), jj+1)); } if ((jj % 100) == 0) { m_Log.statusMessage("Evaluating on test split. Processed " +jj+" instances..."); } } if (outputPredictionsText) { outBuff.append("\n"); } outBuff.append("=== Evaluation on test split ===\n"); break; case 4: // Test on user split m_Log.statusMessage("Evaluating on test data..."); eval = new Evaluation(inst, costMatrix); if (outputPredictionsText) { printPredictionsHeader(outBuff, inst, "test set"); } Instance instance; int jj = 0; while (source.hasMoreElements(userTestStructure)) { instance = source.nextElement(userTestStructure); processClassifierPrediction(instance, classifier, eval, predInstances, plotShape, plotSize); if (outputPredictionsText) { outBuff.append(predictionText(classifier, instance, jj+1)); } if ((++jj % 100) == 0) { m_Log.statusMessage("Evaluating on test data. Processed " +jj+" instances..."); } } if (outputPredictionsText) { outBuff.append("\n"); } outBuff.append("=== Evaluation on test set ===\n"); break; default: throw new Exception("Test mode not implemented"); } if (outputSummary) { outBuff.append(eval.toSummaryString(outputEntropy) + "\n"); } if (inst.attribute(classIndex).isNominal()) { if (outputPerClass) { outBuff.append(eval.toClassDetailsString() + "\n"); } if (outputConfusion) { outBuff.append(eval.toMatrixString() + "\n"); } } if ( (fullClassifier instanceof Sourcable) && m_OutputSourceCode.isSelected()) { outBuff.append("=== Source code ===\n\n"); outBuff.append( Evaluation.wekaStaticWrapper( ((Sourcable) fullClassifier), m_SourceCodeClass.getText())); } m_History.updateResult(name); m_Log.logMessage("Finished " + cname); m_Log.statusMessage("OK"); } catch (Exception ex) { ex.printStackTrace(); m_Log.logMessage(ex.getMessage()); JOptionPane.showMessageDialog(ClassifierPanel.this, "Problem evaluating classifier:\n" + ex.getMessage(), "Evaluate classifier", JOptionPane.ERROR_MESSAGE); m_Log.statusMessage("Problem evaluating classifier"); } finally { try { if (!saveVis && outputModel) { FastVector vv = new FastVector(); vv.addElement(fullClassifier); Instances trainHeader = new Instances(m_Instances, 0); trainHeader.setClassIndex(classIndex); vv.addElement(trainHeader); if (grph != null) { vv.addElement(grph); } m_History.addObject(name, vv); } else if (saveVis && predInstances != null && predInstances.numInstances() > 0) { if (predInstances.attribute(predInstances.classIndex()) .isNumeric()) { postProcessPlotInfo(plotSize); } m_CurrentVis = new VisualizePanel(); m_CurrentVis.setName(name+" ("+inst.relationName()+")"); m_CurrentVis.setLog(m_Log); PlotData2D tempd = new PlotData2D(predInstances); tempd.setShapeSize(plotSize); tempd.setShapeType(plotShape); tempd.setPlotName(name+" ("+inst.relationName()+")"); tempd.addInstanceNumberAttribute(); m_CurrentVis.addPlot(tempd); m_CurrentVis.setColourIndex(predInstances.classIndex()+1); FastVector vv = new FastVector(); if (outputModel) { vv.addElement(fullClassifier); Instances trainHeader = new Instances(m_Instances, 0); trainHeader.setClassIndex(classIndex); vv.addElement(trainHeader); if (grph != null) { vv.addElement(grph); } } vv.addElement(m_CurrentVis); if ((eval != null) && (eval.predictions() != null)) { vv.addElement(eval.predictions()); vv.addElement(inst.classAttribute()); } m_History.addObject(name, vv); } } catch (Exception ex) { ex.printStackTrace(); } if (isInterrupted()) { m_Log.logMessage("Interrupted " + cname); m_Log.statusMessage("Interrupted"); } synchronized (this) { m_StartBut.setEnabled(true); m_StopBut.setEnabled(false); m_RunThread = null; } if (m_Log instanceof TaskLogger) { ((TaskLogger)m_Log).taskFinished(); } } } }; m_RunThread.setPriority(Thread.MIN_PRIORITY);
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?