⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 classifierpanel.java

📁 一个数据挖掘软件ALPHAMINERR的整个过程的JAVA版源代码
💻 JAVA
📖 第 1 页 / 共 5 页
字号:
      m_Summary = sp.getSummary();
      if (m_TestInstances != null) {
	sp.setInstances(m_TestInstances);
      }
      sp.addPropertyChangeListener(new PropertyChangeListener() {
	public void propertyChange(PropertyChangeEvent e) {
	  m_TestInstances = sp.getInstances();
	}
      });
      // Add propertychangelistener to update m_TestInstances whenever
      // it changes in the settestframe
      m_SetTestFrame = new JFrame("Test Instances");
      m_SetTestFrame.getContentPane().setLayout(new BorderLayout());
      m_SetTestFrame.getContentPane().add(sp, BorderLayout.CENTER);
      m_SetTestFrame.pack();
    }
    m_SetTestFrame.setVisible(true);
  }

  /**
   * Process a classifier's prediction for an instance and update a
   * set of plotting instances and additional plotting info. plotInfo
   * for nominal class datasets holds shape types (actual data points have
   * automatic shape type assignment; classifier error data points have
   * box shape type). For numeric class datasets, the actual data points
   * are stored in plotInstances and plotInfo stores the error (which is
   * later converted to shape size values)
   * @param toPredict the actual data point
   * @param classifier the classifier
   * @param eval the evaluation object to use for evaluating the classifier on
   * the instance to predict
   * @param predictions a fastvector to add the prediction to
   * @param plotInstances a set of plottable instances
   * @param plotShape additional plotting information (shape)
   * @param plotSize additional plotting information (size)
   */
  public static void processClassifierPrediction(Instance toPredict,
                                           Classifier classifier,
					   Evaluation eval,
					   FastVector predictions,
					   Instances plotInstances,
					   FastVector plotShape,
					   FastVector plotSize) {
    try {
      double pred;
      // classifier is a distribution classifier and class is nominal
      if (predictions != null) {
	Instance classMissing = (Instance)toPredict.copy();
	classMissing.setDataset(toPredict.dataset());
	classMissing.setClassMissing();
	Classifier dc = classifier;
	double [] dist = 
	  dc.distributionForInstance(classMissing);
	pred = eval.evaluateModelOnce(dist, toPredict);
	predictions.addElement(new 
	  NominalPrediction(toPredict.classValue(), dist, toPredict.weight()));
      } else {
	pred = eval.evaluateModelOnce(classifier, 
				      toPredict);
      }

      double [] values = new double[plotInstances.numAttributes()];
      for (int i = 0; i < plotInstances.numAttributes(); i++) {
	if (i < toPredict.classIndex()) {
	  values[i] = toPredict.value(i);
	} else if (i == toPredict.classIndex()) {
	  values[i] = pred;
	  values[i+1] = toPredict.value(i);
	  /* // if the class value of the instances to predict is missing then
	  // set it to the predicted value
	  if (toPredict.isMissing(i)) {
	    values[i+1] = pred;
	    } */
	  i++;
	} else {
	  values[i] = toPredict.value(i-1);
	}
      }

      plotInstances.add(new Instance(1.0, values));
      if (toPredict.classAttribute().isNominal()) {
	if (toPredict.isMissing(toPredict.classIndex()) 
	    || Instance.isMissingValue(pred)) {
	  plotShape.addElement(new Integer(Plot2D.MISSING_SHAPE));
	} else if (pred != toPredict.classValue()) {
	  // set to default error point shape
	  plotShape.addElement(new Integer(Plot2D.ERROR_SHAPE));
	} else {
	  // otherwise set to constant (automatically assigned) point shape
	  plotShape.addElement(new Integer(Plot2D.CONST_AUTOMATIC_SHAPE));
	}
	plotSize.addElement(new Integer(Plot2D.DEFAULT_SHAPE_SIZE));
      } else {
	// store the error (to be converted to a point size later)
	Double errd = null;
	if (!toPredict.isMissing(toPredict.classIndex()) && 
	    !Instance.isMissingValue(pred)) {
	  errd = new Double(pred - toPredict.classValue());
	  plotShape.addElement(new Integer(Plot2D.CONST_AUTOMATIC_SHAPE));
	} else {
	  // missing shape if actual class not present or prediction is missing
	  plotShape.addElement(new Integer(Plot2D.MISSING_SHAPE));
	}
	plotSize.addElement(errd);
      }
    } catch (Exception ex) {
      ex.printStackTrace();
    }
  }

  /**
   * Post processes numeric class errors into shape sizes for plotting
   * in the visualize panel
   * @param plotSize a FastVector of numeric class errors
   */
  private void postProcessPlotInfo(FastVector plotSize) {
    int maxpSize = 20;
    double maxErr = Double.NEGATIVE_INFINITY;
    double minErr = Double.POSITIVE_INFINITY;
    double err;
    
    for (int i = 0; i < plotSize.size(); i++) {
      Double errd = (Double)plotSize.elementAt(i);
      if (errd != null) {
	err = Math.abs(errd.doubleValue());
        if (err < minErr) {
	  minErr = err;
	}
	if (err > maxErr) {
	  maxErr = err;
	}
      }
    }
    
    for (int i = 0; i < plotSize.size(); i++) {
      Double errd = (Double)plotSize.elementAt(i);
      if (errd != null) {
	err = Math.abs(errd.doubleValue());
	if (maxErr - minErr > 0) {
	  double temp = (((err - minErr) / (maxErr - minErr)) 
			 * maxpSize);
	  plotSize.setElementAt(new Integer((int)temp), i);
	} else {
	  plotSize.setElementAt(new Integer(1), i);
	}
      } else {
	plotSize.setElementAt(new Integer(1), i);
      }
    }
  }

  /**
   * Sets up the structure for the visualizable instances. This dataset
   * contains the original attributes plus the classifier's predictions
   * for the class as an attribute called "predicted+WhateverTheClassIsCalled".
   * @param trainInstancs the instances that the classifier is trained on
   * @return a new set of instances containing one more attribute (predicted
   * class) than the trainInstances
   */
  public static Instances setUpVisualizableInstances(Instances trainInstances) {
    FastVector hv = new FastVector();
    Attribute predictedClass;

    Attribute classAt = trainInstances.attribute(trainInstances.classIndex());
    if (classAt.isNominal()) {
      FastVector attVals = new FastVector();
      for (int i = 0; i < classAt.numValues(); i++) {
	attVals.addElement(classAt.value(i));
      }
      predictedClass = new Attribute("predicted"+classAt.name(), attVals);
    } else {
      predictedClass = new Attribute("predicted"+classAt.name());
    }

    for (int i = 0; i < trainInstances.numAttributes(); i++) {
      if (i == trainInstances.classIndex()) {
	hv.addElement(predictedClass);
      }
      hv.addElement(trainInstances.attribute(i).copy());
    }
    return new Instances(trainInstances.relationName()+"_predicted", hv, 
			 trainInstances.numInstances());
  }

  /**
   * Starts running the currently configured classifier with the current
   * settings. This is run in a separate thread, and will only start if
   * there is no classifier already running. The classifier output is sent
   * to the results history panel.
   */
  protected void startClassifier() {

    if (m_RunThread == null) {
      synchronized (this) {
	m_StartBut.setEnabled(false);
	m_StopBut.setEnabled(true);
      }
      m_RunThread = new Thread() {
	public void run() {
	  // Copy the current state of things
	  m_Log.statusMessage("Setting up...");
	  CostMatrix costMatrix = null;
	  Instances inst = new Instances(m_Instances);
	  Instances userTest = null;
	  // additional vis info (either shape type or point size)
	  FastVector plotShape = new FastVector();
	  FastVector plotSize = new FastVector();
	  Instances predInstances = null;

	  // will hold the prediction objects if the class is nominal
	  FastVector predictions = null;
	 
	  // for timing
	  long trainTimeStart = 0, trainTimeElapsed = 0;

	  if (m_TestInstances != null) {
	    userTest = new Instances(m_TestInstances);
	  }
	  if (m_EvalWRTCostsBut.isSelected()) {
	    costMatrix = new CostMatrix((CostMatrix) m_CostMatrixEditor
					.getValue());
	  }
	  boolean outputModel = m_OutputModelBut.isSelected();
	  boolean outputConfusion = m_OutputConfusionBut.isSelected();
	  boolean outputPerClass = m_OutputPerClassBut.isSelected();
	  boolean outputSummary = true;
          boolean outputEntropy = m_OutputEntropyBut.isSelected();
	  boolean saveVis = m_StorePredictionsBut.isSelected();
	  boolean outputPredictionsText = m_OutputPredictionsTextBut.isSelected();

	  String grph = null;

	  int testMode = 0;
	  int numFolds = 10, percent = 66;
	  int classIndex = m_ClassCombo.getSelectedIndex();
	  Classifier classifier = (Classifier) m_ClassifierEditor.getValue();
	  Classifier fullClassifier = null;
	  StringBuffer outBuff = new StringBuffer();
	  String name = (new SimpleDateFormat("HH:mm:ss - "))
	  .format(new Date());
	  String cname = classifier.getClass().getName();
	  if (cname.startsWith("weka.classifiers.")) {
	    name += cname.substring("weka.classifiers.".length());
	  } else {
	    name += cname;
	  }
	  try {
	    if (m_CVBut.isSelected()) {
	      testMode = 1;
	      numFolds = Integer.parseInt(m_CVText.getText());
	      if (numFolds <= 1) {
		throw new Exception("Number of folds must be greater than 1");
	      }
	    } else if (m_PercentBut.isSelected()) {
	      testMode = 2;
	      percent = Integer.parseInt(m_PercentText.getText());
	      if ((percent <= 0) || (percent >= 100)) {
		throw new Exception("Percentage must be between 0 and 100");
	      }
	    } else if (m_TrainBut.isSelected()) {
	      testMode = 3;
	    } else if (m_TestSplitBut.isSelected()) {
	      testMode = 4;
	      // Check the test instance compatibility
	      if (userTest == null) {
		throw new Exception("No user test set has been opened");
	      }
	      if (!inst.equalHeaders(userTest)) {
		throw new Exception("Train and test set are not compatible");
	      }
	      userTest.setClassIndex(classIndex);
	    } else {
	      throw new Exception("Unknown test mode");
	    }
	    inst.setClassIndex(classIndex);

	    // set up the structure of the plottable instances for 
	    // visualization
	    predInstances = setUpVisualizableInstances(inst);
	    predInstances.setClassIndex(inst.classIndex()+1);
	    
	    if (inst.classAttribute().isNominal()) {
	      predictions = new FastVector();
	    }

	    // Output some header information
	    m_Log.logMessage("Started " + cname);
	    if (m_Log instanceof TaskLogger) {
	      ((TaskLogger)m_Log).taskStarted();
	    }
	    outBuff.append("=== Run information ===\n\n");
	    outBuff.append("Scheme:       " + cname);
	    if (classifier instanceof OptionHandler) {
	      String [] o = ((OptionHandler) classifier).getOptions();
	      outBuff.append(" " + Utils.joinOptions(o));
	    }
	    outBuff.append("\n");
	    outBuff.append("Relation:     " + inst.relationName() + '\n');
	    outBuff.append("Instances:    " + inst.numInstances() + '\n');
	    outBuff.append("Attributes:   " + inst.numAttributes() + '\n');
	    if (inst.numAttributes() < 100) {
	      for (int i = 0; i < inst.numAttributes(); i++) {
		outBuff.append("              " + inst.attribute(i).name()
			       + '\n');
	      }
	    } else {
	      outBuff.append("              [list of attributes omitted]\n");
	    }

	    outBuff.append("Test mode:    ");
	    switch (testMode) {
	      case 3: // Test on training
	      outBuff.append("evaluate on training data\n");
	      break;
	      case 1: // CV mode
	      outBuff.append("" + numFolds + "-fold cross-validation\n");
	      break;
	      case 2: // Percent split
	      outBuff.append("split " + percent
			       + "% train, remainder test\n");
	      break;
	      case 4: // Test on user split
	      outBuff.append("user supplied test set: "
			     + userTest.numInstances() + " instances\n");
	      break;
	    }
            if (costMatrix != null) {
               outBuff.append("Evaluation cost matrix:\n")
               .append(costMatrix.toString()).append("\n");
            }
	    outBuff.append("\n");
	    m_History.addResult(name, outBuff);
	    m_History.setSingle(name);
	    
	    // Build the model and output it.
	    if (outputModel || (testMode == 3) || (testMode == 4)) {
	      m_Log.statusMessage("Building model on training data...");

	      trainTimeStart = System.currentTimeMillis();
	      classifier.buildClassifier(inst);
	      trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
	    }

	    if (outputModel) {
	      outBuff.append("=== Classifier model (full training set) ===\n\n");
	      outBuff.append(classifier.toString() + "\n");
	      outBuff.append("\nTime taken to build model: " +
			     Utils.doubleToString(trainTimeElapsed / 1000.0,2)
			     + " seconds\n\n");
	      m_History.updateResult(name);
	      if (classifier instanceof Drawable) {
		grph = null;
		try {
		  grph = ((Drawable)classifier).graph();
		} catch (Exception ex) {
		}
	      }
	      // copy full model for output
	      SerializedObject so = new SerializedObject(classifier);
	      fullClassifier = (Classifier) so.getObject();
	    }
	    
	    Evaluation eval = null;
	    switch (testMode) {
	      case 3: // Test on training
	      m_Log.statusMessage("Evaluating on training data...");
	      eval = new Evaluation(inst, costMatrix);
	      for (int jj=0;jj<inst.numInstances();jj++) {
		processClassifierPrediction(inst.instance(jj), classifier,
					    eval, predictions,
					    predInstances, plotShape, 
					    plotSize);
		
		if ((jj % 100) == 0) {
		  m_Log.statusMessage("Evaluating on training data. Processed "
				      +jj+" instances...");
		}
	      }

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -