📄 rakel.java
字号:
for(int i = 0; i < testData.numInstances(); i++)
{
Instance instance = testData.instance(i);
Prediction result = updatePrediction(instance, i, model);
// Prediction result = makePrediction(instance);
//System.out.println(java.util.Arrays.toString(result.getConfidences()));
for(int j = 0; j < numLabels; j++)
{
int classIdx = testData.numAttributes() - numLabels + j;
boolean actual = Utils.eq(1, instance.value(classIdx));
predictions[i][j] = new BinaryPrediction(
result.getPrediction(j),
actual,
result.getConfidence(j));
}
}
}
public void buildClassifier(Instances trainData) throws Exception {
if (cvParamSelection) {
paramSelectionViaCV(trainData);
System.out.println("Selected Parameters\n" +
"Subset size : " + getSizeOfSubset() +
"Number of models: " + getNumModels() +
"Threshold : " + getThreshold());
}
// need a structure to hold different combinations
combinations = new HashSet<String>();
for (int i=0; i<numOfModels; i++)
updateClassifier(trainData, i);
}
public void updateClassifier(Instances trainData, int model) throws Exception {
if (combinations == null)
combinations = new HashSet<String>();
Random rnd = new Random();
// --select a random subset of classes not seen before
boolean[] selected;
do {
selected = new boolean[numLabels];
for (int j=0; j<sizeOfSubset; j++) {
int randomLabel;
randomLabel = Math.abs(rnd.nextInt() % numLabels);
while (selected[randomLabel] != false) {
randomLabel = Math.abs(rnd.nextInt() % numLabels);
}
selected[randomLabel] = true;
//System.out.println("label: " + randomLabel);
classIndicesPerSubset[model][j] = randomLabel;
}
Arrays.sort(classIndicesPerSubset[model]);
} while (combinations.add(Arrays.toString(classIndicesPerSubset[model])) == false);
System.out.println("Building model " + model + ", subset: " + Arrays.toString(classIndicesPerSubset[model]));
// --remove the unselected labels
int numPredictors = trainData.numAttributes()-numLabels;
absoluteIndicesToRemove[model] = new int[numLabels-sizeOfSubset];
int k=0;
for (int j=0; j<numLabels; j++)
if (selected[j] == false) {
absoluteIndicesToRemove[model][k] = numPredictors+j;
k++;
}
Remove remove = new Remove();
remove.setAttributeIndicesArray(absoluteIndicesToRemove[model]);
remove.setInputFormat(trainData);
remove.setInvertSelection(false);
Instances trainSubset = Filter.useFilter(trainData, remove);
//System.out.println(trainSubset.toSummaryString());
// build a LabelPowersetClassifier for the selected label subset;
subsetClassifiers[model] = new LabelPowersetClassifier(Classifier.makeCopy(getBaseClassifier()), sizeOfSubset);
subsetClassifiers[model].buildClassifier(trainSubset);
// keep the header of the training data for testing
trainSubset.delete();
metadataTest[model] = trainSubset;
}
public Prediction updatePrediction(Instance instance, int instanceNumber, int model) throws Exception {
int numPredictors = instance.numAttributes()-numLabels;
// transform instance
//// new2 solution
Instance newInstance;
if (instance instanceof SparseInstance) {
newInstance = new SparseInstance(instance);
for (int i=1; i<numLabels-sizeOfSubset; i++)
newInstance.deleteAttributeAt(newInstance.numAttributes());
} else {
double[] vals = new double[numPredictors+sizeOfSubset];
for (int j=0; j<vals.length-sizeOfSubset; j++)
vals[j] = instance.value(j);
newInstance = new Instance(instance.weight(), vals);
}
//// new solution
/*
double[] vals = new double[numPredictors+sizeOfSubset];
for (int j=0; j<vals.length-sizeOfSubset; j++)
vals[j] = instance.value(j);
Instance newInstance = (instance instanceof SparseInstance)
? new SparseInstance(instance.weight(), vals)
: new Instance(instance.weight(), vals);
*/
//// old solution
/*
Instance newInstance = new Instance(numPredictors+sizeOfSubset);
for (int j=0; j<newInstance.numAttributes(); j++)
newInstance.setValue(j, instance.value(j));
*/
newInstance.setDataset(metadataTest[model]);
double[] predictions = subsetClassifiers[model].makePrediction(newInstance).getPredictedLabels();
for (int j=0; j<sizeOfSubset; j++) {
sumVotesIncremental[instanceNumber][classIndicesPerSubset[model][j]] += predictions[j];
lengthVotesIncremental[instanceNumber][classIndicesPerSubset[model][j]]++;
}
/*
for (int i=0; i<numLabels; i++)
System.out.print(instance.value(numPredictors+i) + " ");
System.out.println("");
System.out.println(Arrays.toString(sumVotesIncremental[instanceNumber]));
System.out.println(Arrays.toString(lengthVotesIncremental[instanceNumber]));
//*/
double[] confidence = new double[numLabels];
double[] labels = new double[numLabels];
for (int i=0; i<numLabels; i++) {
confidence[i] = sumVotesIncremental[instanceNumber][i]/lengthVotesIncremental[instanceNumber][i];
if (confidence[i] >= 0.5)
labels[i] = 1;
else
labels[i] = 0;
}
Prediction pred = new Prediction(labels, confidence);
return pred;
}
public Prediction makePrediction(Instance instance) throws Exception {
int numPredictors = instance.numAttributes()-numLabels;
Arrays.fill(sumVotes, 0);
Arrays.fill(lengthVotes, 0);
for (int i=0; i<numOfModels; i++) {
if (subsetClassifiers[i] == null)
continue;
// transform instance
//// new solution
double[] vals = new double[numPredictors+sizeOfSubset];
for (int j=0; j<vals.length-sizeOfSubset; j++)
vals[j] = instance.value(j);
Instance newInstance = (instance instanceof SparseInstance)
? new SparseInstance(instance.weight(), vals)
: new Instance(instance.weight(), vals);
//// old solution
/*
//System.out.println("old instance: " + instance.toString());
Instance newInstance = new Instance(numPredictors+sizeOfSubset);
for (int j=0; j<newInstance.numAttributes(); j++)
newInstance.setValue(j, instance.value(j));
//*/
newInstance.setDataset(metadataTest[i]);
//System.out.println("new instance: " + newInstance.toString());
double[] predictions = subsetClassifiers[i].makePrediction(newInstance).getPredictedLabels();
for (int j=0; j<sizeOfSubset; j++) {
sumVotes[classIndicesPerSubset[i][j]] += predictions[j];
lengthVotes[classIndicesPerSubset[i][j]]++;
}
}
/*
for (int i=0; i<numLabels; i++)
System.out.print(instance.value(numPredictors+i) + " ");
System.out.println("");
System.out.println(Arrays.toString(sumVotes));
System.out.println(Arrays.toString(lengthVotes));
//*/
double[] confidence = new double[numLabels];
double[] labels = new double[numLabels];
for (int i=0; i<numLabels; i++) {
confidence[i] = sumVotes[i]/lengthVotes[i];
if (confidence[i] >= 0.5)
labels[i] = 1;
else
labels[i] = 0;
}
Prediction pred = new Prediction(labels, confidence);
return pred;
}
public void nullSubsetClassifier(int i) {
subsetClassifiers[i] = null;
}
public String getRevision() {
throw new UnsupportedOperationException("Not supported yet.");
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -