⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 infogainfeatureselector.java

📁 dragontoolkit用于机器学习
💻 JAVA
字号:
package dragon.ir.classification.featureselection;

import dragon.ir.classification.*;
import dragon.ir.classification.DocClassSet;
import dragon.ir.index.*;
import dragon.matrix.SparseMatrix;
import dragon.matrix.vector.DoubleVector;
import dragon.nlp.Token;
import dragon.nlp.compare.*;
import dragon.util.*;

/**
 * <p>A Feature Selector which uses information gain to select top features</p>
 * <p>Please refer the paper below for details of the algorithm.<br>
 * Yang, Y. and Pedersen, J.O., “A comparative study on feature selection in text categorization,”
 * In Proceedings of International Conference on Machine Learning, 1997, pp. 412-420.
 * </p>
 * <p>Copyright: Copyright (c) 2005</p>
 * <p>Company: IST, Drexel University</p>
 * @author Davis Zhou
 * @version 1.0
 */

public class InfoGainFeatureSelector extends AbstractFeatureSelector implements java.io.Serializable{
	private static final long serialVersionUID = 1L;
	private double topPercentage;

    public InfoGainFeatureSelector(double topPercentage) {
        this.topPercentage =topPercentage;
    }

    protected int[] getSelectedFeatures(SparseMatrix doctermMatrix, DocClassSet trainingSet){
        System.out.println("InfoGainSelector does not accept SparseMatrix as input. Please use IndexReader as input instead.");
        return null;
    }

    protected int[] getSelectedFeatures(IndexReader indexReader, DocClassSet trainingSet){
        SortedArray list,selectedList;
        int[] featureMap;
        int i,termNum;

        list=computeTermIG(indexReader,trainingSet);
        termNum=(int)(topPercentage*indexReader.getCollection().getTermNum());
        termNum=Math.min(list.size(),termNum);
        selectedList=new SortedArray(termNum,new IndexComparator());
        for(i=0;i<termNum;i++){
            selectedList.add(list.get(i));
        }
        featureMap=new int[selectedList.size()];
        for(i=0;i<featureMap.length;i++)
            featureMap[i]=((Token)selectedList.get(i)).getIndex();
        return featureMap;
    }

    private SortedArray computeTermIG(IndexReader indexReader, DocClassSet trainingSet){
        DoubleVector termVector, classVector, classPrior, classDistrWiTerm, classDistrWoTerm;
        DocClass docClass;
        SortedArray list;
        Token curTerm;
        int[] arrDoc, arrDocIndex;
        double classEntropy;
        int termNum, trainingDocNum, docCount, docLabel, i, j;

        trainingDocNum=0;
        for(i=0;i<trainingSet.getClassNum();i++)
            trainingDocNum+=trainingSet.getDocClass(i).getDocNum();

        classPrior=getClassPrior(trainingSet);
        classEntropy=calEntropy(classPrior);

        classVector=classPrior.copy();
        classVector.multiply(trainingDocNum);

        arrDoc=new int[indexReader.getCollection().getDocNum()];
        MathUtil.initArray(arrDoc,-1);
        for(i=0;i<trainingSet.getClassNum();i++){
            docClass=trainingSet.getDocClass(i);
            for(j=0;j<docClass.getDocNum();j++)
                arrDoc[docClass.getDoc(j).getIndex()]=i;
        }

        termNum=indexReader.getCollection().getTermNum();
        list=new SortedArray(termNum,new IndexComparator());
        termVector=new DoubleVector(termNum);
        classDistrWiTerm=new DoubleVector(classPrior.size());
        classDistrWoTerm=new DoubleVector(classPrior.size());
        for(i=0;i<termNum;i++){
            arrDocIndex=indexReader.getTermDocIndexList(i);
            if(arrDocIndex==null || arrDocIndex.length==0)
                continue;
            classDistrWiTerm.assign(0);
            classDistrWoTerm.assign(classVector);
            docCount=0;
            for(j=0;j<arrDocIndex.length;j++){
                docLabel=arrDoc[arrDocIndex[j]];
                if(docLabel>=0){
                    //this document is in the training set
                    classDistrWiTerm.add(docLabel,1);
                    classDistrWoTerm.add(docLabel,-1);
                    docCount++;
                }
            }
            if(docCount==0)
                continue;

            classDistrWiTerm.multiply(1.0/docCount);
            classDistrWoTerm.multiply(1.0/(trainingDocNum-docCount));
            termVector.set(i,classEntropy-calEntropy(classDistrWiTerm)-calEntropy(classDistrWoTerm));
        }

        for(i=0;i<termVector.size();i++){
            curTerm=new Token(i,0);
            if(termVector.get(i)<=0) //this term does not exist in the training documents
                continue;
            curTerm.setWeight(termVector.get(i));
            list.add(curTerm);
        }
        list.setComparator(new WeightComparator(true));
        return list;
    }

    private double calEntropy(DoubleVector probVector){
       double sum;
       int i;

       sum=0;
       for(i=0;i<probVector.size();i++){
           if(probVector.get(i)==0)
               sum-=Double.MIN_VALUE*Math.log(Double.MIN_VALUE);
           else
               sum -= probVector.get(i) * Math.log(probVector.get(i));
       }
        return sum;
    }
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -