⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 cnnregressionmodel.java

📁 化学图形处理软件
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
        }        if (key.equals("Wts")) {            if (!(obj instanceof Double[])) {                throw new QSARModelException("The class of the 'Wts' object must be Double[]");            }        }        if (key.equals("mask")) {            if (!(obj instanceof Boolean[])) {                throw new QSARModelException("The class of the 'mask' object must be Boolean[]");            }        }        if (key.equals("linout") ||                key.equals("entropy") ||                key.equals("softmax") ||                key.equals("censored") ||                key.equals("skip") ||                key.equals("Hess") ||                key.equals("trace")) {            if (!(obj instanceof Boolean)) {                throw new QSARModelException("The class of the 'trace|skip|Hess|linout|entropy|softmax|censored' object must be Boolean");            }        }        if (key.equals("rang") ||                key.equals("decay") ||                key.equals("abstol") ||                key.equals("reltol")) {            if (!(obj instanceof Double)) {                throw new QSARModelException("The class of the 'reltol|abstol|decay|rang' object must be Double");            }        }        if (key.equals("maxit") ||                key.equals("MaxNWts")) {            if (!(obj instanceof Integer)) {                throw new QSARModelException("The class of the 'maxit|MaxNWts' object must be Integer");            }        }        if (key.equals("newdata")) {            if (!(obj instanceof Double[][])) {                throw new QSARModelException("The class of the 'newdata' object must be Double[][]");            }        }        params.put(key, obj);    }    /**     * Fits a CNN regression model.     * <p/>     * This method calls the R function to fit a CNN regression model     * to the specified dependent and independent variables. If an error     * occurs in the R session, an exception is thrown.     * <p/>     * Note that, this method should be called prior to calling the various get     * methods to obtain information regarding the fit.     */    public void build() throws QSARModelException {        Double[][] x;        Double[][] y;        x = (Double[][]) this.params.get("x");        y = (Double[][]) this.params.get("y");        if (x.length != y.length)            throw new QSARModelException("Number of observations does not match number of rows in the design matrix");        if (nvar == 0) nvar = x[0].length;        // lets build the model        String paramVarName = loadParametersIntoRSession();        String cmd = "buildCNN(\"" + getModelName() + "\", " + paramVarName + ")";        REXP ret = rengine.eval(cmd);        if (ret == null) {            CNNRegressionModel.logger.debug("Error in buildCNN");            throw new QSARModelException("Error in buildCNN");        }        // remove the parameter list        rengine.eval("rm(" + paramVarName + ")");        // save the model object on the Java side        modelObject = ret.asList();    }    /**     * Uses a fitted model to predict the response for new observations.     * <p/>     * This function uses a previously fitted model to obtain predicted values     * for a new set of observations. If the model has not been fitted prior to this     * call an exception will be thrown. Use <code>setParameters</code>     * to set the values of the independent variable for the new observations and the     * interval type.     *     * @throws org.openscience.cdk.qsar.model.QSARModelException     *          if the model has not been built prior to a call     *          to this method. Also if the number of independent variables specified for prediction     *          is not the same as specified during model building     */    public void predict() throws QSARModelException {        if (modelObject == null)            throw new QSARModelException("Before calling predict() you must fit the model using build()");        Double[][] newx = (Double[][]) params.get("newdata");        if (newx[0].length != nvar) {            throw new QSARModelException("Number of independent variables used for prediction must match those used for fitting");        }        String pn = loadParametersIntoRSession();        REXP ret = rengine.eval("predicCNN(\"" + getModelName() + "\", " + pn + ")");        if (ret == null) throw new QSARModelException("Error occured in prediction");        // remove the parameter list        rengine.eval("rm(" + pn + ")");        modelPredict = ret.asDoubleMatrix();    }    /**     * Get the matrix of predicted values obtained from <code>predict.nnet<code>.     *     * @return The result of the prediction.     */    public double[][] getPredictions() {        return modelPredict;    }    /**     * Returns an <code>RList</code> object summarizing the nnet regression model.     * <p/>     * The return object can be queried via the <code>RList</code> methods to extract the     * required components.     *     * @return A summary for the nnet regression model     * @throws org.openscience.cdk.qsar.model.QSARModelException     *          if the model has not been built prior to a call     *          to this method     */    public RList summary() throws QSARModelException {        if (modelObject == null)            throw new QSARModelException("Before calling summary() you must fit the model using build()");        REXP ret = rengine.eval("summary(" + getModelName() + ")");        if (ret == null) {            logger.debug("Error in summary()");            throw new QSARModelException("Error in summary()");        }        return ret.asList();    }    /**     * Loads a <code>'nnet'</code> object from disk in to the current session.     *     * @param fileName The disk file containing the model     * @throws org.openscience.cdk.qsar.model.QSARModelException     *          if the model being loaded is not a <code>'nnet'</code> model     *          object  or the file does not exist     */    public void loadModel(String fileName) throws QSARModelException {        File f = new File(fileName);        if (!f.exists()) throw new QSARModelException(fileName + " does not exist");        rengine.assign("tmpFileName", fileName);        REXP ret = rengine.eval("loadModel(tmpFileName)");        if (ret == null) throw new QSARModelException("Model could not be loaded");        String name = ret.asList().at("name").asString();        if (!isOfClass(name, "nnet")) {            removeObject(name);            throw new QSARModelException("Loaded object was not of class \'nnet\'");        }        modelObject = ret.asList().at("model").asList();        setModelName(name);        nvar = (int) getN()[0];        noutput = (int) getN()[2];    }    /**     * Loads a  <code>'nnet'</code> object from a serialized string into the current session.     *     * @param serializedModel A String containing the serialized version of the model     * @param modelName       A String indicating the name of the model in the R session     * @throws org.openscience.cdk.qsar.model.QSARModelException     *          if the model being loaded is not a <code>'nnet'</code> model     *          object     */    public void loadModel(String serializedModel, String modelName) throws QSARModelException {        rengine.assign("tmpSerializedModel", serializedModel);        rengine.assign("tmpModelName", modelName);        REXP ret = rengine.eval("unserializeModel(tmpSerializedModel, tmpModelName)");        if (ret == null) throw new QSARModelException("Model could not be unserialized");        String name = ret.asList().at("name").asString();        if (!isOfClass(name, "nnet")) {            removeObject(name);            throw new QSARModelException("Loaded object was not of class \'nnet\'");        }        modelObject = ret.asList().at("model").asList();        setModelName(name);        nvar = (int) getN()[0];        noutput = (int) getN()[2];    }// Autogenerated code: assumes that 'modelObject' is// a RList object    /**     * Gets the <code>censored</code> field of an <code>'nnet'</code> object.     *     * @return The value of the censored field     */    public RBool getCensored() {        return modelObject.at("censored").asBool();    }    /**     * Gets the <code>conn</code> field of an <code>'nnet'</code> object.     *     * @return The value of the conn field     */    public double[] getConn() {        return modelObject.at("conn").asDoubleArray();    }    /**     * Gets the <code>decay</code> field of an <code>'nnet'</code> object.     *     * @return The value of the decay field     */    public double getDecay() {        return modelObject.at("decay").asDouble();    }    /**     * Gets the <code>entropy</code> field of an <code>'nnet'</code> object.     *     * @return The value of the entropy field     */    public RBool getEntropy() {        return modelObject.at("entropy").asBool();    }    /**     * Gets the <code>fitted.values</code> field of an <code>'nnet'</code> object.     *     * @return The value of the fitted.values field     */    public double[][] getFittedValues() {        return modelObject.at("fitted.values").asDoubleMatrix();    }    /**     * Gets the <code>n</code> field of an <code>'nnet'</code> object.     *     * @return The value of the n field     */    public double[] getN() {        return modelObject.at("n").asDoubleArray();    }    /**     * Gets the <code>nconn</code> field of an <code>'nnet'</code> object.     *     * @return The value of the nconn field     */    public double[] getNconn() {        return modelObject.at("nconn").asDoubleArray();    }    /**     * Gets the <code>nsunits</code> field of an <code>'nnet'</code> object.     *     * @return The value of the nsunits field     */    public double getNsunits() {        return modelObject.at("nsunits").asDouble();    }    /**     * Gets the <code>nunits</code> field of an <code>'nnet'</code> object.     *     * @return The value of the nunits field     */    public double getNunits() {        return modelObject.at("nunits").asDouble();    }    /**     * Gets the <code>residuals</code> field of an <code>'nnet'</code> object.     *     * @return The value of the residuals field     */    public double[][] getResiduals() {        return modelObject.at("residuals").asDoubleMatrix();    }    /**     * Gets the <code>softmax</code> field of an <code>'nnet'</code> object.     *     * @return The value of the softmax field     */    public RBool getSoftmax() {        return modelObject.at("softmax").asBool();    }    /**     * Gets the <code>value</code> field of an <code>'nnet'</code> object.     *     * @return The value of the value field     */    public double getValue() {        return modelObject.at("value").asDouble();    }    /**     * Gets the <code>wts</code> field of an <code>'nnet'</code> object.     *     * @return The value of the wts field     */    public double[] getWts() {        return modelObject.at("wts").asDoubleArray();    }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -