📄 rakel.java
字号:
package mulan.classifier;
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Random;
import mulan.evaluation.BinaryPrediction;
import mulan.evaluation.Evaluator;
import mulan.evaluation.IntegratedEvaluation;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.SparseInstance;
import weka.core.TechnicalInformation;
import weka.core.Utils;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Remove;
/**
* Class that implements the RAKEL (Random k-labelsets) algorithm <p>
*
* @author Grigorios Tsoumakas
* @version $Revision: 0.02 $
*/
@SuppressWarnings("serial")
public class RAKEL extends AbstractMultiLabelClassifier
{
double[][] sumVotesIncremental; /* comment */
double[][] lengthVotesIncremental;
double[] sumVotes;
double[] lengthVotes;
int numOfModels;
int sizeOfSubset;
int[][] classIndicesPerSubset;
int[][] absoluteIndicesToRemove;
LabelPowersetClassifier[] subsetClassifiers;
protected Instances[] metadataTest;
HashSet<String> combinations;
BinaryPrediction[][] predictions;
boolean incremental =true;
boolean cvParamSelection=false;
int cvNumFolds, cvMinK, cvMaxK, cvStepK, cvMaxM, cvThresholdSteps;
double cvThresholdStart, cvThresholdIncrement;
/**
* Returns an instance of a TechnicalInformation object, containing
* detailed information about the technical background of this class,
* e.g., paper reference or book this class is based on.
*
* @return the technical information about this class
*/
public TechnicalInformation getTechnicalInformation() {
TechnicalInformation result = new TechnicalInformation(Type.INPROCEEDINGS);
result.setValue(Field.AUTHOR, "Grigorios Tsoumakas, Ioannis Vlahavas");
result.setValue(Field.TITLE, "Random k-Labelsets: An Ensemble Method for Multilabel Classification");
result.setValue(Field.BOOKTITLE, "Proc. 18th European Conference on Machine Learning (ECML 2007)");
result.setValue(Field.PAGES, "406 - 417");
result.setValue(Field.LOCATION, "Warsaw, Poland");
result.setValue(Field.MONTH, "17-21 September");
result.setValue(Field.YEAR, "2007");
return result;
}
public RAKEL(int labels, int models, int subset) {
numLabels = labels;
numOfModels = models;
sizeOfSubset = subset;
classIndicesPerSubset = new int[numOfModels][sizeOfSubset];
absoluteIndicesToRemove = new int[numOfModels][sizeOfSubset];
subsetClassifiers = new LabelPowersetClassifier[numOfModels];
metadataTest = new Instances[numOfModels];
sumVotes = new double[numLabels];
lengthVotes = new double[numLabels];
}
public void setSizeOfSubset(int size) {
sizeOfSubset=size;
classIndicesPerSubset = new int[numOfModels][sizeOfSubset];
absoluteIndicesToRemove = new int[numOfModels][sizeOfSubset];
}
public int getSizeOfSubset() {
return sizeOfSubset;
}
public void setNumModels(int models) {
numOfModels = models;
classIndicesPerSubset = new int[numOfModels][sizeOfSubset];
absoluteIndicesToRemove = new int[numOfModels][sizeOfSubset];
subsetClassifiers = new LabelPowersetClassifier[numOfModels];
metadataTest = new Instances[numOfModels];
}
public int getNumModels() {
return numOfModels;
}
public BinaryPrediction[][] getPredictions() {
return predictions;
}
private int binomial(int n, int m)
{
int[] b = new int[n+1];
b[0]=1;
for (int i=1; i<=n; i++)
{
b[i] = 1;
for (int j=i-1; j>0; --j)
b[j] += b[j-1];
}
return b[m];
}
public void setParamSets(int numFolds, int minK, int maxK, int stepK, int maxM, double thresholdStart, double thresholdIncrement, int thresholdSteps){
cvNumFolds = numFolds;
cvMinK = minK;
cvMaxK = maxK;
cvStepK = stepK;
cvMaxM = maxM;
cvThresholdStart = thresholdStart;
cvThresholdIncrement = thresholdIncrement;
cvThresholdSteps = thresholdSteps;
}
public void setParamSelectionViaCV(boolean flag){
cvParamSelection = flag;
}
private void paramSelectionViaCV(Instances trainData) throws Exception {
ArrayList []metric = new ArrayList[cvNumFolds];
//* Evaluate using X-fold CV
for (int f=0; f<cvNumFolds; f++)
{
metric[f] = new ArrayList();
Instances foldTrainData = trainData.trainCV(cvNumFolds, f);
Instances foldTestData = trainData.testCV(cvNumFolds, f);
// rakel
for (int k=cvMinK; k<=cvMaxK; k+=cvStepK)
{
RAKEL rakel = new RAKEL(numLabels,binomial(numLabels, k), k);
rakel.setBaseClassifier(baseClassifier);
int finalM = Math.min(binomial(numLabels,k),cvMaxM);
for (int m=0; m<finalM; m++)
{
rakel.updateClassifier(foldTrainData, m);
Evaluator evaluator = new Evaluator();
rakel.updatePredictions(foldTestData, m);
rakel.nullSubsetClassifier(m);
IntegratedEvaluation[] results = evaluator.evaluateOverThreshold(rakel.getPredictions(), foldTestData, cvThresholdStart, cvThresholdIncrement, cvThresholdSteps);
for (int t=0; t<results.length; t++) {
metric[f].add(results[t].hammingLoss());
}
}
}
}
ArrayList finalResults = new ArrayList();
for (int i=0; i<metric[0].size(); i++) {
double sum=0;
for (int f=0; f<cvNumFolds; f++)
sum = sum + (Double) metric[f].get(i);
finalResults.add(sum/cvNumFolds);
}
double minMetric=1; // HammingLoss
int counter=0;
for (int k=cvMinK; k<=cvMaxK; k+=cvStepK)
{
int finalM = Math.min(binomial(numLabels,k),cvMaxM);
for (int m=0; m<finalM; m++)
{
for (int t=0; t<cvThresholdSteps; t++) {
double avgMetric=0;
for (int f=0; f<cvNumFolds; f++)
avgMetric += (Double) metric[f].get(counter);
avgMetric /= cvNumFolds;
if (avgMetric < minMetric) {
setSizeOfSubset(k);
setNumModels(m);
setThreshold(cvThresholdStart+cvThresholdIncrement*t);
minMetric = avgMetric;
}
counter++;
}
}
}
}
public void updatePredictions(Instances testData, int model) throws Exception {
if (predictions == null) {
predictions = new BinaryPrediction[testData.numInstances()][numLabels];
sumVotesIncremental = new double[testData.numInstances()][numLabels];
lengthVotesIncremental = new double[testData.numInstances()][numLabels];
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -