⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 globalscoresearchalgorithm.java

📁 MacroWeka扩展了著名数据挖掘工具weka
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
	 * @param bayesNet : Bayes Network containing structure to evaluate
	 * @param nNrOfFolds : the number of folds k to perform k-fold cv
	 * @return accuracy (in interval 0..1) measured using leave one out cv.
	 * @throws Exception passed on by updateClassifier
	 */
	public double kFoldCV(BayesNet bayesNet, int nNrOfFolds) throws Exception {
		m_BayesNet = bayesNet;
		double fAccuracy = 0.0;
		double fWeight = 0.0;
		Instances instances = bayesNet.m_Instances;
		// estimate CPTs based on complete data set
		bayesNet.estimateCPTs();
		int nFoldStart = 0;
		int nFoldEnd = instances.numInstances() / nNrOfFolds;
		int iFold = 1;
		while (nFoldStart < instances.numInstances()) {
			// remove influence of fold iFold from the probability distribution
			for (int iInstance = nFoldStart; iInstance < nFoldEnd; iInstance++) {
				Instance instance = instances.instance(iInstance);
				instance.setWeight(-instance.weight());
				bayesNet.updateClassifier(instance);
			}
			
			// measure accuracy on fold iFold
			for (int iInstance = nFoldStart; iInstance < nFoldEnd; iInstance++) {
				Instance instance = instances.instance(iInstance);
				instance.setWeight(-instance.weight());
				fAccuracy += accuracyIncrease(instance);
				instance.setWeight(-instance.weight());
				fWeight += instance.weight();
			}

			// restore influence of fold iFold from the probability distribution
			for (int iInstance = nFoldStart; iInstance < nFoldEnd; iInstance++) {
				Instance instance = instances.instance(iInstance);
				instance.setWeight(-instance.weight());
				bayesNet.updateClassifier(instance);
			}

			// go to next fold
			nFoldStart = nFoldEnd;
			iFold++;
			nFoldEnd = iFold * instances.numInstances() / nNrOfFolds;
		}
		return fAccuracy / fWeight;
	} // kFoldCV
	
	/** accuracyIncrease determines how much the accuracy estimate should
	 * be increased due to the contribution of a single given instance. 
	 * 
	 * @param instance : instance for which to calculate the accuracy increase.
	 * @return increase in accuracy due to given instance. 
	 * @throws Exception passed on by distributionForInstance and classifyInstance
	 */
	double accuracyIncrease(Instance instance) throws Exception {
		if (m_bUseProb) {
			double [] fProb = m_BayesNet.distributionForInstance(instance);
			return fProb[(int) instance.classValue()] * instance.weight();
		} else {
			if (m_BayesNet.classifyInstance(instance) == instance.classValue()) {
				return instance.weight();
			}
		}
		return 0;
	} // accuracyIncrease

	/**
	 * @return use probabilities or not in accuracy estimate
	 */
	public boolean getUseProb() {
		return m_bUseProb;
	} // getUseProb

	/**
	 * @param useProb : use probabilities or not in accuracy estimate
	 */
	public void setUseProb(boolean useProb) {
		m_bUseProb = useProb;
	} // setUseProb
	
	/**
	 * set cross validation strategy to be used in searching for networks.
	 * @param newCVType : cross validation strategy 
	 */
	public void setCVType(SelectedTag newCVType) {
		if (newCVType.getTags() == TAGS_CV_TYPE) {
			m_nCVType = newCVType.getSelectedTag().getID();
		}
	} // setCVType

	/**
	 * get cross validation strategy to be used in searching for networks.
	 * @return cross validation strategy 
	 */
	public SelectedTag getCVType() {
		return new SelectedTag(m_nCVType, TAGS_CV_TYPE);
	} // getCVType

  public void setMarkovBlanketClassifier(boolean bMarkovBlanketClassifier) {
    super.setMarkovBlanketClassifier(bMarkovBlanketClassifier);
  }

  public boolean getMarkovBlanketClassifier() {
    return super.getMarkovBlanketClassifier();
  }

	/**
	 * Returns an enumeration describing the available options
	 * 
	 * @return an enumeration of all the available options
	 */
	public Enumeration listOptions() {
		Vector newVector = new Vector();

    newVector.addElement(new Option(
            "\tApplies a Markov Blanket correction to the network structure, "
          + "\tafter a network structure is learned. This ensures that all "
          + "\tnodes in the network are part of the Markov blanket of the "
          + "\tclassifier node.\n",
          "mbc", 0, "-mbc"));
      
		newVector.addElement(
			new Option(
				"\tScore type (LOO-CV,k-Fold-CV,Cumulative-CV)\n",
				"S",
				1,
				"-S [LOO-CV|k-Fold-CV|Cumulative-CV]"));

		newVector.addElement(new Option("\tUse probabilistic or 0/1 scoring.\n\t(default probabilistic scoring)", "Q", 0, "-Q"));

		Enumeration enu = super.listOptions();
		while (enu.hasMoreElements()) {
			newVector.addElement(enu.nextElement());
		}
		return newVector.elements();
	} // listOptions

	/**
	 * Parses a given list of options. Valid options are:<p>
	 *
	 * For other options see search algorithm.
	 *
	 * @param options the list of options as an array of strings
	 * @exception Exception if an option is not supported
	 */
	public void setOptions(String[] options) throws Exception {

    setMarkovBlanketClassifier(Utils.getFlag("mbc", options));

		String sScore = Utils.getOption('S', options);

		if (sScore.compareTo("LOO-CV") == 0) {
			setCVType(new SelectedTag(LOOCV, TAGS_CV_TYPE));
		}
		if (sScore.compareTo("k-Fold-CV") == 0) {
			setCVType(new SelectedTag(KFOLDCV, TAGS_CV_TYPE));
		}
		if (sScore.compareTo("Cumulative-CV") == 0) {
			setCVType(new SelectedTag(CUMCV, TAGS_CV_TYPE));
		}
		setUseProb(!Utils.getFlag('Q', options));		
		super.setOptions(options);
	} // setOptions

	/**
	 * Gets the current settings of the search algorithm.
	 *
	 * @return an array of strings suitable for passing to setOptions
	 */
	public String[] getOptions() {
		String[] superOptions = super.getOptions();
		String[] options = new String[4 + superOptions.length];
		int current = 0;

    if (getMarkovBlanketClassifier())
      options[current++] = "-mbc";

		options[current++] = "-S";

		switch (m_nCVType) {
			case (LOOCV) :
				options[current++] = "LOO-CV";
				break;
			case (KFOLDCV) :
				options[current++] = "k-Fold-CV";
				break;
			case (CUMCV) :
				options[current++] = "Cumulative-CV";
				break;
		}
		
		if (getUseProb()) {
		  options[current++] = "-Q";
		}

                // insert options from parent class
                for (int iOption = 0; iOption < superOptions.length; iOption++) {
                        options[current++] = superOptions[iOption];
                }

		// Fill up rest with empty strings, not nulls!
		while (current < options.length) {
			options[current++] = "";
		}
		return options;
	} // getOptions

	/**
	 * @return a string to describe the CVType option.
	 */
	public String CVTypeTipText() {
	  return "Select cross validation strategy to be used in searching for networks." +
	  "LOO-CV = Leave one out cross validation\n" +
	  "k-Fold-CV = k fold cross validation\n" +
	  "Cumulative-CV = cumulative cross validation."
	  ;
	} // CVTypeTipText

	/**
	 * @return a string to describe the UseProb option.
	 */
	public String useProbTipText() {
	  return "If set to true, the probability of the class if returned in the estimate of the "+
	  "accuracy. If set to false, the accuracy estimate is only increased if the classifier returns " +
	  "exactly the correct class.";
	} // useProbTipText

	/**
	 * This will return a string describing the search algorithm.
	 * @return The string.
	 */
	public String globalInfo() {
	  return "This Bayes Network learning algorithm uses cross validation to estimate " +
	  "classification accuracy.";
	} // globalInfo

  /**
   * @return a string to describe the MarkovBlanketClassifier option.
   */
  public String markovBlanketClassifierTipText() {
    return super.markovBlanketClassifierTipText();
  }

} // class CVSearchAlgorithm

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -