📄 bvdecomposesegcvsub.java
字号:
int [] segmentIndex = new int [ tps ];
for(int index = 0; index < segmentIndex.length; index++, currentDataIndex++){
segmentIndex[index] = currentDataIndex;
}
segmentList.add(segmentIndex);
}
}
int remainder = tps % k; // remainder is used to determine when to shrink the fold size by 1.
//foldSize = ROUNDUP( tps / k ) (round up, eg 3 -> 3, 3.3->4)
int foldSize = (int) Math.ceil( (double)tps /(double) k); //roundup fold size double to integer
int index = 0;
int currentIndex, endIndex;
for( int count = 0; count < k; count ++){
if( remainder != 0 && count == remainder ){
foldSize -= 1;
}
foldIndex[count][0] = index;
foldIndex[count][1] = foldSize;
index += foldSize;
}
for( int l = 0; l < m_ClassifyIterations; l++) {
for(int i = 1; i <= q; i++){
int [] currentSegment = (int[]) segmentList.get(i - 1);
randomize(currentSegment, random);
//CROSS FOLD VALIDATION for current Segment
for( int j = 1; j <= k; j++){
Instances TP = null;
for(int foldNum = 1; foldNum <= k; foldNum++){
if( foldNum != j){
int startFoldIndex = foldIndex[ foldNum - 1 ][ 0 ]; //start index
foldSize = foldIndex[ foldNum - 1 ][ 1 ];
int endFoldIndex = startFoldIndex + foldSize - 1;
for(int currentFoldIndex = startFoldIndex; currentFoldIndex <= endFoldIndex; currentFoldIndex++){
if( TP == null ){
TP = new Instances(data, currentSegment[ currentFoldIndex ], 1);
}else{
TP.add( data.instance( currentSegment[ currentFoldIndex ] ) );
}
}
}
}
TP.randomize(random);
if( getTrainSize() > TP.numInstances() ){
throw new Exception("The training set size of " + getTrainSize() + ", is greater than the training pool "
+ TP.numInstances() );
}
Instances train = new Instances(TP, 0, m_TrainSize);
Classifier current = Classifier.makeCopy(m_Classifier);
current.buildClassifier(train); // create a clssifier using the instances in train.
int currentTestIndex = foldIndex[ j - 1 ][ 0 ]; //start index
int testFoldSize = foldIndex[ j - 1 ][ 1 ]; //size
int endTestIndex = currentTestIndex + testFoldSize - 1;
while( currentTestIndex <= endTestIndex ){
Instance testInst = data.instance( currentSegment[currentTestIndex] );
int pred = (int)current.classifyInstance( testInst );
if(pred != testInst.classValue()) {
m_Error++; // add 1 to mis-classifications.
}
instanceProbs[ currentSegment[ currentTestIndex ] ][ pred ]++;
currentTestIndex++;
}
if( i == 1 && j == 1){
int[] segmentElast = (int[])segmentList.lastElement();
for( currentIndex = 0; currentIndex < segmentElast.length; currentIndex++){
Instance testInst = data.instance( segmentElast[currentIndex] );
int pred = (int)current.classifyInstance( testInst );
if(pred != testInst.classValue()) {
m_Error++; // add 1 to mis-classifications.
}
instanceProbs[ segmentElast[ currentIndex ] ][ pred ]++;
}
}
}
}
}
m_Error /= (double)( m_ClassifyIterations * data.numInstances() );
m_KWBias = 0.0;
m_KWVariance = 0.0;
m_KWSigma = 0.0;
m_WBias = 0.0;
m_WVariance = 0.0;
for (int i = 0; i < data.numInstances(); i++) {
Instance current = data.instance( i );
double [] predProbs = instanceProbs[ i ];
double pActual, pPred;
double bsum = 0, vsum = 0, ssum = 0;
double wBSum = 0, wVSum = 0;
Vector centralTendencies = findCentralTendencies( predProbs );
if( centralTendencies == null ){
throw new Exception("Central tendency was null.");
}
for (int j = 0; j < numClasses; j++) {
pActual = (current.classValue() == j) ? 1 : 0;
pPred = predProbs[j] / m_ClassifyIterations;
bsum += (pActual - pPred) * (pActual - pPred) - pPred * (1 - pPred) / (m_ClassifyIterations - 1);
vsum += pPred * pPred;
ssum += pActual * pActual;
}
m_KWBias += bsum;
m_KWVariance += (1 - vsum);
m_KWSigma += (1 - ssum);
for( int count = 0; count < centralTendencies.size(); count++ ) {
int wB = 0, wV = 0;
int centralTendency = ((Integer)centralTendencies.get(count)).intValue();
// For a single instance xi, find the bias and variance.
for (int j = 0; j < numClasses; j++) {
//Webb definition
if( j != (int)current.classValue() && j == centralTendency ) {
wB += predProbs[j];
}
if( j != (int)current.classValue() && j != centralTendency ) {
wV += predProbs[j];
}
}
wBSum += (double) wB;
wVSum += (double) wV;
}
// calculate bais by dividing bSum by the number of central tendencies and
// total number of instances. (effectively finding the average and dividing
// by the number of instances to get the nominalised probability).
m_WBias += ( wBSum / ((double) ( centralTendencies.size() * m_ClassifyIterations )));
// calculate variance by dividing vSum by the total number of interations
m_WVariance += ( wVSum / ((double) ( centralTendencies.size() * m_ClassifyIterations )));
}
m_KWBias /= (2.0 * (double) data.numInstances());
m_KWVariance /= (2.0 * (double) data.numInstances());
m_KWSigma /= (2.0 * (double) data.numInstances());
// bias = bias / number of data instances
m_WBias /= (double) data.numInstances();
// variance = variance / number of data instances.
m_WVariance /= (double) data.numInstances();
if (m_Debug) {
System.err.println("Decomposition finished");
}
}
/** Finds the central tendency, given the classifications for an instance.
*
* Where the central tendency is defined as the class that was most commonly
* selected for a given instance.<p>
*
* For example, instance 'x' may be classified out of 3 classes y = {1, 2, 3},
* so if x is classified 10 times, and is classified as follows, '1' = 2 times, '2' = 5 times
* and '3' = 3 times. Then the central tendency is '2'. <p>
*
* However, it is important to note that this method returns a list of all classes
* that have the highest number of classifications.
*
* In cases where there are several classes with the largest number of classifications, then
* all of these classes are returned. For example if 'x' is classified '1' = 4 times,
* '2' = 4 times and '3' = 2 times. Then '1' and '2' are returned.<p>
*
* @param predProbs the array of classifications for a single instance.
*
* @return a Vector containing Integer objects which store the class(s) which
* are the central tendency.
*/
public Vector findCentralTendencies(double[] predProbs) {
int centralTValue = 0;
int currentValue = 0;
//array to store the list of classes the have the greatest number of classifictions.
int index = 0;
Vector centralTClasses;
centralTClasses = new Vector(); //create an array with size of the number of classes.
// Go through array, finding the central tendency.
for( int i = 0; i < predProbs.length; i++) {
currentValue = (int) predProbs[i];
// if current value is greater than the central tendency value then
// clear vector and add new class to vector array.
if( currentValue > centralTValue) {
centralTClasses.clear();
centralTClasses.addElement( new Integer(i) );
centralTValue = currentValue;
} else if( currentValue != 0 && currentValue == centralTValue) {
centralTClasses.addElement( new Integer(i) );
}
}
//return all classes that have the greatest number of classifications.
if( centralTValue != 0){
return centralTClasses;
} else {
return null;
}
}
/**
* Returns description of the bias-variance decomposition results.
*
* @return the bias-variance decomposition results as a string
*/
public String toString() {
String result = "\nBias-Variance Decomposition Segmentation, Cross Validation\n" +
"with subsampling.\n";
if (getClassifier() == null) {
return "Invalid setup";
}
result += "\nClassifier : " + getClassifier().getClass().getName();
if (getClassifier() instanceof OptionHandler) {
result += Utils.joinOptions(((OptionHandler)m_Classifier).getOptions());
}
result += "\nData File : " + getDataFileName();
result += "\nClass Index : ";
if (getClassIndex() == 0) {
result += "last";
} else {
result += getClassIndex();
}
result += "\nIterations : " + getClassifyIterations();
result += "\np : " + getP();
result += "\nTraining Size : " + getTrainSize();
result += "\nSeed : " + getSeed();
result += "\n\nDefinition : " +"Kohavi and Wolpert";
result += "\nError :" + Utils.doubleToString(getError(), 4);
result += "\nBias^2 :" + Utils.doubleToString(getKWBias(), 4);
result += "\nVariance :" + Utils.doubleToString(getKWVariance(), 4);
result += "\nSigma^2 :" + Utils.doubleToString(getKWSigma(), 4);
result += "\n\nDefinition : " +"Webb";
result += "\nError :" + Utils.doubleToString(getError(), 4);
result += "\nBias :" + Utils.doubleToString(getWBias(), 4);
result += "\nVariance :" + Utils.doubleToString(getWVariance(), 4);
return result;
}
/**
* Test method for this class
*
* @param args the command line arguments
*/
public static void main(String [] args) {
try {
BVDecomposeSegCVSub bvd = new BVDecomposeSegCVSub();
try {
bvd.setOptions(args);
Utils.checkForRemainingOptions(args);
} catch (Exception ex) {
String result = ex.getMessage() + "\nBVDecompose Options:\n\n";
Enumeration enu = bvd.listOptions();
while (enu.hasMoreElements()) {
Option option = (Option) enu.nextElement();
result += option.synopsis() + "\n" + option.description() + "\n";
}
throw new Exception(result);
}
bvd.decompose();
System.out.println(bvd.toString());
} catch (Exception ex) {
System.err.println(ex.getMessage());
}
}
/**
* Accepts an array of ints and randomises the values in the array, using the
* random seed.
*
*@param index is the array of integers
*@param random is the Random seed.
*/
public final void randomize(int[] index, Random random) {
for( int j = index.length - 1; j > 0; j-- ){
int k = random.nextInt( j + 1 );
int temp = index[j];
index[j] = index[k];
index[k] = temp;
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -