⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 racedincrementallogitboost.java

📁 代码是一个分类器的实现,其中使用了部分weka的源代码。可以将项目导入eclipse运行
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
	  }	}	if (max > 0) {	  return maxIndex;	} else {	  return Instance.missingValue();	}      case Attribute.NUMERIC:	return dist[0];      default:	return Instance.missingValue();      }    }    /**      * returns the distribution the committee generates for an instance (given Fs values)      *      * @param Fs the Fs values     * @return the distribution     * @throws Exception if anything goes wrong     */    public double[] distributionForInstance(double[] Fs) throws Exception {            double [] distribution = new double [m_NumClasses];      for (int j = 0; j < m_NumClasses; j++) {	distribution[j] = RtoP(Fs, j);      }      return distribution;    }        /**      * updates the Fs values given a new model in the committee      *      * @param instance the instance to use     * @param newModel the new model     * @param Fs the Fs values to update     * @return the updated Fs values     * @throws Exception if anything goes wrong     */    public double[] updateFS(Instance instance, Classifier[] newModel, double[] Fs) throws Exception {            instance = (Instance)instance.copy();      instance.setDataset(m_NumericClassData);            double [] Fi = new double [m_NumClasses];      double Fsum = 0;      for (int j = 0; j < m_NumClasses; j++) {	Fi[j] = newModel[j].classifyInstance(instance);	Fsum += Fi[j];      }      Fsum /= m_NumClasses;            double[] newFs = new double[Fs.length];      for (int j = 0; j < m_NumClasses; j++) {	newFs[j] = Fs[j] + ((Fi[j] - Fsum) * (m_NumClasses - 1) / m_NumClasses);      }      return newFs;    }    /**      * returns the distribution the committee generates for an instance     *      * @param instance the instance to get the distribution for     * @return the distribution     * @throws Exception if anything goes wrong     */    public double[] distributionForInstance(Instance instance) throws Exception {      instance = (Instance)instance.copy();      instance.setDataset(m_NumericClassData);      double [] Fs = new double [m_NumClasses];       for (int i = 0; i < m_models.size(); i++) {	double [] Fi = new double [m_NumClasses];	double Fsum = 0;	Classifier[] model = (Classifier[]) m_models.elementAt(i);	for (int j = 0; j < m_NumClasses; j++) {	  Fi[j] = model[j].classifyInstance(instance);	  Fsum += Fi[j];	}	Fsum /= m_NumClasses;	for (int j = 0; j < m_NumClasses; j++) {	  Fs[j] += (Fi[j] - Fsum) * (m_NumClasses - 1) / m_NumClasses;	}      }      double [] distribution = new double [m_NumClasses];      for (int j = 0; j < m_NumClasses; j++) {	distribution[j] = RtoP(Fs, j);      }      return distribution;    }    /**      * performs a boosting iteration, returning a new model for the committee     *      * @param data the data to boost on     * @return the new model     * @throws Exception if anything goes wrong     */    protected Classifier[] boost(Instances data) throws Exception {            Classifier[] newModel = Classifier.makeCopies(m_Classifier, m_NumClasses);            // Create a copy of the data with the class transformed into numeric      Instances boostData = new Instances(data);      boostData.deleteWithMissingClass();      int numInstances = boostData.numInstances();            // Temporarily unset the class index      int classIndex = data.classIndex();      boostData.setClassIndex(-1);      boostData.deleteAttributeAt(classIndex);      boostData.insertAttributeAt(new Attribute("'pseudo class'"), classIndex);      boostData.setClassIndex(classIndex);      double [][] trainFs = new double [numInstances][m_NumClasses];      double [][] trainYs = new double [numInstances][m_NumClasses];      for (int j = 0; j < m_NumClasses; j++) {	for (int i = 0, k = 0; i < numInstances; i++, k++) {	  while (data.instance(k).classIsMissing()) k++;	  trainYs[i][j] = (data.instance(k).classValue() == j) ? 1 : 0;	}      }            // Evaluate / increment trainFs from the classifiers      for (int x = 0; x < m_models.size(); x++) {	for (int i = 0; i < numInstances; i++) {	  double [] pred = new double [m_NumClasses];	  double predSum = 0;	  Classifier[] model = (Classifier[]) m_models.elementAt(x);	  for (int j = 0; j < m_NumClasses; j++) {	    pred[j] = model[j].classifyInstance(boostData.instance(i));	    predSum += pred[j];	  }	  predSum /= m_NumClasses;	  for (int j = 0; j < m_NumClasses; j++) {	    trainFs[i][j] += (pred[j] - predSum) * (m_NumClasses-1) 	      / m_NumClasses;	  }	}      }      for (int j = 0; j < m_NumClasses; j++) {		// Set instance pseudoclass and weights	for (int i = 0; i < numInstances; i++) {	  double p = RtoP(trainFs[i], j);	  Instance current = boostData.instance(i);	  double z, actual = trainYs[i][j];	  if (actual == 1) {	    z = 1.0 / p;	    if (z > Z_MAX) { // threshold	      z = Z_MAX;	    }	  } else if (actual == 0) {	    z = -1.0 / (1.0 - p);	    if (z < -Z_MAX) { // threshold	      z = -Z_MAX;	    }	  } else {	    z = (actual - p) / (p * (1 - p));	  }	  double w = (actual - p) / z;	  current.setValue(classIndex, z);	  current.setWeight(numInstances * w);	}		Instances trainData = boostData;	if (m_UseResampling) {	  double[] weights = new double[boostData.numInstances()];	  for (int kk = 0; kk < weights.length; kk++) {	    weights[kk] = boostData.instance(kk).weight();	  }	  trainData = boostData.resampleWithWeights(m_RandomInstance, 						    weights);	}		// Build the classifier	newModel[j].buildClassifier(trainData);      }                  return newModel;    }    /**      * outputs description of the committee     *      * @return a string representation of the classifier     */    public String toString() {            StringBuffer text = new StringBuffer();            text.append("RacedIncrementalLogitBoost: Best committee on validation data\n");      text.append("Base classifiers: \n");            for (int i = 0; i < m_models.size(); i++) {	text.append("\nModel "+(i+1));	Classifier[] cModels = (Classifier[]) m_models.elementAt(i);	for (int j = 0; j < m_NumClasses; j++) {	  text.append("\n\tClass " + (j + 1) 		      + " (" + m_ClassAttribute.name() 		      + "=" + m_ClassAttribute.value(j) + ")\n\n"		      + cModels[j].toString() + "\n");	}      }      text.append("Number of models: " +		  m_models.size() + "\n");            text.append("Chunk size per model: " + m_chunkSize + "\n");            return text.toString();    }  }  /**   * Returns default capabilities of the classifier.   *   * @return      the capabilities of this classifier   */  public Capabilities getCapabilities() {    Capabilities result = super.getCapabilities();    // class    result.disableAllClasses();    result.disableAllClassDependencies();    result.enable(Capability.NOMINAL_CLASS);    // instances    result.setMinimumNumberInstances(0);        return result;  } /**   * Builds the classifier.   *   * @param data the instances to train the classifier with   * @throws Exception if something goes wrong   */  public void buildClassifier(Instances data) throws Exception {    m_RandomInstance = new Random(m_Seed);    Instances boostData;    int classIndex = data.classIndex();    // can classifier handle the data?    getCapabilities().testWithFail(data);    // remove instances with missing class    data = new Instances(data);    data.deleteWithMissingClass();        if (m_Classifier == null) {      throw new Exception("A base classifier has not been specified!");    }    if (!(m_Classifier instanceof WeightedInstancesHandler) &&	!m_UseResampling) {      m_UseResampling = true;    }    m_NumClasses = data.numClasses();    m_ClassAttribute = data.classAttribute();    // Create a copy of the data with the class transformed into numeric    boostData = new Instances(data);    // Temporarily unset the class index    boostData.setClassIndex(-1);    boostData.deleteAttributeAt(classIndex);    boostData.insertAttributeAt(new Attribute("'pseudo class'"), classIndex);    boostData.setClassIndex(classIndex);    m_NumericClassData = new Instances(boostData, 0);    data.randomize(m_RandomInstance);    // create the committees    int cSize = m_minChunkSize;    m_committees = new FastVector();    while (cSize <= m_maxChunkSize) {      m_committees.addElement(new Committee(cSize));      m_maxBatchSizeRequired = cSize;      cSize *= 2;    }    // set up for consumption    m_validationSet = new Instances(data, m_validationChunkSize);    m_currentSet = new Instances(data, m_maxBatchSizeRequired);    m_bestCommittee = null;    m_numInstancesConsumed = 0;    // start eating what we've been given    for (int i=0; i<data.numInstances(); i++) updateClassifier(data.instance(i));  } /**   * Updates the classifier.   *   * @param instance the next instance in the stream of training data   * @throws Exception if something goes wrong   */  public void updateClassifier(Instance instance) throws Exception {    m_numInstancesConsumed++;    if (m_validationSet.numInstances() < m_validationChunkSize) {      m_validationSet.add(instance);      m_validationSetChanged = true;    } else {      m_currentSet.add(instance);      boolean hasChanged = false;            // update each committee      for (int i=0; i<m_committees.size(); i++) {	Committee c = (Committee) m_committees.elementAt(i);	if (c.update()) {	  	  hasChanged = true;	  	  if (m_PruningType == PRUNETYPE_LOGLIKELIHOOD) {	    double oldLL = c.logLikelihood();	    double newLL = c.logLikelihoodAfter();	    if (newLL >= oldLL && c.committeeSize() > 1) {	      c.pruneLastModel();	      if (m_Debug) System.out.println("Pruning " + c.chunkSize()+ " committee (" +					      oldLL + " < " + newLL + ")");	    } else c.keepLastModel();	  } else c.keepLastModel(); // no pruning	}       }      if (hasChanged) {	if (m_Debug) System.out.println("After consuming " + m_numInstancesConsumed					+ " instances... (" + m_validationSet.numInstances()					+ " + " + m_currentSet.numInstances()					+ " instances currently in memory)");		// find best committee	double lowestError = 1.0;	for (int i=0; i<m_committees.size(); i++) {	  Committee c = (Committee) m_committees.elementAt(i);	  if (c.committeeSize() > 0) {	    double err = c.validationError();	    double ll = c.logLikelihood();	    if (m_Debug) System.out.println("Chunk size " + c.chunkSize() + " with "					    + c.committeeSize() + " models, has validation error of "					    + err + ", log likelihood of " + ll);	    if (err < lowestError) {	      lowestError = err;	      m_bestCommittee = c;	    }	  }	}      }      if (m_currentSet.numInstances() >= m_maxBatchSizeRequired) {	m_currentSet = new Instances(m_currentSet, m_maxBatchSizeRequired);	// reset consumation counts	for (int i=0; i<m_committees.size(); i++) {	  Committee c = (Committee) m_committees.elementAt(i);	  c.resetConsumed();	}      }    }  }  /**   * Convert from function responses to probabilities   *   * @param Fs an array containing the responses from each function   * @param j the class value of interest   * @return the probability prediction for j   * @throws Exception if can't normalize   */  protected static double RtoP(double []Fs, int j)     throws Exception {    double maxF = -Double.MAX_VALUE;    for (int i = 0; i < Fs.length; i++) {      if (Fs[i] > maxF) {	maxF = Fs[i];      }    }    double sum = 0;    double[] probs = new double[Fs.length];    for (int i = 0; i < Fs.length; i++) {      probs[i] = Math.exp(Fs[i] - maxF);      sum += probs[i];    }    if (sum == 0) {      throw new Exception("Can't normalize");    }    return probs[j] / sum;  }  /**   * Computes class distribution of an instance using the best committee.   *    * @param instance the instance to get the distribution for   * @return the distribution   * @throws Exception if anything goes wrong   */  public double[] distributionForInstance(Instance instance) throws Exception {    if (m_bestCommittee != null) return m_bestCommittee.distributionForInstance(instance);    else {      if (m_validationSetChanged || m_zeroR == null) {	m_zeroR = new ZeroR();	m_zeroR.buildClassifier(m_validationSet);	m_validationSetChanged = false;      }      return m_zeroR.distributionForInstance(instance);    }  }  /**   * Returns an enumeration describing the available options   *   * @return an enumeration of all the available options   */  public Enumeration listOptions() {

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -