📄 classifierpanel.java
字号:
outBuff.append("=== Evaluation on training set ===\n");
break;
case 1: // CV mode
m_Log.statusMessage("Randomizing instances...");
int rnd = 1;
try {
rnd = Integer.parseInt(m_RandomSeedText.getText().trim());
// System.err.println("Using random seed "+rnd);
} catch (Exception ex) {
m_Log.logMessage("Trouble parsing random seed value");
rnd = 1;
}
Random random = new Random(rnd);
inst.randomize(random);
if (inst.attribute(classIndex).isNominal()) {
m_Log.statusMessage("Stratifying instances...");
inst.stratify(numFolds);
}
eval = new Evaluation(inst, costMatrix);
// Make some splits and do a CV
for (int fold = 0; fold < numFolds; fold++) {
m_Log.statusMessage("Creating splits for fold "
+ (fold + 1) + "...");
Instances train = inst.trainCV(numFolds, fold, random);
eval.setPriors(train);
m_Log.statusMessage("Building model for fold "
+ (fold + 1) + "...");
classifier.buildClassifier(train);
Instances test = inst.testCV(numFolds, fold);
m_Log.statusMessage("Evaluating model for fold "
+ (fold + 1) + "...");
for (int jj=0;jj<test.numInstances();jj++) {
processClassifierPrediction(test.instance(jj), classifier,
eval, predictions,
predInstances, plotShape,
plotSize);
}
}
if (inst.attribute(classIndex).isNominal()) {
outBuff.append("=== Stratified cross-validation ===\n");
} else {
outBuff.append("=== Cross-validation ===\n");
}
break;
case 2: // Percent split
m_Log.statusMessage("Randomizing instances...");
try {
rnd = Integer.parseInt(m_RandomSeedText.getText().trim());
} catch (Exception ex) {
m_Log.logMessage("Trouble parsing random seed value");
rnd = 1;
}
inst.randomize(new Random(rnd));
int trainSize = inst.numInstances() * percent / 100;
int testSize = inst.numInstances() - trainSize;
Instances train = new Instances(inst, 0, trainSize);
Instances test = new Instances(inst, trainSize, testSize);
m_Log.statusMessage("Building model on training split...");
classifier.buildClassifier(train);
eval = new Evaluation(train, costMatrix);
m_Log.statusMessage("Evaluating on test split...");
for (int jj=0;jj<test.numInstances();jj++) {
processClassifierPrediction(test.instance(jj), classifier,
eval, predictions,
predInstances, plotShape,
plotSize);
if ((jj % 100) == 0) {
m_Log.statusMessage("Evaluating on test split. Processed "
+jj+" instances...");
}
}
outBuff.append("=== Evaluation on test split ===\n");
break;
case 4: // Test on user split
m_Log.statusMessage("Evaluating on test data...");
eval = new Evaluation(inst, costMatrix);
if (outputPredictionsText) {
outBuff.append("=== Predictions on test set ===\n\n");
outBuff.append(" inst#, actual, predicted, error");
if (inst.classAttribute().isNominal()) {
outBuff.append(", probability distribution");
}
outBuff.append("\n");
}
for (int jj=0;jj<userTest.numInstances();jj++) {
processClassifierPrediction(userTest.instance(jj), classifier,
eval, predictions,
predInstances, plotShape,
plotSize);
if (outputPredictionsText) {
outBuff.append(predictionText(classifier, userTest.instance(jj), jj+1));
}
if ((jj % 100) == 0) {
m_Log.statusMessage("Evaluating on test data. Processed "
+jj+" instances...");
}
}
if (outputPredictionsText) {
outBuff.append("\n");
}
outBuff.append("=== Evaluation on test set ===\n");
break;
default:
throw new Exception("Test mode not implemented");
}
if (outputSummary) {
outBuff.append(eval.toSummaryString(outputEntropy) + "\n");
}
if (inst.attribute(classIndex).isNominal()) {
if (outputPerClass) {
outBuff.append(eval.toClassDetailsString() + "\n");
}
if (outputConfusion) {
outBuff.append(eval.toMatrixString() + "\n");
}
}
m_History.updateResult(name);
m_Log.logMessage("Finished " + cname);
m_Log.statusMessage("OK");
} catch (Exception ex) {
ex.printStackTrace();
m_Log.logMessage(ex.getMessage());
JOptionPane.showMessageDialog(ClassifierPanel.this,
"Problem evaluating classifier:\n"
+ ex.getMessage(),
"Evaluate classifier",
JOptionPane.ERROR_MESSAGE);
m_Log.statusMessage("Problem evaluating classifier");
} finally {
try {
if (predInstances != null && predInstances.numInstances() > 0) {
if (predInstances.attribute(predInstances.classIndex())
.isNumeric()) {
postProcessPlotInfo(plotSize);
}
m_CurrentVis = new VisualizePanel();
m_CurrentVis.setName(name+" ("+inst.relationName()+")");
m_CurrentVis.setLog(m_Log);
PlotData2D tempd = new PlotData2D(predInstances);
tempd.setShapeSize(plotSize);
tempd.setShapeType(plotShape);
tempd.setPlotName(name+" ("+inst.relationName()+")");
tempd.addInstanceNumberAttribute();
m_CurrentVis.addPlot(tempd);
m_CurrentVis.setColourIndex(predInstances.classIndex()+1);
m_CurrentVis.setXIndex(m_visXIndex);
m_CurrentVis.setYIndex(m_visYIndex);
m_CurrentVis.addActionListener(new ActionListener() {
public void actionPerformed(ActionEvent e) {
if (m_CurrentVis.getInstances().
relationName().
compareTo(m_Instances.relationName()) == 0) {
setXY_VisualizeIndexes(m_CurrentVis.getXIndex(),
m_CurrentVis.getYIndex());
}
}
});
if (saveVis) {
FastVector vv = new FastVector();
if (outputModel) {
vv.addElement(fullClassifier);
Instances trainHeader = new Instances(m_Instances, 0);
trainHeader.setClassIndex(classIndex);
vv.addElement(trainHeader);
}
vv.addElement(m_CurrentVis);
if (grph != null) {
vv.addElement(grph);
}
if (predictions != null) {
vv.addElement(predictions);
vv.addElement(inst.classAttribute());
}
m_History.addObject(name, vv);
} else if (outputModel) {
FastVector vv = new FastVector();
vv.addElement(fullClassifier);
Instances trainHeader = new Instances(m_Instances, 0);
trainHeader.setClassIndex(classIndex);
vv.addElement(trainHeader);
m_History.addObject(name, vv);
}
}
} catch (Exception ex) {
ex.printStackTrace();
}
if (isInterrupted()) {
m_Log.logMessage("Interrupted " + cname);
m_Log.statusMessage("Interrupted");
}
synchronized (this) {
m_StartBut.setEnabled(true);
m_StopBut.setEnabled(false);
m_RunThread = null;
}
if (m_Log instanceof TaskLogger) {
((TaskLogger)m_Log).taskFinished();
}
}
}
};
m_RunThread.setPriority(Thread.MIN_PRIORITY);
m_RunThread.start();
}
}
protected String predictionText(Classifier classifier, Instance inst, int instNum) throws Exception {
//> inst# actual predicted error probability distribution
StringBuffer text = new StringBuffer();
// inst #
text.append(Utils.padLeft("" + instNum, 6) + " ");
if (inst.classAttribute().isNominal()) {
// actual
if (inst.classIsMissing()) text.append(Utils.padLeft("?", 10) + " ");
else text.append(Utils.padLeft("" + ((int) inst.classValue()+1) + ":"
+ inst.stringValue(inst.classAttribute()), 10) + " ");
// predicted
double[] probdist = null;
double pred;
if (inst.classAttribute().isNominal()) {
probdist = classifier.distributionForInstance(inst);
pred = (double) Utils.maxIndex(probdist);
if (probdist[(int) pred] <= 0.0) pred = Instance.missingValue();
} else {
pred = classifier.classifyInstance(inst);
}
text.append(Utils.padLeft((Instance.isMissingValue(pred) ? "?" :
(((int) pred+1) + ":"
+ inst.classAttribute().value((int) pred))), 10) + " ");
// error
if (pred == inst.classValue()) text.append(Utils.padLeft(" ", 6) + " ");
else text.append(Utils.padLeft("+", 6) + " ");
// prob dist
if (inst.classAttribute().type() == Attribute.NOMINAL) {
for (int i=0; i<probdist.length; i++) {
if (i == (int) pred) text.append(" *");
else text.append(" ");
text.append(Utils.doubleToString(probdist[i], 5, 3));
}
}
} else {
// actual
if (inst.classIsMissing()) text.append(Utils.padLeft("?", 10) + " ");
else text.append(Utils.doubleToString(inst.classValue(), 10, 3) + " ");
// predicted
double pred = classifier.classifyInstance(inst);
if (Instance.isMissingValue(pred)) text.append(Utils.padLeft("?", 10) + " ");
else text.append(Utils.doubleToString(pred, 10, 3) + " ");
// err
if (!inst.classIsMissing() && !Instance.isMissingValue(pred))
text.append(Utils.doubleToString(pred - inst.classValue(), 10, 3));
}
text.append("\n");
return text.toString();
}
/**
* Handles constructing a popup menu with visualization options.
* @param name the name of the result history list entry clicked on by
* the user
* @param x the x coordinate for popping up the menu
* @param y the y coordinate for popping up the menu
*/
protected void visualize(String name, int x, int y) {
final String selectedName = name;
JPopupMenu resultListMenu = new JPopupMenu();
JMenuItem visMainBuffer = new JMenuItem("View in main window");
if (selectedName != null) {
visMainBuffer.addActionListener(new ActionListener() {
public void actionPerformed(ActionEvent e) {
m_History.setSingle(selectedName);
}
});
} else {
visMainBuffer.setEnabled(false);
}
resultListMenu.add(visMainBuffer);
JMenuItem visSepBuffer = new JMenuItem("View in separate window");
if (selectedName != null) {
visSepBuffer.addActionListener(new ActionListener() {
public void actionPerformed(ActionEvent e) {
m_History.openFrame(selectedName);
}
});
} else {
visSepBuffer.setEnabled(false);
}
resultListMenu.add(visSepBuffer);
JMenuItem saveOutput = new JMenuItem("Save result buffer");
if (selectedName != null) {
saveOutput.addActionListener(new ActionListener() {
public void actionPerformed(ActionEvent e) {
saveBuffer(selectedName);
}
});
} else {
saveOutput.setEnabled(false);
}
resultListMenu.add(saveOutput);
resultListMenu.addSeparator();
JMenuItem loadModel = new JMenuItem("Load model");
loadModel.addActionListener(new ActionListener() {
public void actionPerformed(ActionEvent e) {
loadClassifier();
}
});
resultListMenu.add(loadModel);
FastVector o = null;
if (selectedName != null) {
o = (FastVector)m_History.getNamedObject(selectedName);
}
VisualizePanel temp_vp = null;
String temp_grph = null;
FastVector temp_preds = null;
Attribute temp_classAtt = null;
Classifier temp_classifier = null;
Instances temp_trainHeader = null;
if (o != null) {
for (int i = 0; i < o.size(); i++) {
Object temp = o.elementAt(i);
if (temp instanceof Classifier) {
temp_classifier = (Classifier)temp;
} else if (temp instanceof Instances) { // training header
temp_trainHeader = (Instances)temp;
} else if (temp instanceof VisualizePanel) { // normal errors
temp_vp = (VisualizePanel)temp;
} else if (temp instanceof String) { // graphable output
temp_grph = (String)temp;
} else if (temp instanceof FastVector) { // predictions
temp_preds = (FastVector)temp;
} else if (temp instanceof Attribute) { // class attribute
temp_classAtt = (Attribute)temp;
}
}
}
final VisualizePanel vp = temp_vp;
final String grph = temp_grph;
final FastVector preds = temp_preds;
final Attribute classAtt = temp_classAtt;
final Classifier classifier = temp_classifier;
final Instances trainHeader = temp_trainHeader;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -