📄 expdt_tc2.java
字号:
Debug.dp(Debug.PROGRESS, "PROGRESS: Attribution of " + i + " complete."); } Debug.dp(Debug.PROGRESS, "PROGRESS: Attribution complete."); // Combine all data sources. For now, globals go in every // one. Combiner c = new Combiner(); ClassStreamAttValVecI[] trainAttsByClass = new ClassStreamAttValVec[numClasses]; ClassStreamAttValVecI[] testAttsByClass = new ClassStreamAttValVec[numClasses]; for(int i=0; i < numClasses; i++){ trainAttsByClass[i] = c.combine(trainGlobalData, trainEventAtts[i]); testAttsByClass[i] = c.combine(testGlobalData, testEventAtts[i]); } // Now we have to do some garbage collection. trainStreamData = null; testStreamData = null; eventClusterers = null; trainEventSEV = null; trainEventCV = null; clustersByClass = null; attribsByClass = null; System.gc(); // So now we have the raw data in the correct form for each // attributor. // And now, we can construct a learner for each case. // Well, for now, I'm going to do something completely crazy. // Let's run each classifier nonetheless over the whole data // ... and see what the hell happens. Maybe some voting scheme // is possible!! This is a strange form of ensemble // classifier. // Each naive bayes algorithm only gets one Debug.setDebugLevel(Debug.PROGRESS); int[][] selectedIndices = new int[numClasses][]; J48[] dtLearners = new J48[numClasses]; for(int i=0; i < numClasses; i++){ dtLearners[i] = new J48(); Debug.dp(Debug.PROGRESS, "PROGRESS: Beginning format conversion for class " + i); Instances data = WekaBridge.makeInstances(trainAttsByClass[i], "Train "+i); Debug.dp(Debug.PROGRESS, "PROGRESS: Conversion complete. Starting learning"); if(thisExp.featureSel){ Debug.dp(Debug.PROGRESS, "PROGRESS: Doing feature selection"); BestFirst bfs = new BestFirst(); CfsSubsetEval cfs = new CfsSubsetEval(); cfs.buildEvaluator(data); selectedIndices[i] = bfs.search(cfs, data); // Now extract the features. System.out.print("Selected features for class " + i + ": "); String featureString = new String(); for(int j=0; j < selectedIndices[i].length; j++){ featureString += (selectedIndices[i][j] +1)+ ","; } featureString += ("last"); System.out.println(featureString); // Now apply the filter. AttributeFilter af = new AttributeFilter(); af.setInvertSelection(true); af.setAttributeIndices(featureString); af.inputFormat(data); data = af.useFilter(data, af); } dtLearners[i].buildClassifier(data); Debug.dp(Debug.PROGRESS, "Learnt tree: \n" + dtLearners[i].toString()); } DTClassifier[] dtClassifiers = new DTClassifier[numClasses]; for(int i=0; i < numClasses; i++){ dtClassifiers[i] = new DTClassifier(dtLearners[i]); // System.out.println(nbClassifiers[i].toString()); } Debug.dp(Debug.PROGRESS, "PROGRESS: Learning complete. "); // Now test on training data (each one) /* for(int i=0; i < numClasses; i++){ String className = domDesc.getClassDescVec().getClassLabel(i); ClassificationVecI classvi = (ClassificationVecI) trainAttsByClass[i].getClassVec().clone(); StreamAttValVecI savvi = trainAttsByClass[i].getStreamAttValVec(); for(int j=0; j < trainAttsByClass[i].size(); j++){ nbClassifiers[i].classify(savvi.elAt(j), classvi.elAt(j)); } System.out.println(">>> Learner for class " + className); int numCorrect = 0; for(int j=0; j < classvi.size(); j++){ System.out.print(classvi.elAt(j).toString()); if(classvi.elAt(j).getRealClass() == classvi.elAt(j).getPredictedClass()){ numCorrect++; } } System.out.println("Train accuracy for " + className + " classifier: " + numCorrect + " of " + numTrainStreams + " (" + numCorrect*100.0/numTrainStreams + "%)"); } */ System.out.println(">>> Testing stage <<<"); // First, print the results of using the straight testers. ClassificationVecI[] classns = new ClassificationVecI[numClasses]; for(int i=0; i < numClasses; i++){ String className = domDesc.getClassDescVec().getClassLabel(i); classns[i] = (ClassificationVecI) testAttsByClass[i].getClassVec().clone(); StreamAttValVecI savvi = testAttsByClass[i].getStreamAttValVec(); Instances data = WekaBridge.makeInstances(testAttsByClass[i], "Test " + i); if(thisExp.featureSel){ String featureString = new String(); for(int j=0; j < selectedIndices[i].length; j++){ featureString += (selectedIndices[i][j]+1) + ","; } featureString += "last"; // Now apply the filter. AttributeFilter af = new AttributeFilter(); af.setInvertSelection(true); af.setAttributeIndices(featureString); af.inputFormat(data); data = af.useFilter(data, af); } for(int j=0; j < numTestStreams; j++){ dtClassifiers[i].classify(data.instance(j), classns[i].elAt(j)); } System.out.println(">>> Learner for class " + className); int numCorrect = 0; for(int j=0; j < numTestStreams; j++){ System.out.print(classns[i].elAt(j).toString()); if(classns[i].elAt(j).getRealClass() == classns[i].elAt(j).getPredictedClass()){ numCorrect++; } } System.out.println("Test accuracy for " + className + " classifier: " + numCorrect + " of " + numTestStreams + " (" + numCorrect*100.0/numTestStreams + "%)"); } // Now do voting. This is a hack solution. int numCorrect = 0; for(int i=0; i < numTestStreams; i++){ int[] votes = new int[numClasses]; int realClass = classns[0].elAt(i).getRealClass(); String realClassName = domDesc.getClassDescVec().getClassLabel(realClass); for(int j=0; j < numClasses; j++){ int thisPrediction = classns[j].elAt(i).getPredictedClass(); // if(thisPrediction == j){ // votes[thisPrediction] += 2; // } //else { votes[thisPrediction]++; //} } int maxIndex = -1; int maxVotes = 0; String voteRes = "[ "; for(int j=0; j <numClasses; j++){ voteRes += votes[j] + " "; if(votes[j] > maxVotes){ maxIndex = j; maxVotes = votes[j]; } } voteRes += "]"; // Now print the result: String predictedClassName = domDesc.getClassDescVec().getClassLabel(maxIndex); if(maxIndex == realClass){ System.out.println("Class " + realClassName + " CORRECTLY classified with " + maxVotes + " votes. Votes: " + voteRes); numCorrect++; } else { System.out.println("Class " + realClassName + " INCORRECTLY classified as " + predictedClassName + " with " + maxVotes + " votes. Votes: " + voteRes); } } System.out.println("Final voted accuracy: " + numCorrect + " of " + numTestStreams + " (" + numCorrect*100.0/numTestStreams + "%)"); } }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -