📄 cnnregressionmodel.java
字号:
} if (key.equals("Wts")) { if (!(obj instanceof Double[])) { throw new QSARModelException("The class of the 'Wts' object must be Double[]"); } } if (key.equals("mask")) { if (!(obj instanceof Boolean[])) { throw new QSARModelException("The class of the 'mask' object must be Boolean[]"); } } if (key.equals("linout") || key.equals("entropy") || key.equals("softmax") || key.equals("censored") || key.equals("skip") || key.equals("Hess") || key.equals("trace")) { if (!(obj instanceof Boolean)) { throw new QSARModelException("The class of the 'trace|skip|Hess|linout|entropy|softmax|censored' object must be Boolean"); } } if (key.equals("rang") || key.equals("decay") || key.equals("abstol") || key.equals("reltol")) { if (!(obj instanceof Double)) { throw new QSARModelException("The class of the 'reltol|abstol|decay|rang' object must be Double"); } } if (key.equals("maxit") || key.equals("MaxNWts")) { if (!(obj instanceof Integer)) { throw new QSARModelException("The class of the 'maxit|MaxNWts' object must be Integer"); } } if (key.equals("newdata")) { if (!(obj instanceof Double[][])) { throw new QSARModelException("The class of the 'newdata' object must be Double[][]"); } } params.put(key, obj); } /** * Fits a CNN regression model. * <p/> * This method calls the R function to fit a CNN regression model * to the specified dependent and independent variables. If an error * occurs in the R session, an exception is thrown. * <p/> * Note that, this method should be called prior to calling the various get * methods to obtain information regarding the fit. */ public void build() throws QSARModelException { Double[][] x; Double[][] y; x = (Double[][]) this.params.get("x"); y = (Double[][]) this.params.get("y"); if (x.length != y.length) throw new QSARModelException("Number of observations does not match number of rows in the design matrix"); if (nvar == 0) nvar = x[0].length; // lets build the model String paramVarName = loadParametersIntoRSession(); String cmd = "buildCNN(\"" + getModelName() + "\", " + paramVarName + ")"; REXP ret = rengine.eval(cmd); if (ret == null) { CNNRegressionModel.logger.debug("Error in buildCNN"); throw new QSARModelException("Error in buildCNN"); } // remove the parameter list rengine.eval("rm(" + paramVarName + ")"); // save the model object on the Java side modelObject = ret.asList(); } /** * Uses a fitted model to predict the response for new observations. * <p/> * This function uses a previously fitted model to obtain predicted values * for a new set of observations. If the model has not been fitted prior to this * call an exception will be thrown. Use <code>setParameters</code> * to set the values of the independent variable for the new observations and the * interval type. * * @throws org.openscience.cdk.qsar.model.QSARModelException * if the model has not been built prior to a call * to this method. Also if the number of independent variables specified for prediction * is not the same as specified during model building */ public void predict() throws QSARModelException { if (modelObject == null) throw new QSARModelException("Before calling predict() you must fit the model using build()"); Double[][] newx = (Double[][]) params.get("newdata"); if (newx[0].length != nvar) { throw new QSARModelException("Number of independent variables used for prediction must match those used for fitting"); } String pn = loadParametersIntoRSession(); REXP ret = rengine.eval("predicCNN(\"" + getModelName() + "\", " + pn + ")"); if (ret == null) throw new QSARModelException("Error occured in prediction"); // remove the parameter list rengine.eval("rm(" + pn + ")"); modelPredict = ret.asDoubleMatrix(); } /** * Get the matrix of predicted values obtained from <code>predict.nnet<code>. * * @return The result of the prediction. */ public double[][] getPredictions() { return modelPredict; } /** * Returns an <code>RList</code> object summarizing the nnet regression model. * <p/> * The return object can be queried via the <code>RList</code> methods to extract the * required components. * * @return A summary for the nnet regression model * @throws org.openscience.cdk.qsar.model.QSARModelException * if the model has not been built prior to a call * to this method */ public RList summary() throws QSARModelException { if (modelObject == null) throw new QSARModelException("Before calling summary() you must fit the model using build()"); REXP ret = rengine.eval("summary(" + getModelName() + ")"); if (ret == null) { logger.debug("Error in summary()"); throw new QSARModelException("Error in summary()"); } return ret.asList(); } /** * Loads a <code>'nnet'</code> object from disk in to the current session. * * @param fileName The disk file containing the model * @throws org.openscience.cdk.qsar.model.QSARModelException * if the model being loaded is not a <code>'nnet'</code> model * object or the file does not exist */ public void loadModel(String fileName) throws QSARModelException { File f = new File(fileName); if (!f.exists()) throw new QSARModelException(fileName + " does not exist"); rengine.assign("tmpFileName", fileName); REXP ret = rengine.eval("loadModel(tmpFileName)"); if (ret == null) throw new QSARModelException("Model could not be loaded"); String name = ret.asList().at("name").asString(); if (!isOfClass(name, "nnet")) { removeObject(name); throw new QSARModelException("Loaded object was not of class \'nnet\'"); } modelObject = ret.asList().at("model").asList(); setModelName(name); nvar = (int) getN()[0]; noutput = (int) getN()[2]; } /** * Loads a <code>'nnet'</code> object from a serialized string into the current session. * * @param serializedModel A String containing the serialized version of the model * @param modelName A String indicating the name of the model in the R session * @throws org.openscience.cdk.qsar.model.QSARModelException * if the model being loaded is not a <code>'nnet'</code> model * object */ public void loadModel(String serializedModel, String modelName) throws QSARModelException { rengine.assign("tmpSerializedModel", serializedModel); rengine.assign("tmpModelName", modelName); REXP ret = rengine.eval("unserializeModel(tmpSerializedModel, tmpModelName)"); if (ret == null) throw new QSARModelException("Model could not be unserialized"); String name = ret.asList().at("name").asString(); if (!isOfClass(name, "nnet")) { removeObject(name); throw new QSARModelException("Loaded object was not of class \'nnet\'"); } modelObject = ret.asList().at("model").asList(); setModelName(name); nvar = (int) getN()[0]; noutput = (int) getN()[2]; }// Autogenerated code: assumes that 'modelObject' is// a RList object /** * Gets the <code>censored</code> field of an <code>'nnet'</code> object. * * @return The value of the censored field */ public RBool getCensored() { return modelObject.at("censored").asBool(); } /** * Gets the <code>conn</code> field of an <code>'nnet'</code> object. * * @return The value of the conn field */ public double[] getConn() { return modelObject.at("conn").asDoubleArray(); } /** * Gets the <code>decay</code> field of an <code>'nnet'</code> object. * * @return The value of the decay field */ public double getDecay() { return modelObject.at("decay").asDouble(); } /** * Gets the <code>entropy</code> field of an <code>'nnet'</code> object. * * @return The value of the entropy field */ public RBool getEntropy() { return modelObject.at("entropy").asBool(); } /** * Gets the <code>fitted.values</code> field of an <code>'nnet'</code> object. * * @return The value of the fitted.values field */ public double[][] getFittedValues() { return modelObject.at("fitted.values").asDoubleMatrix(); } /** * Gets the <code>n</code> field of an <code>'nnet'</code> object. * * @return The value of the n field */ public double[] getN() { return modelObject.at("n").asDoubleArray(); } /** * Gets the <code>nconn</code> field of an <code>'nnet'</code> object. * * @return The value of the nconn field */ public double[] getNconn() { return modelObject.at("nconn").asDoubleArray(); } /** * Gets the <code>nsunits</code> field of an <code>'nnet'</code> object. * * @return The value of the nsunits field */ public double getNsunits() { return modelObject.at("nsunits").asDouble(); } /** * Gets the <code>nunits</code> field of an <code>'nnet'</code> object. * * @return The value of the nunits field */ public double getNunits() { return modelObject.at("nunits").asDouble(); } /** * Gets the <code>residuals</code> field of an <code>'nnet'</code> object. * * @return The value of the residuals field */ public double[][] getResiduals() { return modelObject.at("residuals").asDoubleMatrix(); } /** * Gets the <code>softmax</code> field of an <code>'nnet'</code> object. * * @return The value of the softmax field */ public RBool getSoftmax() { return modelObject.at("softmax").asBool(); } /** * Gets the <code>value</code> field of an <code>'nnet'</code> object. * * @return The value of the value field */ public double getValue() { return modelObject.at("value").asDouble(); } /** * Gets the <code>wts</code> field of an <code>'nnet'</code> object. * * @return The value of the wts field */ public double[] getWts() { return modelObject.at("wts").asDoubleArray(); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -