📄 cvparameterselection.java
字号:
/**
* Parses a given list of options. Valid options are:<p>
*
* -D <br>
* Turn on debugging output.<p>
*
* -W classname <br>
* Specify the full class name of classifier to perform cross-validation
* selection on.<p>
*
* -X num <br>
* Number of folds used for cross validation (default 10). <p>
*
* -S seed <br>
* Random number seed (default 1).<p>
*
* -P "N 1 5 10" <br>
* Sets an optimisation parameter for the classifier with name -N,
* lower bound 1, upper bound 5, and 10 optimisation steps.
* The upper bound may be the character 'A' or 'I' to substitute
* the number of attributes or instances in the training data,
* respectively.
* This parameter may be supplied more than once to optimise over
* several classifier options simultaneously. <p>
*
* Options after -- are passed to the designated sub-classifier. <p>
*
* @param options the list of options as an array of strings
* @exception Exception if an option is not supported
*/
public void setOptions(String[] options) throws Exception {
String foldsString = Utils.getOption('X', options);
if (foldsString.length() != 0) {
setNumFolds(Integer.parseInt(foldsString));
} else {
setNumFolds(10);
}
String cvParam;
m_CVParams = new FastVector();
do {
cvParam = Utils.getOption('P', options);
if (cvParam.length() != 0) {
addCVParameter(cvParam);
}
} while (cvParam.length() != 0);
super.setOptions(options);
}
/**
* Gets the current settings of the Classifier.
*
* @return an array of strings suitable for passing to setOptions
*/
public String [] getOptions() {
String[] superOptions;
if (m_InitOptions != null) {
try {
m_Classifier.setOptions((String[])m_InitOptions.clone());
superOptions = super.getOptions();
m_Classifier.setOptions((String[])m_BestClassifierOptions.clone());
} catch (Exception e) {
throw new RuntimeException("CVParameterSelection: could not set options " +
"in getOptions().");
}
} else {
superOptions = super.getOptions();
}
String [] options = new String [superOptions.length + m_CVParams.size() * 2 + 2];
int current = 0;
for (int i = 0; i < m_CVParams.size(); i++) {
options[current++] = "-P"; options[current++] = "" + getCVParameter(i);
}
options[current++] = "-X"; options[current++] = "" + getNumFolds();
System.arraycopy(superOptions, 0, options, current,
superOptions.length);
return options;
}
/**
* Generates the classifier.
*
* @param instances set of instances serving as training data
* @exception Exception if the classifier has not been generated successfully
*/
public void buildClassifier(Instances instances) throws Exception {
if (instances.checkForStringAttributes()) {
throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");
}
Instances trainData = new Instances(instances);
trainData.deleteWithMissingClass();
if (trainData.numInstances() == 0) {
throw new IllegalArgumentException("No training instances without " +
"missing class.");
}
if (trainData.numInstances() < m_NumFolds) {
throw new IllegalArgumentException("Number of training instances " +
"smaller than number of folds.");
}
if (!(m_Classifier instanceof OptionHandler)) {
throw new IllegalArgumentException("Base classifier should be OptionHandler.");
}
m_InitOptions = ((OptionHandler)m_Classifier).getOptions();
m_BestPerformance = -99;
m_NumAttributes = trainData.numAttributes();
Random random = new Random(m_Seed);
trainData.randomize(random);
m_TrainFoldSize = trainData.trainCV(m_NumFolds, 0, random).numInstances();
// Check whether there are any parameters to optimize
if (m_CVParams.size() == 0) {
m_Classifier.buildClassifier(trainData);
m_BestClassifierOptions = m_InitOptions;
return;
}
if (trainData.classAttribute().isNominal()) {
trainData.stratify(m_NumFolds);
}
m_BestClassifierOptions = null;
// Set up m_ClassifierOptions -- take getOptions() and remove
// those being optimised.
m_ClassifierOptions = ((OptionHandler)m_Classifier).getOptions();
for (int i = 0; i < m_CVParams.size(); i++) {
Utils.getOption(((CVParameter)m_CVParams.elementAt(i)).m_ParamChar,
m_ClassifierOptions);
}
findParamsByCrossValidation(0, trainData, random);
String [] options = (String [])m_BestClassifierOptions.clone();
((OptionHandler)m_Classifier).setOptions(options);
m_Classifier.buildClassifier(trainData);
}
/**
* Predicts the class distribution for the given test instance.
*
* @param instance the instance to be classified
* @return the predicted class value
* @exception Exception if an error occurred during the prediction
*/
public double[] distributionForInstance(Instance instance) throws Exception {
return m_Classifier.distributionForInstance(instance);
}
/**
* Adds a scheme parameter to the list of parameters to be set
* by cross-validation
*
* @param cvParam the string representation of a scheme parameter. The
* format is: <br>
* param_char lower_bound upper_bound increment <br>
* eg to search a parameter -P from 1 to 10 by increments of 2: <br>
* P 1 10 2 <br>
* @exception Exception if the parameter specifier is of the wrong format
*/
public void addCVParameter(String cvParam) throws Exception {
CVParameter newCV = new CVParameter(cvParam);
m_CVParams.addElement(newCV);
}
/**
* Gets the scheme paramter with the given index.
*/
public String getCVParameter(int index) {
if (m_CVParams.size() <= index) {
return "";
}
return ((CVParameter)m_CVParams.elementAt(index)).toString();
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String CVParametersTipText() {
return "Sets the scheme parameters which are to be set "+
"by cross-validation.\n"+
"The format for each string should be:\n"+
"param_char lower_bound upper_bound increment\n"+
"eg to search a parameter -P from 1 to 10 by increments of 2:\n"+
" \"P 1 10 2\" ";
}
/**
* Get method for CVParameters.
*/
public Object[] getCVParameters() {
Object[] CVParams = m_CVParams.toArray();
String params[] = new String[CVParams.length];
for(int i=0; i<CVParams.length; i++)
params[i] = CVParams[i].toString();
return params;
}
/**
* Set method for CVParameters.
*/
public void setCVParameters(Object[] params) throws Exception {
FastVector backup = m_CVParams;
m_CVParams = new FastVector();
for(int i=0; i<params.length; i++) {
try{
addCVParameter((String)params[i]);
}
catch(Exception ex) { m_CVParams = backup; throw ex; }
}
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String numFoldsTipText() {
return "Get the number of folds used for cross-validation.";
}
/**
* Gets the number of folds for the cross-validation.
*
* @return the number of folds for the cross-validation
*/
public int getNumFolds() {
return m_NumFolds;
}
/**
* Sets the number of folds for the cross-validation.
*
* @param numFolds the number of folds for the cross-validation
* @exception Exception if parameter illegal
*/
public void setNumFolds(int numFolds) throws Exception {
if (numFolds < 0) {
throw new IllegalArgumentException("Stacking: Number of cross-validation " +
"folds must be positive.");
}
m_NumFolds = numFolds;
}
/**
* Returns the type of graph this classifier
* represents.
*/
public int graphType() {
if (m_Classifier instanceof Drawable)
return ((Drawable)m_Classifier).graphType();
else
return Drawable.NOT_DRAWABLE;
}
/**
* Returns graph describing the classifier (if possible).
*
* @return the graph of the classifier in dotty format
* @exception Exception if the classifier cannot be graphed
*/
public String graph() throws Exception {
if (m_Classifier instanceof Drawable)
return ((Drawable)m_Classifier).graph();
else throw new Exception("Classifier: " +
m_Classifier.getClass().getName() + " " +
Utils.joinOptions(m_BestClassifierOptions)
+ " cannot be graphed");
}
/**
* Returns description of the cross-validated classifier.
*
* @return description of the cross-validated classifier as a string
*/
public String toString() {
if (m_InitOptions == null)
return "CVParameterSelection: No model built yet.";
String result = "Cross-validated Parameter selection.\n"
+ "Classifier: " + m_Classifier.getClass().getName() + "\n";
try {
for (int i = 0; i < m_CVParams.size(); i++) {
CVParameter cvParam = (CVParameter)m_CVParams.elementAt(i);
result += "Cross-validation Parameter: '-"
+ cvParam.m_ParamChar + "'"
+ " ranged from " + cvParam.m_Lower
+ " to ";
switch ((int)(cvParam.m_Lower - cvParam.m_Upper + 0.5)) {
case 1:
result += m_NumAttributes;
break;
case 2:
result += m_TrainFoldSize;
break;
default:
result += cvParam.m_Upper;
break;
}
result += " with " + cvParam.m_Steps + " steps\n";
}
} catch (Exception ex) {
result += ex.getMessage();
}
result += "Classifier Options: "
+ Utils.joinOptions(m_BestClassifierOptions)
+ "\n\n" + m_Classifier.toString();
return result;
}
public String toSummaryString() {
String result = "Selected values: "
+ Utils.joinOptions(m_BestClassifierOptions);
return result + '\n';
}
/**
* Main method for testing this class.
*
* @param argv the options
*/
public static void main(String [] argv) {
try {
System.out.println(Evaluation.evaluateModel(new CVParameterSelection(),
argv));
} catch (Exception e) {
System.err.println(e.getMessage());
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -