⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 nbtreesplit.java

📁 一个很好的LIBSVM的JAVA源码。对于要研究和改进SVM算法的学者。可以参考。来自数据挖掘工具YALE工具包。
💻 JAVA
字号:
/*
 *  YALE - Yet Another Learning Environment
 *  Copyright (C) 2001-2004
 *      Simon Fischer, Ralf Klinkenberg, Ingo Mierswa, 
 *          Katharina Morik, Oliver Ritthoff
 *      Artificial Intelligence Unit
 *      Computer Science Department
 *      University of Dortmund
 *      44221 Dortmund,  Germany
 *  email: yale-team@lists.sourceforge.net
 *  web:   http://yale.cs.uni-dortmund.de/
 *
 *  This program is free software; you can redistribute it and/or
 *  modify it under the terms of the GNU General Public License as 
 *  published by the Free Software Foundation; either version 2 of the
 *  License, or (at your option) any later version. 
 *
 *  This program is distributed in the hope that it will be useful, but
 *  WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
 *  General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program; if not, write to the Free Software
 *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
 *  USA.
 */
package edu.udo.cs.yale.operator.learner.decisiontree.y45.j48;

import java.util.*;
import weka.core.*;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.bayes.NaiveBayesUpdateable;
import weka.filters.supervised.attribute.Discretize;
import weka.filters.Filter;

/**
 * Class implementing a NBTree split on an attribute.
 *
 * @author Mark Hall (mhall@cs.waikato.ac.nz)
 * @version $Revision: 1.1 $
 */
public class NBTreeSplit extends ClassifierSplitModel{

  /** Desired number of branches. */
  private int m_complexityIndex;  

  /** Attribute to split on. */
  private int m_attIndex;         

  /** Minimum number of objects in a split.   */
  private int m_minNoObj;         

  /** Value of split point. */
  private double m_splitPoint;   

  /** The sum of the weights of the instances. */
  private double m_sumOfWeights;  

  /** The weight of the instances incorrectly classified by the 
      naive bayes models arising from this split*/
  private double m_errors;

  private C45Split m_c45S;

  /** The global naive bayes model for this node */
  NBTreeNoSplit m_globalNB;

  /**
   * Initializes the split model.
   */
  public NBTreeSplit(int attIndex, int minNoObj, double sumOfWeights) {
    
    // Get index of attribute to split on.
    m_attIndex = attIndex;
        
    // Set minimum number of objects.
    m_minNoObj = minNoObj;

    // Set the sum of the weights
    m_sumOfWeights = sumOfWeights;
    
  }

  /**
   * Creates a NBTree-type split on the given data. Assumes that none of
   * the class values is missing.
   *
   * @exception Exception if something goes wrong
   */
  public void buildClassifier(Instances trainInstances) 
       throws Exception {

    // Initialize the remaining instance variables.
    m_numSubsets = 0;
    m_splitPoint = Double.MAX_VALUE;
    m_errors = 0;
    if (m_globalNB != null) {
      m_errors = m_globalNB.getErrors();
    }

    // Different treatment for enumerated and numeric
    // attributes.
    if (trainInstances.attribute(m_attIndex).isNominal()) {
      m_complexityIndex = trainInstances.attribute(m_attIndex).numValues();
      handleEnumeratedAttribute(trainInstances);
    }else{
      m_complexityIndex = 2;
      trainInstances.sort(trainInstances.attribute(m_attIndex));
      handleNumericAttribute(trainInstances);
    }
  }

  /**
   * Returns index of attribute for which split was generated.
   */
  public final int attIndex() {
    
    return m_attIndex;
  }

  /**
   * Creates split on enumerated attribute.
   *
   * @exception Exception if something goes wrong
   */
  private void handleEnumeratedAttribute(Instances trainInstances)
       throws Exception {

    m_c45S = new C45Split(m_attIndex, 2, m_sumOfWeights);
    m_c45S.buildClassifier(trainInstances);
    if (m_c45S.numSubsets() == 0) {
      return;
    }
    m_errors = 0;
    Instance instance;

    Instances [] trainingSets = new Instances [m_complexityIndex];
    for (int i = 0; i < m_complexityIndex; i++) {
      trainingSets[i] = new Instances(trainInstances, 0);
    }
    /*    m_distribution = new Distribution(m_complexityIndex,
	  trainInstances.numClasses()); */
    int subset;
    for (int i = 0; i < trainInstances.numInstances(); i++) {
      instance = trainInstances.instance(i);
      subset = m_c45S.whichSubset(instance);
      if (subset > -1) {
	trainingSets[subset].add((Instance)instance.copy());
      } else {
	double [] weights = m_c45S.weights(instance);
	for (int j = 0; j < m_complexityIndex; j++) {
	  try {
	    Instance temp = (Instance) instance.copy();
	    if (weights.length == m_complexityIndex) {
	      temp.setWeight(temp.weight() * weights[j]);
	    } else {
	      temp.setWeight(temp.weight() / m_complexityIndex);
	    }
	    trainingSets[j].add(temp);
	  } catch (Exception ex) {
	    ex.printStackTrace();
	    System.err.println("*** "+m_complexityIndex);
	    System.err.println(weights.length);
	    System.exit(1);
	  }
	}
      }
    }

    /*    // compute weights (weights of instances per subset
    m_weights = new double [m_complexityIndex];
    for (int i = 0; i < m_complexityIndex; i++) {
      m_weights[i] = trainingSets[i].sumOfWeights();
    }
    Utils.normalize(m_weights); */

    /*
    // Only Instances with known values are relevant.
    Enumeration enum = trainInstances.enumerateInstances();
    while (enum.hasMoreElements()) {
      instance = (Instance) enum.nextElement();
      if (!instance.isMissing(m_attIndex)) {
	//	m_distribution.add((int)instance.value(m_attIndex),instance);
	trainingSets[(int)instances.value(m_attIndex)].add(instance);
      } else {
	// add these to the error count
	m_errors += instance.weight();
      }
      } */

    Random r = new Random(1);
    int minNumCount = 0;
    for (int i = 0; i < m_complexityIndex; i++) {
      if (trainingSets[i].numInstances() >= 5) {
	minNumCount++;
	// Discretize the sets
	Discretize disc = new Discretize();
	disc.setInputFormat(trainingSets[i]);
	trainingSets[i] = Filter.useFilter(trainingSets[i], disc);

	trainingSets[i].randomize(r);
	trainingSets[i].stratify(5);
	NaiveBayesUpdateable fullModel = new NaiveBayesUpdateable();
	fullModel.buildClassifier(trainingSets[i]);

	// add the errors for this branch of the split
	m_errors += NBTreeNoSplit.crossValidate(fullModel, trainingSets[i], r);
      } else {
	// if fewer than min obj then just count them as errors
	for (int j = 0; j < trainingSets[i].numInstances(); j++) {
	  m_errors += trainingSets[i].instance(j).weight();
	}
      }
    }
    
    // Check if there are at least five instances in at least two of the subsets
    // subsets.
    if (minNumCount > 1) {
      m_numSubsets = m_complexityIndex;
    }
  }

  /**
   * Creates split on numeric attribute.
   *
   * @exception Exception if something goes wrong
   */
  private void handleNumericAttribute(Instances trainInstances)
       throws Exception {

    m_c45S = new C45Split(m_attIndex, 2, m_sumOfWeights);
    m_c45S.buildClassifier(trainInstances);
    if (m_c45S.numSubsets() == 0) {
      return;
    }
    m_errors = 0;

    Instances [] trainingSets = new Instances [m_complexityIndex];
    trainingSets[0] = new Instances(trainInstances, 0);
    trainingSets[1] = new Instances(trainInstances, 0);
    int subset = -1;
    
    // populate the subsets
    for (int i = 0; i < trainInstances.numInstances(); i++) {
      Instance instance = trainInstances.instance(i);
      subset = m_c45S.whichSubset(instance);
      if (subset != -1) {
	trainingSets[subset].add((Instance)instance.copy());
      } else {
	double [] weights = m_c45S.weights(instance);
	for (int j = 0; j < m_complexityIndex; j++) {
	  Instance temp = (Instance)instance.copy();
	  if (weights.length == m_complexityIndex) {
	    temp.setWeight(temp.weight() * weights[j]);
	  } else {
	    temp.setWeight(temp.weight() / m_complexityIndex);
	  }
	  trainingSets[j].add(temp); 
	}
      }
    }
    
    /*    // compute weights (weights of instances per subset
    m_weights = new double [m_complexityIndex];
    for (int i = 0; i < m_complexityIndex; i++) {
      m_weights[i] = trainingSets[i].sumOfWeights();
    }
    Utils.normalize(m_weights); */

    Random r = new Random(1);
    int minNumCount = 0;
    for (int i = 0; i < m_complexityIndex; i++) {
      if (trainingSets[i].numInstances() > 5) {
	minNumCount++;
	// Discretize the sets
		Discretize disc = new Discretize();
	disc.setInputFormat(trainingSets[i]);
	trainingSets[i] = Filter.useFilter(trainingSets[i], disc);

	trainingSets[i].randomize(r);
	trainingSets[i].stratify(5);
	NaiveBayesUpdateable fullModel = new NaiveBayesUpdateable();
	fullModel.buildClassifier(trainingSets[i]);

	// add the errors for this branch of the split
	m_errors += NBTreeNoSplit.crossValidate(fullModel, trainingSets[i], r);
      } else {
	for (int j = 0; j < trainingSets[i].numInstances(); j++) {
	  m_errors += trainingSets[i].instance(j).weight();
	}
      }
    }
    
    // Check if minimum number of Instances in at least two
    // subsets.
    if (minNumCount > 1) {
      m_numSubsets = m_complexityIndex;
    }
  }

  /**
   * Returns index of subset instance is assigned to.
   * Returns -1 if instance is assigned to more than one subset.
   *
   * @exception Exception if something goes wrong
   */
  public final int whichSubset(Instance instance) 
    throws Exception {
    
    return m_c45S.whichSubset(instance);
  }

  /**
   * Returns weights if instance is assigned to more than one subset.
   * Returns null if instance is only assigned to one subset.
   */
  public final double [] weights(Instance instance) {
    return m_c45S.weights(instance);
    //     return m_weights;
  }

  /**
   * Returns a string containing java source code equivalent to the test
   * made at this node. The instance being tested is called "i".
   *
   * @param index index of the nominal value tested
   * @param data the data containing instance structure info
   * @return a value of type 'String'
   */
  public final String sourceExpression(int index, Instances data) {
    return m_c45S.sourceExpression(index, data);
  }

  /**
   * Prints the condition satisfied by instances in a subset.
   *
   * @param index of subset 
   * @param data training set.
   */
  public final String rightSide(int index,Instances data) {
    return m_c45S.rightSide(index, data);
  }

  /**
   * Prints left side of condition..
   *
   * @param data training set.
   */
  public final String leftSide(Instances data) {

    return m_c45S.leftSide(data);
  }

  /**
   * Return the probability for a class value
   *
   * @param classIndex the index of the class value
   * @param instance the instance to generate a probability for
   * @param theSubset the subset to consider
   * @return a probability
   * @exception Exception if an error occurs
   */
  public double classProb(int classIndex, Instance instance, int theSubset) 
    throws Exception {

    // use the global naive bayes model
    if (theSubset > -1) {
      return m_globalNB.classProb(classIndex, instance, theSubset);
    } else {
      throw new Exception("This shouldn't happen!!!");
    }
  }

  /**
   * Return the global naive bayes model for this node
   *
   * @return a <code>NBTreeNoSplit</code> value
   */
  public NBTreeNoSplit getGlobalModel() {
    return m_globalNB;
  }

  /**
   * Set the global naive bayes model for this node
   *
   * @param global a <code>NBTreeNoSplit</code> value
   */
  public void setGlobalModel(NBTreeNoSplit global) {
    m_globalNB = global;
  }

  /**
   * Return the errors made by the naive bayes models arising
   * from this split.
   *
   * @return a <code>double</code> value
   */
  public double getErrors() {
    return m_errors;
  }
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -