📄 bayesnet.java
字号:
} } // determine cardinality of parent set & reserve space for frequency counts int nCardinality = m_ParentSets[nNode].GetCardinalityOfParents() * m_Instances.attribute(nCandidateParent).numValues(); int numValues = m_Instances.attribute(nNode).numValues(); int[][] nCounts = new int[nCardinality][numValues]; // set up candidate parent m_ParentSets[nNode].AddParent(nCandidateParent, m_Instances); // calculate the score double logScore = CalcNodeScore(nNode); // delete temporarily added parent m_ParentSets[nNode].DeleteLastParent(m_Instances); return logScore; } // CalcScore /** * Calc Node Score for given parent set * * @param nNode node for which the score is calculate * @return log score */ protected double CalcNodeScore(int nNode) { if (m_bUseADTree && m_ADTree != null) { return CalcNodeScoreADTree(nNode, m_Instances); } else { return CalcNodeScore(nNode, m_Instances); } } /** * helper function for CalcNodeScore above using the ADTree data structure * @param nNode node for which the score is calculate * @param instances used to calculate score with * @return log score */ private double CalcNodeScoreADTree(int nNode, Instances instances) { // get set of parents, insert iNode int nNrOfParents = m_ParentSets[nNode].GetNrOfParents(); int [] nNodes = new int [nNrOfParents + 1]; for (int iParent = 0; iParent < nNrOfParents; iParent++) { nNodes[iParent] = m_ParentSets[nNode].GetParent(iParent); } nNodes[nNrOfParents] = nNode; // calculate offsets int [] nOffsets = new int [nNrOfParents + 1]; int nOffset = 1; nOffsets[nNrOfParents] = 1; nOffset *= instances.attribute(nNode).numValues(); for (int iNode = nNrOfParents - 1; iNode >=0; iNode--) { nOffsets[iNode] = nOffset; nOffset *= instances.attribute(nNodes[iNode]).numValues(); } // sort nNodes & offsets for (int iNode = 1; iNode < nNodes.length; iNode++) { int iNode2 = iNode; while (iNode2 > 0 && nNodes[iNode2] < nNodes[iNode2 - 1]) { int h = nNodes[iNode2]; nNodes[iNode2] = nNodes[iNode2 - 1]; nNodes[iNode2 - 1] = h; h = nOffsets[iNode2]; nOffsets[iNode2] = nOffsets[iNode2 - 1]; nOffsets[iNode2 - 1] = h; iNode2--; } } // get counts from ADTree int nCardinality = m_ParentSets[nNode].GetCardinalityOfParents(); int numValues = instances.attribute(nNode).numValues(); int [] nCounts = new int[nCardinality * numValues];//if (nNrOfParents > 1) { /* System.out.println("==========================="); for (int iNode = 0; iNode < nNodes.length; iNode++) { System.out.print(nNodes[iNode] + " " + nOffsets[iNode] + ": "); } System.out.println(); */// int i = 3;//}//CalcNodeScore2(nNode, instances); m_ADTree.getCounts(nCounts, nNodes, nOffsets, 0, 0, false);// for (int iNode = 0; iNode < nCounts.length; iNode++) {// System.out.print(nCounts[iNode] + " ");// }// System.out.println(); return CalcScoreOfCounts(nCounts, nCardinality, numValues, instances); } // CalcNodeScore private double CalcNodeScore(int nNode, Instances instances) { // determine cardinality of parent set & reserve space for frequency counts int nCardinality = m_ParentSets[nNode].GetCardinalityOfParents(); int numValues = instances.attribute(nNode).numValues(); int[][] nCounts = new int[nCardinality][numValues]; // initialize (don't need this?) for (int iParent = 0; iParent < nCardinality; iParent++) { for (int iValue = 0; iValue < numValues; iValue++) { nCounts[iParent][iValue] = 0; } } // estimate distributions Enumeration enumInsts = instances.enumerateInstances(); while (enumInsts.hasMoreElements()) { Instance instance = (Instance) enumInsts.nextElement(); // updateClassifier; double iCPT = 0; for (int iParent = 0; iParent < m_ParentSets[nNode].GetNrOfParents(); iParent++) { int nParent = m_ParentSets[nNode].GetParent(iParent); iCPT = iCPT * instances.attribute(nParent).numValues() + instance.value(nParent); } nCounts[(int) iCPT][(int) instance.value(nNode)]++; } /* System.out.print("Counts:"); for (int iNode = 0; iNode < nCardinality; iNode++) { for (int iNode2 = 0; iNode2 < numValues; iNode2++) { System.out.print(nCounts[iNode][iNode2] + " "); } } System.out.println();*/ return CalcScoreOfCounts2(nCounts, nCardinality, numValues, instances); } // CalcNodeScore /** * utility function used by CalcScore and CalcNodeScore to determine the score * based on observed frequencies. * * @param nCounts array with observed frequencies * @param nCardinality ardinality of parent set * @param numValues number of values a node can take * @param instances to calc score with * @return log score */ protected double CalcScoreOfCounts(int [] nCounts, int nCardinality, int numValues, Instances instances) { // calculate scores using the distributions double fLogScore = 0.0; for (int iParent = 0; iParent < nCardinality; iParent++) { switch (m_nScoreType) { case (Scoreable.BAYES): { double nSumOfCounts = 0; for (int iSymbol = 0; iSymbol < numValues; iSymbol++) { if (m_fAlpha + nCounts[iParent * numValues + iSymbol] != 0) { fLogScore += Statistics.lnGamma(m_fAlpha + nCounts[iParent * numValues + iSymbol]); nSumOfCounts += m_fAlpha + nCounts[iParent * numValues + iSymbol]; } } if (nSumOfCounts != 0) { fLogScore -= Statistics.lnGamma(nSumOfCounts); } if (m_fAlpha != 0) { fLogScore -= numValues * Statistics.lnGamma(m_fAlpha); fLogScore += Statistics.lnGamma(numValues * m_fAlpha); } } break; case (Scoreable.MDL): case (Scoreable.AIC): case (Scoreable.ENTROPY): { double nSumOfCounts = 0; for (int iSymbol = 0; iSymbol < numValues; iSymbol++) { nSumOfCounts += nCounts[iParent * numValues + iSymbol]; } for (int iSymbol = 0; iSymbol < numValues; iSymbol++) { if (nCounts[iParent * numValues + iSymbol] > 0) { fLogScore += nCounts[iParent * numValues + iSymbol] * Math.log(nCounts[iParent * numValues + iSymbol] / nSumOfCounts); } } } break; default: {} } } switch (m_nScoreType) { case (Scoreable.MDL): { fLogScore -= 0.5 * nCardinality * (numValues - 1) * Math.log(instances.numInstances()); // it seems safe to assume that numInstances>0 here } break; case (Scoreable.AIC): { fLogScore -= nCardinality * (numValues - 1); } break; } return fLogScore; } // CalcNodeScore protected double CalcScoreOfCounts2(int[][] nCounts, int nCardinality, int numValues, Instances instances) { // calculate scores using the distributions double fLogScore = 0.0; for (int iParent = 0; iParent < nCardinality; iParent++) { switch (m_nScoreType) { case (Scoreable.BAYES): { double nSumOfCounts = 0; for (int iSymbol = 0; iSymbol < numValues; iSymbol++) { if (m_fAlpha + nCounts[iParent][iSymbol] != 0) { fLogScore += Statistics.lnGamma(m_fAlpha + nCounts[iParent][iSymbol]); nSumOfCounts += m_fAlpha + nCounts[iParent][iSymbol]; } } if (nSumOfCounts != 0) { fLogScore -= Statistics.lnGamma(nSumOfCounts); } if (m_fAlpha != 0) { fLogScore -= numValues * Statistics.lnGamma(m_fAlpha); fLogScore += Statistics.lnGamma(numValues * m_fAlpha); } } break; case (Scoreable.MDL): case (Scoreable.AIC): case (Scoreable.ENTROPY): { double nSumOfCounts = 0; for (int iSymbol = 0; iSymbol < numValues; iSymbol++) { nSumOfCounts += nCounts[iParent][iSymbol]; } for (int iSymbol = 0; iSymbol < numValues; iSymbol++) { if (nCounts[iParent][iSymbol] > 0) { fLogScore += nCounts[iParent][iSymbol] * Math.log(nCounts[iParent][iSymbol] / nSumOfCounts); } } } break; default: {} } } switch (m_nScoreType) { case (Scoreable.MDL): { fLogScore -= 0.5 * nCardinality * (numValues - 1) * Math.log(instances.numInstances()); // it seems safe to assume that numInstances>0 here } break; case (Scoreable.AIC): { fLogScore -= nCardinality * (numValues - 1); } break; } return fLogScore; } // CalcNodeScore /** * @return a string to describe the ScoreType option. */ public String scoreTypeTipText() { return "The score type determines the measure used to judge the quality of a" + " network structure. It can be one of Bayes, Minimum Description Length (MDL)," + " Akaike Information Criterion (AIC), and Entropy."; } /** * @return a string to describe the Alpha option. */ public String alphaTipText() { return "Alpha is used for estimating the probability tables and can be interpreted" + " as the initial count on each value."; } /** * @return a string to describe the InitAsNaiveBayes option. */ public String initAsNaiveBayesTipText() { return "When set to true (default), the initial network used for structure learning" + " is a Naive Bayes Network, that is, a network with an arrow from the classifier" + " node to each other node. When set to false, an empty network is used as initial"+ " network structure"; } /** * @return a string to describe the UseADTreeoption. */ public String useADTreeTipText() { return "When ADTree (the data structure for increasing speed on counts," + " not to be confused with the classifier under the same name) is used" + " learning time goes down typically. However, because ADTrees are memory" + " intensive, memory problems may occur. Switching this option off makes" + " the structure learning algorithms slower, and run with less memory." + " By default, ADTrees are used."; } /** * @return a string to describe the MaxNrOfParentsoption. */ public String maxNrOfParentsTipText() { return "Set the maximum number of parents a node in the Bayes net can have." + " When initialized as Naive Bayes, setting this parameter to 1 results in" + " a Naive Bayes classifier. When set to 2, a Tree Augmented Bayes Network (TAN)" + " is learned, and when set >2, a Bayes Net Augmented Bayes Network (BAN)" + " is learned. By setting it to a value much larger than the number of nodes" + " in the network (the default of 100000 pretty much guarantees this), no" + " restriction on the number of parents is enforced"; } /** * Main method for testing this class. * * @param argv the options */ public static void main(String[] argv) { try { System.out.println(Evaluation.evaluateModel(new BayesNet(), argv)); } catch (Exception e) { e.printStackTrace(); System.err.println(e.getMessage()); } } // main } // class BayesNet
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -