📄 classifierpanel.java
字号:
m_SetTestFrame.setVisible(true); } /** * Process a classifier's prediction for an instance and update a * set of plotting instances and additional plotting info. plotInfo * for nominal class datasets holds shape types (actual data points have * automatic shape type assignment; classifer error data points have * box shape type). For numeric class datasets, the actual data points * are stored in plotInstances and plotInfo stores the error (which is * later converted to shape size values) * @param toPredict the actual data point * @param classifier the classifier * @param eval the evaluation object to use for evaluating the classifer on * the instance to predict * @param predictions a fastvector to add the prediction to * @param plotInstances a set of plottable instances * @param plotShape additional plotting information (shape) * @param plotSize additional plotting information (size) */ private void processClassifierPrediction(Instance toPredict, Classifier classifier, Evaluation eval, FastVector predictions, Instances plotInstances, FastVector plotShape, FastVector plotSize) { try { double pred; // classifier is a distribution classifer and class is nominal if (predictions != null) { Instance classMissing = (Instance)toPredict.copy(); classMissing.setDataset(toPredict.dataset()); classMissing.setClassMissing(); DistributionClassifier dc = (DistributionClassifier)classifier; double [] dist = dc.distributionForInstance(classMissing); pred = eval.evaluateModelOnce(dist, toPredict); int actual = (int)toPredict.classValue(); predictions.addElement(new NominalPrediction(actual, dist, toPredict.weight())); } else { pred = eval.evaluateModelOnce(classifier, toPredict); } double [] values = new double[plotInstances.numAttributes()]; for (int i = 0; i < plotInstances.numAttributes(); i++) { if (i < toPredict.classIndex()) { values[i] = toPredict.value(i); } else if (i == toPredict.classIndex()) { values[i] = pred; values[i+1] = toPredict.value(i); /* // if the class value of the instances to predict is missing then // set it to the predicted value if (toPredict.isMissing(i)) { values[i+1] = pred; } */ i++; } else { values[i] = toPredict.value(i-1); } } plotInstances.add(new Instance(1.0, values)); if (toPredict.classAttribute().isNominal()) { if (toPredict.isMissing(toPredict.classIndex())) { plotShape.addElement(new Integer(Plot2D.MISSING_SHAPE)); } else if (pred != toPredict.classValue()) { // set to default error point shape plotShape.addElement(new Integer(Plot2D.ERROR_SHAPE)); } else { // otherwise set to constant (automatically assigned) point shape plotShape.addElement(new Integer(Plot2D.CONST_AUTOMATIC_SHAPE)); } plotSize.addElement(new Integer(Plot2D.DEFAULT_SHAPE_SIZE)); } else { // store the error (to be converted to a point size later) Double errd = null; if (!toPredict.isMissing(toPredict.classIndex())) { errd = new Double(pred - toPredict.classValue()); plotShape.addElement(new Integer(Plot2D.CONST_AUTOMATIC_SHAPE)); } else { // missing shape if actual class not present plotShape.addElement(new Integer(Plot2D.MISSING_SHAPE)); } plotSize.addElement(errd); } } catch (Exception ex) { ex.printStackTrace(); } } /** * Post processes numeric class errors into shape sizes for plotting * in the visualize panel * @param plotSize a FastVector of numeric class errors */ private void postProcessPlotInfo(FastVector plotSize) { int maxpSize = 20; double maxErr = Double.NEGATIVE_INFINITY; double minErr = Double.POSITIVE_INFINITY; double err; for (int i = 0; i < plotSize.size(); i++) { Double errd = (Double)plotSize.elementAt(i); if (errd != null) { err = Math.abs(errd.doubleValue()); if (err < minErr) { minErr = err; } if (err > maxErr) { maxErr = err; } } } for (int i = 0; i < plotSize.size(); i++) { Double errd = (Double)plotSize.elementAt(i); if (errd != null) { err = Math.abs(errd.doubleValue()); if (maxErr - minErr > 0) { double temp = (((err - minErr) / (maxErr - minErr)) * maxpSize); plotSize.setElementAt(new Integer((int)temp), i); } else { plotSize.setElementAt(new Integer(1), i); } } else { plotSize.setElementAt(new Integer(1), i); } } } /** * Sets up the structure for the visualizable instances. This dataset * contains the original attributes plus the classifier's predictions * for the class as an attribute called "predicted+WhateverTheClassIsCalled". * @param trainInstancs the instances that the classifier is trained on * @return a new set of instances containing one more attribute (predicted * class) than the trainInstances */ private Instances setUpVisualizableInstances(Instances trainInstances) { FastVector hv = new FastVector(); Attribute predictedClass; Attribute classAt = trainInstances.attribute(trainInstances.classIndex()); if (classAt.isNominal()) { FastVector attVals = new FastVector(); for (int i = 0; i < classAt.numValues(); i++) { attVals.addElement(classAt.value(i)); } predictedClass = new Attribute("predicted"+classAt.name(), attVals); } else { predictedClass = new Attribute("predicted"+classAt.name()); } for (int i = 0; i < trainInstances.numAttributes(); i++) { if (i == trainInstances.classIndex()) { hv.addElement(predictedClass); } hv.addElement(trainInstances.attribute(i).copy()); } return new Instances(trainInstances.relationName()+"_predicted", hv, trainInstances.numInstances()); } /** * Starts running the currently configured classifier with the current * settings. This is run in a separate thread, and will only start if * there is no classifier already running. The classifier output is sent * to the results history panel. */ protected void startClassifier() { if (m_RunThread == null) { synchronized (this) { m_StartBut.setEnabled(false); m_StopBut.setEnabled(true); } m_RunThread = new Thread() { public void run() { // Copy the current state of things m_Log.statusMessage("Setting up..."); CostMatrix costMatrix = null; Instances inst = new Instances(m_Instances); Instances userTest = null; // additional vis info (either shape type or point size) FastVector plotShape = new FastVector(); FastVector plotSize = new FastVector(); Instances predInstances = null; // will hold the prediction objects if the class is nominal FastVector predictions = null; // for timing long trainTimeStart = 0, trainTimeElapsed = 0; if (m_TestInstances != null) { userTest = new Instances(m_TestInstancesCopy); } if (m_EvalWRTCostsBut.isSelected()) { costMatrix = new CostMatrix((CostMatrix) m_CostMatrixEditor .getValue()); } boolean outputModel = m_OutputModelBut.isSelected(); boolean outputConfusion = m_OutputConfusionBut.isSelected(); boolean outputPerClass = m_OutputPerClassBut.isSelected(); boolean outputSummary = true; boolean outputEntropy = m_OutputEntropyBut.isSelected(); boolean saveVis = m_StorePredictionsBut.isSelected(); boolean outputPredictionsText = m_OutputPredictionsTextBut.isSelected(); String grph = null; int testMode = 0; int numFolds = 10, percent = 66; int classIndex = m_ClassCombo.getSelectedIndex(); Classifier classifier = (Classifier) m_ClassifierEditor.getValue(); StringBuffer outBuff = new StringBuffer(); String name = (new SimpleDateFormat("HH:mm:ss - ")) .format(new Date()); String cname = classifier.getClass().getName(); if (cname.startsWith("weka.classifiers.")) { name += cname.substring("weka.classifiers.".length()); } else { name += cname; } try { if (m_CVBut.isSelected()) { testMode = 1; numFolds = Integer.parseInt(m_CVText.getText()); if (numFolds <= 1) { throw new Exception("Number of folds must be greater than 1"); } } else if (m_PercentBut.isSelected()) { testMode = 2; percent = Integer.parseInt(m_PercentText.getText()); if ((percent <= 0) || (percent >= 100)) { throw new Exception("Percentage must be between 0 and 100"); } } else if (m_TrainBut.isSelected()) { testMode = 3; } else if (m_TestSplitBut.isSelected()) { testMode = 4; // Check the test instance compatibility if (userTest == null) { throw new Exception("No user test set has been opened"); } if (!inst.equalHeaders(userTest)) { throw new Exception("Train and test set are not compatible"); } userTest.setClassIndex(classIndex); } else { throw new Exception("Unknown test mode"); } inst.setClassIndex(classIndex); // set up the structure of the plottable instances for // visualization predInstances = setUpVisualizableInstances(inst); predInstances.setClassIndex(inst.classIndex()+1); if (inst.classAttribute().isNominal() && classifier instanceof DistributionClassifier) { predictions = new FastVector(); } // Output some header information m_Log.logMessage("Started " + cname); if (m_Log instanceof TaskLogger) { ((TaskLogger)m_Log).taskStarted(); } outBuff.append("=== Run information ===\n\n"); outBuff.append("Scheme: " + cname); if (classifier instanceof OptionHandler) { String [] o = ((OptionHandler) classifier).getOptions(); outBuff.append(" " + Utils.joinOptions(o)); } outBuff.append("\n"); outBuff.append("Relation: " + inst.relationName() + '\n'); outBuff.append("Instances: " + inst.numInstances() + '\n'); outBuff.append("Attributes: " + inst.numAttributes() + '\n'); if (inst.numAttributes() < 100) { for (int i = 0; i < inst.numAttributes(); i++) { outBuff.append(" " + inst.attribute(i).name() + '\n'); } } else { outBuff.append(" [list of attributes omitted]\n"); } outBuff.append("Test mode: "); switch (testMode) { case 3: // Test on training outBuff.append("evaluate on training data\n"); break; case 1: // CV mode outBuff.append("" + numFolds + "-fold cross-validation\n"); break; case 2: // Percent split outBuff.append("split " + percent + "% train, remainder test\n"); break; case 4: // Test on user split outBuff.append("user supplied test set: " + userTest.numInstances() + " instances\n"); break; } if (costMatrix != null) { outBuff.append("Evaluation cost matrix:\n") .append(costMatrix.toString()).append("\n"); } outBuff.append("\n"); m_History.addResult(name, outBuff); m_History.setSingle(name); // Build the model and output it. if (outputModel || (testMode == 3) || (testMode == 4)) { m_Log.statusMessage("Building model on training data..."); trainTimeStart = System.currentTimeMillis(); classifier.buildClassifier(inst); trainTimeElapsed = System.currentTimeMillis() - trainTimeStart; } if (outputModel) { outBuff.append("=== Classifier model (full training set) ===\n\n"); outBuff.append(classifier.toString() + "\n"); outBuff.append("\nTime taken to build model: " + Utils.doubleToString(trainTimeElapsed / 1000.0,2) + " seconds\n\n"); m_History.updateResult(name); if (classifier instanceof Drawable) { grph = null; try { grph = ((Drawable)classifier).graph(); } catch (Exception ex) { } } } Evaluation eval = null; switch (testMode) { case 3: // Test on training m_Log.statusMessage("Evaluating on training data..."); eval = new Evaluation(inst, costMatrix); for (int jj=0;jj<inst.numInstances();jj++) { processClassifierPrediction(inst.instance(jj), classifier, eval, predictions, predInstances, plotShape, plotSize); if ((jj % 100) == 0) { m_Log.statusMessage("Evaluating on training data. Processed " +jj+" instances..."); } } outBuff.append("=== Evaluation on training set ===\n"); break; case 1: // CV mode m_Log.statusMessage("Randomizing instances..."); int rnd = 1; try { rnd = Integer.parseInt(m_RandomSeedText.getText().trim()); System.err.println("Using random seed "+rnd); } catch (Exception ex) { m_Log.logMessage("Trouble parsing random seed value"); rnd = 1; } inst.randomize(new Random(rnd)); if (inst.attribute(classIndex).isNominal()) { m_Log.statusMessage("Stratifying instances..."); inst.stratify(numFolds); } eval = new Evaluation(inst, costMatrix); // Make some splits and do a CV for (int fold = 0; fold < numFolds; fold++) { m_Log.statusMessage("Creating splits for fold " + (fold + 1) + "..."); Instances train = inst.trainCV(numFolds, fold); Instances test = inst.testCV(numFolds, fold); m_Log.statusMessage("Building model for fold " + (fold + 1) + "..."); classifier.buildClassifier(train); m_Log.statusMessage("Evaluating model for fold " + (fold + 1) + "...");
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -