📄 bnctester.java
字号:
System.out.println("Number of nodes = " + graphNodes.size()); System.out.print("Node names: "); for (int i = 0; i < graphNodes.size(); ++i) { System.out.print(((InferenceGraphNode) graphNodes.get(i)).get_name() + " "); } System.out.println(""); } } /* * */ protected void loadDataset() throws Exception { if (namesFileName != null) { // Load data in C4.5 format dataset = new Dataset(); // Load names file if (debugMode) { System.out.println("\nLoading names file: " + namesFileName); } dataset.names = NamesReader.open(namesFileName); if (dataset.names == null || dataset.names.length == 0) { throw new Exception("'.names' is empty"); } int nbAttributes = dataset.names.length - 1; int nbClasses = dataset.names[nbAttributes].getStates().length; if (debugMode) {// System.out.println("Number of classes = "+nbClasses); System.out.println("Number of attributes = " + nbAttributes + " + one class."); } // Verify that least one attribute is present and that it is discrete int nbIgnored = 0; for (int i = 0; i < nbAttributes; ++i) { AttributeType type = dataset.names[i].getType(); if (type == AttributeType.CONTINUOUS) { throw new Exception("Attribute '" + dataset.names[i].getName() + "' is continuous."); } else if (type == AttributeType.IGNORE) { ++nbIgnored; } } if (debugMode) { System.out.println("Number of ignored attributes = " + nbIgnored + "."); } if (nbIgnored == nbAttributes) { throw new Exception("All attributes are specified to be ignored by the '.names' file."); } // Read test data file if (debugMode) { System.out.println("\nLoading test data file: " + dataFileName); } DatasetReader datasetReader = new DatasetReader(); datasetReader.setDiscardIncompleteCases(true); dataset.cases = datasetReader.open(dataFileName, dataset.names); if (debugMode) { System.out.println("Number of cases = " + dataset.cases.size()); } } else { // Load data in from comma delimited file with a header DatasetReader datasetReader = new DatasetReader(); datasetReader.setDiscardIncompleteCases(true); dataset = datasetReader.open(dataFileName, className); if (dataset == null) { throw new Exception("Data set is empty."); } } } /* * */ protected void verifyDatasetVsNetwork() throws Exception { // Create reverse lookup for nodes and find the class node classNode = null; HashMap nodesHashMap = new HashMap(); Vector graphNodes = graph.get_nodes(); for (int i = 0; i < graphNodes.size(); ++i) { InferenceGraphNode node = (InferenceGraphNode) graphNodes.get(i); String nodeName = node.get_name(); if (classNode == null && className.equals(nodeName)) { classNode = node; } else { nodesHashMap.put(nodeName, node); } } if (classNode == null) { throw new Exception("The network does not have node with class name '" + className + "'."); } // Check names vs. network int nbAttributes = dataset.names.length - 1; attribNodes = new InferenceGraphNode[nbAttributes]; for (int i = 0; i < nbAttributes; i++) { if (dataset.names[i].getType() != AttributeType.IGNORE) { InferenceGraphNode n = (InferenceGraphNode) nodesHashMap.get( dataset.names[i].getName()); if (n == null) { throw new Exception("The network does not have node with attribute name '" + dataset.names[i].getName() + "'"); } attribNodes[i] = n; } else { attribNodes[i] = null; } } } /** * @param net Description of Parameter * @exception Exception Description of Exception */ protected void testClassifier(BayesNet net) throws Exception { int nbVars = dataset.names.length; int nbAttrib = nbVars - 1; int[] caseIndexes = new int[nbVars]; jbnc.graphs.BNCInference inference = new jbnc.graphs.BNCInference(net); // Test each case nbPass = 0; nbFail = 0; for (int caseNb = 0; caseNb < dataset.cases.size(); ++caseNb) { // Set node values int[] thisCase = (int[]) dataset.cases.get(caseNb); for (int j = 0; j < nbAttrib; ++j) { caseIndexes[j] = thisCase[j]; } // Do inference double[] classProb = inference.getCondClassProb(caseIndexes); // Find the most probable classification index int maxIndex = -1; double maxPr = -1; for (int j = 0; j < classProb.length; ++j) { if (classProb[j] > maxPr) { maxPr = classProb[j]; maxIndex = j; } } // Record result int trueClassIndex = thisCase[nbAttrib]; if (maxIndex == trueClassIndex) { ++nbPass; } else { ++nbFail; } /* * System.out.print(caseNb+": ( "); * for(int i=0; i<classProb.length; ++i) * System.out.print(classProb[i]+" "); * System.out.println(") " + (maxIndex == trueClassIndex) ); */ } } /* * */ protected void testClassifier_old() throws Exception { int nbAttributes = dataset.names.length - 1; boolean do_produce_clusters = false; QBInference inference = new QBInference(graph.get_bayes_net(), do_produce_clusters); // Test each case nbPass = 0; nbFail = 0; for (int caseNb = 0; caseNb < dataset.cases.size(); ++caseNb) { // Find index for current cases class // Set node values Vector thisCase = (Vector) dataset.cases.get(caseNb); for (int j = 0; j < nbAttributes; ++j) { InferenceGraphNode thisNode = attribNodes[j]; if (thisNode != null) { // set value int observedIndex = ((Integer) thisCase.get(j)).intValue(); String observedValue = dataset.names[j].getState(observedIndex); thisNode.set_observation_value(observedValue); // verify that the value was actually set int v = thisNode.get_observed_value(); if (v >= 0) { String[] setValues = thisNode.get_values(); if (!observedValue.equals(setValues[v])) { throw new Exception("Mishmash in setting observation value: " + "case #" + (caseNb + 1) + ", attribute #" + (j + 1) + "\n Got '" + setValues[v] + "' instead of requested " + observedValue + "."); } } else { throw new Exception("Failed to set observation value '" + observedValue + "', case #" + (caseNb + 1) + ", attribute #" + (j + 1)); } } } // Do inference inference.inference(className); ProbabilityFunction prFunction = inference.get_result(); // Find the most probable classification index int maxIndex = -1; double maxPr = -1; double[] pr = prFunction.get_values(); for (int j = 0; j < pr.length; ++j) { if (pr[j] > maxPr) { maxPr = pr[j]; maxIndex = j; } } // Get name of most probable class String classValue = prFunction.get_variable(0).get_value(maxIndex); // Record result int trueClassIndex = ((Integer) thisCase.get(nbAttributes)).intValue(); String trueClassValue = dataset.names[nbAttributes].getState(trueClassIndex); boolean pass = trueClassValue.equals(classValue); if (pass) { ++nbPass; } else { ++nbFail; } if (debugMode) { System.out.print(caseNb + ": ( "); for (int i = 0; i < pr.length; ++i) { System.out.print(pr[i] + " "); } System.out.println(") " + pass); } } } /* * */ protected void printClassificationReport() { // Print report System.out.println(""); System.out.println("Pass = " + nbPass); System.out.println("Fail = " + nbFail); int total = nbPass + nbFail; System.out.println("Total = " + total); if (total != 0) { NumberFormat f = new DecimalFormat("0.###%"); double err = (double) (nbFail) / total; System.out.println("Error = " + f.format(err)); double acc = (double) (nbPass) / total; System.out.println("Accuracy = " + f.format(acc)); } } /** */ protected void clear() { useTimer = false; debugMode = false; dataFileName = null; namesFileName = null; netFileName = null; className = null; dataset = null; graph = null; attribNodes = null; classNode = null; nbPass = 0; nbFail = 0; } public class Result { public int nbPass; public int nbFail; /** * Constructor for the Result object * * @param pass Description of Parameter * @param fail Description of Parameter */ public Result(int pass, int fail) { nbPass = pass; nbFail = fail; } public double getError() throws Exception { int tot = nbPass + nbFail; if (tot == 0) { throw new Exception("pass + fail == zero (division by zero)"); } return nbFail / (double) tot; } public double getAccuracy() throws Exception { int tot = nbPass + nbFail; if (tot == 0) { throw new Exception("pass + fail == zero (division by zero)"); } return nbPass / (double) tot; } /* * Compute the estimated standard deviation according to * the binomial model, which assumes every test instance is * a Bernoulli trial, thus std-dev=sqrt(error*(1-error)/(n-1)) */ public double getTheoreticalStdDev() throws Exception { int tot = nbPass + nbFail; if (tot < 2) { throw new Exception("pass + fail < 2"); } double err = nbFail / tot; return Math.sqrt((double) nbFail * (double) nbPass / ((double) tot * (double) tot * (double) (tot - 1))); } /** * Print test results to the standard output in the short form (single * line). * * @exception Exception Description of Exception */ public void reportShort() throws Exception { // Print report int total = nbPass + nbFail; if (total != 0) { NumberFormat f = new DecimalFormat("0.###%"); System.out.println("Error = " + f.format(getError()) + " +- " + f.format(getTheoreticalStdDev()) + " (" + f.format(getAccuracy()) + ") [" + nbPass + "/" + nbFail + "/" + total + "]"); } else { System.out.println("Error = ? (?) [0/0/0]"); } } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -