⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 bntools.java

📁 bayes network classifier toolbox 贝叶斯网络分类工具箱
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
                int xSize = varSize[i_x];                if (i_x == nbAttrib) {                    throw new Exception("Class node cannot have parents.");                }                if (varIndx[1] == nbAttrib) {                    int i_y = varIndx[2];                    int i_z = nbAttrib;                    int ySize = varSize[i_y];                    int zSize = varSize[i_z];                    int[][][] freqXYZ = fc.freqXYZ[i_x][i_y];                    int[] Nij = new int[ySize * zSize];                    for (int k_x = 0; k_x < xSize; ++k_x) {                        int j = 0;                        for (int k_z = 0; k_z < zSize; ++k_z) {                            for (int k_y = 0; k_y < ySize; ++k_y) {                                Nij[j++] += freqXYZ[k_x][k_y][k_z];                            }                        }                    }                    int index = 0;                    double alpha = alpha_ijk * xSize;                    for (int k_x = 0; k_x < xSize; ++k_x) {                        int j = 0;                        for (int k_z = 0; k_z < zSize; ++k_z) {                            for (int k_y = 0; k_y < ySize; ++k_y) {                                vals[index] = (freqXYZ[k_x][k_y][k_z] + alpha_ijk)                                        / (Nij[j++] + alpha);                                ++index;                            }                        }                    }                } else if (varIndx[2] == nbAttrib) {                    int i_y = varIndx[1];                    int i_z = nbAttrib;                    int ySize = varSize[i_y];                    int zSize = varSize[i_z];                    int[][][] freqXYZ = fc.freqXYZ[i_x][i_y];                    int[] Nij = new int[ySize * zSize];                    for (int k_x = 0; k_x < xSize; ++k_x) {                        int j = 0;                        for (int k_y = 0; k_y < ySize; ++k_y) {                            for (int k_z = 0; k_z < zSize; ++k_z) {                                Nij[j++] += freqXYZ[k_x][k_y][k_z];                            }                        }                    }                    int index = 0;                    double alpha = alpha_ijk * xSize;                    for (int k_x = 0; k_x < xSize; ++k_x) {                        int j = 0;                        for (int k_y = 0; k_y < ySize; ++k_y) {                            for (int k_z = 0; k_z < zSize; ++k_z) {                                vals[index] = (freqXYZ[k_x][k_y][k_z] + alphaK) / (Nij[j++] + alpha);                                ++index;                            }                        }                    }                } else {                    throw new Exception("If variable has two parents one of them " +                            "must be class variable.");                }            } else {                throw new Exception("Variable cannot have more then two parents.");            }        }    }    /**     * Learns parameters for the current network structure. Existing network     * parameters are replaced with the new ones. This method can use "uniform"     * Dirihlet priors.     *     * @param net         Bayesian network.     * @param useDirihlet Indicates whether Dirihlet priors should be used for     *                    network parameters.     * @param alphaK      alpha<sub>k</sub> parameter for Dirihlet priors. All     *                    alpha<sub>k</sub> are assumed to be the same and     *                    greater than zero.     * @param data        Description of Parameter     * @throws Exception     */    public static void learnParameters(BayesNet net,                                       DatasetInt data,                                       boolean useDirihlet,                                       double alphaK) throws Exception {        if (useDirihlet && (alphaK <= 0)) {            throw new Exception("When using Dirihlet priors alphaK must be greater than zero.");        }        AttributeSpecs[] names = data.names;        int nbAttrib = names.length - 1;        int nbVars = nbAttrib + 1;        int nbCases = data.cases.size();        int[] varSize = new int[nbVars];        for (int i = 0; i < nbVars; ++i) {            varSize[i] = names[i].getStates().length;        }        // Sanity check        if (nbVars != net.number_variables()) {            throw new Exception("Number of variables in the data set and in the network do no agree ("                    + nbVars + "!=" + net.number_variables() + ").");        }        // Iterate through the list of probability functions        // and calculate new values using frequencies in the training dataset.        ProbabilityVariable[] vars = net.get_probability_variables();        ProbabilityFunction[] funcs = net.get_probability_functions();        for (int funcNb = 0; funcNb < funcs.length; ++funcNb) {            if (funcs[funcNb] == null) {                continue;            }            int[] varIndx = funcs[funcNb].get_indexes();            double[] vals = funcs[funcNb].get_values();            int[] vCount = new int[vals.length];            int[] varCycle = new int[varIndx.length];            varCycle[varCycle.length - 1] = 1;            for (int i = varCycle.length - 2; i >= 0; --i) {                varCycle[i] = varCycle[i + 1] * varSize[varIndx[i + 1]];            }            // Calculate frequencies            int[] count = new int[varCycle[0]];            for (int caseNb = 0; caseNb < nbCases; ++caseNb) {                int[] thisCase = (int[]) data.cases.get(caseNb);                if (thisCase == null) {                    continue;                }                int index = 0;                for (int varNb = 0; varNb < varIndx.length; ++varNb) {                    index += varCycle[varNb] * thisCase[varIndx[varNb]];                }                ++vCount[index];                ++count[index % varCycle[0]];            }            // Assign probability            if (useDirihlet) {                double alpha = alphaK * vals.length / varCycle[0];                for (int i = 0; i < vals.length; ++i) {                    vals[i] = (vCount[i] + alphaK)                            / (count[i % varCycle[0]] + alpha);                }            } else {                double beta_ij = beta_ijk * vals.length / varCycle[0];                for (int i = 0; i < vals.length; ++i) {                    int c = count[i % varCycle[0]];                    vals[i] = (vCount[i] + beta_ijk) / (c + beta_ij);                }            }        }    }    /**     * Description of the Method     *     * @param args Description of Parameter     */    public static void main(String args[]) {        try {            System.out.println("Double.MIN_VALUE = " + Double.MIN_VALUE);            // Test gamma            System.out.println("Testing function gammaLn.");            double step = 0.1;            for (int x = 1; x <= 10; ++x) {                double g = gammaLn(x);                double gg = Math.exp(g);                System.out.println("gammaLn(" + x + ") = " + g);            }        } catch (Exception e) {        }    }    /**     * Get node names from a graph in an order they appear in the dataset.     *     * @param dataset Description of Parameter     * @param graph   Description of Parameter     * @return The Nodes value     * @throws Exception Description of Exception     */    protected static InferenceGraphNode[] getNodes(Dataset dataset,                                                   InferenceGraph graph) throws Exception {        // Create reverse lookup for nodes and find the class node        HashMap nodesHashMap = new HashMap();        Vector graphNodes = graph.get_nodes();        Iterator ni = graphNodes.iterator();        while (ni.hasNext()) {            InferenceGraphNode node = (InferenceGraphNode) ni.next();            nodesHashMap.put(node.get_name(), node);        }        // Check names vs. network        int nbNodes = dataset.names.length;        InferenceGraphNode[] nodes = new InferenceGraphNode[nbNodes];        for (int i = 0; i < nbNodes; ++i) {            if (dataset.names[i].getType() != AttributeType.IGNORE) {                InferenceGraphNode n = (InferenceGraphNode) nodesHashMap.get(dataset.names[i].getName());                if (n == null) {                    throw new Exception("The network does not have node with attribute name '"                            + dataset.names[i].getName() + "'");                }                nodes[i] = n;            } else {                nodes[i] = null;            }        }        return nodes;    }    /**     * Learns parameters for the current network structure. Existing network     * parameters are replaced with the new ones. This method can use "uniform"     * Dirihlet priors.     *     * @param net         Bayesian network.     * @param useDirihlet Indicates whether Dirihlet priors should be used for     *                    network parameters.     * @param alphaK      alpha<sub>k</sub> parameter for Dirihlet priors. All     *                    alpha<sub>k</sub> are assumed to be the same and     *                    greater than zero.     * @param data        Description of Parameter     * @throws Exception     */    protected static void learnParameters_old(BayesNet net,                                              Dataset data,                                              boolean useDirihlet,                                              double alphaK) throws Exception {        // Verify that all attributes are discrete        AttributeSpecs[] names = data.names;        for (int n = 0; n < names.length; ++n) {            if (names[n].getType() != AttributeType.DISCRETE) {                throw new Exception("All attributes in the data set have to be discrete.");            }        }        if (useDirihlet && (alphaK <= 0)) {            throw new Exception("When using Dirihlet priors alphaK must be greater than zero.");        }        int nbAttrib = names.length - 1;        int nbVars = nbAttrib + 1;        int nbCases = data.cases.size();        int[] varSize = new int[nbVars];        for (int i = 0; i < nbVars; ++i) {            varSize[i] = names[i].getStates().length;        }        // Sanity check        if (nbVars != net.number_variables()) {            throw new Exception("Number of variables in the data set and in the network do no agree ("                    + nbVars + "!=" + net.number_variables() + ").");        }        // Iterate through the list of probability functions        // and calculate new values using frequencies in the training dataset.        ProbabilityVariable[] vars = net.get_probability_variables();        ProbabilityFunction[] funcs = net.get_probability_functions();        for (int funcNb = 0; funcNb < funcs.length; ++funcNb) {            if (funcs[funcNb] == null) {                continue;            }            int[] varIndx = funcs[funcNb].get_indexes();            double[] vals = funcs[funcNb].get_values();            int[] vCount = new int[vals.length];            int[] varCycle = new int[varIndx.length];            varCycle[varCycle.length - 1] = 1;            for (int i = varCycle.length - 2; i >= 0; --i) {                varCycle[i] = varCycle[i + 1] * varSize[varIndx[i + 1]];            }            // Calculate frequencies            int[] count = new int[varCycle[0]];            for (int caseNb = 0; caseNb < nbCases; ++caseNb) {                Vector thisCase = (Vector) data.cases.get(caseNb);                if (thisCase == null) {                    continue;                }                int index = 0;                for (int varNb = 0; varNb < varIndx.length; ++varNb) {                    index += varCycle[varNb]                            * ((Integer) thisCase.get(varIndx[varNb])).intValue();                }                ++vCount[index];                ++count[index % varCycle[0]];            }            // Assign probability            if (useDirihlet) {                double alpha = alphaK * vals.length / varCycle[0];                for (int i = 0; i < vals.length; ++i) {                    vals[i] = (vCount[i] + alphaK)                            / (count[i % varCycle[0]] + alpha);                }            } else {                double beta_ij = beta_ijk * vals.length / varCycle[0];                for (int i = 0; i < vals.length; ++i) {                    int c = count[i % varCycle[0]];                    vals[i] = (vCount[i] + beta_ijk) / (c + beta_ij);                }            }        }    }    //}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -