📄 cnnclassificationmodel.java
字号:
* * This function allows the caller to set the various parameters available * for the * <a href="http://www.maths.lth.se/help/R/.R/library/nnet/html/nnet.html" target="_top">nnet</a> * and * <a href="http://www.maths.lth.se/help/R/.R/library/nnet/html/predict.nnet.html" target="_top">predict.nnet</a> * R routines. See the R help pages for the details of the available * parameters. * * @param key A String containing the name of the parameter as described in the * R help pages * @param obj An Object containing the value of the parameter * @throws QSARModelException if the type of the supplied value does not match the expected type */ public void setParameters(String key, Object obj) throws QSARModelException { // since we know the possible values of key we should check the coresponding // objects and throw errors if required. Note that this checking can't really check // for values (such as number of variables in the X matrix to build the model and the // X matrix to make new predictions) - these should be checked in functions that will // use these parameters. The main checking done here is for the class of obj and // some cases where the value of obj is not dependent on what is set before it if (key.equals("y")) { if (!(obj instanceof String[][])) { throw new QSARModelException("The class of the 'y' object must be String[][]"); } else { noutput = ((String[][])obj)[0].length; } } if (key.equals("x")) { if (!(obj instanceof Double[][])) { throw new QSARModelException("The class of the 'x' object must be Double[][]"); } else { nvar = ((Double[][])obj)[0].length; } } if (key.equals("weights")) { if (!(obj instanceof Double[])) { throw new QSARModelException("The class of the 'weights' object must be Double[]"); } } if (key.equals("size")) { if (!(obj instanceof Integer)) { throw new QSARModelException("The class of the 'size' object must be Integer"); } } if (key.equals("subset")) { if (!(obj instanceof Integer[])) { throw new QSARModelException("The class of the 'size' object must be Integer[]"); } } if (key.equals("Wts")) { if (!(obj instanceof Double[])) { throw new QSARModelException("The class of the 'Wts' object must be Double[]"); } } if (key.equals("mask")) { if (!(obj instanceof Boolean[])) { throw new QSARModelException("The class of the 'mask' object must be Boolean[]"); } } if (key.equals("linout") || key.equals("entropy") || key.equals("softmax") || key.equals("censored") || key.equals("skip") || key.equals("Hess") || key.equals("trace")) { if (!(obj instanceof Boolean)) { throw new QSARModelException("The class of the 'trace|skip|Hess|linout|entropy|softmax|censored' object must be Boolean"); } } if (key.equals("rang") || key.equals("decay") || key.equals("abstol") || key.equals("reltol")) { if (!(obj instanceof Double)) { throw new QSARModelException("The class of the 'reltol|abstol|decay|rang' object must be Double"); } } if (key.equals("maxit") || key.equals("MaxNWts")) { if (!(obj instanceof Integer)) { throw new QSARModelException("The class of the 'maxit|MaxNWts' object must be Integer"); } } if (key.equals("newdata")) { if ( !(obj instanceof Double[][])) { throw new QSARModelException("The class of the 'newdata' object must be Double[][]"); } } this.params.put(key,obj); } /** * Fits a CNN classification model. * * This method calls the R function to fit a CNN classification model * to the specified dependent and independent variables. If an error * occurs in the R session, an exception is thrown. * <p> * Note that, this method should be called prior to calling the various get * methods to obtain information regarding the fit. */ public void build() throws QSARModelException { try { this.modelfit = (CNNClassificationModelFit)revaluator.call("buildCNNClass", new Object[]{ getModelName(), this.params }); } catch (Exception re) { throw new QSARModelException(re.toString()); } } /** * Uses a fitted model to predict the response for new observations. * * This function uses a previously fitted model to obtain predicted values * for a new set of observations. If the model has not been fitted prior to this * call an exception will be thrown. Use <code>setParameters</code> * to set the values of the independent variable for the new observations. You can also * set the <code>type</code> argument (see <a href="http://www.maths.lth.se/help/R/.R/library/nnet/html/nnet.html" target="_top">here</a>). * However, since this class performs CNN classification, the default setting (<code>type='raw'</code>) is sufficient. *x */ public void predict() throws QSARModelException { if (this.modelfit == null) throw new QSARModelException("Before calling predict() you must fit the model using build()"); Double[][] newx = (Double[][])this.params.get("newdata"); if (newx[0].length != this.nvar) { throw new QSARModelException("Number of independent variables used for prediction must match those used for fitting"); } try { this.modelpredict = (CNNClassificationModelPredict)revaluator.call("predictCNNClass", new Object[]{ getModelName(), this.params }); } catch (Exception re) { throw new QSARModelException(re.toString()); } } /** * Loads a CNNRegresionModel object from disk in to the current session. * * * @param fileName The disk file containing the model * @throws QSARModelException if the model that was loaded was not a CNNClassification * model */ public void loadModel(String fileName) throws QSARModelException { // should probably check that the filename does exist Object model = (Object)revaluator.call("loadModel", new Object[]{ (Object)fileName }); String modelName = (String)revaluator.call("loadModel.getName", new Object[] { (Object)fileName }); if (model.getClass().getName().equals("org.openscience.cdk.qsar.model.R.CNNClassificationModelFit")) { this.modelfit = (CNNClassificationModelFit)model; this.setModelName(modelName); Double tmp = (Double)revaluator.eval(modelName+"$n[1]"); nvar = (int)tmp.doubleValue(); } else throw new QSARModelException("The loaded model was not a CNNClassificationModel"); } /** * Loads an CNNClassificationModel object from a serialized string into the current session. * * @param serializedModel A String containing the serialized version of the model * @param modelName A String indicating the name of the model in the R session * @throws QSARModelException if the model being loaded is not a CNN classification model * object */ public void loadModel(String serializedModel, String modelName) throws QSARModelException { // should probably check that the fileName does exist Object model = (Object)revaluator.call("unserializeModel", new Object[]{ (Object)serializedModel, (Object)modelName }); String modelname = modelName; if (model.getClass().getName().equals("org.openscience.cdk.qsar.model.R.CNNClassificationModelFit")) { this.modelfit =(CNNClassificationModelFit)model; this.setModelName(modelname); Double tmp = (Double)revaluator.eval(modelname+"$n[1]"); nvar = (int)tmp.doubleValue(); } else throw new QSARModelException("The loaded model was not a CNNClassificationModel"); } /** * Gets final value of the fitting criteria. * * This method only returns meaningful results if the <code>build</code> * method of this class has been previously called. * * @return A double indicating the value of the fitting criterion plus weight decay term. */ public double getFitValue() { return(this.modelfit.getValue()); } /** * Gets optimized weights for the model. * * This method only returns meaningful results if the <code>build</code> * method of this class has been previously called. * * @return A double[] containing the weights. The number of weights will be * equal to <center>(Ni * Nh) + (Nh * No) + Nh + No</center> where Ni, Nh and No * are the number of input, hidden and output neurons. */ public double[] getFitWeights() { return(this.modelfit.getWeights()); } /** * Gets fitted values from the final model. * * This method only returns meaningful results if the <code>build</code> * method of this class has been previously called. * * @return A double[][] containing the fitted values for each output neuron * in the columns. Note that even if a single output neuron was specified during * model building the return value is still a 2D array (with a single column). */ public double[][] getFitFitted() { return(this.modelfit.getFitted()); } /** * Gets residuals for the fitted values from the final model. * * This method only returns meaningful results if the <code>build</code> * method of this class has been previously called. * * @return A double[][] containing the residuals for each output neuron * in the columns. Note that even if a single output neuron was specified during * model building the return value is still a 2D array (with a single column). */ public double[][] getFitResiduals() { return(this.modelfit.getResiduals()); } /** * Gets the Hessian of the measure of fit. * * If the <code>Hess</code> option was set to TRUE before the call to build * then the CNN routine will return the Hessian of the measure of fit at the best set of * weights found. * This method only returns meaningful results if the <code>build</code> * method of this class has been previously called. * * @return A double[][] containing the Hessian. It will be a square array * with dimensions equal to the Nwt x Nwt, where Nwt is the total number of weights * in the CNN model. */ public double[][] getFitHessian() { return(this.modelfit.getHessian()); } /** * Gets predicted values for new data using a previously built model. * * This method only returns meaningful results if the <code>build</code> * method of this class has been previously called. Since this is a classification * model the values represent the probability that an observation belongs to the given * class. * * @return A double[][] containing the predicted for each output neuron * in the columns. Note that even if a single output neuron was specified during * model building the return value is still a 2D array (with a single column). * */ public double[][] getPredictPredictedRaw() { return(this.modelpredict.getPredictedRaw()); } /** * Gets predicted values for new data using a previously built model. * * This method only returns meaningful results if the <code>build</code> * method of this class has been previously called. This function returns an * array of Strings indicating the class assignments of the observations, rather than * the raw probabilities. * * @return A String[] containing the class assigned to each observation. * */ public String[] getPredictPredictedClass() { return(this.modelpredict.getPredictedClass()); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -