📄 classifierpanel.java
字号:
m_Summary = sp.getSummary();
if (m_TestInstances != null) {
sp.setInstances(m_TestInstances);
}
sp.addPropertyChangeListener(new PropertyChangeListener() {
public void propertyChange(PropertyChangeEvent e) {
m_TestInstances = sp.getInstances();
}
});
// Add propertychangelistener to update m_TestInstances whenever
// it changes in the settestframe
m_SetTestFrame = new JFrame("Test Instances");
m_SetTestFrame.getContentPane().setLayout(new BorderLayout());
m_SetTestFrame.getContentPane().add(sp, BorderLayout.CENTER);
m_SetTestFrame.pack();
}
m_SetTestFrame.setVisible(true);
}
/**
* Process a classifier's prediction for an instance and update a
* set of plotting instances and additional plotting info. plotInfo
* for nominal class datasets holds shape types (actual data points have
* automatic shape type assignment; classifier error data points have
* box shape type). For numeric class datasets, the actual data points
* are stored in plotInstances and plotInfo stores the error (which is
* later converted to shape size values)
* @param toPredict the actual data point
* @param classifier the classifier
* @param eval the evaluation object to use for evaluating the classifier on
* the instance to predict
* @param predictions a fastvector to add the prediction to
* @param plotInstances a set of plottable instances
* @param plotShape additional plotting information (shape)
* @param plotSize additional plotting information (size)
*/
public static void processClassifierPrediction(Instance toPredict,
Classifier classifier,
Evaluation eval,
FastVector predictions,
Instances plotInstances,
FastVector plotShape,
FastVector plotSize) {
try {
double pred;
// classifier is a distribution classifier and class is nominal
if (predictions != null) {
Instance classMissing = (Instance)toPredict.copy();
classMissing.setDataset(toPredict.dataset());
classMissing.setClassMissing();
Classifier dc = classifier;
double [] dist =
dc.distributionForInstance(classMissing);
pred = eval.evaluateModelOnce(dist, toPredict);
predictions.addElement(new
NominalPrediction(toPredict.classValue(), dist, toPredict.weight()));
} else {
pred = eval.evaluateModelOnce(classifier,
toPredict);
}
double [] values = new double[plotInstances.numAttributes()];
for (int i = 0; i < plotInstances.numAttributes(); i++) {
if (i < toPredict.classIndex()) {
values[i] = toPredict.value(i);
} else if (i == toPredict.classIndex()) {
values[i] = pred;
values[i+1] = toPredict.value(i);
/* // if the class value of the instances to predict is missing then
// set it to the predicted value
if (toPredict.isMissing(i)) {
values[i+1] = pred;
} */
i++;
} else {
values[i] = toPredict.value(i-1);
}
}
plotInstances.add(new Instance(1.0, values));
if (toPredict.classAttribute().isNominal()) {
if (toPredict.isMissing(toPredict.classIndex())
|| Instance.isMissingValue(pred)) {
plotShape.addElement(new Integer(Plot2D.MISSING_SHAPE));
} else if (pred != toPredict.classValue()) {
// set to default error point shape
plotShape.addElement(new Integer(Plot2D.ERROR_SHAPE));
} else {
// otherwise set to constant (automatically assigned) point shape
plotShape.addElement(new Integer(Plot2D.CONST_AUTOMATIC_SHAPE));
}
plotSize.addElement(new Integer(Plot2D.DEFAULT_SHAPE_SIZE));
} else {
// store the error (to be converted to a point size later)
Double errd = null;
if (!toPredict.isMissing(toPredict.classIndex()) &&
!Instance.isMissingValue(pred)) {
errd = new Double(pred - toPredict.classValue());
plotShape.addElement(new Integer(Plot2D.CONST_AUTOMATIC_SHAPE));
} else {
// missing shape if actual class not present or prediction is missing
plotShape.addElement(new Integer(Plot2D.MISSING_SHAPE));
}
plotSize.addElement(errd);
}
} catch (Exception ex) {
ex.printStackTrace();
}
}
/**
* Post processes numeric class errors into shape sizes for plotting
* in the visualize panel
* @param plotSize a FastVector of numeric class errors
*/
private void postProcessPlotInfo(FastVector plotSize) {
int maxpSize = 20;
double maxErr = Double.NEGATIVE_INFINITY;
double minErr = Double.POSITIVE_INFINITY;
double err;
for (int i = 0; i < plotSize.size(); i++) {
Double errd = (Double)plotSize.elementAt(i);
if (errd != null) {
err = Math.abs(errd.doubleValue());
if (err < minErr) {
minErr = err;
}
if (err > maxErr) {
maxErr = err;
}
}
}
for (int i = 0; i < plotSize.size(); i++) {
Double errd = (Double)plotSize.elementAt(i);
if (errd != null) {
err = Math.abs(errd.doubleValue());
if (maxErr - minErr > 0) {
double temp = (((err - minErr) / (maxErr - minErr))
* maxpSize);
plotSize.setElementAt(new Integer((int)temp), i);
} else {
plotSize.setElementAt(new Integer(1), i);
}
} else {
plotSize.setElementAt(new Integer(1), i);
}
}
}
/**
* Sets up the structure for the visualizable instances. This dataset
* contains the original attributes plus the classifier's predictions
* for the class as an attribute called "predicted+WhateverTheClassIsCalled".
* @param trainInstancs the instances that the classifier is trained on
* @return a new set of instances containing one more attribute (predicted
* class) than the trainInstances
*/
public static Instances setUpVisualizableInstances(Instances trainInstances) {
FastVector hv = new FastVector();
Attribute predictedClass;
Attribute classAt = trainInstances.attribute(trainInstances.classIndex());
if (classAt.isNominal()) {
FastVector attVals = new FastVector();
for (int i = 0; i < classAt.numValues(); i++) {
attVals.addElement(classAt.value(i));
}
predictedClass = new Attribute("predicted"+classAt.name(), attVals);
} else {
predictedClass = new Attribute("predicted"+classAt.name());
}
for (int i = 0; i < trainInstances.numAttributes(); i++) {
if (i == trainInstances.classIndex()) {
hv.addElement(predictedClass);
}
hv.addElement(trainInstances.attribute(i).copy());
}
return new Instances(trainInstances.relationName()+"_predicted", hv,
trainInstances.numInstances());
}
/**
* Starts running the currently configured classifier with the current
* settings. This is run in a separate thread, and will only start if
* there is no classifier already running. The classifier output is sent
* to the results history panel.
*/
protected void startClassifier() {
if (m_RunThread == null) {
synchronized (this) {
m_StartBut.setEnabled(false);
m_StopBut.setEnabled(true);
}
m_RunThread = new Thread() {
public void run() {
// Copy the current state of things
m_Log.statusMessage("Setting up...");
CostMatrix costMatrix = null;
Instances inst = new Instances(m_Instances);
Instances userTest = null;
// additional vis info (either shape type or point size)
FastVector plotShape = new FastVector();
FastVector plotSize = new FastVector();
Instances predInstances = null;
// will hold the prediction objects if the class is nominal
FastVector predictions = null;
// for timing
long trainTimeStart = 0, trainTimeElapsed = 0;
if (m_TestInstances != null) {
userTest = new Instances(m_TestInstances);
}
if (m_EvalWRTCostsBut.isSelected()) {
costMatrix = new CostMatrix((CostMatrix) m_CostMatrixEditor
.getValue());
}
boolean outputModel = m_OutputModelBut.isSelected();
boolean outputConfusion = m_OutputConfusionBut.isSelected();
boolean outputPerClass = m_OutputPerClassBut.isSelected();
boolean outputSummary = true;
boolean outputEntropy = m_OutputEntropyBut.isSelected();
boolean saveVis = m_StorePredictionsBut.isSelected();
boolean outputPredictionsText = m_OutputPredictionsTextBut.isSelected();
String grph = null;
int testMode = 0;
int numFolds = 10, percent = 66;
int classIndex = m_ClassCombo.getSelectedIndex();
Classifier classifier = (Classifier) m_ClassifierEditor.getValue();
Classifier fullClassifier = null;
StringBuffer outBuff = new StringBuffer();
String name = (new SimpleDateFormat("HH:mm:ss - "))
.format(new Date());
String cname = classifier.getClass().getName();
if (cname.startsWith("weka.classifiers.")) {
name += cname.substring("weka.classifiers.".length());
} else {
name += cname;
}
try {
if (m_CVBut.isSelected()) {
testMode = 1;
numFolds = Integer.parseInt(m_CVText.getText());
if (numFolds <= 1) {
throw new Exception("Number of folds must be greater than 1");
}
} else if (m_PercentBut.isSelected()) {
testMode = 2;
percent = Integer.parseInt(m_PercentText.getText());
if ((percent <= 0) || (percent >= 100)) {
throw new Exception("Percentage must be between 0 and 100");
}
} else if (m_TrainBut.isSelected()) {
testMode = 3;
} else if (m_TestSplitBut.isSelected()) {
testMode = 4;
// Check the test instance compatibility
if (userTest == null) {
throw new Exception("No user test set has been opened");
}
if (!inst.equalHeaders(userTest)) {
throw new Exception("Train and test set are not compatible");
}
userTest.setClassIndex(classIndex);
} else {
throw new Exception("Unknown test mode");
}
inst.setClassIndex(classIndex);
// set up the structure of the plottable instances for
// visualization
predInstances = setUpVisualizableInstances(inst);
predInstances.setClassIndex(inst.classIndex()+1);
if (inst.classAttribute().isNominal()) {
predictions = new FastVector();
}
// Output some header information
m_Log.logMessage("Started " + cname);
if (m_Log instanceof TaskLogger) {
((TaskLogger)m_Log).taskStarted();
}
outBuff.append("=== Run information ===\n\n");
outBuff.append("Scheme: " + cname);
if (classifier instanceof OptionHandler) {
String [] o = ((OptionHandler) classifier).getOptions();
outBuff.append(" " + Utils.joinOptions(o));
}
outBuff.append("\n");
outBuff.append("Relation: " + inst.relationName() + '\n');
outBuff.append("Instances: " + inst.numInstances() + '\n');
outBuff.append("Attributes: " + inst.numAttributes() + '\n');
if (inst.numAttributes() < 100) {
for (int i = 0; i < inst.numAttributes(); i++) {
outBuff.append(" " + inst.attribute(i).name()
+ '\n');
}
} else {
outBuff.append(" [list of attributes omitted]\n");
}
outBuff.append("Test mode: ");
switch (testMode) {
case 3: // Test on training
outBuff.append("evaluate on training data\n");
break;
case 1: // CV mode
outBuff.append("" + numFolds + "-fold cross-validation\n");
break;
case 2: // Percent split
outBuff.append("split " + percent
+ "% train, remainder test\n");
break;
case 4: // Test on user split
outBuff.append("user supplied test set: "
+ userTest.numInstances() + " instances\n");
break;
}
if (costMatrix != null) {
outBuff.append("Evaluation cost matrix:\n")
.append(costMatrix.toString()).append("\n");
}
outBuff.append("\n");
m_History.addResult(name, outBuff);
m_History.setSingle(name);
// Build the model and output it.
if (outputModel || (testMode == 3) || (testMode == 4)) {
m_Log.statusMessage("Building model on training data...");
trainTimeStart = System.currentTimeMillis();
classifier.buildClassifier(inst);
trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
}
if (outputModel) {
outBuff.append("=== Classifier model (full training set) ===\n\n");
outBuff.append(classifier.toString() + "\n");
outBuff.append("\nTime taken to build model: " +
Utils.doubleToString(trainTimeElapsed / 1000.0,2)
+ " seconds\n\n");
m_History.updateResult(name);
if (classifier instanceof Drawable) {
grph = null;
try {
grph = ((Drawable)classifier).graph();
} catch (Exception ex) {
}
}
// copy full model for output
SerializedObject so = new SerializedObject(classifier);
fullClassifier = (Classifier) so.getObject();
}
Evaluation eval = null;
switch (testMode) {
case 3: // Test on training
m_Log.statusMessage("Evaluating on training data...");
eval = new Evaluation(inst, costMatrix);
for (int jj=0;jj<inst.numInstances();jj++) {
processClassifierPrediction(inst.instance(jj), classifier,
eval, predictions,
predInstances, plotShape,
plotSize);
if ((jj % 100) == 0) {
m_Log.statusMessage("Evaluating on training data. Processed "
+jj+" instances...");
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -