⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 racedincrementallogitboost.java

📁 一个数据挖掘软件ALPHAMINERR的整个过程的JAVA版源代码
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
    protected Classifier[] boost(Instances data) throws Exception {
      
      Classifier[] newModel = Classifier.makeCopies(m_Classifier, m_NumClasses);
      
      // Create a copy of the data with the class transformed into numeric
      Instances boostData = new Instances(data);
      boostData.deleteWithMissingClass();
      int numInstances = boostData.numInstances();
      
      // Temporarily unset the class index
      int classIndex = data.classIndex();
      boostData.setClassIndex(-1);
      boostData.deleteAttributeAt(classIndex);
      boostData.insertAttributeAt(new Attribute("'pseudo class'"), classIndex);
      boostData.setClassIndex(classIndex);
      double [][] trainFs = new double [numInstances][m_NumClasses];
      double [][] trainYs = new double [numInstances][m_NumClasses];
      for (int j = 0; j < m_NumClasses; j++) {
	for (int i = 0, k = 0; i < numInstances; i++, k++) {
	  while (data.instance(k).classIsMissing()) k++;
	  trainYs[i][j] = (data.instance(k).classValue() == j) ? 1 : 0;
	}
      }
      
      // Evaluate / increment trainFs from the classifiers
      for (int x = 0; x < m_models.size(); x++) {
	for (int i = 0; i < numInstances; i++) {
	  double [] pred = new double [m_NumClasses];
	  double predSum = 0;
	  Classifier[] model = (Classifier[]) m_models.elementAt(x);
	  for (int j = 0; j < m_NumClasses; j++) {
	    pred[j] = model[j].classifyInstance(boostData.instance(i));
	    predSum += pred[j];
	  }
	  predSum /= m_NumClasses;
	  for (int j = 0; j < m_NumClasses; j++) {
	    trainFs[i][j] += (pred[j] - predSum) * (m_NumClasses-1) 
	      / m_NumClasses;
	  }
	}
      }

      for (int j = 0; j < m_NumClasses; j++) {
	
	// Set instance pseudoclass and weights
	for (int i = 0; i < numInstances; i++) {
	  double p = RtoP(trainFs[i], j);
	  Instance current = boostData.instance(i);
	  double z, actual = trainYs[i][j];
	  if (actual == 1) {
	    z = 1.0 / p;
	    if (z > Z_MAX) { // threshold
	      z = Z_MAX;
	    }
	  } else if (actual == 0) {
	    z = -1.0 / (1.0 - p);
	    if (z < -Z_MAX) { // threshold
	      z = -Z_MAX;
	    }
	  } else {
	    z = (actual - p) / (p * (1 - p));
	  }

	  double w = (actual - p) / z;
	  current.setValue(classIndex, z);
	  current.setWeight(numInstances * w);
	}
	
	Instances trainData = boostData;
	if (m_UseResampling) {
	  double[] weights = new double[boostData.numInstances()];
	  for (int kk = 0; kk < weights.length; kk++) {
	    weights[kk] = boostData.instance(kk).weight();
	  }
	  trainData = boostData.resampleWithWeights(m_RandomInstance, 
						    weights);
	}
	
	// Build the classifier
	newModel[j].buildClassifier(trainData);
      }      
      
      return newModel;
    }

    /* outputs description of the committee */
    public String toString() {
      
      StringBuffer text = new StringBuffer();
      
      text.append("RacedIncrementalLogitBoost: Best committee on validation data\n");
      text.append("Base classifiers: \n");
      
      for (int i = 0; i < m_models.size(); i++) {
	text.append("\nModel "+(i+1));
	Classifier[] cModels = (Classifier[]) m_models.elementAt(i);
	for (int j = 0; j < m_NumClasses; j++) {
	  text.append("\n\tClass " + (j + 1) 
		      + " (" + m_ClassAttribute.name() 
		      + "=" + m_ClassAttribute.value(j) + ")\n\n"
		      + cModels[j].toString() + "\n");
	}
      }
      text.append("Number of models: " +
		  m_models.size() + "\n");      
      text.append("Chunk size per model: " + m_chunkSize + "\n");
      
      return text.toString();
    }
  }

 /**
   * Builds the classifier.
   *
   * @param instances the instances to train the classifier with
   * @exception Exception if something goes wrong
   */
  public void buildClassifier(Instances data) throws Exception {

    m_RandomInstance = new Random(m_Seed);

    Instances boostData;
    int classIndex = data.classIndex();

    if (data.classAttribute().isNumeric()) {
      throw new Exception("LogitBoost can't handle a numeric class!");
    }
    if (m_Classifier == null) {
      throw new Exception("A base classifier has not been specified!");
    }

    if (!(m_Classifier instanceof WeightedInstancesHandler) &&
	!m_UseResampling) {
      m_UseResampling = true;
    }
    if (data.checkForStringAttributes()) {
      throw new Exception("Can't handle string attributes!");
    }

    m_NumClasses = data.numClasses();
    m_ClassAttribute = data.classAttribute();

    // Create a copy of the data with the class transformed into numeric
    boostData = new Instances(data);
    boostData.deleteWithMissingClass();

    // Temporarily unset the class index
    boostData.setClassIndex(-1);
    boostData.deleteAttributeAt(classIndex);
    boostData.insertAttributeAt(new Attribute("'pseudo class'"), classIndex);
    boostData.setClassIndex(classIndex);
    m_NumericClassData = new Instances(boostData, 0);

    data = new Instances(data);
    data.randomize(m_RandomInstance);

    // create the committees
    int cSize = m_minChunkSize;
    m_committees = new FastVector();
    while (cSize <= m_maxChunkSize) {
      m_committees.addElement(new Committee(cSize));
      m_maxBatchSizeRequired = cSize;
      cSize *= 2;
    }

    // set up for consumption
    m_validationSet = new Instances(data, m_validationChunkSize);
    m_currentSet = new Instances(data, m_maxBatchSizeRequired);
    m_bestCommittee = null;
    m_numInstancesConsumed = 0;

    // start eating what we've been given
    for (int i=0; i<data.numInstances(); i++) updateClassifier(data.instance(i));
  }

 /**
   * Updates the classifier.
   *
   * @param instance the next instance in the stream of training data
   * @exception Exception if something goes wrong
   */
  public void updateClassifier(Instance instance) throws Exception {

    m_numInstancesConsumed++;

    if (m_validationSet.numInstances() < m_validationChunkSize) {
      m_validationSet.add(instance);
      m_validationSetChanged = true;
    } else {
      m_currentSet.add(instance);
      boolean hasChanged = false;
      
      // update each committee
      for (int i=0; i<m_committees.size(); i++) {
	Committee c = (Committee) m_committees.elementAt(i);
	if (c.update()) {
	  
	  hasChanged = true;
	  
	  if (m_PruningType == PRUNETYPE_LOGLIKELIHOOD) {
	    double oldLL = c.logLikelihood();
	    double newLL = c.logLikelihoodAfter();
	    if (newLL >= oldLL && c.committeeSize() > 1) {
	      c.pruneLastModel();
	      if (m_Debug) System.out.println("Pruning " + c.chunkSize()+ " committee (" +
					      oldLL + " < " + newLL + ")");
	    } else c.keepLastModel();
	  } else c.keepLastModel(); // no pruning
	} 
      }
      if (hasChanged) {

	if (m_Debug) System.out.println("After consuming " + m_numInstancesConsumed
					+ " instances... (" + m_validationSet.numInstances()
					+ " + " + m_currentSet.numInstances()
					+ " instances currently in memory)");
	
	// find best committee
	double lowestError = 1.0;
	for (int i=0; i<m_committees.size(); i++) {
	  Committee c = (Committee) m_committees.elementAt(i);

	  if (c.committeeSize() > 0) {

	    double err = c.validationError();
	    double ll = c.logLikelihood();

	    if (m_Debug) System.out.println("Chunk size " + c.chunkSize() + " with "
					    + c.committeeSize() + " models, has validation error of "
					    + err + ", log likelihood of " + ll);
	    if (err < lowestError) {
	      lowestError = err;
	      m_bestCommittee = c;
	    }
	  }
	}
      }
      if (m_currentSet.numInstances() >= m_maxBatchSizeRequired) {
	m_currentSet = new Instances(m_currentSet, m_maxBatchSizeRequired);

	// reset consumation counts
	for (int i=0; i<m_committees.size(); i++) {
	  Committee c = (Committee) m_committees.elementAt(i);
	  c.resetConsumed();
	}
      }
    }
  }

  /**
   * Convert from function responses to probabilities
   *
   * @param R an array containing the responses from each function
   * @param j the class value of interest
   * @return the probability prediction for j
   */
  protected static double RtoP(double []Fs, int j) 
    throws Exception {

    double maxF = -Double.MAX_VALUE;
    for (int i = 0; i < Fs.length; i++) {
      if (Fs[i] > maxF) {
	maxF = Fs[i];
      }
    }
    double sum = 0;
    double[] probs = new double[Fs.length];
    for (int i = 0; i < Fs.length; i++) {
      probs[i] = Math.exp(Fs[i] - maxF);
      sum += probs[i];
    }
    if (sum == 0) {
      throw new Exception("Can't normalize");
    }
    return probs[j] / sum;
  }

  /**
   * Computes class distribution of an instance using the best committee.
   */
  public double[] distributionForInstance(Instance instance) throws Exception {

    if (m_bestCommittee != null) return m_bestCommittee.distributionForInstance(instance);
    else {
      if (m_validationSetChanged || m_zeroR == null) {
	m_zeroR = new ZeroR();
	m_zeroR.buildClassifier(m_validationSet);
	m_validationSetChanged = false;
      }
      return m_zeroR.distributionForInstance(instance);
    }
  }

  /**
   * Returns an enumeration describing the available options
   *
   * @return an enumeration of all the available options
   */
  public Enumeration listOptions() {

    Vector newVector = new Vector(9);

    newVector.addElement(new Option(
	      "\tTurn on debugging output.",
	      "D", 0, "-D"));

    newVector.addElement(new Option(
	      "\tMinimum size of chunks.\n"
	      +"\t(default 500)",
	      "C", 1, "-C <num>"));

    newVector.addElement(new Option(
	      "\tMaximum size of chunks.\n"
	      +"\t(default 20000)",
	      "M", 1, "-M <num>"));

    newVector.addElement(new Option(
	      "\tSize of validation set.\n"
	      +"\t(default 5000)",
	      "V", 1, "-V <num>"));

    newVector.addElement(new Option(
	      "\tFull name of 'weak' learner to boost.\n"
	      +"\teg: weka.classifiers.DecisionStump",
	      "W", 1, "-W <learner class name>"));

    newVector.addElement(new Option(
	      "\tCommittee pruning to perform.\n"
	      +"\t0=none, 1=log likelihood (default)",
	      "P", 1, "-P <pruning type>"));

    newVector.addElement(new Option(
	      "\tUse resampling for boosting.",
	      "Q", 0, "-Q"));
    newVector.addElement(new Option(
	      "\tSeed for resampling. (Default 1)",
	      "S", 1, "-S <num>"));

    if ((m_Classifier != null) &&
	(m_Classifier instanceof OptionHandler)) {
      newVector.addElement(new Option(
	  "",
	  "", 0, "\nOptions specific to weak learner "
	  + m_Classifier.getClass().getName() + ":"));
      Enumeration em = ((OptionHandler)m_Classifier).listOptions();
      while (em.hasMoreElements()) {
	newVector.addElement(em.nextElement());
      }
    }
    return newVector.elements();
  }


  /**
   * Parses a given list of options. Valid options are:<p>
   *
   * @param options the list of options as an array of strings
   * @exception Exception if an option is not supported
   */
  public void setOptions(String[] options) throws Exception {

    String minChunkSize = Utils.getOption('C', options);
    if (minChunkSize.length() != 0) {
      setMinChunkSize(Integer.parseInt(minChunkSize));
    } else {
      setMinChunkSize(500);
    }

    String maxChunkSize = Utils.getOption('M', options);
    if (maxChunkSize.length() != 0) {
      setMaxChunkSize(Integer.parseInt(maxChunkSize));
    } else {
      setMaxChunkSize(20000);
    }

    String validationChunkSize = Utils.getOption('V', options);
    if (validationChunkSize.length() != 0) {
      setValidationChunkSize(Integer.parseInt(validationChunkSize));
    } else {
      setValidationChunkSize(5000);
    }

    String pruneType = Utils.getOption('P', options);
    if (pruneType.length() != 0) {
      setPruningType(new SelectedTag(Integer.parseInt(pruneType), TAGS_PRUNETYPE));
    } else {
      setPruningType(new SelectedTag(PRUNETYPE_LOGLIKELIHOOD, TAGS_PRUNETYPE));
    }

    setUseResampling(Utils.getFlag('Q', options));

    String seedString = Utils.getOption('S', options);
    if (seedString.length() != 0) {
      setSeed(Integer.parseInt(seedString));
    } else {
      setSeed(1);
    }

    setDebug(Utils.getFlag('D', options));

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -