⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 semisupem.java

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
    /**     * Generates the classifier.     *     * @param instances set of instances serving as training data      * @exception Exception if the classifier has not been generated successfully     */    public void buildClassifier(Instances data) throws Exception {	if (data.checkForStringAttributes()) {	    throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");	}	if (data.classAttribute().isNumeric()) {	    throw new UnsupportedClassTypeException("Can't handle a numeric class!");	}	if (m_Classifier == null) {	    throw new Exception("A base classifier has not been specified!");	}	m_LabeledInstances = data;	// Add "hard" soft-labeled instances of labeled data to the data for EM	m_AllInstances = new SoftClassifiedInstances(data);	Random m_Random = new Random(m_rseed);	// Make random soft-labeled instances for unlabeled data	m_UnlabeledInstances = new SoftClassifiedInstances(m_UnlabeledData, m_Random);	if (m_Lambda != 1.0) 	    weightInstances(m_UnlabeledInstances, m_Lambda);	// Add the unlabeled data to the complete data set	m_AllInstances.addInstances(m_UnlabeledInstances);	initModel();  	if (m_verbose) {	    System.out.println("Labeled Data Classes: ");	    Enumeration enumInsts = m_LabeledInstances.enumerateInstances();	    while (enumInsts.hasMoreElements()) {		Instance instance = (Instance) enumInsts.nextElement();		System.out.print(m_AllInstances.classAttribute().value((int)instance.classValue()) + " ");	    }	    System.out.println("\nNum Unlabeled: " + m_UnlabeledInstances.numInstances() );	    //  	    System.out.println("Labeled data: " + m_LabeledInstances);	    //  	    System.out.println("Unlabeled data: " + m_UnlabeledInstances);  	}	if (m_UnlabeledInstances.numInstances() != 0)	    iterate();    }    /** Weighted all given instances with given weight */    protected void weightInstances (Instances insts, double weight) {	Enumeration enumInsts = insts.enumerateInstances();	while (enumInsts.hasMoreElements()) {	    Instance instance = (Instance) enumInsts.nextElement();	    instance.setWeight(weight);	}    }    /** Intialize model using appropriate set of data */    protected void initModel() throws Exception {	SoftClassifiedInstances seedInstances = new SoftClassifiedInstances(m_LabeledInstances);	if (m_seedUnseenClasses && m_UnlabeledInstances.numInstances() != 0) {	    List unseenClasses = unseenClasses(seedInstances);	    if (!unseenClasses.isEmpty()) {		if (m_verbose) 		    System.out.println("Unseen classes: " + unseenClasses);		// Add a seed instance for the unseen classes that is soft labeled equally		// in all unkown classes.		Instance farthest =  farthestInstance(m_UnlabeledInstances, seedInstances);		softLabelClasses((SoftClassifiedInstance)farthest, unseenClasses);		if (m_verbose) 		    System.out.println("Seeded Instance: " + classDistributionString((SoftClassifiedInstance)farthest));		seedInstances.add(farthest);	    }	}	m_Classifier.buildClassifier(seedInstances);    }    /** Return a list of class values for which there are no     *  instances in insts */    protected ArrayList unseenClasses(Instances insts) {	int[] classCounts = new int[insts.numClasses()];	Enumeration enumInsts = insts.enumerateInstances();	while (enumInsts.hasMoreElements()) {	    Instance inst = (Instance) enumInsts.nextElement();	    classCounts[(int)inst.classValue()]++;	}	ArrayList result = new ArrayList();	for (int i = 0; i < insts.numClasses(); i++) {	    if (classCounts[i] == 0) {		result.add(new Integer(i));	    }	}	return result;    }    /** Return the instance in candidateInsts that is farthest from any instance     * in insts */    protected Instance farthestInstance(Instances candidateInsts, Instances insts) {	double maxDist = Double.NEGATIVE_INFINITY;	Instance farthestInst = null;	double dist;	setMinMax(m_AllInstances);	Enumeration enumInsts = candidateInsts.enumerateInstances();	while (enumInsts.hasMoreElements()) {	    Instance candidate = (Instance) enumInsts.nextElement();	    dist = minimumDistance(candidate, insts);	    if (dist > maxDist) {		maxDist = dist;		farthestInst = candidate;	    }	}	return farthestInst;    }    /** Return the distance from inst to the closest instance in insts */    protected double minimumDistance(Instance inst, Instances insts) {	double minDist = Double.POSITIVE_INFINITY;	double dist;	Enumeration enumInsts = insts.enumerateInstances();	while (enumInsts.hasMoreElements()) {	    Instance X = (Instance) enumInsts.nextElement();	    dist = distance(inst, X);	    if (dist < minDist) {		minDist = dist;	    }	}	return minDist;    }    /** Soft label inst as being equally likely to be in an of the given classes */	    protected void softLabelClasses(SoftClassifiedInstance inst, List classes) 	throws Exception {	double prob = 1.0/classes.size();	double[] dist = new double[((Instance)inst).dataset().numClasses()];	for (int i = 0; i < classes.size(); i++) {	    dist[((Integer)classes.get(i)).intValue()] = prob;	}	inst.setClassDistribution(dist);	    }    /** Run EM iterations until likelihood stops increasing significantly or max iterations exhausted */    protected void iterate() throws Exception {	double logLikelihood, oldLogLikelihood;	logLikelihood = 0;	oldLogLikelihood = 0;	for (int i = 0; i < m_max_iterations; i++) {	    //	     if (m_verbose) {	    //	    		System.out.println(m_Classifier);	    //	    	     }	    oldLogLikelihood = logLikelihood;	    logLikelihood = eStep();	    if (m_verbose) {		System.out.println("\nIteration " + i + ":  LogLikelihood = " + logLikelihood + "\n\n");	    }	    if ( (i > 0) && ((logLikelihood - oldLogLikelihood) < m_minLogLikelihoodIncr))		break;	    mStep();	}    }	    protected double eStep() throws Exception {	double logLikelihood = 0;	double classifiedCorrect = 0;	double[] dist;	Enumeration enumInsts = m_UnlabeledInstances.enumerateInstances();	while (enumInsts.hasMoreElements()) {	    Instance instance = (Instance) enumInsts.nextElement();	    dist = m_Classifier.unNormalizedDistributionForInstance(instance);	    //	  instance.setClassDistribution(dist);	    //    System.out.println("Instance:" + instance + " Dist: " + classDistributionString(instance));	    logLikelihood += logSum(dist);	    NaiveBayesSimple.normalizeLogs(dist);	    //	    System.out.println("Norm Dist: " + classDistributionString((SoftClassifiedInstance)instance));	    ((SoftClassifiedInstance)instance).setClassDistribution(dist);	    if (m_verbose) {		// System.out.println(classDistributionString(instance));		if (Utils.maxIndex(dist) == (int)instance.classValue()) {		    classifiedCorrect++;		}	    }	}	if (m_verbose) {	    System.out.println("\nAccuracy on Unlabeled: " + classifiedCorrect/ m_UnlabeledInstances.numInstances());	}	enumInsts = m_LabeledInstances.enumerateInstances();	while (enumInsts.hasMoreElements()) {	    Instance instance = (Instance) enumInsts.nextElement();	    dist = m_Classifier.unNormalizedDistributionForInstance(instance);	    logLikelihood += logSum(dist);	}	return logLikelihood/m_AllInstances.numInstances();    }    /** Sums log of probabilities using special method for summing in log space     */    public double logSum(double[] logProbs) {	double sum = 0; 	double max = logProbs[Utils.maxIndex(logProbs)];	for (int i = 0; i < logProbs.length; i++) {	    sum +=  Math.exp(logProbs[i] - max);	}	return max + Math.log(sum);    }    protected String classDistributionString(SoftClassifiedInstance inst) {	double[] dist = inst.getClassDistribution();	StringBuffer text = new StringBuffer();	Attribute classAtt = m_AllInstances.classAttribute();	text.append(classAtt.value((int)((Instance)inst).classValue()) + " | ");	for (int i = 0; i < m_AllInstances.numClasses(); i++) {	    text.append(classAtt.value(i) + ":" + dist[i] + " ");	}	return text.toString();    }    protected void mStep() throws Exception {	m_Classifier.buildClassifier(m_AllInstances);    }    /**     * Calculates the class membership probabilities for the given test instance.     *     * @param instance the instance to be classified     * @return predicted class probability distribution     * @exception Exception if distribution can't be computed     */    public double[] distributionForInstance(Instance instance) throws Exception {	double[] dist = m_Classifier.unNormalizedDistributionForInstance(instance);	NaiveBayesSimple.normalizeLogs(dist);	return dist;    }    /**     * Calculates the distance between two instances     *     * @param first the first instance     * @param second the second instance     * @return the distance between the two given instances     */              protected double distance(Instance first, Instance second) {    	double diff, distance = 0;	Instances dataset = first.dataset();	for(int i = 0; i < dataset.numAttributes(); i++) { 	    if (i == dataset.classIndex()) {		continue;	    }	    if (dataset.attribute(i).isNominal()) {		// If attribute is nominal		if (first.isMissing(i) || second.isMissing(i) ||		    ((int)first.value(i) != (int)second.value(i))) {		    distance += 1;		}	    } else {			// If attribute is numeric		if (first.isMissing(i) || second.isMissing(i)){		    if (first.isMissing(i) && second.isMissing(i)) {			diff = 1;		    } else {			if (second.isMissing(i)) {			    diff = norm(first.value(i), i);			} else {			    diff = norm(second.value(i), i);			}			if (diff < 0.5) {			    diff = 1.0 - diff;			}		    }		} else {		    diff = norm(first.value(i), i) - norm(second.value(i), i);		}		distance += diff * diff;	    }	}    	return distance;    }        /**     * Normalizes a given value of a numeric attribute.     *     * @param x the value to be normalized     * @param i the attribute's index     */    protected double norm(double x,int i) {	if (Double.isNaN(m_MinArray[i])	    || Utils.eq(m_MaxArray[i], m_MinArray[i])) {	    return 0;	} else {	    return (x - m_MinArray[i]) / (m_MaxArray[i] - m_MinArray[i]);	}    }    /** Compute and store min max values for each numeric feature */    protected void setMinMax(Instances insts) {	m_MinArray = new double [insts.numAttributes()];	m_MaxArray = new double [insts.numAttributes()];	for (int i = 0; i < insts.numAttributes(); i++) {	    m_MinArray[i] = m_MaxArray[i] = Double.NaN;	}	Enumeration enum = insts.enumerateInstances();	while (enum.hasMoreElements()) {	    updateMinMax((Instance) enum.nextElement());	}    }    /**     * Updates the minimum and maximum values for all the attributes     * based on a new instance.     *     * @param instance the new instance     */    protected void updateMinMax(Instance instance) {	Instances dataset = instance.dataset();    	for (int j = 0;j < dataset.numAttributes(); j++) {	    if ((dataset.attribute(j).isNumeric()) && (!instance.isMissing(j))) {		if (Double.isNaN(m_MinArray[j])) {		    m_MinArray[j] = instance.value(j);		    m_MaxArray[j] = instance.value(j);		} else {		    if (instance.value(j) < m_MinArray[j]) {			m_MinArray[j] = instance.value(j);		    } else {			if (instance.value(j) > m_MaxArray[j]) {			    m_MaxArray[j] = instance.value(j);			}		    }		}	    }	}    }    /**     * Main method for testing this class.     *     * @param argv the options     */    //    public static void main(String [] argv) {    //      try {    //        NaiveBayesSimpleSoft baseClassifier = new NaiveBayesSimpleSoft();    //        baseClassifier.setMinStdDev(.15);    //        Instances instances = new Instances(new BufferedReader(new FileReader(argv[0])));    //        instances.setClassIndex(instances.numAttributes() - 1);    //        SemiSupEM emClassifier = new SemiSupEM();    //        emClassifier.resetOptions();    //        emClassifier.setClassifier(baseClassifier);    //        emClassifier.setDebug(true);    //        //      emClassifier.setUnlabeledSeeding(true);    //        Random random = new Random();    //        instances.randomize(random);    //        int numLabeled = Integer.parseInt(argv[1]);    //        Instances labeledInsts = new Instances(instances, 0, numLabeled);    //        Instances unlabeledInsts = new Instances(instances, numLabeled, (instances.numInstances() - numLabeled));    //        emClassifier.setUnlabeled(unlabeledInsts);    //        emClassifier.buildClassifier(labeledInsts);    //      } catch (Exception e) {    //        System.err.println(e.getMessage());    //      }    //    }    public static void main(String [] argv) {	try {	    Instances instances = new Instances(new BufferedReader(new FileReader(argv[0])));	    instances.setClassIndex(instances.numAttributes() - 1);	    Random random = new Random(Integer.parseInt(argv[2]));	    instances.randomize(random);	    int numLabeled = Integer.parseInt(argv[1]);	    Instances labeledInsts = new Instances(instances, 0, numLabeled);	    Instances unlabeledInsts = new Instances(instances, numLabeled, (instances.numInstances() - numLabeled));	    SemiSupEM emClassifier = new SemiSupEM();	    emClassifier.setOptions(argv);	    emClassifier.setUnlabeled(unlabeledInsts);	    emClassifier.buildClassifier(labeledInsts);	} catch (Exception e) {	    System.err.println(e.getMessage());	}    }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -