⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 bnctester.java

📁 bayes network classifier toolbox 贝叶斯网络分类工具箱
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
            System.out.println("Number of nodes = " + graphNodes.size());            System.out.print("Node names: ");            for (int i = 0; i < graphNodes.size(); ++i) {                System.out.print(((InferenceGraphNode) graphNodes.get(i)).get_name() + " ");            }            System.out.println("");        }    }    /*     *     */    protected void loadDataset() throws Exception {        if (namesFileName != null) {            // Load data in C4.5 format            dataset = new Dataset();            // Load names file            if (debugMode) {                System.out.println("\nLoading names file: " + namesFileName);            }            dataset.names = NamesReader.open(namesFileName);            if (dataset.names == null || dataset.names.length == 0) {                throw new Exception("'.names' is empty");            }            int nbAttributes = dataset.names.length - 1;            int nbClasses = dataset.names[nbAttributes].getStates().length;            if (debugMode) {//        System.out.println("Number of classes    = "+nbClasses);                System.out.println("Number of attributes = " + nbAttributes + " + one class.");            }            // Verify that least one attribute is present and that it is discrete            int nbIgnored = 0;            for (int i = 0; i < nbAttributes; ++i) {                AttributeType type = dataset.names[i].getType();                if (type == AttributeType.CONTINUOUS) {                    throw new Exception("Attribute '"                            + dataset.names[i].getName()                            + "' is continuous.");                } else if (type == AttributeType.IGNORE) {                    ++nbIgnored;                }            }            if (debugMode) {                System.out.println("Number of ignored attributes = " + nbIgnored + ".");            }            if (nbIgnored == nbAttributes) {                throw new Exception("All attributes are specified to be ignored by the '.names' file.");            }            // Read test data file            if (debugMode) {                System.out.println("\nLoading test data file: " + dataFileName);            }            DatasetReader datasetReader = new DatasetReader();            datasetReader.setDiscardIncompleteCases(true);            dataset.cases = datasetReader.open(dataFileName, dataset.names);            if (debugMode) {                System.out.println("Number of cases = " + dataset.cases.size());            }        } else {            // Load data in from comma delimited file with a header            DatasetReader datasetReader = new DatasetReader();            datasetReader.setDiscardIncompleteCases(true);            dataset = datasetReader.open(dataFileName, className);            if (dataset == null) {                throw new Exception("Data set is empty.");            }        }    }    /*     *     */    protected void verifyDatasetVsNetwork() throws Exception {        // Create reverse lookup for nodes and find the class node        classNode = null;        HashMap nodesHashMap = new HashMap();        Vector graphNodes = graph.get_nodes();        for (int i = 0; i < graphNodes.size(); ++i) {            InferenceGraphNode node = (InferenceGraphNode) graphNodes.get(i);            String nodeName = node.get_name();            if (classNode == null && className.equals(nodeName)) {                classNode = node;            } else {                nodesHashMap.put(nodeName, node);            }        }        if (classNode == null) {            throw new Exception("The network does not have node with class name '"                    + className + "'.");        }        // Check names vs. network        int nbAttributes = dataset.names.length - 1;        attribNodes = new InferenceGraphNode[nbAttributes];        for (int i = 0; i < nbAttributes; i++) {            if (dataset.names[i].getType() != AttributeType.IGNORE) {                InferenceGraphNode n = (InferenceGraphNode) nodesHashMap.get(                        dataset.names[i].getName());                if (n == null) {                    throw new Exception("The network does not have node with attribute name '"                            + dataset.names[i].getName() + "'");                }                attribNodes[i] = n;            } else {                attribNodes[i] = null;            }        }    }    /**     * @param  net            Description of Parameter     * @exception  Exception  Description of Exception     */    protected void testClassifier(BayesNet net) throws Exception {        int nbVars = dataset.names.length;        int nbAttrib = nbVars - 1;        int[] caseIndexes = new int[nbVars];        jbnc.graphs.BNCInference inference = new jbnc.graphs.BNCInference(net);        // Test each case        nbPass = 0;        nbFail = 0;        for (int caseNb = 0; caseNb < dataset.cases.size(); ++caseNb) {            // Set node values            int[] thisCase = (int[]) dataset.cases.get(caseNb);            for (int j = 0; j < nbAttrib; ++j) {                caseIndexes[j] = thisCase[j];            }            // Do inference            double[] classProb = inference.getCondClassProb(caseIndexes);            // Find the most probable classification index            int maxIndex = -1;            double maxPr = -1;            for (int j = 0; j < classProb.length; ++j) {                if (classProb[j] > maxPr) {                    maxPr = classProb[j];                    maxIndex = j;                }            }            // Record result            int trueClassIndex = thisCase[nbAttrib];            if (maxIndex == trueClassIndex) {                ++nbPass;            } else {                ++nbFail;            }            /*             *  System.out.print(caseNb+": ( ");             *  for(int i=0; i<classProb.length; ++i)             *  System.out.print(classProb[i]+" ");             *  System.out.println(") " + (maxIndex == trueClassIndex) );             */        }    }    /*     *     */    protected void testClassifier_old() throws Exception {        int nbAttributes = dataset.names.length - 1;        boolean do_produce_clusters = false;        QBInference inference = new QBInference(graph.get_bayes_net(), do_produce_clusters);        // Test each case        nbPass = 0;        nbFail = 0;        for (int caseNb = 0; caseNb < dataset.cases.size(); ++caseNb) {            // Find index for current cases class            // Set node values            Vector thisCase = (Vector) dataset.cases.get(caseNb);            for (int j = 0; j < nbAttributes; ++j) {                InferenceGraphNode thisNode = attribNodes[j];                if (thisNode != null) {                    // set value                    int observedIndex = ((Integer) thisCase.get(j)).intValue();                    String observedValue = dataset.names[j].getState(observedIndex);                    thisNode.set_observation_value(observedValue);                    // verify that the value was actually set                    int v = thisNode.get_observed_value();                    if (v >= 0) {                        String[] setValues = thisNode.get_values();                        if (!observedValue.equals(setValues[v])) {                            throw new Exception("Mishmash in setting observation value: "                                    + "case #" + (caseNb + 1)                                    + ", attribute #" + (j + 1)                                    + "\n       Got '" + setValues[v]                                    + "' instead of requested " + observedValue + ".");                        }                    } else {                        throw new Exception("Failed to set observation value '"                                + observedValue                                + "', case #" + (caseNb + 1)                                + ", attribute #" + (j + 1));                    }                }            }            // Do inference            inference.inference(className);            ProbabilityFunction prFunction = inference.get_result();            // Find the most probable classification index            int maxIndex = -1;            double maxPr = -1;            double[] pr = prFunction.get_values();            for (int j = 0; j < pr.length; ++j) {                if (pr[j] > maxPr) {                    maxPr = pr[j];                    maxIndex = j;                }            }            // Get name of most probable class            String classValue = prFunction.get_variable(0).get_value(maxIndex);            // Record result            int trueClassIndex = ((Integer) thisCase.get(nbAttributes)).intValue();            String trueClassValue = dataset.names[nbAttributes].getState(trueClassIndex);            boolean pass = trueClassValue.equals(classValue);            if (pass) {                ++nbPass;            } else {                ++nbFail;            }            if (debugMode) {                System.out.print(caseNb + ": ( ");                for (int i = 0; i < pr.length; ++i) {                    System.out.print(pr[i] + " ");                }                System.out.println(") " + pass);            }        }    }    /*     *     */    protected void printClassificationReport() {        // Print report        System.out.println("");        System.out.println("Pass  = " + nbPass);        System.out.println("Fail  = " + nbFail);        int total = nbPass + nbFail;        System.out.println("Total = " + total);        if (total != 0) {            NumberFormat f = new DecimalFormat("0.###%");            double err = (double) (nbFail) / total;            System.out.println("Error    = " + f.format(err));            double acc = (double) (nbPass) / total;            System.out.println("Accuracy = " + f.format(acc));        }    }    /** */    protected void clear() {        useTimer = false;        debugMode = false;        dataFileName = null;        namesFileName = null;        netFileName = null;        className = null;        dataset = null;        graph = null;        attribNodes = null;        classNode = null;        nbPass = 0;        nbFail = 0;    }    public class Result {        public int nbPass;        public int nbFail;        /**         *  Constructor for the Result object         *         * @param  pass  Description of Parameter         * @param  fail  Description of Parameter         */        public Result(int pass, int fail) {            nbPass = pass;            nbFail = fail;        }        public double getError() throws Exception {            int tot = nbPass + nbFail;            if (tot == 0) {                throw new Exception("pass + fail == zero (division by zero)");            }            return nbFail / (double) tot;        }        public double getAccuracy() throws Exception {            int tot = nbPass + nbFail;            if (tot == 0) {                throw new Exception("pass + fail == zero (division by zero)");            }            return nbPass / (double) tot;        }        /*         *  Compute the estimated standard deviation according to         *  the binomial model, which assumes every test instance is         *  a Bernoulli trial, thus std-dev=sqrt(error*(1-error)/(n-1))         */        public double getTheoreticalStdDev() throws Exception {            int tot = nbPass + nbFail;            if (tot < 2) {                throw new Exception("pass + fail < 2");            }            double err = nbFail / tot;            return Math.sqrt((double) nbFail * (double) nbPass                    / ((double) tot * (double) tot * (double) (tot - 1)));        }        /**         *  Print test results to the standard output in the short form (single         *  line).         *         * @exception  Exception  Description of Exception         */        public void reportShort() throws Exception {            // Print report            int total = nbPass + nbFail;            if (total != 0) {                NumberFormat f = new DecimalFormat("0.###%");                System.out.println("Error = " + f.format(getError())                        + " +- " + f.format(getTheoreticalStdDev())                        + " (" + f.format(getAccuracy()) + ")  ["                        + nbPass + "/" + nbFail + "/"                        + total + "]");            } else {                System.out.println("Error = ? (?)  [0/0/0]");            }        }    }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -