📄 ensembleselection.java
字号:
System.out.println("" + modelPerformance[i] + " " + modelNamesArray[modelIndexes[i]]); } } // We're now ready to build our array of the models which were chosen // and there associated weights. m_chosen_models = new EnsembleSelectionLibraryModel[chosenModels]; m_chosen_model_weights = new int[chosenModels]; libraryIndex = 0; // chosenIndex indexes over the models which were chosen by // EnsembleSelection // (those which have non-zero weight). int chosenIndex = 0; iter = m_library.getModels().iterator(); while (iter.hasNext()) { int weightOfModel = modelWeights[libraryIndex++]; EnsembleSelectionLibraryModel model = (EnsembleSelectionLibraryModel) iter .next(); if (weightOfModel > 0) { // If the model was chosen at least once, add it to our array // of chosen models and weights. m_chosen_models[chosenIndex] = model; m_chosen_model_weights[chosenIndex] = weightOfModel; // Note that the EnsembleSelectionLibraryModel may not be // "loaded" - // that is, its classifier(s) may be null pointers. That's okay // - // we'll "rehydrate" them later, if and when we need to. ++chosenIndex; } } } /** * Calculates the class membership probabilities for the given test instance. * * @param instance the instance to be classified * @return predicted class probability distribution * @throws Exception if instance could not be classified * successfully */ public double[] distributionForInstance(Instance instance) throws Exception { String stringInstance = instance.toString(); double cachedPreds[][] = null; if (m_cachedPredictions != null) { // If we have any cached predictions (i.e., if cachePredictions was // called), look for a cached set of predictions for this instance. if (m_cachedPredictions.containsKey(stringInstance)) { cachedPreds = (double[][]) m_cachedPredictions.get(stringInstance); } } double[] prediction = new double[instance.numClasses()]; for (int i = 0; i < prediction.length; ++i) { prediction[i] = 0.0; } // Now do a weighted average of the predictions of each of our models. for (int i = 0; i < m_chosen_models.length; ++i) { double[] predictionForThisModel = null; if (cachedPreds == null) { // If there are no predictions cached, we'll load the model's // classifier(s) in to memory and get the predictions. m_chosen_models[i].rehydrateModel(m_workingDirectory.getAbsolutePath()); predictionForThisModel = m_chosen_models[i].getAveragePrediction(instance); // We could release the model here to save memory, but we assume // that there is enough available since we're not using the // prediction caching functionality. If we load and release a // model // every time we need to get a prediction for an instance, it // can be // prohibitively slow. } else { // If it's cached, just get it from the array of cached preds // for this instance. predictionForThisModel = cachedPreds[i]; } // We have encountered a bug where MultilayerPerceptron returns a // null // prediction array. If that happens, we just don't count that model // in // our ensemble prediction. if (predictionForThisModel != null) { // Okay, the model returned a valid prediction array, so we'll // add the appropriate fraction of this model's prediction. for (int j = 0; j < prediction.length; ++j) { prediction[j] += m_chosen_model_weights[i] * predictionForThisModel[j] / m_total_weight; } } } // normalize to add up to 1. if (instance.classAttribute().isNominal()) { if (Utils.sum(prediction) > 0) Utils.normalize(prediction); } return prediction; } /** * This function tests whether or not a given path is appropriate for being * the working directory. Specifically, we care that we can write to the * path and that it doesn't point to a "non-directory" file handle. * * @param dir the directory to test * @return true if the directory is valid */ private boolean validWorkingDirectory(String dir) { boolean valid = false; File f = new File((dir)); if (f.exists()) { if (f.isDirectory() && f.canWrite()) valid = true; } else { if (f.canWrite()) valid = true; } return valid; } /** * This method tries to find a reasonable path name for the ensemble working * directory where models and files will be stored. * * * @return true if m_workingDirectory now has a valid file name */ public static String getDefaultWorkingDirectory() { String defaultDirectory = new String(""); boolean success = false; int i = 1; while (i < MAX_DEFAULT_DIRECTORIES && !success) { File f = new File(System.getProperty("user.home"), "Ensemble-" + i); if (!f.exists() && f.getParentFile().canWrite()) { defaultDirectory = f.getPath(); success = true; } i++; } if (!success) { defaultDirectory = new String(""); // should we print an error or something? } return defaultDirectory; } /** * Output a representation of this classifier * * @return a string representation of the classifier */ public String toString() { // We just print out the models which were selected, and the number // of times each was selected. String result = new String(); if (m_chosen_models != null) { for (int i = 0; i < m_chosen_models.length; ++i) { result += m_chosen_model_weights[i]; result += " " + m_chosen_models[i].getStringRepresentation() + "\n"; } } else { result = "No models selected."; } return result; } /** * Cache predictions for the individual base classifiers in the ensemble * with respect to the given dataset. This is used so that when testing a * large ensemble on a test set, we don't have to keep the models in memory. * * @param test The instances for which to cache predictions. * @throws Exception if somethng goes wrong */ private void cachePredictions(Instances test) throws Exception { m_cachedPredictions = new HashMap(); Evaluation evalModel = null; Instances originalInstances = null; // If the verbose flag is set, we'll also print out the performances of // all the individual models w.r.t. this test set while we're at it. boolean printModelPerformances = getVerboseOutput(); if (printModelPerformances) { // To get performances, we need to keep the class attribute. originalInstances = new Instances(test); } // For each model, we'll go through the dataset and get predictions. // The idea is we want to only have one model in memory at a time, so // we'll // load one model in to memory, get all its predictions, and add them to // the // hash map. Then we can release it from memory and move on to the next. for (int i = 0; i < m_chosen_models.length; ++i) { if (printModelPerformances) { // If we're going to print predictions, we need to make a new // Evaluation object. evalModel = new Evaluation(originalInstances); } Date startTime = new Date(); // Load the model in to memory. m_chosen_models[i].rehydrateModel(m_workingDirectory.getAbsolutePath()); // Now loop through all the instances and get the model's // predictions. for (int j = 0; j < test.numInstances(); ++j) { Instance currentInstance = test.instance(j); // When we're looking for a cached prediction later, we'll only // have the non-class attributes, so we set the class missing // here // in order to make the string match up properly. currentInstance.setClassMissing(); String stringInstance = currentInstance.toString(); // When we come in here with the first model, the instance will // not // yet be part of the map. if (!m_cachedPredictions.containsKey(stringInstance)) { // The instance isn't in the map yet, so add it. // For each instance, we store a two-dimensional array - the // first // index is over all the models in the ensemble, and the // second // index is over the (i.e., typical prediction array). int predSize = test.classAttribute().isNumeric() ? 1 : test .classAttribute().numValues(); double predictionArray[][] = new double[m_chosen_models.length][predSize]; m_cachedPredictions.put(stringInstance, predictionArray); } // Get the array from the map which is associated with this // instance double predictions[][] = (double[][]) m_cachedPredictions .get(stringInstance); // And add our model's prediction for it. predictions[i] = m_chosen_models[i].getAveragePrediction(test .instance(j)); if (printModelPerformances) { evalModel.evaluateModelOnceAndRecordPrediction( predictions[i], originalInstances.instance(j)); } } // Now we're done with model #i, so we can release it. m_chosen_models[i].releaseModel(); Date endTime = new Date(); long diff = endTime.getTime() - startTime.getTime(); if (m_Debug) System.out.println("Test time for " + m_chosen_models[i].getStringRepresentation() + " was: " + diff); if (printModelPerformances) { String output = new String(m_chosen_models[i] .getStringRepresentation() + ": "); output += "\tRMSE:" + evalModel.rootMeanSquaredError(); output += "\tACC:" + evalModel.pctCorrect(); if (test.numClasses() == 2) { // For multiclass problems, we could print these too, but // it's // not clear which class we should use in that case... so // instead // we only print these metrics for binary classification // problems. output += "\tROC:" + evalModel.areaUnderROC(1); output += "\tPREC:" + evalModel.precision(1); output += "\tFSCR:" + evalModel.fMeasure(1); } System.out.println(output); } } } /** * Return the technical information. There is actually another * paper that describes our current method of CV for this classifier * TODO: Cite Technical report when published * * @return the technical information about this class */ public TechnicalInformation getTechnicalInformation() { TechnicalInformation result; result = new TechnicalInformation(Type.INPROCEEDINGS); result.setValue(Field.AUTHOR, "Rich Caruana, Alex Niculescu, Geoff Crew, and Alex Ksikes"); result.setValue(Field.TITLE, "Ensemble Selection from Libraries of Models"); result.setValue(Field.BOOKTITLE, "21st International Conference on Machine Learning"); result.setValue(Field.YEAR, "2004"); return result; } /** * Executes the classifier from commandline. * * @param argv * should contain the following arguments: -t training file [-T * test file] [-c class index] */ public static void main(String[] argv) { try { String options[] = (String[]) argv.clone(); // do we get the input from XML instead of normal parameters? String xml = Utils.getOption("xml", options); if (!xml.equals("")) options = new XMLOptions(xml).toArray(); String trainFileName = Utils.getOption('t', options); String objectInputFileName = Utils.getOption('l', options); String testFileName = Utils.getOption('T', options); if (testFileName.length() != 0 && objectInputFileName.length() != 0 && trainFileName.length() == 0) { System.out.println("Caching predictions"); EnsembleSelection classifier = null; BufferedReader testReader = new BufferedReader(new FileReader( testFileName)); // Set up the Instances Object Instances test; int classIndex = -1; String classIndexString = Utils.getOption('c', options); if (classIndexString.length() != 0) { classIndex = Integer.parseInt(classIndexString); } test = new Instances(testReader, 1); if (classIndex != -1) { test.setClassIndex(classIndex - 1); } else { test.setClassIndex(test.numAttributes() - 1); } if (classIndex > test.numAttributes()) { throw new Exception("Index of class attribute too large."); } while (test.readInstance(testReader)) { } testReader.close(); // Now yoink the EnsembleSelection Object from the fileSystem InputStream is = new FileInputStream(objectInputFileName); if (objectInputFileName.endsWith(".gz")) { is = new GZIPInputStream(is); } // load from KOML? if (!(objectInputFileName.endsWith("UpdateableClassifier.koml") && KOML .isPresent())) { ObjectInputStream objectInputStream = new ObjectInputStream( is); classifier = (EnsembleSelection) objectInputStream .readObject(); objectInputStream.close(); } else { BufferedInputStream xmlInputStream = new BufferedInputStream( is); classifier = (EnsembleSelection) KOML.read(xmlInputStream); xmlInputStream.close(); } String workingDir = Utils.getOption('W', argv); if (!workingDir.equals("")) { classifier.setWorkingDirectory(new File(workingDir)); } classifier.setDebug(Utils.getFlag('D', argv)); classifier.setVerboseOutput(Utils.getFlag('O', argv)); classifier.cachePredictions(test); // Now we write the model back out to the file system. String objectOutputFileName = objectInputFileName; OutputStream os = new FileOutputStream(objectOutputFileName); // binary if (!(objectOutputFileName.endsWith(".xml") || (objectOutputFileName .endsWith(".koml") && KOML.isPresent()))) { if (objectOutputFileName.endsWith(".gz")) { os = new GZIPOutputStream(os); } ObjectOutputStream objectOutputStream = new ObjectOutputStream( os); objectOutputStream.writeObject(classifier); objectOutputStream.flush(); objectOutputStream.close(); } // KOML/XML else { BufferedOutputStream xmlOutputStream = new BufferedOutputStream( os); if (objectOutputFileName.endsWith(".xml")) { XMLSerialization xmlSerial = new XMLClassifier(); xmlSerial.write(xmlOutputStream, classifier); } else // whether KOML is present has already been checked // if not present -> ".koml" is interpreted as binary - see // above if (objectOutputFileName.endsWith(".koml")) { KOML.write(xmlOutputStream, classifier); } xmlOutputStream.close(); } } System.out.println(Evaluation.evaluateModel( new EnsembleSelection(), argv)); } catch (Exception e) { if ( (e.getMessage() != null) && (e.getMessage().indexOf("General options") == -1) ) e.printStackTrace(); else System.err.println(e.getMessage()); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -