📄 ensembleselectionlibrarymodel.java
字号:
throw new EnsembleModelMismatchException("Seeds " + newModel.getSeed() + " and " + getSeed() + " not equal"); if (newModel.getFolds() != getFolds()) throw new EnsembleModelMismatchException("Folds " + newModel.getFolds() + " and " + getFolds() + " not equal"); if (newModel.getValidationRatio() != getValidationRatio()) throw new EnsembleModelMismatchException("Validation Ratios " + newModel.getValidationRatio() + " and " + getValidationRatio() + " not equal"); // setFileName(modelFileName); m_models = newModel.getModels(); m_validationPredictions = newModel.getValidationPredictions(); Date endTime = new Date(); int diff = (int) (endTime.getTime() - startTime.getTime()); if (m_Debug) System.out.println("Time to load " + modelFileName + " was: " + diff); } } /** * The purpose of this method is to "rehydrate" the classifier object fot * this library model from the filesystem. * * @param workingDirectory the working directory to use */ public void rehydrateModel(String workingDirectory) { if (m_models == null) { File file = new File(workingDirectory, m_fileName); if (m_Debug) System.out.println("Rehydrating Model: " + file.getPath()); EnsembleSelectionLibraryModel model = EnsembleSelectionLibraryModel .loadModel(file.getPath()); m_models = model.getModels(); } } /** * Releases the model from memory. TODO - need to be saving these so we can * retrieve them later!! */ public void releaseModel() { /* * if (m_unsaved) { saveModel(); } */ m_models = null; } /** * Train the classifier for the specified fold on the given data * * @param trainData the data to train with * @param fold the fold number * @throws Exception if something goes wrong, e.g., out of memory */ public void train(Instances trainData, int fold) throws Exception { if (m_models != null) { try { // OK, this is it... this is the point where our code surrenders // to the weka classifiers. m_models[fold].buildClassifier(trainData); } catch (Throwable t) { m_models[fold] = null; throw new Exception( "Exception caught while training: (null could mean out of memory)" + t.getMessage()); } } else { throw new Exception("Cannot train: model was null"); // TODO: throw Exception? } } /** * Set the seed * * @param seed the seed value */ public void setSeed(int seed) { m_seed = seed; } /** * Get the seed * * @return the seed value */ public int getSeed() { return m_seed; } /** * Sets the validation set ratio (only meaningful if folds == 1) * * @param validationRatio the new ration */ public void setValidationRatio(double validationRatio) { m_validationRatio = validationRatio; } /** * get validationRatio * * @return the current ratio */ public double getValidationRatio() { return m_validationRatio; } /** * Set the number of folds for cross validation. The number of folds also * indicates how many classifiers will be built to represent this model. * * @param folds the number of folds to use */ public void setFolds(int folds) { m_folds = folds; } /** * get the number of folds * * @return the current number of folds */ public int getFolds() { return m_folds; } /** * set the checksum * * @param instancesChecksum the new checksum */ public void setChecksum(String instancesChecksum) { m_checksum = instancesChecksum; } /** * get the checksum * * @return the current checksum */ public String getChecksum() { return m_checksum; } /** * Returs the array of classifiers * * @return the current models */ public Classifier[] getModels() { return m_models; } /** * Sets the .elm file name for this library model * * @param fileName the new filename */ public void setFileName(String fileName) { m_fileName = fileName; } /** * Gets a checksum for the string defining this classifier. This is used to * preserve uniqueness in the classifier names. * * @param string the classifier definition * @return the checksum string */ public static String getStringChecksum(String string) { String checksumString = null; try { Adler32 checkSummer = new Adler32(); byte[] utf8 = string.toString().getBytes("UTF8"); ; checkSummer.update(utf8); checksumString = Long.toHexString(checkSummer.getValue()); } catch (UnsupportedEncodingException e) { // TODO Auto-generated catch block e.printStackTrace(); } return checksumString; } /** * The purpose of this method is to get an appropriate file name for a model * based on its string representation of a model. All generated filenames * are limited to less than 128 characters and all of them will end with a * 64 bit checksum value of their string representation to try to maintain * some uniqueness of file names. * * @param stringRepresentation string representation of model * @return unique filename */ public static String getFileName(String stringRepresentation) { // Get rid of space and quote marks(windows doesn't lke them) String fileName = stringRepresentation.trim().replace(' ', '_') .replace('"', '_'); if (fileName.length() > 115) { fileName = fileName.substring(0, 115); } fileName += getStringChecksum(stringRepresentation) + EnsembleSelectionLibraryModel.FILE_EXTENSION; return fileName; } /** * Saves the given model to the specified file. * * @param directory the directory to save the model to * @param model the model to save */ public static void saveModel(String directory, EnsembleSelectionLibraryModel model) { try { String fileName = getFileName(model.getStringRepresentation()); File file = new File(directory, fileName); // System.out.println("Saving model: "+file.getPath()); // model.setFileName(new String(file.getPath())); // Serialize to a file ObjectOutput out = new ObjectOutputStream( new FileOutputStream(file)); out.writeObject(model); out.close(); } catch (IOException e) { e.printStackTrace(); } } /** * loads the specified model * * @param modelFilePath the path of the model * @return the model */ public static EnsembleSelectionLibraryModel loadModel(String modelFilePath) { EnsembleSelectionLibraryModel model = null; try { File file = new File(modelFilePath); ObjectInputStream in = new ObjectInputStream(new FileInputStream( file)); model = (EnsembleSelectionLibraryModel) in.readObject(); in.close(); } catch (ClassNotFoundException e) { e.printStackTrace(); } catch (IOException e) { e.printStackTrace(); } return model; } /* * Problems persist in this code so we left it commented out. The intent was * to create the methods necessary for custom serialization to allow for * forwards/backwards compatability of .elm files accross multiple versions * of this classifier. The main problem however is that these methods do not * appear to be called. I'm not sure what the problem is, but this would be * a great feature. If anyone is a seasoned veteran of this serialization * stuff, please help! * * private void writeObject(ObjectOutputStream stream) throws IOException { * //stream.defaultWriteObject(); //stream.writeObject(b); * * //first serialize the LibraryModel fields * * //super.writeObject(stream); * * //now serialize the LibraryModel fields * * stream.writeObject(m_Classifier); * * stream.writeObject(m_DescriptionText); * * stream.writeObject(m_ErrorText); * * stream.writeObject(new Boolean(m_OptionsWereValid)); * * stream.writeObject(m_StringRepresentation); * * stream.writeObject(m_models); * * * //now serialize the EnsembleLibraryModel fields //stream.writeObject(new * String("blah")); * * stream.writeObject(new Integer(m_seed)); * * stream.writeObject(m_checksum); * * stream.writeObject(new Double(m_validationRatio)); * * stream.writeObject(new Integer(m_folds)); * * stream.writeObject(m_fileName); * * stream.writeObject(new Boolean(m_isTrained)); * * * if (m_validationPredictions == null) { * } * * if (m_Debug) System.out.println("Saving * "+m_validationPredictions.length+" indexed array"); * stream.writeObject(m_validationPredictions); * } * * private void readObject(ObjectInputStream stream) throws IOException, * ClassNotFoundException { //stream.defaultReadObject(); //b = (String) * stream.readObject(); * * //super.readObject(stream); * * //deserialize the LibraryModel fields m_Classifier = * (Classifier)stream.readObject(); * * m_DescriptionText = (String)stream.readObject(); * * m_ErrorText = (String)stream.readObject(); * * m_OptionsWereValid = ((Boolean)stream.readObject()).booleanValue(); * * m_StringRepresentation = (String)stream.readObject(); * * * * //now deserialize the EnsembleLibraryModel fields m_models = * (Classifier[])stream.readObject(); * * m_seed = ((Integer)stream.readObject()).intValue(); * * m_checksum = (String)stream.readObject(); * * m_validationRatio = ((Double)stream.readObject()).doubleValue(); * * m_folds = ((Integer)stream.readObject()).intValue(); * * m_fileName = (String)stream.readObject(); * * m_isTrained = ((Boolean)stream.readObject()).booleanValue(); * * m_validationPredictions = (double[][])stream.readObject(); * * if (m_Debug) System.out.println("Loaded * "+m_validationPredictions.length+" indexed array"); } * */ /** * getter for validation predictions * * @return the current validation predictions */ public double[][] getValidationPredictions() { return m_validationPredictions; } /** * setter for validation predictions * * @param predictions the new validation predictions */ public void setValidationPredictions(double[][] predictions) { if (m_Debug) System.out.println("Saving validation array of size " + predictions.length); m_validationPredictions = new double[predictions.length][]; System.arraycopy(predictions, 0, m_validationPredictions, 0, predictions.length); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -