⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 prototype.java

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 JAVA
字号:
/* *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* *    Prototype.java *    Copyright (C) 2004 Raymond J. Mooney * */package weka.classifiers.misc;import weka.classifiers.Classifier;import weka.classifiers.DistributionClassifier;import weka.classifiers.Evaluation;import java.io.*;import java.util.*;import weka.core.*;import weka.core.metrics.*;/** * Class for building and using a simple prototype classifier. * Computes an average/mean/prototype vector for each class. * New examples are classified based on computing distance * from the instance feature vector to the closest prototype. * * For real-valued attributes, standard vector mean and Euclidian distance are used. * To handle nominal attributes, the distribution of values for each category * are computed (as in naive Bayes) as part of the prototype. The distance along a * nominal attribute from an instance with a value V for this attribute to the  * prototype for a given class is then: 1- P(V|class) * * In order to make each attribute contribute equally to the distance, values * are normalized to [0,1] by setting NormalizeAttributes, which is set by default * * Predicted class probabilities to make a DistributionClassifier are * assumed to be inversely proportional to the distances from the prototypes * * Borrows some structure from NaiveBayesSimple * * @author Ray Mooney (mooney@cs.utexas.edu) * @version $Revision: 1.1 $  **/public class Prototype extends DistributionClassifier implements  WeightedInstancesHandler, OptionHandler{    /** All the counts for nominal attributes. */    protected double [][][] m_Counts;      /** The means for numeric attributes. */    protected double [][] m_Means;    /** The range (from min to max) taken on by each of the numeric attributes */    protected double [] m_Ranges;        /** The instances used for training. */    protected Instances m_Instances;    /** If set, Normalize all real attribute values between 0 and 1 so that     *  each dimension contributes equally to distance */    protected boolean m_NormalizeAttributes = true;    /**     * Returns an enumeration describing the available options.. <p>     *     * @return an enumeration of all the available options.     *     **/    public Enumeration listOptions () {	Vector newVector = new Vector(1);	newVector.addElement(new Option("\tNormal attribute values to [0,1].", "N", 0, "-N"));	return  newVector.elements();    }    /**     * Parses a given list of options.     * @param options the list of options as an array of strings     * @exception Exception if an option is not supported     *     **/    public void setOptions (String[] options)	throws Exception    {	setNormalizeAttributes(Utils.getFlag('N', options));    }    /**     * Gets the current settings.     *     * @return an array of strings suitable for passing to setOptions()     */    public String[] getOptions () {	String [] options = new String [1];	int current = 0;	if (m_NormalizeAttributes) options[current++] = "-N";	while (current < options.length) {	    options[current++] = "";	}	return  options;    }    public void setNormalizeAttributes (boolean v) {	m_NormalizeAttributes = v;    }    public boolean getNormalizeAttributes () {	return m_NormalizeAttributes;    }    public String normalizeAttributesTipText() {	return "Scale all real-valued attributes to the range [0,1] to equalize contribution of all attributes.";    }    /**     * Returns a string describing this clusterer     * @return a description of the evaluator suitable for     * displaying in the explorer/experimenter gui     */    public String globalInfo() {	return "Simple algorithm that computes an average or prototype example for each class "+ 	    "and then classifies instances based on distance to closest prototype";    }    /**     * Generates the classifier.     *     * @param instances set of instances serving as training data      * @exception Exception if the classifier has not been generated successfully     */    public void buildClassifier(Instances instances) throws Exception {	if (instances.checkForStringAttributes()) {	    throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");	}	if (instances.classAttribute().isNumeric()) {	    throw new UnsupportedClassTypeException("Only nominal class allowed");	}    	m_Instances = instances;    	// Reserve space	m_Counts = new double[instances.numClasses()][instances.numAttributes() - 1][0];	m_Means = new double[instances.numClasses()][instances.numAttributes() - 1];	m_Ranges = new double[instances.numAttributes() - 1];	double[] maxs = new double[instances.numAttributes() - 1];	double[] mins = new double[instances.numAttributes()- 1];	Enumeration enum = instances.enumerateAttributes();	int attIndex = 0;    	while (enum.hasMoreElements()) {	    Attribute attribute = (Attribute) enum.nextElement();	    maxs[attIndex] = Double.NEGATIVE_INFINITY;	    mins[attIndex] = Double.POSITIVE_INFINITY;	    if (attribute.isNominal()) {		for (int j = 0; j < instances.numClasses(); j++) {		    m_Counts[j][attIndex] = new double[attribute.numValues()];		}	    } else {		for (int j = 0; j < instances.numClasses(); j++) {		    m_Counts[j][attIndex] = new double[1];		}	    }	    attIndex++;	}    	// Compute counts and sums 	Enumeration enumInsts = instances.enumerateInstances();	while (enumInsts.hasMoreElements()) {	    Instance instance = (Instance) enumInsts.nextElement();	    int classNum = (int)instance.classValue();	    Enumeration enumAtts = instances.enumerateAttributes();	    attIndex = 0;	    while (enumAtts.hasMoreElements()) {		Attribute attribute = (Attribute) enumAtts.nextElement();		if (!instance.isMissing(attribute)) {		    if (attribute.isNominal()) {			m_Counts[classNum][attIndex][(int)instance.value(attribute)] += instance.weight();		    } else {			double value = instance.value(attribute);			m_Means[classNum][attIndex] += value * instance.weight();			m_Counts[classNum][attIndex][0] += instance.weight();			if (m_NormalizeAttributes) {			    if (value < mins[attIndex])				mins[attIndex] = value;			    if (value > maxs[attIndex])				maxs[attIndex] = value;			}		    }		}		attIndex++;	    }	}	// Compute means across complete datset for use	// when not sufficient class-specific info	double[] overallMeans = new double[instances.numAttributes() - 1];	double[] overallCounts = new double[instances.numAttributes() - 1];	Enumeration enumAtts = instances.enumerateAttributes();	attIndex = 0;	while (enumAtts.hasMoreElements()) {	    Attribute attribute = (Attribute) enumAtts.nextElement();	    if (attribute.isNumeric()) {		for (int j = 0; j < instances.numClasses(); j++) {		    overallMeans[attIndex] += m_Means[j][attIndex];		    overallCounts[attIndex] += m_Counts[j][attIndex][0];		}		if (overallCounts[attIndex] !=0)		    overallMeans[attIndex] /= overallCounts[attIndex];	    }	    attIndex ++;    	}	// Compute conditional probs, means,	enumAtts = instances.enumerateAttributes();	attIndex = 0;	double sum = 0;	while (enumAtts.hasMoreElements()) {	    Attribute attribute = (Attribute) enumAtts.nextElement();	    m_Ranges[attIndex] = maxs[attIndex] - mins[attIndex];	    for (int j = 0; j < instances.numClasses(); j++) {		if (attribute.isNumeric()) {		    if (m_Counts[j][attIndex][0] != 0) {			m_Means[j][attIndex] /= m_Counts[j][attIndex][0];		    } else { // Back-off to class independent stats if no data for class			m_Means[j][attIndex] = overallMeans[attIndex];		    }		} else if (attribute.isNominal()) {		    sum = Utils.sum(m_Counts[j][attIndex]);		    if (sum != 0) 			for (int i = 0; i < attribute.numValues(); i++) {			    m_Counts[j][attIndex][i] = m_Counts[j][attIndex][i] / sum;			}		}	    }	    attIndex++;	}	//	System.out.println(toString());    }    /**     * Calculates the class membership probabilities for the given test instance.     *     * @param instance the instance to be classified     * @return predicted class probability distribution     * @exception Exception if distribution can't be computed     */    public double[] distributionForInstance(Instance instance) throws Exception {	double[] dists = new double[instance.numClasses()];	for (int j = 0; j < instance.numClasses(); j++) {	    Enumeration enumAtts = instance.enumerateAttributes();	    int attIndex = 0;	    while (enumAtts.hasMoreElements()) {		Attribute attribute = (Attribute) enumAtts.nextElement();		if (!instance.isMissing(attribute)) {		    if (attribute.isNominal()) {			// Nominal distance is based on 1 - P(value | class)			dists[j] += Math.pow((1 - m_Counts[j][attIndex][(int)instance.value(attribute)]),2);		    } else {			double diff = (instance.value(attribute) - m_Means[j][attIndex]);			if (m_NormalizeAttributes) {			    // Scaling by the range is equivalent to scaling all atrributes to [0,1]			    // and equalizes the contribution of all attributes			    if (m_Ranges[attIndex] == 0)				diff = 1.0;			    else				diff = diff / m_Ranges[attIndex];			}			dists[j] += Math.pow(diff, 2);		    }		}		attIndex++;	    }	    // Use inverse of Euclidian distance from prototype as similarity that is normalized	    // to a probability distribution	    dists[j] = Math.sqrt(dists[j]);	    if (dists[j] == 0.0)		dists[j] = Double.MAX_VALUE;	    else		dists[j] = 1 / dists[j];	}	Utils.normalize(dists);	return dists;    }	    /**     * Returns a description of the classifier.     *     * @return a description of the classifier as a string.     */    public String toString() {	if (m_Instances == null) {	    return "No model built yet.";	}	try {	    StringBuffer text = new StringBuffer("Prototype Model");	    int attIndex;      	    for (int i = 0; i < m_Instances.numClasses(); i++) {		text.append("\n\nClass " + m_Instances.classAttribute().value(i) 			    + "\n\n");		Enumeration enumAtts = m_Instances.enumerateAttributes();		attIndex = 0;		while (enumAtts.hasMoreElements()) {		    Attribute attribute = (Attribute) enumAtts.nextElement();		    text.append("Attribute " + attribute.name() + "\n");		    if (attribute.isNominal()) {			for (int j = 0; j < attribute.numValues(); j++) {			    text.append(attribute.value(j) + "\t");			}			text.append("\n");			for (int j = 0; j < attribute.numValues(); j++)			    text.append(Utils.					doubleToString(m_Counts[i][attIndex][j], 10, 8)					+ "\t");		    } else {			text.append("Mean: " + Utils.				    doubleToString(m_Means[i][attIndex], 10, 8) + "\t");		    }		    text.append("\n\n");		    attIndex++;		}	    }      	    return text.toString();	} catch (Exception e) {	    return "Can't print Prototype classifier!";	}    }    /**     * Main method for testing this class.     *     * @param argv the options     */    public static void main(String [] argv) {	Classifier scheme;	try {	    scheme = new Prototype();	    System.out.println(Evaluation.evaluateModel(scheme, argv));	} catch (Exception e) {	    System.err.println(e.getMessage());	}    }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -