📄 racedincrementallogitboost.java
字号:
// Temporarily unset the class index int classIndex = data.classIndex(); boostData.setClassIndex(-1); boostData.deleteAttributeAt(classIndex); boostData.insertAttributeAt(new Attribute("'pseudo class'"), classIndex); boostData.setClassIndex(classIndex); double [][] trainFs = new double [numInstances][m_NumClasses]; double [][] trainYs = new double [numInstances][m_NumClasses]; for (int j = 0; j < m_NumClasses; j++) { for (int i = 0, k = 0; i < numInstances; i++, k++) { while (data.instance(k).classIsMissing()) k++; trainYs[i][j] = (data.instance(k).classValue() == j) ? 1 : 0; } } // Evaluate / increment trainFs from the classifiers for (int x = 0; x < m_models.size(); x++) { for (int i = 0; i < numInstances; i++) { double [] pred = new double [m_NumClasses]; double predSum = 0; Classifier[] model = (Classifier[]) m_models.elementAt(x); for (int j = 0; j < m_NumClasses; j++) { pred[j] = model[j].classifyInstance(boostData.instance(i)); predSum += pred[j]; } predSum /= m_NumClasses; for (int j = 0; j < m_NumClasses; j++) { trainFs[i][j] += (pred[j] - predSum) * (m_NumClasses-1) / m_NumClasses; } } } for (int j = 0; j < m_NumClasses; j++) { // Set instance pseudoclass and weights for (int i = 0; i < numInstances; i++) { double p = RtoP(trainFs[i], j); Instance current = boostData.instance(i); double z, actual = trainYs[i][j]; if (actual == 1) { z = 1.0 / p; if (z > Z_MAX) { // threshold z = Z_MAX; } } else if (actual == 0) { z = -1.0 / (1.0 - p); if (z < -Z_MAX) { // threshold z = -Z_MAX; } } else { z = (actual - p) / (p * (1 - p)); } double w = (actual - p) / z; current.setValue(classIndex, z); current.setWeight(numInstances * w); } Instances trainData = boostData; if (m_UseResampling) { double[] weights = new double[boostData.numInstances()]; for (int kk = 0; kk < weights.length; kk++) { weights[kk] = boostData.instance(kk).weight(); } trainData = boostData.resampleWithWeights(m_RandomInstance, weights); } // Build the classifier newModel[j].buildClassifier(trainData); } return newModel; } /* outputs description of the committee */ public String toString() { StringBuffer text = new StringBuffer(); text.append("RacedIncrementalLogitBoost: Best committee on validation data\n"); text.append("Base classifiers: \n"); for (int i = 0; i < m_models.size(); i++) { text.append("\nModel "+(i+1)); Classifier[] cModels = (Classifier[]) m_models.elementAt(i); for (int j = 0; j < m_NumClasses; j++) { text.append("\n\tClass " + (j + 1) + " (" + m_ClassAttribute.name() + "=" + m_ClassAttribute.value(j) + ")\n\n" + cModels[j].toString() + "\n"); } } text.append("Number of models: " + m_models.size() + "\n"); text.append("Chunk size per model: " + m_chunkSize + "\n"); return text.toString(); } } /** * Builds the classifier. * * @param instances the instances to train the classifier with * @exception Exception if something goes wrong */ public void buildClassifier(Instances data) throws Exception { m_RandomInstance = new Random(m_Seed); Instances boostData; int classIndex = data.classIndex(); if (data.classAttribute().isNumeric()) { throw new Exception("LogitBoost can't handle a numeric class!"); } if (m_Classifier == null) { throw new Exception("A base classifier has not been specified!"); } if (!(m_Classifier instanceof WeightedInstancesHandler) && !m_UseResampling) { m_UseResampling = true; } if (data.checkForStringAttributes()) { throw new Exception("Can't handle string attributes!"); } m_NumClasses = data.numClasses(); m_ClassAttribute = data.classAttribute(); // Create a copy of the data with the class transformed into numeric boostData = new Instances(data); boostData.deleteWithMissingClass(); // Temporarily unset the class index boostData.setClassIndex(-1); boostData.deleteAttributeAt(classIndex); boostData.insertAttributeAt(new Attribute("'pseudo class'"), classIndex); boostData.setClassIndex(classIndex); m_NumericClassData = new Instances(boostData, 0); data = new Instances(data); data.randomize(m_RandomInstance); // create the committees int cSize = m_minChunkSize; m_committees = new FastVector(); while (cSize <= m_maxChunkSize) { m_committees.addElement(new Committee(cSize)); m_maxBatchSizeRequired = cSize; cSize *= 2; } // set up for consumption m_validationSet = new Instances(data, m_validationChunkSize); m_currentSet = new Instances(data, m_maxBatchSizeRequired); m_bestCommittee = null; m_numInstancesConsumed = 0; // start eating what we've been given for (int i=0; i<data.numInstances(); i++) updateClassifier(data.instance(i)); } /** * Updates the classifier. * * @param instance the next instance in the stream of training data * @exception Exception if something goes wrong */ public void updateClassifier(Instance instance) throws Exception { m_numInstancesConsumed++; if (m_validationSet.numInstances() < m_validationChunkSize) { m_validationSet.add(instance); m_validationSetChanged = true; } else { m_currentSet.add(instance); boolean hasChanged = false; // update each committee for (int i=0; i<m_committees.size(); i++) { Committee c = (Committee) m_committees.elementAt(i); if (c.update()) { hasChanged = true; if (m_PruningType == PRUNETYPE_LOGLIKELIHOOD) { double oldLL = c.logLikelihood(); double newLL = c.logLikelihoodAfter(); if (newLL >= oldLL && c.committeeSize() > 1) { c.pruneLastModel(); if (m_Debug) System.out.println("Pruning " + c.chunkSize()+ " committee (" + oldLL + " < " + newLL + ")"); } else c.keepLastModel(); } else c.keepLastModel(); // no pruning } } if (hasChanged) { if (m_Debug) System.out.println("After consuming " + m_numInstancesConsumed + " instances... (" + m_validationSet.numInstances() + " + " + m_currentSet.numInstances() + " instances currently in memory)"); // find best committee double lowestError = 1.0; for (int i=0; i<m_committees.size(); i++) { Committee c = (Committee) m_committees.elementAt(i); if (c.committeeSize() > 0) { double err = c.validationError(); double ll = c.logLikelihood(); if (m_Debug) System.out.println("Chunk size " + c.chunkSize() + " with " + c.committeeSize() + " models, has validation error of " + err + ", log likelihood of " + ll); if (err < lowestError) { lowestError = err; m_bestCommittee = c; } } } } if (m_currentSet.numInstances() >= m_maxBatchSizeRequired) { m_currentSet = new Instances(m_currentSet, m_maxBatchSizeRequired); // reset consumation counts for (int i=0; i<m_committees.size(); i++) { Committee c = (Committee) m_committees.elementAt(i); c.resetConsumed(); } } } } /** * Convert from function responses to probabilities * * @param R an array containing the responses from each function * @param j the class value of interest * @return the probability prediction for j */ protected static double RtoP(double []Fs, int j) throws Exception { double maxF = -Double.MAX_VALUE; for (int i = 0; i < Fs.length; i++) { if (Fs[i] > maxF) { maxF = Fs[i]; } } double sum = 0; double[] probs = new double[Fs.length]; for (int i = 0; i < Fs.length; i++) { probs[i] = Math.exp(Fs[i] - maxF); sum += probs[i]; } if (sum == 0) { throw new Exception("Can't normalize"); } return probs[j] / sum; } /** * Computes class distribution of an instance using the best committee. */ public double[] distributionForInstance(Instance instance) throws Exception { if (m_bestCommittee != null) return m_bestCommittee.distributionForInstance(instance); else { if (m_validationSetChanged || m_zeroR == null) { m_zeroR = new ZeroR(); m_zeroR.buildClassifier(m_validationSet); m_validationSetChanged = false; } return m_zeroR.distributionForInstance(instance); } } /** * Returns an enumeration describing the available options * * @return an enumeration of all the available options */ public Enumeration listOptions() { Vector newVector = new Vector(9); newVector.addElement(new Option( "\tTurn on debugging output.", "D", 0, "-D")); newVector.addElement(new Option( "\tMinimum size of chunks.\n" +"\t(default 500)", "C", 1, "-C <num>")); newVector.addElement(new Option( "\tMaximum size of chunks.\n" +"\t(default 20000)", "M", 1, "-M <num>")); newVector.addElement(new Option( "\tSize of validation set.\n" +"\t(default 5000)", "V", 1, "-V <num>")); newVector.addElement(new Option( "\tFull name of 'weak' learner to boost.\n" +"\teg: weka.classifiers.DecisionStump", "W", 1, "-W <learner class name>")); newVector.addElement(new Option( "\tCommittee pruning to perform.\n" +"\t0=none, 1=log likelihood (default)", "P", 1, "-P <pruning type>")); newVector.addElement(new Option( "\tUse resampling for boosting.", "Q", 0, "-Q")); newVector.addElement(new Option( "\tSeed for resampling. (Default 1)", "S", 1, "-S <num>")); if ((m_Classifier != null) && (m_Classifier instanceof OptionHandler)) { newVector.addElement(new Option( "", "", 0, "\nOptions specific to weak learner " + m_Classifier.getClass().getName() + ":")); Enumeration enum = ((OptionHandler)m_Classifier).listOptions(); while (enum.hasMoreElements()) { newVector.addElement(enum.nextElement()); } } return newVector.elements(); } /** * Parses a given list of options. Valid options are:<p> * * @param options the list of options as an array of strings * @exception Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { String minChunkSize = Utils.getOption('C', options); if (minChunkSize.length() != 0) { setMinChunkSize(Integer.parseInt(minChunkSize)); } else { setMinChunkSize(500); } String maxChunkSize = Utils.getOption('M', options); if (maxChunkSize.length() != 0) { setMaxChunkSize(Integer.parseInt(maxChunkSize)); } else { setMaxChunkSize(20000); } String validationChunkSize = Utils.getOption('V', options); if (validationChunkSize.length() != 0) { setValidationChunkSize(Integer.parseInt(validationChunkSize)); } else { setValidationChunkSize(5000); } String pruneType = Utils.getOption('P', options); if (pruneType.length() != 0) { setPruningType(new SelectedTag(Integer.parseInt(pruneType), TAGS_PRUNETYPE)); } else { setPruningType(new SelectedTag(PRUNETYPE_LOGLIKELIHOOD, TAGS_PRUNETYPE)); } setUseResampling(Utils.getFlag('Q', options)); String seedString = Utils.getOption('S', options); if (seedString.length() != 0) { setSeed(Integer.parseInt(seedString)); } else { setSeed(1); } setDebug(Utils.getFlag('D', options)); String classifierName = Utils.getOption('W', options); if (classifierName.length() == 0) { throw new Exception("A classifier must be specified with" + " the -W option.");
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -