⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 racedincrementallogitboost.java

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
      // Temporarily unset the class index      int classIndex = data.classIndex();      boostData.setClassIndex(-1);      boostData.deleteAttributeAt(classIndex);      boostData.insertAttributeAt(new Attribute("'pseudo class'"), classIndex);      boostData.setClassIndex(classIndex);      double [][] trainFs = new double [numInstances][m_NumClasses];      double [][] trainYs = new double [numInstances][m_NumClasses];      for (int j = 0; j < m_NumClasses; j++) {	for (int i = 0, k = 0; i < numInstances; i++, k++) {	  while (data.instance(k).classIsMissing()) k++;	  trainYs[i][j] = (data.instance(k).classValue() == j) ? 1 : 0;	}      }            // Evaluate / increment trainFs from the classifiers      for (int x = 0; x < m_models.size(); x++) {	for (int i = 0; i < numInstances; i++) {	  double [] pred = new double [m_NumClasses];	  double predSum = 0;	  Classifier[] model = (Classifier[]) m_models.elementAt(x);	  for (int j = 0; j < m_NumClasses; j++) {	    pred[j] = model[j].classifyInstance(boostData.instance(i));	    predSum += pred[j];	  }	  predSum /= m_NumClasses;	  for (int j = 0; j < m_NumClasses; j++) {	    trainFs[i][j] += (pred[j] - predSum) * (m_NumClasses-1) 	      / m_NumClasses;	  }	}      }      for (int j = 0; j < m_NumClasses; j++) {		// Set instance pseudoclass and weights	for (int i = 0; i < numInstances; i++) {	  double p = RtoP(trainFs[i], j);	  Instance current = boostData.instance(i);	  double z, actual = trainYs[i][j];	  if (actual == 1) {	    z = 1.0 / p;	    if (z > Z_MAX) { // threshold	      z = Z_MAX;	    }	  } else if (actual == 0) {	    z = -1.0 / (1.0 - p);	    if (z < -Z_MAX) { // threshold	      z = -Z_MAX;	    }	  } else {	    z = (actual - p) / (p * (1 - p));	  }	  double w = (actual - p) / z;	  current.setValue(classIndex, z);	  current.setWeight(numInstances * w);	}		Instances trainData = boostData;	if (m_UseResampling) {	  double[] weights = new double[boostData.numInstances()];	  for (int kk = 0; kk < weights.length; kk++) {	    weights[kk] = boostData.instance(kk).weight();	  }	  trainData = boostData.resampleWithWeights(m_RandomInstance, 						    weights);	}		// Build the classifier	newModel[j].buildClassifier(trainData);      }                  return newModel;    }    /* outputs description of the committee */    public String toString() {            StringBuffer text = new StringBuffer();            text.append("RacedIncrementalLogitBoost: Best committee on validation data\n");      text.append("Base classifiers: \n");            for (int i = 0; i < m_models.size(); i++) {	text.append("\nModel "+(i+1));	Classifier[] cModels = (Classifier[]) m_models.elementAt(i);	for (int j = 0; j < m_NumClasses; j++) {	  text.append("\n\tClass " + (j + 1) 		      + " (" + m_ClassAttribute.name() 		      + "=" + m_ClassAttribute.value(j) + ")\n\n"		      + cModels[j].toString() + "\n");	}      }      text.append("Number of models: " +		  m_models.size() + "\n");            text.append("Chunk size per model: " + m_chunkSize + "\n");            return text.toString();    }  } /**   * Builds the classifier.   *   * @param instances the instances to train the classifier with   * @exception Exception if something goes wrong   */  public void buildClassifier(Instances data) throws Exception {    m_RandomInstance = new Random(m_Seed);    Instances boostData;    int classIndex = data.classIndex();    if (data.classAttribute().isNumeric()) {      throw new Exception("LogitBoost can't handle a numeric class!");    }    if (m_Classifier == null) {      throw new Exception("A base classifier has not been specified!");    }    if (!(m_Classifier instanceof WeightedInstancesHandler) &&	!m_UseResampling) {      m_UseResampling = true;    }    if (data.checkForStringAttributes()) {      throw new Exception("Can't handle string attributes!");    }    m_NumClasses = data.numClasses();    m_ClassAttribute = data.classAttribute();    // Create a copy of the data with the class transformed into numeric    boostData = new Instances(data);    boostData.deleteWithMissingClass();    // Temporarily unset the class index    boostData.setClassIndex(-1);    boostData.deleteAttributeAt(classIndex);    boostData.insertAttributeAt(new Attribute("'pseudo class'"), classIndex);    boostData.setClassIndex(classIndex);    m_NumericClassData = new Instances(boostData, 0);    data = new Instances(data);    data.randomize(m_RandomInstance);    // create the committees    int cSize = m_minChunkSize;    m_committees = new FastVector();    while (cSize <= m_maxChunkSize) {      m_committees.addElement(new Committee(cSize));      m_maxBatchSizeRequired = cSize;      cSize *= 2;    }    // set up for consumption    m_validationSet = new Instances(data, m_validationChunkSize);    m_currentSet = new Instances(data, m_maxBatchSizeRequired);    m_bestCommittee = null;    m_numInstancesConsumed = 0;    // start eating what we've been given    for (int i=0; i<data.numInstances(); i++) updateClassifier(data.instance(i));  } /**   * Updates the classifier.   *   * @param instance the next instance in the stream of training data   * @exception Exception if something goes wrong   */  public void updateClassifier(Instance instance) throws Exception {    m_numInstancesConsumed++;    if (m_validationSet.numInstances() < m_validationChunkSize) {      m_validationSet.add(instance);      m_validationSetChanged = true;    } else {      m_currentSet.add(instance);      boolean hasChanged = false;            // update each committee      for (int i=0; i<m_committees.size(); i++) {	Committee c = (Committee) m_committees.elementAt(i);	if (c.update()) {	  	  hasChanged = true;	  	  if (m_PruningType == PRUNETYPE_LOGLIKELIHOOD) {	    double oldLL = c.logLikelihood();	    double newLL = c.logLikelihoodAfter();	    if (newLL >= oldLL && c.committeeSize() > 1) {	      c.pruneLastModel();	      if (m_Debug) System.out.println("Pruning " + c.chunkSize()+ " committee (" +					      oldLL + " < " + newLL + ")");	    } else c.keepLastModel();	  } else c.keepLastModel(); // no pruning	}       }      if (hasChanged) {	if (m_Debug) System.out.println("After consuming " + m_numInstancesConsumed					+ " instances... (" + m_validationSet.numInstances()					+ " + " + m_currentSet.numInstances()					+ " instances currently in memory)");		// find best committee	double lowestError = 1.0;	for (int i=0; i<m_committees.size(); i++) {	  Committee c = (Committee) m_committees.elementAt(i);	  if (c.committeeSize() > 0) {	    double err = c.validationError();	    double ll = c.logLikelihood();	    if (m_Debug) System.out.println("Chunk size " + c.chunkSize() + " with "					    + c.committeeSize() + " models, has validation error of "					    + err + ", log likelihood of " + ll);	    if (err < lowestError) {	      lowestError = err;	      m_bestCommittee = c;	    }	  }	}      }      if (m_currentSet.numInstances() >= m_maxBatchSizeRequired) {	m_currentSet = new Instances(m_currentSet, m_maxBatchSizeRequired);	// reset consumation counts	for (int i=0; i<m_committees.size(); i++) {	  Committee c = (Committee) m_committees.elementAt(i);	  c.resetConsumed();	}      }    }  }  /**   * Convert from function responses to probabilities   *   * @param R an array containing the responses from each function   * @param j the class value of interest   * @return the probability prediction for j   */  protected static double RtoP(double []Fs, int j)     throws Exception {    double maxF = -Double.MAX_VALUE;    for (int i = 0; i < Fs.length; i++) {      if (Fs[i] > maxF) {	maxF = Fs[i];      }    }    double sum = 0;    double[] probs = new double[Fs.length];    for (int i = 0; i < Fs.length; i++) {      probs[i] = Math.exp(Fs[i] - maxF);      sum += probs[i];    }    if (sum == 0) {      throw new Exception("Can't normalize");    }    return probs[j] / sum;  }  /**   * Computes class distribution of an instance using the best committee.   */  public double[] distributionForInstance(Instance instance) throws Exception {    if (m_bestCommittee != null) return m_bestCommittee.distributionForInstance(instance);    else {      if (m_validationSetChanged || m_zeroR == null) {	m_zeroR = new ZeroR();	m_zeroR.buildClassifier(m_validationSet);	m_validationSetChanged = false;      }      return m_zeroR.distributionForInstance(instance);    }  }  /**   * Returns an enumeration describing the available options   *   * @return an enumeration of all the available options   */  public Enumeration listOptions() {    Vector newVector = new Vector(9);    newVector.addElement(new Option(	      "\tTurn on debugging output.",	      "D", 0, "-D"));    newVector.addElement(new Option(	      "\tMinimum size of chunks.\n"	      +"\t(default 500)",	      "C", 1, "-C <num>"));    newVector.addElement(new Option(	      "\tMaximum size of chunks.\n"	      +"\t(default 20000)",	      "M", 1, "-M <num>"));    newVector.addElement(new Option(	      "\tSize of validation set.\n"	      +"\t(default 5000)",	      "V", 1, "-V <num>"));    newVector.addElement(new Option(	      "\tFull name of 'weak' learner to boost.\n"	      +"\teg: weka.classifiers.DecisionStump",	      "W", 1, "-W <learner class name>"));    newVector.addElement(new Option(	      "\tCommittee pruning to perform.\n"	      +"\t0=none, 1=log likelihood (default)",	      "P", 1, "-P <pruning type>"));    newVector.addElement(new Option(	      "\tUse resampling for boosting.",	      "Q", 0, "-Q"));    newVector.addElement(new Option(	      "\tSeed for resampling. (Default 1)",	      "S", 1, "-S <num>"));    if ((m_Classifier != null) &&	(m_Classifier instanceof OptionHandler)) {      newVector.addElement(new Option(	  "",	  "", 0, "\nOptions specific to weak learner "	  + m_Classifier.getClass().getName() + ":"));      Enumeration enum = ((OptionHandler)m_Classifier).listOptions();      while (enum.hasMoreElements()) {	newVector.addElement(enum.nextElement());      }    }    return newVector.elements();  }  /**   * Parses a given list of options. Valid options are:<p>   *   * @param options the list of options as an array of strings   * @exception Exception if an option is not supported   */  public void setOptions(String[] options) throws Exception {    String minChunkSize = Utils.getOption('C', options);    if (minChunkSize.length() != 0) {      setMinChunkSize(Integer.parseInt(minChunkSize));    } else {      setMinChunkSize(500);    }    String maxChunkSize = Utils.getOption('M', options);    if (maxChunkSize.length() != 0) {      setMaxChunkSize(Integer.parseInt(maxChunkSize));    } else {      setMaxChunkSize(20000);    }    String validationChunkSize = Utils.getOption('V', options);    if (validationChunkSize.length() != 0) {      setValidationChunkSize(Integer.parseInt(validationChunkSize));    } else {      setValidationChunkSize(5000);    }    String pruneType = Utils.getOption('P', options);    if (pruneType.length() != 0) {      setPruningType(new SelectedTag(Integer.parseInt(pruneType), TAGS_PRUNETYPE));    } else {      setPruningType(new SelectedTag(PRUNETYPE_LOGLIKELIHOOD, TAGS_PRUNETYPE));    }    setUseResampling(Utils.getFlag('Q', options));    String seedString = Utils.getOption('S', options);    if (seedString.length() != 0) {      setSeed(Integer.parseInt(seedString));    } else {      setSeed(1);    }    setDebug(Utils.getFlag('D', options));    String classifierName = Utils.getOption('W', options);    if (classifierName.length() == 0) {      throw new Exception("A classifier must be specified with"			  + " the -W option.");

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -