⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 cnnclassificationmodel.java

📁 化学图形处理软件
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
     *     * This function allows the caller to set the various parameters available     * for the      * <a href="http://www.maths.lth.se/help/R/.R/library/nnet/html/nnet.html" target="_top">nnet</a>     * and      * <a href="http://www.maths.lth.se/help/R/.R/library/nnet/html/predict.nnet.html" target="_top">predict.nnet</a>     * R routines. See the R help pages for the details of the available     * parameters.     *      * @param key A String containing the name of the parameter as described in the      * R help pages     * @param obj An Object containing the value of the parameter     * @throws QSARModelException if the type of the supplied value does not match the expected type     */    public void setParameters(String key, Object obj) throws QSARModelException {        // since we know the possible values of key we should check the coresponding        // objects and throw errors if required. Note that this checking can't really check        // for values (such as number of variables in the X matrix to build the model and the        // X matrix to make new predictions) - these should be checked in functions that will        // use these parameters. The main checking done here is for the class of obj and        // some cases where the value of obj is not dependent on what is set before it        if (key.equals("y")) {            if (!(obj instanceof String[][])) {                throw new QSARModelException("The class of the 'y' object must be String[][]");            } else {                noutput = ((String[][])obj)[0].length;            }        }        if (key.equals("x")) {            if (!(obj instanceof Double[][])) {                throw new QSARModelException("The class of the 'x' object must be Double[][]");            } else {                 nvar = ((Double[][])obj)[0].length;             }        }        if (key.equals("weights")) {            if (!(obj instanceof Double[])) {                throw new QSARModelException("The class of the 'weights' object must be Double[]");            }        }        if (key.equals("size")) {            if (!(obj instanceof Integer)) {                throw new QSARModelException("The class of the 'size' object must be Integer");            }        }        if (key.equals("subset")) {            if (!(obj instanceof Integer[])) {                throw new QSARModelException("The class of the 'size' object must be Integer[]");            }        }        if (key.equals("Wts")) {            if (!(obj instanceof Double[])) {                throw new QSARModelException("The class of the 'Wts' object must be Double[]");            }        }        if (key.equals("mask")) {            if (!(obj instanceof Boolean[])) {                throw new QSARModelException("The class of the 'mask' object must be Boolean[]");            }        }        if (key.equals("linout") ||             key.equals("entropy") ||            key.equals("softmax") ||            key.equals("censored") ||            key.equals("skip") ||             key.equals("Hess") ||            key.equals("trace")) {            if (!(obj instanceof Boolean)) {                throw new QSARModelException("The class of the 'trace|skip|Hess|linout|entropy|softmax|censored' object must be Boolean");            }        }        if (key.equals("rang") ||            key.equals("decay") ||            key.equals("abstol") ||            key.equals("reltol")) {            if (!(obj instanceof Double)) {                throw new QSARModelException("The class of the 'reltol|abstol|decay|rang' object must be Double");            }        }        if (key.equals("maxit") ||            key.equals("MaxNWts")) {            if (!(obj instanceof Integer)) {                throw new QSARModelException("The class of the 'maxit|MaxNWts' object must be Integer");            }        }        if (key.equals("newdata")) {            if ( !(obj instanceof Double[][])) {                throw new QSARModelException("The class of the 'newdata' object must be Double[][]");            }        }        this.params.put(key,obj);    }    /**     * Fits a CNN classification model.     *     * This method calls the R function to fit a CNN classification model     * to the specified dependent and independent variables. If an error     * occurs in the R session, an exception is thrown.     * <p>     * Note that, this method should be called prior to calling the various get     * methods to obtain information regarding the fit.     */    public void build() throws QSARModelException {        try {            this.modelfit = (CNNClassificationModelFit)revaluator.call("buildCNNClass",                     new Object[]{ getModelName(), this.params });        } catch (Exception re) {            throw new QSARModelException(re.toString());        }    }    /**     * Uses a fitted model to predict the response for new observations.     *     * This function uses a previously fitted model to obtain predicted values     * for a new set of observations. If the model has not been fitted prior to this     * call an exception will be thrown. Use <code>setParameters</code>     * to set the values of the independent variable for the new observations. You can also     * set the <code>type</code> argument (see <a href="http://www.maths.lth.se/help/R/.R/library/nnet/html/nnet.html" target="_top">here</a>).      * However, since this class performs CNN classification, the default setting (<code>type='raw'</code>) is sufficient.     *x      */    public void predict() throws QSARModelException {        if (this.modelfit == null)             throw new QSARModelException("Before calling predict() you must fit the model using build()");        Double[][] newx = (Double[][])this.params.get("newdata");        if (newx[0].length != this.nvar) {            throw new QSARModelException("Number of independent variables used for prediction must match those used for fitting");        }                    try {            this.modelpredict = (CNNClassificationModelPredict)revaluator.call("predictCNNClass",                    new Object[]{ getModelName(), this.params });        } catch (Exception re) {            throw new QSARModelException(re.toString());        }    }            /**     * Loads a CNNRegresionModel object from disk in to the current session.     *      *     * @param fileName The disk file containing the model     * @throws QSARModelException if the model that was loaded was not a CNNClassification     * model     */    public void loadModel(String fileName) throws QSARModelException {        // should probably check that the filename does exist        Object model = (Object)revaluator.call("loadModel", new Object[]{ (Object)fileName });        String modelName = (String)revaluator.call("loadModel.getName", new Object[] { (Object)fileName });        if (model.getClass().getName().equals("org.openscience.cdk.qsar.model.R.CNNClassificationModelFit")) {            this.modelfit = (CNNClassificationModelFit)model;            this.setModelName(modelName);            Double tmp = (Double)revaluator.eval(modelName+"$n[1]");            nvar = (int)tmp.doubleValue();        } else throw new QSARModelException("The loaded model was not a CNNClassificationModel");    }    /**     * Loads an CNNClassificationModel object from a serialized string into the current session.     *     * @param serializedModel A String containing the serialized version of the model     * @param modelName A String indicating the name of the model in the R session     * @throws QSARModelException if the model being loaded is not a CNN classification model     * object     */    public void  loadModel(String serializedModel, String modelName) throws QSARModelException {        // should probably check that the fileName does exist        Object model = (Object)revaluator.call("unserializeModel", new Object[]{ (Object)serializedModel, (Object)modelName });        String modelname = modelName;        if (model.getClass().getName().equals("org.openscience.cdk.qsar.model.R.CNNClassificationModelFit")) {            this.modelfit =(CNNClassificationModelFit)model;            this.setModelName(modelname);            Double tmp = (Double)revaluator.eval(modelname+"$n[1]");            nvar = (int)tmp.doubleValue();        } else throw new QSARModelException("The loaded model was not a CNNClassificationModel");    }    /**     * Gets final value of the fitting criteria.     *     * This method only returns meaningful results if the <code>build</code>     * method of this class has been previously called.     *      * @return  A double indicating the  value of the fitting criterion plus weight decay term.     */    public double getFitValue() {        return(this.modelfit.getValue());    }    /**     * Gets optimized weights for the model.     *     * This method only returns meaningful results if the <code>build</code>     * method of this class has been previously called.     *      * @return  A double[] containing the weights. The number of weights will be     * equal to <center>(Ni * Nh) + (Nh * No) + Nh + No</center> where Ni, Nh and No     * are the number of input, hidden and output neurons.     */    public double[] getFitWeights() {        return(this.modelfit.getWeights());    }    /**     * Gets fitted values from the final model.     *     * This method only returns meaningful results if the <code>build</code>     * method of this class has been previously called.     *      * @return  A double[][] containing the fitted values for each output neuron     * in the columns. Note that even if a single output neuron was specified during     * model building the return value is still a 2D array (with a single column).     */    public double[][] getFitFitted() {        return(this.modelfit.getFitted());    }    /**     * Gets residuals for the fitted values from the final model.     *     * This method only returns meaningful results if the <code>build</code>     * method of this class has been previously called.     *      * @return  A double[][] containing the residuals for each output neuron     * in the columns. Note that even if a single output neuron was specified during     * model building the return value is still a 2D array (with a single column).     */    public double[][] getFitResiduals() {        return(this.modelfit.getResiduals());    }    /**     * Gets the Hessian of the measure of fit.     *     * If the <code>Hess</code> option was set to TRUE before the call to build     * then the CNN routine will return the Hessian of the measure of fit at the best set of     * weights found.  * This method only returns meaningful results if the <code>build</code>     * method of this class has been previously called.     *      * @return  A double[][] containing the Hessian. It will be a square array     * with dimensions equal to the Nwt x Nwt, where Nwt is the total number of weights     * in the CNN model.     */    public double[][] getFitHessian() {        return(this.modelfit.getHessian());    }    /**     * Gets predicted values for new data using a previously built model.     *     * This method only returns meaningful results if the <code>build</code>     * method of this class has been previously called. Since this is a classification     * model the values represent the probability that an observation belongs to the given      * class.     *      * @return  A double[][] containing the predicted for each output neuron     * in the columns. Note that even if a single output neuron was specified during     * model building the return value is still a 2D array (with a single column).     *      */    public double[][] getPredictPredictedRaw() {        return(this.modelpredict.getPredictedRaw());    }        /**     * Gets predicted values for new data using a previously built model.     *     * This method only returns meaningful results if the <code>build</code>     * method of this class has been previously called. This function returns an     * array of Strings indicating the class assignments of the observations, rather than     * the raw probabilities.     *      * @return  A String[] containing the class assigned to each observation.     *      */    public String[] getPredictPredictedClass() {        return(this.modelpredict.getPredictedClass());    }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -