📄 bvdecomposesegcvsub.java
字号:
if (getDataFileName() != null) {
options[current++] = "-t"; options[current++] = "" + getDataFileName();
}
options[current++] = "-T"; options[current++] = "" + getTrainSize();
if (getClassifier() != null) {
options[current++] = "-W";
options[current++] = getClassifier().getClass().getName();
}
options[current++] = "--";
System.arraycopy(classifierOptions, 0, options, current,
classifierOptions.length);
current += classifierOptions.length;
while (current < options.length) {
options[current++] = "";
}
return options;
}
/**
* Set the classifiers being analysed
*
* @param newClassifier the Classifier to use.
*/
public void setClassifier(Classifier newClassifier) {
m_Classifier = newClassifier;
}
/**
* Gets the name of the classifier being analysed
*
* @return the classifier being analysed.
*/
public Classifier getClassifier() {
return m_Classifier;
}
/**
* Sets debugging mode
*
* @param debug true if debug output should be printed
*/
public void setDebug(boolean debug) {
m_Debug = debug;
}
/**
* Gets whether debugging is turned on
*
* @return true if debugging output is on
*/
public boolean getDebug() {
return m_Debug;
}
/**
* Sets the random number seed
*/
public void setSeed(int seed) {
m_Seed = seed;
}
/**
* Gets the random number seed
*
* @return the random number seed
*/
public int getSeed() {
return m_Seed;
}
/**
* Sets the number of times an instance is classified
*
* @param classifyIterations number of times an instance is classified
*/
public void setClassifyIterations(int classifyIterations) {
m_ClassifyIterations = classifyIterations;
}
/**
* Gets the number of times an instance is classified
*
* @return the maximum number of times an instance is classified
*/
public int getClassifyIterations() {
return m_ClassifyIterations;
}
/**
* Sets the name of the dataset file.
*
* @param dataFileName name of dataset file.
*/
public void setDataFileName(String dataFileName) {
m_DataFileName = dataFileName;
}
/**
* Get the name of the data file used for the decomposition
*
* @return the name of the data file
*/
public String getDataFileName() {
return m_DataFileName;
}
/**
* Get the index (starting from 1) of the attribute used as the class.
*
* @return the index of the class attribute
*/
public int getClassIndex() {
return m_ClassIndex + 1;
}
/**
* Sets index of attribute to discretize on
*
* @param classIndex the index (starting from 1) of the class attribute
*/
public void setClassIndex(int classIndex) {
m_ClassIndex = classIndex - 1;
}
/**
* Get the calculated bias squared according to the Kohavi and Wolpert definition
*
* @return the bias squared
*/
public double getKWBias() {
return m_KWBias;
}
/**
* Get the calculated bias according to the Webb definition
*
* @return the bias
*
*/
public double getWBias() {
return m_WBias;
}
/**
* Get the calculated variance according to the Kohavi and Wolpert definition
*
* @return the variance
*/
public double getKWVariance() {
return m_KWVariance;
}
/**
* Get the calculated variance according to the Webb definition
*
* @return
*
*/
public double getWVariance() {
return m_WVariance;
}
/**
* Get the calculated sigma according to the Kohavi and Wolpert definition
*
* @return the sigma
*
*/
public double getKWSigma() {
return m_KWSigma;
}
/**
* Set the training size.
*
* @param size the size of the training set
*
*/
public void setTrainSize(int size) {
m_TrainSize = size;
}
/**
* Get the training size
*
* @return the size of the training set
*
*/
public int getTrainSize() {
return m_TrainSize;
}
/**
* Set the proportion of instances that are common between two training sets
* used to train a classifier.
*
* @param proportion the proportion of instances that are common between training
* sets.
*
*/
public void setP(double proportion) {
m_P = proportion;
}
/**
* Get the proportion of instances that are common between two training sets.
*
* @return the proportion
*
*/
public double getP() {
return m_P;
}
/**
* Get the calculated error rate
*
* @return the error rate
*/
public double getError() {
return m_Error;
}
/**
* Carry out the bias-variance decomposition using the sub-sampled cross-validation method.
*
* @exception Exception if the decomposition couldn't be carried out
*/
public void decompose() throws Exception {
Reader dataReader;
Instances data;
int tps; // training pool size, size of segment E.
int k; // number of folds in segment E.
int q; // number of segments of size tps.
dataReader = new BufferedReader(new FileReader(m_DataFileName)); //open file
data = new Instances(dataReader); // encapsulate in wrapper class called weka.Instances()
if (m_ClassIndex < 0) {
data.setClassIndex(data.numAttributes() - 1);
} else {
data.setClassIndex(m_ClassIndex);
}
if (data.classAttribute().type() != Attribute.NOMINAL) {
throw new Exception("Class attribute must be nominal");
}
int numClasses = data.numClasses();
data.deleteWithMissingClass();
if ( data.checkForStringAttributes() ) {
throw new Exception("Can't handle string attributes!");
}
// Dataset size must be greater than 2
if ( data.numInstances() <= 2 ){
throw new Exception("Dataset size must be greater than 2.");
}
if ( m_TrainSize == -1 ){ // default value
m_TrainSize = (int) Math.floor( (double) data.numInstances() / 2.0 );
}else if ( m_TrainSize < 0 || m_TrainSize >= data.numInstances() - 1 ) { // Check if 0 < training Size < D - 1
throw new Exception("Training set size of "+m_TrainSize+" is invalid.");
}
if ( m_P == -1 ){ // default value
m_P = (double) m_TrainSize / ( (double)data.numInstances() - 1 );
}else if ( m_P < ( m_TrainSize / ( (double)data.numInstances() - 1 ) ) || m_P >= 1.0 ) { //Check if p is in range: m/(|D|-1) <= p < 1.0
throw new Exception("Proportion is not in range: "+ (m_TrainSize / ((double) data.numInstances() - 1 )) +" <= p < 1.0 ");
}
//roundup tps from double to integer
tps = (int) Math.ceil( ((double)m_TrainSize / (double)m_P) + 1 );
k = (int) Math.ceil( tps / (tps - (double) m_TrainSize));
// number of folds cannot be more than the number of instances in the training pool
if ( k > tps ) {
throw new Exception("The required number of folds is too many."
+ "Change p or the size of the training set.");
}
// calculate the number of segments, round down.
q = (int) Math.floor( (double) data.numInstances() / (double)tps );
//create confusion matrix, columns = number of instances in data set, as all will be used, by rows = number of classes.
double [][] instanceProbs = new double [data.numInstances()][numClasses];
int [][] foldIndex = new int [ k ][ 2 ];
Vector segmentList = new Vector(q + 1);
//Set random seed
Random random = new Random(m_Seed);
data.randomize(random);
//create index arrays for different segments
int currentDataIndex = 0;
for( int count = 1; count <= (q + 1); count++ ){
if( count > q){
int [] segmentIndex = new int [ (data.numInstances() - (q * tps)) ];
for(int index = 0; index < segmentIndex.length; index++, currentDataIndex++){
segmentIndex[index] = currentDataIndex;
}
segmentList.add(segmentIndex);
} else {
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -