classifierpanel.java

来自「Weka」· Java 代码 · 共 1,867 行 · 第 1/5 页

JAVA
1,867
字号
		throw new Exception("No user test set has been specified");	      }	      if (!inst.equalHeaders(userTestStructure)) {		throw new Exception("Train and test set are not compatible");	      }              userTestStructure.setClassIndex(classIndex);	    } else {	      throw new Exception("Unknown test mode");	    }	    inst.setClassIndex(classIndex);	    // set up the structure of the plottable instances for 	    // visualization            if (saveVis) {              predInstances = setUpVisualizableInstances(inst);              predInstances.setClassIndex(inst.classIndex()+1);            } 	    // Output some header information	    m_Log.logMessage("Started " + cname);	    m_Log.logMessage("Command: " + cmd);	    if (m_Log instanceof TaskLogger) {	      ((TaskLogger)m_Log).taskStarted();	    }	    outBuff.append("=== Run information ===\n\n");	    outBuff.append("Scheme:       " + cname);	    if (classifier instanceof OptionHandler) {	      String [] o = ((OptionHandler) classifier).getOptions();	      outBuff.append(" " + Utils.joinOptions(o));	    }	    outBuff.append("\n");	    outBuff.append("Relation:     " + inst.relationName() + '\n');	    outBuff.append("Instances:    " + inst.numInstances() + '\n');	    outBuff.append("Attributes:   " + inst.numAttributes() + '\n');	    if (inst.numAttributes() < 100) {	      for (int i = 0; i < inst.numAttributes(); i++) {		outBuff.append("              " + inst.attribute(i).name()			       + '\n');	      }	    } else {	      outBuff.append("              [list of attributes omitted]\n");	    }	    outBuff.append("Test mode:    ");	    switch (testMode) {	      case 3: // Test on training		outBuff.append("evaluate on training data\n");		break;	      case 1: // CV mode		outBuff.append("" + numFolds + "-fold cross-validation\n");		break;	      case 2: // Percent split		outBuff.append("split " + percent		    + "% train, remainder test\n");		break;	      case 4: // Test on user split		if (source.isIncremental())		  outBuff.append("user supplied test set: "		      + " size unknown (reading incrementally)\n");		else		  outBuff.append("user supplied test set: "		      + source.getDataSet().numInstances() + " instances\n");		break;	    }            if (costMatrix != null) {               outBuff.append("Evaluation cost matrix:\n")               .append(costMatrix.toString()).append("\n");            }	    outBuff.append("\n");	    m_History.addResult(name, outBuff);	    m_History.setSingle(name);	    	    // Build the model and output it.	    if (outputModel || (testMode == 3) || (testMode == 4)) {	      m_Log.statusMessage("Building model on training data...");	      trainTimeStart = System.currentTimeMillis();	      classifier.buildClassifier(inst);	      trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;	    }	    if (outputModel) {	      outBuff.append("=== Classifier model (full training set) ===\n\n");	      outBuff.append(classifier.toString() + "\n");	      outBuff.append("\nTime taken to build model: " +			     Utils.doubleToString(trainTimeElapsed / 1000.0,2)			     + " seconds\n\n");	      m_History.updateResult(name);	      if (classifier instanceof Drawable) {		grph = null;		try {		  grph = ((Drawable)classifier).graph();		} catch (Exception ex) {		}	      }	      // copy full model for output	      SerializedObject so = new SerializedObject(classifier);	      fullClassifier = (Classifier) so.getObject();	    }	    	    switch (testMode) {	      case 3: // Test on training	      m_Log.statusMessage("Evaluating on training data...");	      eval = new Evaluation(inst, costMatrix);	      	      if (outputPredictionsText) {		printPredictionsHeader(outBuff, inst, "training set");	      }	      for (int jj=0;jj<inst.numInstances();jj++) {		processClassifierPrediction(inst.instance(jj), classifier,					    eval, predInstances, plotShape, 					    plotSize);				if (outputPredictionsText) { 		  outBuff.append(predictionText(classifier, inst.instance(jj), jj+1));		}		if ((jj % 100) == 0) {		  m_Log.statusMessage("Evaluating on training data. Processed "				      +jj+" instances...");		}	      }	      if (outputPredictionsText) {		outBuff.append("\n");	      } 	      outBuff.append("=== Evaluation on training set ===\n");	      break;	      case 1: // CV mode	      m_Log.statusMessage("Randomizing instances...");	      int rnd = 1;	      try {		rnd = Integer.parseInt(m_RandomSeedText.getText().trim());		// System.err.println("Using random seed "+rnd);	      } catch (Exception ex) {		m_Log.logMessage("Trouble parsing random seed value");		rnd = 1;	      }	      Random random = new Random(rnd);	      inst.randomize(random);	      if (inst.attribute(classIndex).isNominal()) {		m_Log.statusMessage("Stratifying instances...");		inst.stratify(numFolds);	      }	      eval = new Evaluation(inst, costMatrix);      	      if (outputPredictionsText) {		printPredictionsHeader(outBuff, inst, "test data");	      }	      // Make some splits and do a CV	      for (int fold = 0; fold < numFolds; fold++) {		m_Log.statusMessage("Creating splits for fold "				    + (fold + 1) + "...");		Instances train = inst.trainCV(numFolds, fold, random);		eval.setPriors(train);		m_Log.statusMessage("Building model for fold "				    + (fold + 1) + "...");		Classifier current = null;		try {		  current = Classifier.makeCopy(template);		} catch (Exception ex) {		  m_Log.logMessage("Problem copying classifier: " + ex.getMessage());		}		current.buildClassifier(train);		Instances test = inst.testCV(numFolds, fold);		m_Log.statusMessage("Evaluating model for fold "				    + (fold + 1) + "...");		for (int jj=0;jj<test.numInstances();jj++) {		  processClassifierPrediction(test.instance(jj), current,					      eval, predInstances, plotShape,					      plotSize);		  if (outputPredictionsText) { 		    outBuff.append(predictionText(current, test.instance(jj), jj+1));		  }		}	      }	      if (outputPredictionsText) {		outBuff.append("\n");	      } 	      if (inst.attribute(classIndex).isNominal()) {		outBuff.append("=== Stratified cross-validation ===\n");	      } else {		outBuff.append("=== Cross-validation ===\n");	      }	      break;			      case 2: // Percent split	      if (!m_PreserveOrderBut.isSelected()) {		m_Log.statusMessage("Randomizing instances...");		try {		  rnd = Integer.parseInt(m_RandomSeedText.getText().trim());		} catch (Exception ex) {		  m_Log.logMessage("Trouble parsing random seed value");		  rnd = 1;		}		inst.randomize(new Random(rnd));	      }	      int trainSize = inst.numInstances() * percent / 100;	      int testSize = inst.numInstances() - trainSize;	      Instances train = new Instances(inst, 0, trainSize);	      Instances test = new Instances(inst, trainSize, testSize);	      m_Log.statusMessage("Building model on training split...");	      Classifier current = null;	      try {		current = Classifier.makeCopy(template);	      } catch (Exception ex) {		m_Log.logMessage("Problem copying classifier: " + ex.getMessage());	      }	      current.buildClassifier(train);	      eval = new Evaluation(train, costMatrix);	      m_Log.statusMessage("Evaluating on test split...");	     	      if (outputPredictionsText) {		printPredictionsHeader(outBuff, inst, "test split");	      }     	      for (int jj=0;jj<test.numInstances();jj++) {		processClassifierPrediction(test.instance(jj), current,					    eval, predInstances, plotShape,					    plotSize);		if (outputPredictionsText) { 		    outBuff.append(predictionText(current, test.instance(jj), jj+1));		}		if ((jj % 100) == 0) {		  m_Log.statusMessage("Evaluating on test split. Processed "				      +jj+" instances...");		}	      }	      if (outputPredictionsText) {		outBuff.append("\n");	      } 	      outBuff.append("=== Evaluation on test split ===\n");	      break;			      case 4: // Test on user split	      m_Log.statusMessage("Evaluating on test data...");	      eval = new Evaluation(inst, costMatrix);	      	      if (outputPredictionsText) {		printPredictionsHeader(outBuff, inst, "test set");	      }	      Instance instance;	      int jj = 0;	      while (source.hasMoreElements(userTestStructure)) {		instance = source.nextElement(userTestStructure);		processClassifierPrediction(instance, classifier,		    eval, predInstances, plotShape,		    plotSize);		if (outputPredictionsText) { 		  outBuff.append(predictionText(classifier, instance, jj+1));		}		if ((++jj % 100) == 0) {		  m_Log.statusMessage("Evaluating on test data. Processed "		      +jj+" instances...");		}	      }	      if (outputPredictionsText) {		outBuff.append("\n");	      } 	      outBuff.append("=== Evaluation on test set ===\n");	      break;	      default:	      throw new Exception("Test mode not implemented");	    }	    	    if (outputSummary) {	      outBuff.append(eval.toSummaryString(outputEntropy) + "\n");	    }	    if (inst.attribute(classIndex).isNominal()) {	      if (outputPerClass) {		outBuff.append(eval.toClassDetailsString() + "\n");	      }	      if (outputConfusion) {		outBuff.append(eval.toMatrixString() + "\n");	      }	    }            if (   (fullClassifier instanceof Sourcable)                  && m_OutputSourceCode.isSelected()) {              outBuff.append("=== Source code ===\n\n");              outBuff.append(                Evaluation.wekaStaticWrapper(                    ((Sourcable) fullClassifier),                    m_SourceCodeClass.getText()));            }	    m_History.updateResult(name);	    m_Log.logMessage("Finished " + cname);	    m_Log.statusMessage("OK");	  } catch (Exception ex) {	    ex.printStackTrace();	    m_Log.logMessage(ex.getMessage());	    JOptionPane.showMessageDialog(ClassifierPanel.this,					  "Problem evaluating classifier:\n"					  + ex.getMessage(),					  "Evaluate classifier",					  JOptionPane.ERROR_MESSAGE);	    m_Log.statusMessage("Problem evaluating classifier");	  } finally {	    try {              if (!saveVis && outputModel) {		  FastVector vv = new FastVector();		  vv.addElement(fullClassifier);		  Instances trainHeader = new Instances(m_Instances, 0);		  trainHeader.setClassIndex(classIndex);		  vv.addElement(trainHeader);                  if (grph != null) {		    vv.addElement(grph);		  }		  m_History.addObject(name, vv);              } else if (saveVis && predInstances != null &&                   predInstances.numInstances() > 0) {		if (predInstances.attribute(predInstances.classIndex())		    .isNumeric()) {		  postProcessPlotInfo(plotSize);		}		m_CurrentVis = new VisualizePanel();		m_CurrentVis.setName(name+" ("+inst.relationName()+")");		m_CurrentVis.setLog(m_Log);		PlotData2D tempd = new PlotData2D(predInstances);		tempd.setShapeSize(plotSize);		tempd.setShapeType(plotShape);		tempd.setPlotName(name+" ("+inst.relationName()+")");		tempd.addInstanceNumberAttribute();				m_CurrentVis.addPlot(tempd);		m_CurrentVis.setColourIndex(predInstances.classIndex()+1);	    		  FastVector vv = new FastVector();		  if (outputModel) {		    vv.addElement(fullClassifier);		    Instances trainHeader = new Instances(m_Instances, 0);		    trainHeader.setClassIndex(classIndex);		    vv.addElement(trainHeader);                    if (grph != null) {                      vv.addElement(grph);                    }		  }		  vv.addElement(m_CurrentVis);		  if ((eval != null) && (eval.predictions() != null)) {		    vv.addElement(eval.predictions());		    vv.addElement(inst.classAttribute());		  }		  m_History.addObject(name, vv);	      }	    } catch (Exception ex) {	      ex.printStackTrace();	    }	    	    if (isInterrupted()) {	      m_Log.logMessage("Interrupted " + cname);	      m_Log.statusMessage("Interrupted");	    }	    synchronized (this) {	      m_StartBut.setEnabled(true);	      m_StopBut.setEnabled(false);	      m_RunThread = null;	    }	    if (m_Log instanceof TaskLogger) {              ((TaskLogger)m_Log).taskFinished();            }          }	}      };      m_RunThread.setPriority(Thread.MIN_PRIORITY);

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?