⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 classifierpanel.java

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 JAVA
📖 第 1 页 / 共 5 页
字号:
    m_SetTestFrame.setVisible(true);  }  /**   * Process a classifier's prediction for an instance and update a   * set of plotting instances and additional plotting info. plotInfo   * for nominal class datasets holds shape types (actual data points have   * automatic shape type assignment; classifer error data points have   * box shape type). For numeric class datasets, the actual data points   * are stored in plotInstances and plotInfo stores the error (which is   * later converted to shape size values)   * @param toPredict the actual data point   * @param classifier the classifier   * @param eval the evaluation object to use for evaluating the classifer on   * the instance to predict   * @param predictions a fastvector to add the prediction to   * @param plotInstances a set of plottable instances   * @param plotShape additional plotting information (shape)   * @param plotSize additional plotting information (size)   */  private void processClassifierPrediction(Instance toPredict,					   Classifier classifier,					   Evaluation eval,					   FastVector predictions,					   Instances plotInstances,					   FastVector plotShape,					   FastVector plotSize) {    try {      double pred;      // classifier is a distribution classifer and class is nominal      if (predictions != null) {	Instance classMissing = (Instance)toPredict.copy();	classMissing.setDataset(toPredict.dataset());	classMissing.setClassMissing();	DistributionClassifier dc = 	  (DistributionClassifier)classifier;	double [] dist = 	  dc.distributionForInstance(classMissing);	pred = eval.evaluateModelOnce(dist, toPredict);	int actual = (int)toPredict.classValue();	predictions.addElement(new 	  NominalPrediction(actual, dist, toPredict.weight()));      } else {	pred = eval.evaluateModelOnce(classifier, 				      toPredict);      }      double [] values = new double[plotInstances.numAttributes()];      for (int i = 0; i < plotInstances.numAttributes(); i++) {	if (i < toPredict.classIndex()) {	  values[i] = toPredict.value(i);	} else if (i == toPredict.classIndex()) {	  values[i] = pred;	  values[i+1] = toPredict.value(i);	  /* // if the class value of the instances to predict is missing then	  // set it to the predicted value	  if (toPredict.isMissing(i)) {	    values[i+1] = pred;	    } */	  i++;	} else {	  values[i] = toPredict.value(i-1);	}      }      plotInstances.add(new Instance(1.0, values));      if (toPredict.classAttribute().isNominal()) {	if (toPredict.isMissing(toPredict.classIndex())) {	  plotShape.addElement(new Integer(Plot2D.MISSING_SHAPE));	} else if (pred != toPredict.classValue()) {	  // set to default error point shape	  plotShape.addElement(new Integer(Plot2D.ERROR_SHAPE));	} else {	  // otherwise set to constant (automatically assigned) point shape	  plotShape.addElement(new Integer(Plot2D.CONST_AUTOMATIC_SHAPE));	}	plotSize.addElement(new Integer(Plot2D.DEFAULT_SHAPE_SIZE));      } else {	// store the error (to be converted to a point size later)	Double errd = null;	if (!toPredict.isMissing(toPredict.classIndex())) {	  errd = new Double(pred - toPredict.classValue());	  plotShape.addElement(new Integer(Plot2D.CONST_AUTOMATIC_SHAPE));	} else {	  // missing shape if actual class not present	  plotShape.addElement(new Integer(Plot2D.MISSING_SHAPE));	}	plotSize.addElement(errd);      }    } catch (Exception ex) {      ex.printStackTrace();    }  }  /**   * Post processes numeric class errors into shape sizes for plotting   * in the visualize panel   * @param plotSize a FastVector of numeric class errors   */  private void postProcessPlotInfo(FastVector plotSize) {    int maxpSize = 20;    double maxErr = Double.NEGATIVE_INFINITY;    double minErr = Double.POSITIVE_INFINITY;    double err;        for (int i = 0; i < plotSize.size(); i++) {      Double errd = (Double)plotSize.elementAt(i);      if (errd != null) {	err = Math.abs(errd.doubleValue());        if (err < minErr) {	  minErr = err;	}	if (err > maxErr) {	  maxErr = err;	}      }    }        for (int i = 0; i < plotSize.size(); i++) {      Double errd = (Double)plotSize.elementAt(i);      if (errd != null) {	err = Math.abs(errd.doubleValue());	if (maxErr - minErr > 0) {	  double temp = (((err - minErr) / (maxErr - minErr)) 			 * maxpSize);	  plotSize.setElementAt(new Integer((int)temp), i);	} else {	  plotSize.setElementAt(new Integer(1), i);	}      } else {	plotSize.setElementAt(new Integer(1), i);      }    }  }  /**   * Sets up the structure for the visualizable instances. This dataset   * contains the original attributes plus the classifier's predictions   * for the class as an attribute called "predicted+WhateverTheClassIsCalled".   * @param trainInstancs the instances that the classifier is trained on   * @return a new set of instances containing one more attribute (predicted   * class) than the trainInstances   */  private Instances setUpVisualizableInstances(Instances trainInstances) {    FastVector hv = new FastVector();    Attribute predictedClass;    Attribute classAt = trainInstances.attribute(trainInstances.classIndex());    if (classAt.isNominal()) {      FastVector attVals = new FastVector();      for (int i = 0; i < classAt.numValues(); i++) {	attVals.addElement(classAt.value(i));      }      predictedClass = new Attribute("predicted"+classAt.name(), attVals);    } else {      predictedClass = new Attribute("predicted"+classAt.name());    }    for (int i = 0; i < trainInstances.numAttributes(); i++) {      if (i == trainInstances.classIndex()) {	hv.addElement(predictedClass);      }      hv.addElement(trainInstances.attribute(i).copy());    }    return new Instances(trainInstances.relationName()+"_predicted", hv, 			 trainInstances.numInstances());  }  /**   * Starts running the currently configured classifier with the current   * settings. This is run in a separate thread, and will only start if   * there is no classifier already running. The classifier output is sent   * to the results history panel.   */  protected void startClassifier() {    if (m_RunThread == null) {      synchronized (this) {	m_StartBut.setEnabled(false);	m_StopBut.setEnabled(true);      }      m_RunThread = new Thread() {	public void run() {	  // Copy the current state of things	  m_Log.statusMessage("Setting up...");	  CostMatrix costMatrix = null;	  Instances inst = new Instances(m_Instances);	  Instances userTest = null;	  // additional vis info (either shape type or point size)	  FastVector plotShape = new FastVector();	  FastVector plotSize = new FastVector();	  Instances predInstances = null;	  // will hold the prediction objects if the class is nominal	  FastVector predictions = null;	 	  // for timing	  long trainTimeStart = 0, trainTimeElapsed = 0;	  if (m_TestInstances != null) {	    userTest = new Instances(m_TestInstancesCopy);	  }	  if (m_EvalWRTCostsBut.isSelected()) {	    costMatrix = new CostMatrix((CostMatrix) m_CostMatrixEditor					.getValue());	  }	  boolean outputModel = m_OutputModelBut.isSelected();	  boolean outputConfusion = m_OutputConfusionBut.isSelected();	  boolean outputPerClass = m_OutputPerClassBut.isSelected();	  boolean outputSummary = true;          boolean outputEntropy = m_OutputEntropyBut.isSelected();	  boolean saveVis = m_StorePredictionsBut.isSelected();	  boolean outputPredictionsText = m_OutputPredictionsTextBut.isSelected();	  String grph = null;	  int testMode = 0;	  int numFolds = 10, percent = 66;	  int classIndex = m_ClassCombo.getSelectedIndex();	  Classifier classifier = (Classifier) m_ClassifierEditor.getValue();	  StringBuffer outBuff = new StringBuffer();	  String name = (new SimpleDateFormat("HH:mm:ss - "))	  .format(new Date());	  String cname = classifier.getClass().getName();	  if (cname.startsWith("weka.classifiers.")) {	    name += cname.substring("weka.classifiers.".length());	  } else {	    name += cname;	  }	  try {	    if (m_CVBut.isSelected()) {	      testMode = 1;	      numFolds = Integer.parseInt(m_CVText.getText());	      if (numFolds <= 1) {		throw new Exception("Number of folds must be greater than 1");	      }	    } else if (m_PercentBut.isSelected()) {	      testMode = 2;	      percent = Integer.parseInt(m_PercentText.getText());	      if ((percent <= 0) || (percent >= 100)) {		throw new Exception("Percentage must be between 0 and 100");	      }	    } else if (m_TrainBut.isSelected()) {	      testMode = 3;	    } else if (m_TestSplitBut.isSelected()) {	      testMode = 4;	      // Check the test instance compatibility	      if (userTest == null) {		throw new Exception("No user test set has been opened");	      }	      if (!inst.equalHeaders(userTest)) {		throw new Exception("Train and test set are not compatible");	      }	      userTest.setClassIndex(classIndex);	    } else {	      throw new Exception("Unknown test mode");	    }	    inst.setClassIndex(classIndex);	    // set up the structure of the plottable instances for 	    // visualization	    predInstances = setUpVisualizableInstances(inst);	    predInstances.setClassIndex(inst.classIndex()+1);	    	    if (inst.classAttribute().isNominal() && 		classifier instanceof DistributionClassifier) {	      predictions = new FastVector();	    }	    // Output some header information	    m_Log.logMessage("Started " + cname);	    if (m_Log instanceof TaskLogger) {	      ((TaskLogger)m_Log).taskStarted();	    }	    outBuff.append("=== Run information ===\n\n");	    outBuff.append("Scheme:       " + cname);	    if (classifier instanceof OptionHandler) {	      String [] o = ((OptionHandler) classifier).getOptions();	      outBuff.append(" " + Utils.joinOptions(o));	    }	    outBuff.append("\n");	    outBuff.append("Relation:     " + inst.relationName() + '\n');	    outBuff.append("Instances:    " + inst.numInstances() + '\n');	    outBuff.append("Attributes:   " + inst.numAttributes() + '\n');	    if (inst.numAttributes() < 100) {	      for (int i = 0; i < inst.numAttributes(); i++) {		outBuff.append("              " + inst.attribute(i).name()			       + '\n');	      }	    } else {	      outBuff.append("              [list of attributes omitted]\n");	    }	    outBuff.append("Test mode:    ");	    switch (testMode) {	      case 3: // Test on training	      outBuff.append("evaluate on training data\n");	      break;	      case 1: // CV mode	      outBuff.append("" + numFolds + "-fold cross-validation\n");	      break;	      case 2: // Percent split	      outBuff.append("split " + percent			       + "% train, remainder test\n");	      break;	      case 4: // Test on user split	      outBuff.append("user supplied test set: "			     + userTest.numInstances() + " instances\n");	      break;	    }            if (costMatrix != null) {               outBuff.append("Evaluation cost matrix:\n")               .append(costMatrix.toString()).append("\n");            }	    outBuff.append("\n");	    m_History.addResult(name, outBuff);	    m_History.setSingle(name);	    	    // Build the model and output it.	    if (outputModel || (testMode == 3) || (testMode == 4)) {	      m_Log.statusMessage("Building model on training data...");	      trainTimeStart = System.currentTimeMillis();	      classifier.buildClassifier(inst);	      trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;	    }	    if (outputModel) {	      outBuff.append("=== Classifier model (full training set) ===\n\n");	      outBuff.append(classifier.toString() + "\n");	      outBuff.append("\nTime taken to build model: " +			     Utils.doubleToString(trainTimeElapsed / 1000.0,2)			     + " seconds\n\n");	      m_History.updateResult(name);	      if (classifier instanceof Drawable) {		grph = null;		try {		  grph = ((Drawable)classifier).graph();		} catch (Exception ex) {		}	      }	    }	    	    Evaluation eval = null;	    switch (testMode) {	      case 3: // Test on training	      m_Log.statusMessage("Evaluating on training data...");	      eval = new Evaluation(inst, costMatrix);	      for (int jj=0;jj<inst.numInstances();jj++) {		processClassifierPrediction(inst.instance(jj), classifier,					    eval, predictions,					    predInstances, plotShape, 					    plotSize);				if ((jj % 100) == 0) {		  m_Log.statusMessage("Evaluating on training data. Processed "				      +jj+" instances...");		}	      }	      outBuff.append("=== Evaluation on training set ===\n");	      break;	      case 1: // CV mode	      m_Log.statusMessage("Randomizing instances...");	      int rnd = 1;	      try {		rnd = Integer.parseInt(m_RandomSeedText.getText().trim());		System.err.println("Using random seed "+rnd);	      } catch (Exception ex) {		m_Log.logMessage("Trouble parsing random seed value");		rnd = 1;	      }	      inst.randomize(new Random(rnd));	      if (inst.attribute(classIndex).isNominal()) {		m_Log.statusMessage("Stratifying instances...");		inst.stratify(numFolds);	      }	      eval = new Evaluation(inst, costMatrix);	      // Make some splits and do a CV	      for (int fold = 0; fold < numFolds; fold++) {		m_Log.statusMessage("Creating splits for fold "				    + (fold + 1) + "...");		Instances train = inst.trainCV(numFolds, fold);		Instances test = inst.testCV(numFolds, fold);		m_Log.statusMessage("Building model for fold "				    + (fold + 1) + "...");		classifier.buildClassifier(train);		m_Log.statusMessage("Evaluating model for fold "				    + (fold + 1) + "...");

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -