⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 rakel.java

📁 Multi-label classification 和weka集成
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
package mulan.classifier;

/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Random;

import mulan.evaluation.BinaryPrediction;
import mulan.evaluation.Evaluator;
import mulan.evaluation.IntegratedEvaluation;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.SparseInstance;
import weka.core.TechnicalInformation;
import weka.core.Utils;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Remove;

/**
 * Class that implements the RAKEL (Random k-labelsets) algorithm <p>
 *
 * @author Grigorios Tsoumakas 
 * @version $Revision: 0.02 $ 
 */
@SuppressWarnings("serial")
public class RAKEL extends AbstractMultiLabelClassifier
{
    double[][] sumVotesIncremental; /* comment */
    double[][] lengthVotesIncremental;
    double[] sumVotes;
    double[] lengthVotes;
    int numOfModels;
    int sizeOfSubset;
    int[][] classIndicesPerSubset;
    int[][] absoluteIndicesToRemove;
    LabelPowersetClassifier[] subsetClassifiers;
    protected Instances[] metadataTest;
    HashSet<String> combinations;		
    BinaryPrediction[][] predictions;
    boolean incremental =true;
    boolean cvParamSelection=false;
    int cvNumFolds, cvMinK, cvMaxK, cvStepK, cvMaxM, cvThresholdSteps;
    double cvThresholdStart, cvThresholdIncrement;    
    
    /**
    * Returns an instance of a TechnicalInformation object, containing 
    * detailed information about the technical background of this class,
    * e.g., paper reference or book this class is based on.
    * 
    * @return the technical information about this class
    */
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(Type.INPROCEEDINGS);
        result.setValue(Field.AUTHOR, "Grigorios Tsoumakas, Ioannis Vlahavas");
        result.setValue(Field.TITLE, "Random k-Labelsets: An Ensemble Method for Multilabel Classification");
        result.setValue(Field.BOOKTITLE, "Proc. 18th European Conference on Machine Learning (ECML 2007)");
        result.setValue(Field.PAGES, "406 - 417");
        result.setValue(Field.LOCATION, "Warsaw, Poland");
        result.setValue(Field.MONTH, "17-21 September");
        result.setValue(Field.YEAR, "2007");

        return result;
    }
               
    public RAKEL(int labels, int models, int subset) {
        numLabels = labels;
        numOfModels = models;
        sizeOfSubset = subset;
        classIndicesPerSubset = new int[numOfModels][sizeOfSubset];
        absoluteIndicesToRemove = new int[numOfModels][sizeOfSubset];
        subsetClassifiers = new LabelPowersetClassifier[numOfModels];
        metadataTest = new Instances[numOfModels];
        sumVotes = new double[numLabels];
        lengthVotes = new double[numLabels];
    }
	
	public void setSizeOfSubset(int size) {
		sizeOfSubset=size;
		classIndicesPerSubset = new int[numOfModels][sizeOfSubset];		
		absoluteIndicesToRemove = new int[numOfModels][sizeOfSubset];
	}
	
        public int getSizeOfSubset() {
            return sizeOfSubset;
        }
        
	public void setNumModels(int models) {
		numOfModels = models;
		classIndicesPerSubset = new int[numOfModels][sizeOfSubset];
		absoluteIndicesToRemove = new int[numOfModels][sizeOfSubset];
		subsetClassifiers = new LabelPowersetClassifier[numOfModels];
		metadataTest = new Instances[numOfModels];
	}
	
        public int getNumModels() {
            return numOfModels;
        }
                
        
	public BinaryPrediction[][] getPredictions() {
		return predictions;
	}
		
        private int binomial(int n, int m) 
        {
            int[] b = new int[n+1];
            b[0]=1;
            for (int i=1; i<=n; i++)
            {
                b[i] = 1;
                for (int j=i-1; j>0; --j)
                    b[j] += b[j-1];
            }
            return b[m];
        }
        
 
        public void setParamSets(int numFolds, int minK, int maxK, int stepK, int maxM, double thresholdStart, double thresholdIncrement, int thresholdSteps){
            cvNumFolds = numFolds;
            cvMinK = minK;
            cvMaxK = maxK;
            cvStepK = stepK;
            cvMaxM = maxM;
            cvThresholdStart = thresholdStart;
            cvThresholdIncrement = thresholdIncrement;
            cvThresholdSteps = thresholdSteps;
        }
        
        public void setParamSelectionViaCV(boolean flag){
            cvParamSelection = flag;
        }
        
        private void paramSelectionViaCV(Instances trainData) throws Exception {           
            ArrayList []metric = new ArrayList[cvNumFolds];
            //* Evaluate using X-fold CV
            for (int f=0; f<cvNumFolds; f++)
            {         
                metric[f] = new ArrayList();
                Instances foldTrainData = trainData.trainCV(cvNumFolds, f);
                Instances foldTestData = trainData.testCV(cvNumFolds, f);
            
                // rakel    
                for (int k=cvMinK; k<=cvMaxK; k+=cvStepK)
                {            
                    RAKEL rakel = new RAKEL(numLabels,binomial(numLabels, k), k);
                    rakel.setBaseClassifier(baseClassifier);
                    int finalM = Math.min(binomial(numLabels,k),cvMaxM);
                    for (int m=0; m<finalM; m++)
                    {
                        rakel.updateClassifier(foldTrainData, m);
                        Evaluator evaluator = new Evaluator();
                        rakel.updatePredictions(foldTestData, m);
                        rakel.nullSubsetClassifier(m);
                        IntegratedEvaluation[] results = evaluator.evaluateOverThreshold(rakel.getPredictions(), foldTestData, cvThresholdStart, cvThresholdIncrement, cvThresholdSteps);                      
                        for (int t=0; t<results.length; t++)  {
                            metric[f].add(results[t].hammingLoss());                                                        
                        }
                    }
                }
            }
            ArrayList finalResults = new ArrayList();
            for (int i=0; i<metric[0].size(); i++) {                
                double sum=0;
                for (int f=0; f<cvNumFolds; f++)
                    sum = sum + (Double) metric[f].get(i);
                finalResults.add(sum/cvNumFolds);
            }
            
            double minMetric=1; // HammingLoss
            int counter=0;
            for (int k=cvMinK; k<=cvMaxK; k+=cvStepK)
            {            
                int finalM = Math.min(binomial(numLabels,k),cvMaxM);
                for (int m=0; m<finalM; m++)
                {
                    for (int t=0; t<cvThresholdSteps; t++)  {
                        double avgMetric=0;
                        for (int f=0; f<cvNumFolds; f++)
                            avgMetric += (Double) metric[f].get(counter);
                        avgMetric /= cvNumFolds;
                        if (avgMetric < minMetric) {
                            setSizeOfSubset(k);
                            setNumModels(m);
                            setThreshold(cvThresholdStart+cvThresholdIncrement*t);
                            minMetric = avgMetric;
                        }
                        counter++;
                    }
                }
            }

        }        
        
        public void updatePredictions(Instances testData, int model) throws Exception {
		if (predictions == null) {
			predictions = new BinaryPrediction[testData.numInstances()][numLabels];
			sumVotesIncremental = new double[testData.numInstances()][numLabels];
			lengthVotesIncremental = new double[testData.numInstances()][numLabels];
		}
		

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -