📄 globalscoresearchalgorithm.java
字号:
public double kFoldCV(BayesNet bayesNet, int nNrOfFolds) throws Exception { m_BayesNet = bayesNet; double fAccuracy = 0.0; double fWeight = 0.0; Instances instances = bayesNet.m_Instances; // estimate CPTs based on complete data set bayesNet.estimateCPTs(); int nFoldStart = 0; int nFoldEnd = instances.numInstances() / nNrOfFolds; int iFold = 1; while (nFoldStart < instances.numInstances()) { // remove influence of fold iFold from the probability distribution for (int iInstance = nFoldStart; iInstance < nFoldEnd; iInstance++) { Instance instance = instances.instance(iInstance); instance.setWeight(-instance.weight()); bayesNet.updateClassifier(instance); } // measure accuracy on fold iFold for (int iInstance = nFoldStart; iInstance < nFoldEnd; iInstance++) { Instance instance = instances.instance(iInstance); instance.setWeight(-instance.weight()); fAccuracy += accuracyIncrease(instance); instance.setWeight(-instance.weight()); fWeight += instance.weight(); } // restore influence of fold iFold from the probability distribution for (int iInstance = nFoldStart; iInstance < nFoldEnd; iInstance++) { Instance instance = instances.instance(iInstance); instance.setWeight(-instance.weight()); bayesNet.updateClassifier(instance); } // go to next fold nFoldStart = nFoldEnd; iFold++; nFoldEnd = iFold * instances.numInstances() / nNrOfFolds; } return fAccuracy / fWeight; } // kFoldCV /** accuracyIncrease determines how much the accuracy estimate should * be increased due to the contribution of a single given instance. * * @param instance : instance for which to calculate the accuracy increase. * @return increase in accuracy due to given instance. * @throws Exception passed on by distributionForInstance and classifyInstance */ double accuracyIncrease(Instance instance) throws Exception { if (m_bUseProb) { double [] fProb = m_BayesNet.distributionForInstance(instance); return fProb[(int) instance.classValue()] * instance.weight(); } else { if (m_BayesNet.classifyInstance(instance) == instance.classValue()) { return instance.weight(); } } return 0; } // accuracyIncrease /** * @return use probabilities or not in accuracy estimate */ public boolean getUseProb() { return m_bUseProb; } // getUseProb /** * @param useProb : use probabilities or not in accuracy estimate */ public void setUseProb(boolean useProb) { m_bUseProb = useProb; } // setUseProb /** * set cross validation strategy to be used in searching for networks. * @param newCVType : cross validation strategy */ public void setCVType(SelectedTag newCVType) { if (newCVType.getTags() == TAGS_CV_TYPE) { m_nCVType = newCVType.getSelectedTag().getID(); } } // setCVType /** * get cross validation strategy to be used in searching for networks. * @return cross validation strategy */ public SelectedTag getCVType() { return new SelectedTag(m_nCVType, TAGS_CV_TYPE); } // getCVType /** * * @param bMarkovBlanketClassifier */ public void setMarkovBlanketClassifier(boolean bMarkovBlanketClassifier) { super.setMarkovBlanketClassifier(bMarkovBlanketClassifier); } /** * * @return */ public boolean getMarkovBlanketClassifier() { return super.getMarkovBlanketClassifier(); } /** * Returns an enumeration describing the available options * * @return an enumeration of all the available options */ public Enumeration listOptions() { Vector newVector = new Vector(); newVector.addElement(new Option( "\tApplies a Markov Blanket correction to the network structure, \n" + "\tafter a network structure is learned. This ensures that all \n" + "\tnodes in the network are part of the Markov blanket of the \n" + "\tclassifier node.", "mbc", 0, "-mbc")); newVector.addElement( new Option( "\tScore type (LOO-CV,k-Fold-CV,Cumulative-CV)", "S", 1, "-S [LOO-CV|k-Fold-CV|Cumulative-CV]")); newVector.addElement(new Option("\tUse probabilistic or 0/1 scoring.\n\t(default probabilistic scoring)", "Q", 0, "-Q")); Enumeration enu = super.listOptions(); while (enu.hasMoreElements()) { newVector.addElement(enu.nextElement()); } return newVector.elements(); } // listOptions /** * Parses a given list of options. <p/> * <!-- options-start --> * Valid options are: <p/> * * <pre> -mbc * Applies a Markov Blanket correction to the network structure, * after a network structure is learned. This ensures that all * nodes in the network are part of the Markov blanket of the * classifier node.</pre> * * <pre> -S [LOO-CV|k-Fold-CV|Cumulative-CV] * Score type (LOO-CV,k-Fold-CV,Cumulative-CV)</pre> * * <pre> -Q * Use probabilistic or 0/1 scoring. * (default probabilistic scoring)</pre> * <!-- options-end --> * * @param options the list of options as an array of strings * @throws Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { setMarkovBlanketClassifier(Utils.getFlag("mbc", options)); String sScore = Utils.getOption('S', options); if (sScore.compareTo("LOO-CV") == 0) { setCVType(new SelectedTag(LOOCV, TAGS_CV_TYPE)); } if (sScore.compareTo("k-Fold-CV") == 0) { setCVType(new SelectedTag(KFOLDCV, TAGS_CV_TYPE)); } if (sScore.compareTo("Cumulative-CV") == 0) { setCVType(new SelectedTag(CUMCV, TAGS_CV_TYPE)); } setUseProb(!Utils.getFlag('Q', options)); super.setOptions(options); } // setOptions /** * Gets the current settings of the search algorithm. * * @return an array of strings suitable for passing to setOptions */ public String[] getOptions() { String[] superOptions = super.getOptions(); String[] options = new String[4 + superOptions.length]; int current = 0; if (getMarkovBlanketClassifier()) options[current++] = "-mbc"; options[current++] = "-S"; switch (m_nCVType) { case (LOOCV) : options[current++] = "LOO-CV"; break; case (KFOLDCV) : options[current++] = "k-Fold-CV"; break; case (CUMCV) : options[current++] = "Cumulative-CV"; break; } if (!getUseProb()) { options[current++] = "-Q"; } // insert options from parent class for (int iOption = 0; iOption < superOptions.length; iOption++) { options[current++] = superOptions[iOption]; } // Fill up rest with empty strings, not nulls! while (current < options.length) { options[current++] = ""; } return options; } // getOptions /** * @return a string to describe the CVType option. */ public String CVTypeTipText() { return "Select cross validation strategy to be used in searching for networks." + "LOO-CV = Leave one out cross validation\n" + "k-Fold-CV = k fold cross validation\n" + "Cumulative-CV = cumulative cross validation." ; } // CVTypeTipText /** * @return a string to describe the UseProb option. */ public String useProbTipText() { return "If set to true, the probability of the class if returned in the estimate of the "+ "accuracy. If set to false, the accuracy estimate is only increased if the classifier returns " + "exactly the correct class."; } // useProbTipText /** * This will return a string describing the search algorithm. * @return The string. */ public String globalInfo() { return "This Bayes Network learning algorithm uses cross validation to estimate " + "classification accuracy."; } // globalInfo /** * @return a string to describe the MarkovBlanketClassifier option. */ public String markovBlanketClassifierTipText() { return super.markovBlanketClassifierTipText(); } /** * Returns the revision string. * * @return the revision */ public String getRevision() { return RevisionUtils.extract("$Revision: 1.10 $"); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -