📄 racedincrementallogitboost.java
字号:
} } if (max > 0) { return maxIndex; } else { return Instance.missingValue(); } case Attribute.NUMERIC: return dist[0]; default: return Instance.missingValue(); } } /** * returns the distribution the committee generates for an instance (given Fs values) * * @param Fs the Fs values * @return the distribution * @throws Exception if anything goes wrong */ public double[] distributionForInstance(double[] Fs) throws Exception { double [] distribution = new double [m_NumClasses]; for (int j = 0; j < m_NumClasses; j++) { distribution[j] = RtoP(Fs, j); } return distribution; } /** * updates the Fs values given a new model in the committee * * @param instance the instance to use * @param newModel the new model * @param Fs the Fs values to update * @return the updated Fs values * @throws Exception if anything goes wrong */ public double[] updateFS(Instance instance, Classifier[] newModel, double[] Fs) throws Exception { instance = (Instance)instance.copy(); instance.setDataset(m_NumericClassData); double [] Fi = new double [m_NumClasses]; double Fsum = 0; for (int j = 0; j < m_NumClasses; j++) { Fi[j] = newModel[j].classifyInstance(instance); Fsum += Fi[j]; } Fsum /= m_NumClasses; double[] newFs = new double[Fs.length]; for (int j = 0; j < m_NumClasses; j++) { newFs[j] = Fs[j] + ((Fi[j] - Fsum) * (m_NumClasses - 1) / m_NumClasses); } return newFs; } /** * returns the distribution the committee generates for an instance * * @param instance the instance to get the distribution for * @return the distribution * @throws Exception if anything goes wrong */ public double[] distributionForInstance(Instance instance) throws Exception { instance = (Instance)instance.copy(); instance.setDataset(m_NumericClassData); double [] Fs = new double [m_NumClasses]; for (int i = 0; i < m_models.size(); i++) { double [] Fi = new double [m_NumClasses]; double Fsum = 0; Classifier[] model = (Classifier[]) m_models.elementAt(i); for (int j = 0; j < m_NumClasses; j++) { Fi[j] = model[j].classifyInstance(instance); Fsum += Fi[j]; } Fsum /= m_NumClasses; for (int j = 0; j < m_NumClasses; j++) { Fs[j] += (Fi[j] - Fsum) * (m_NumClasses - 1) / m_NumClasses; } } double [] distribution = new double [m_NumClasses]; for (int j = 0; j < m_NumClasses; j++) { distribution[j] = RtoP(Fs, j); } return distribution; } /** * performs a boosting iteration, returning a new model for the committee * * @param data the data to boost on * @return the new model * @throws Exception if anything goes wrong */ protected Classifier[] boost(Instances data) throws Exception { Classifier[] newModel = Classifier.makeCopies(m_Classifier, m_NumClasses); // Create a copy of the data with the class transformed into numeric Instances boostData = new Instances(data); boostData.deleteWithMissingClass(); int numInstances = boostData.numInstances(); // Temporarily unset the class index int classIndex = data.classIndex(); boostData.setClassIndex(-1); boostData.deleteAttributeAt(classIndex); boostData.insertAttributeAt(new Attribute("'pseudo class'"), classIndex); boostData.setClassIndex(classIndex); double [][] trainFs = new double [numInstances][m_NumClasses]; double [][] trainYs = new double [numInstances][m_NumClasses]; for (int j = 0; j < m_NumClasses; j++) { for (int i = 0, k = 0; i < numInstances; i++, k++) { while (data.instance(k).classIsMissing()) k++; trainYs[i][j] = (data.instance(k).classValue() == j) ? 1 : 0; } } // Evaluate / increment trainFs from the classifiers for (int x = 0; x < m_models.size(); x++) { for (int i = 0; i < numInstances; i++) { double [] pred = new double [m_NumClasses]; double predSum = 0; Classifier[] model = (Classifier[]) m_models.elementAt(x); for (int j = 0; j < m_NumClasses; j++) { pred[j] = model[j].classifyInstance(boostData.instance(i)); predSum += pred[j]; } predSum /= m_NumClasses; for (int j = 0; j < m_NumClasses; j++) { trainFs[i][j] += (pred[j] - predSum) * (m_NumClasses-1) / m_NumClasses; } } } for (int j = 0; j < m_NumClasses; j++) { // Set instance pseudoclass and weights for (int i = 0; i < numInstances; i++) { double p = RtoP(trainFs[i], j); Instance current = boostData.instance(i); double z, actual = trainYs[i][j]; if (actual == 1) { z = 1.0 / p; if (z > Z_MAX) { // threshold z = Z_MAX; } } else if (actual == 0) { z = -1.0 / (1.0 - p); if (z < -Z_MAX) { // threshold z = -Z_MAX; } } else { z = (actual - p) / (p * (1 - p)); } double w = (actual - p) / z; current.setValue(classIndex, z); current.setWeight(numInstances * w); } Instances trainData = boostData; if (m_UseResampling) { double[] weights = new double[boostData.numInstances()]; for (int kk = 0; kk < weights.length; kk++) { weights[kk] = boostData.instance(kk).weight(); } trainData = boostData.resampleWithWeights(m_RandomInstance, weights); } // Build the classifier newModel[j].buildClassifier(trainData); } return newModel; } /** * outputs description of the committee * * @return a string representation of the classifier */ public String toString() { StringBuffer text = new StringBuffer(); text.append("RacedIncrementalLogitBoost: Best committee on validation data\n"); text.append("Base classifiers: \n"); for (int i = 0; i < m_models.size(); i++) { text.append("\nModel "+(i+1)); Classifier[] cModels = (Classifier[]) m_models.elementAt(i); for (int j = 0; j < m_NumClasses; j++) { text.append("\n\tClass " + (j + 1) + " (" + m_ClassAttribute.name() + "=" + m_ClassAttribute.value(j) + ")\n\n" + cModels[j].toString() + "\n"); } } text.append("Number of models: " + m_models.size() + "\n"); text.append("Chunk size per model: " + m_chunkSize + "\n"); return text.toString(); } } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); // class result.disableAllClasses(); result.disableAllClassDependencies(); result.enable(Capability.NOMINAL_CLASS); // instances result.setMinimumNumberInstances(0); return result; } /** * Builds the classifier. * * @param data the instances to train the classifier with * @throws Exception if something goes wrong */ public void buildClassifier(Instances data) throws Exception { m_RandomInstance = new Random(m_Seed); Instances boostData; int classIndex = data.classIndex(); // can classifier handle the data? getCapabilities().testWithFail(data); // remove instances with missing class data = new Instances(data); data.deleteWithMissingClass(); if (m_Classifier == null) { throw new Exception("A base classifier has not been specified!"); } if (!(m_Classifier instanceof WeightedInstancesHandler) && !m_UseResampling) { m_UseResampling = true; } m_NumClasses = data.numClasses(); m_ClassAttribute = data.classAttribute(); // Create a copy of the data with the class transformed into numeric boostData = new Instances(data); // Temporarily unset the class index boostData.setClassIndex(-1); boostData.deleteAttributeAt(classIndex); boostData.insertAttributeAt(new Attribute("'pseudo class'"), classIndex); boostData.setClassIndex(classIndex); m_NumericClassData = new Instances(boostData, 0); data.randomize(m_RandomInstance); // create the committees int cSize = m_minChunkSize; m_committees = new FastVector(); while (cSize <= m_maxChunkSize) { m_committees.addElement(new Committee(cSize)); m_maxBatchSizeRequired = cSize; cSize *= 2; } // set up for consumption m_validationSet = new Instances(data, m_validationChunkSize); m_currentSet = new Instances(data, m_maxBatchSizeRequired); m_bestCommittee = null; m_numInstancesConsumed = 0; // start eating what we've been given for (int i=0; i<data.numInstances(); i++) updateClassifier(data.instance(i)); } /** * Updates the classifier. * * @param instance the next instance in the stream of training data * @throws Exception if something goes wrong */ public void updateClassifier(Instance instance) throws Exception { m_numInstancesConsumed++; if (m_validationSet.numInstances() < m_validationChunkSize) { m_validationSet.add(instance); m_validationSetChanged = true; } else { m_currentSet.add(instance); boolean hasChanged = false; // update each committee for (int i=0; i<m_committees.size(); i++) { Committee c = (Committee) m_committees.elementAt(i); if (c.update()) { hasChanged = true; if (m_PruningType == PRUNETYPE_LOGLIKELIHOOD) { double oldLL = c.logLikelihood(); double newLL = c.logLikelihoodAfter(); if (newLL >= oldLL && c.committeeSize() > 1) { c.pruneLastModel(); if (m_Debug) System.out.println("Pruning " + c.chunkSize()+ " committee (" + oldLL + " < " + newLL + ")"); } else c.keepLastModel(); } else c.keepLastModel(); // no pruning } } if (hasChanged) { if (m_Debug) System.out.println("After consuming " + m_numInstancesConsumed + " instances... (" + m_validationSet.numInstances() + " + " + m_currentSet.numInstances() + " instances currently in memory)"); // find best committee double lowestError = 1.0; for (int i=0; i<m_committees.size(); i++) { Committee c = (Committee) m_committees.elementAt(i); if (c.committeeSize() > 0) { double err = c.validationError(); double ll = c.logLikelihood(); if (m_Debug) System.out.println("Chunk size " + c.chunkSize() + " with " + c.committeeSize() + " models, has validation error of " + err + ", log likelihood of " + ll); if (err < lowestError) { lowestError = err; m_bestCommittee = c; } } } } if (m_currentSet.numInstances() >= m_maxBatchSizeRequired) { m_currentSet = new Instances(m_currentSet, m_maxBatchSizeRequired); // reset consumation counts for (int i=0; i<m_committees.size(); i++) { Committee c = (Committee) m_committees.elementAt(i); c.resetConsumed(); } } } } /** * Convert from function responses to probabilities * * @param Fs an array containing the responses from each function * @param j the class value of interest * @return the probability prediction for j * @throws Exception if can't normalize */ protected static double RtoP(double []Fs, int j) throws Exception { double maxF = -Double.MAX_VALUE; for (int i = 0; i < Fs.length; i++) { if (Fs[i] > maxF) { maxF = Fs[i]; } } double sum = 0; double[] probs = new double[Fs.length]; for (int i = 0; i < Fs.length; i++) { probs[i] = Math.exp(Fs[i] - maxF); sum += probs[i]; } if (sum == 0) { throw new Exception("Can't normalize"); } return probs[j] / sum; } /** * Computes class distribution of an instance using the best committee. * * @param instance the instance to get the distribution for * @return the distribution * @throws Exception if anything goes wrong */ public double[] distributionForInstance(Instance instance) throws Exception { if (m_bestCommittee != null) return m_bestCommittee.distributionForInstance(instance); else { if (m_validationSetChanged || m_zeroR == null) { m_zeroR = new ZeroR(); m_zeroR.buildClassifier(m_validationSet); m_validationSetChanged = false; } return m_zeroR.distributionForInstance(instance); } } /** * Returns an enumeration describing the available options * * @return an enumeration of all the available options */ public Enumeration listOptions() {
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -