exampleset.java

来自「一个很好的LIBSVM的JAVA源码。对于要研究和改进SVM算法的学者。可以参考」· Java 代码 · 共 435 行

JAVA
435
字号
/*
 *  YALE - Yet Another Learning Environment
 *  Copyright (C) 2001-2004
 *      Simon Fischer, Ralf Klinkenberg, Ingo Mierswa, 
 *          Katharina Morik, Oliver Ritthoff
 *      Artificial Intelligence Unit
 *      Computer Science Department
 *      University of Dortmund
 *      44221 Dortmund,  Germany
 *  email: yale-team@lists.sourceforge.net
 *  web:   http://yale.cs.uni-dortmund.de/
 *
 *  This program is free software; you can redistribute it and/or
 *  modify it under the terms of the GNU General Public License as 
 *  published by the Free Software Foundation; either version 2 of the
 *  License, or (at your option) any later version. 
 *
 *  This program is distributed in the hope that it will be useful, but
 *  WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
 *  General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program; if not, write to the Free Software
 *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
 *  USA.
 */
package edu.udo.cs.mySVM.Examples;

import edu.udo.cs.yale.example.ExampleReader;
import edu.udo.cs.yale.example.Attribute;
import java.lang.Integer;
import java.lang.Double;

import java.io.*;
import java.util.Vector;
import java.util.Map;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Iterator;

/** Implementation of a sparse example set which can be used for learning. This data structure is also used
 *  as SVM model.
 *
 *  @author Stefan R?ping, Ingo Mierswa
 *  @version 1.0
 */
public class ExampleSet {

    private static class MeanVariance {
	private double mean = 0.0d;
	private double variance = 0.0d;
	public MeanVariance(double mean, double variance) {
	    this.mean = mean;
	    this.variance = variance;
	}
	public double getMean() { return mean; }
	public double getVariance() { return variance; }
    }

    /** The dimension of the example set. */
    private int dim;
    /** The number of examples. */
    private int train_size;

    // sparse representation of examples. public for avoiding invocation of a method (slower)
    /** The known attribute values for each example. */
    public double[][] atts;
    /** The corresponding indices for the known attribute values for each example. */
    public int[][] index;

    /** The ids of all examples. */
    public String[] ids;

    /** The SVM alpha values. Will be filled by learning. */
    private double[] alphas;
    /** The labels of the examples if known. -1 and +1 for classification or the real value for regression
     *  tasks. Will be filled by prediction. */
    private double[] ys;
    /** The hyperplane offset. */
    private double b;
    /** This example will be once constructed and delivered with the asked values. */
    private Example x;

    /** This map stores the mean-variance informations about all attributes (att index --> mean-variance).
     *  This information is used to scale the data from the test set. */
    private Map meanVarianceMap = new HashMap();


    /** Creates an empty example set of the given size. */
    public ExampleSet(int size, double b) {
	this.train_size = size;
	this.b = b;

	atts = new double[train_size][];
	index = new int[train_size][];
	ys = new double[train_size];
	alphas = new double[train_size];

	ids = new String[size];

	x = new Example();
    }


    private static Map createMeanVariances(edu.udo.cs.yale.example.ExampleSet exampleSet) {
	double[] sum = new double[exampleSet.getNumberOfAttributes()];
	double[] squaredSum = new double[exampleSet.getNumberOfAttributes()];

	ExampleReader reader = exampleSet.getExampleReader();
	while (reader.hasNext()) {
	    edu.udo.cs.yale.example.Example example = reader.next();
	    for (int a = 0; a < example.getNumberOfAttributes(); a++) {
		double value = example.getValue(example.getAttribute(a));
		sum[a] += value;
		squaredSum[a] += value * value;
	    }
	}

	Map meanVariances = new HashMap();
	for (int a = 0; a < exampleSet.getNumberOfAttributes(); a++) {
	    sum[a] /= exampleSet.getSize();
	    squaredSum[a] /= exampleSet.getSize();
	    meanVariances.put(new Integer(a), new MeanVariance(sum[a], squaredSum[a] - (sum[a] * sum[a])));
	}
	
	return meanVariances;
    }
    
    public ExampleSet(edu.udo.cs.yale.example.ExampleSet exampleSet, boolean scale){
	this(exampleSet, scale ? createMeanVariances(exampleSet) : new HashMap());
    }

    /** Creates a fresh example set of the given size from the Yale example reader. The alpha values and b are 
     *  zero, the label will be set if it is known. */
    public ExampleSet(edu.udo.cs.yale.example.ExampleSet exampleSet, Map meanVariances){
	this(exampleSet.getSize(), 0.0d);
	this.meanVarianceMap = meanVariances;

	ExampleReader reader = exampleSet.getExampleReader();
	int exampleCounter = 0;
	while (reader.hasNext()) {
	    edu.udo.cs.yale.example.Example current = reader.next();
	    Map attributeMap = new LinkedHashMap();
	    for (int a = 0; a < current.getNumberOfAttributes(); a++) {
		Attribute attribute = current.getAttribute(a);
		double value = current.getValue(attribute);
		if (!current.getAttribute(a).isDefault(value)) {
		    attributeMap.put(new Integer(a), new Double(value));
		}
		if ((a+1) > dim) dim = (a+1);
	    }
	    atts[exampleCounter] = new double[attributeMap.size()];
	    index[exampleCounter] = new int[attributeMap.size()];
	    Iterator i = attributeMap.keySet().iterator();
	    int attributeCounter = 0;
	    while (i.hasNext()) {
		Integer indexValue    = (Integer)i.next();
		Double attributeValue = (Double)attributeMap.get(indexValue);
		index[exampleCounter][attributeCounter] = indexValue.intValue();
		double value = attributeValue.doubleValue();
		MeanVariance meanVariance = (MeanVariance)meanVarianceMap.get(new Integer(indexValue.intValue()));
 		if (meanVariance != null) {
		    if (meanVariance.getVariance() == 0.0d)
			value = 0.0d;
		    else
			value = (double)(value - meanVariance.getMean()) / (double)Math.sqrt(meanVariance.getVariance());
		}
		atts[exampleCounter][attributeCounter]  = value;
		attributeCounter++;
	    }
	    Attribute labelAttribute = current.getLabelAttribute();
	    if (labelAttribute != null) {
		double label = current.getLabel();
		if (labelAttribute.isNominal()) {
		    ys[exampleCounter] = (label == labelAttribute.getPositiveIndex() ? 1 : -1);
		} else {
		    ys[exampleCounter] = label;
		}
	    }
	    Attribute idAttribute = current.getIdAttribute();
	    if (idAttribute != null) {
		ids[exampleCounter] = current.getValueAsString(idAttribute);
	    }
	    exampleCounter++;
	}
    }

    /** Reads an example set from the given input stream. */
    public ExampleSet(ObjectInputStream in) throws IOException {
	this(in.readInt(), in.readDouble());
	this.dim = in.readInt();
	String scaleString = in.readUTF();
	if (scaleString.equals("scale")) {
	    int numberOfAttributes = in.readInt();
	    this.meanVarianceMap = new HashMap();
	    for (int i = 0; i < numberOfAttributes; i++) {
		int index = in.readInt();
		double mean = in.readDouble();
		double variance = in.readDouble();
		meanVarianceMap.put(new Integer(index), new MeanVariance(mean, variance));
	    }
	}
	for (int e = 0; e < this.train_size; e++) {
	    index[e] = new int[in.readInt()];
	    atts[e] = new double[index[e].length];
	    for (int a = 0; a < index[e].length; a++) {
		index[e][a] = in.readInt();
		atts[e][a] = in.readDouble();
	    }
	    alphas[e] = in.readDouble();
	    ys[e] = in.readDouble();
	}
    }

    public Map getMeanVariances() {
	return meanVarianceMap;
    }

    private int getNumberOfSupportVectors() {
	int result = 0;
	for (int i = 0; i < alphas.length; i++)
	    if (alphas[i] != 0.0d) result++;
	return result;
    }

    /** Writes the example set into the given output stream. */
    public void writeSupportVectors(ObjectOutputStream out) throws IOException {
	out.writeInt(getNumberOfSupportVectors());
	out.writeDouble(b);
	out.writeInt(dim);
	if ((meanVarianceMap == null) || (meanVarianceMap.size() == 0)) {
	    out.writeUTF("noscale");
	} else {
	    out.writeUTF("scale");
	    out.writeInt(meanVarianceMap.size());
	    Iterator i = meanVarianceMap.keySet().iterator();
	    while (i.hasNext()) {
		Integer index = (Integer)i.next();
		MeanVariance meanVariance = (MeanVariance)meanVarianceMap.get(index);
		out.writeInt(index.intValue());
		out.writeDouble(meanVariance.getMean());
		out.writeDouble(meanVariance.getVariance());
	    }
	}
	for (int e = 0; e < train_size; e++) {
	    if (alphas[e] != 0.0d) {
		out.writeInt(atts[e].length);
		for (int a = 0; a < atts[e].length; a++) {
		    out.writeInt(index[e][a]);
		    out.writeDouble(atts[e][a]);
		}
		out.writeDouble(alphas[e]);
		out.writeDouble(ys[e]);
	    }
	} 
    }


    /**
     * Counts the training examples.
     * @return Number of examples
     */
    public int count_examples() {
	return train_size;
    }


    /**
     * Counts the positive training examples
     * @return Number of positive examples
     */
    public int count_pos_examples() {
	int result = 0;
	for(int i=0;i<train_size;i++){
	    if(ys[i] > 0){
		result++;
	    }
	}
	return result;
    }


    /**
     * Gets the dimension of the examples
     * @return dim
     */
    public int get_dim() {
	return dim;
    }

    /**
     * Gets an example.
     * @param pos Number of example
     * @return Array of example attributes in their default order
     */
    public Example get_example(int pos) {
	x.att = atts[pos];
	x.index = index[pos];
	return x;
    }


    /**
     * Gets an y-value.
     * @param pos Number of example
     * @return y
     */
    public double get_y(int pos) {
	return ys[pos];
    }


    /** Sets the label value for the specified example. */
    public void set_y(int pos, double y) {
	ys[pos] = y;
    }


    /**
     * Gets the y array
     * @return y
     */
    public double[] get_ys() {
	return ys;
    }


    /**
     * Gets an alpha-value.
     * @param pos Number of example
     * @return alpha
     */
    public double get_alpha(int pos) {
	return alphas[pos];
    }


    /**
     * Gets the alpha array
     * @return alpha
     */
    public double[] get_alphas() {
	return alphas;
    }


    /**
     * swap two training examples
     * @param pos1
     * @param pos2
     */
    public void swap(int pos1, int pos2) {
	double[] dummyA = atts[pos1];
	atts[pos1] = atts[pos2];
	atts[pos2] = dummyA;
	int[] dummyI = index[pos1];
	index[pos1] = index[pos2];
	index[pos2] = dummyI;
	double dummyd = alphas[pos1];
	alphas[pos1] = alphas[pos2];
	alphas[pos2] = dummyd;
	dummyd = ys[pos1];
	ys[pos1] = ys[pos2];
	ys[pos2] = dummyd;
    }


    /**
     * get b
     * @return b
     */
    public double get_b() {
	return b;
    }


    /**
     * set b
     * @param new_b
     */
    public void set_b(double new_b) {
	b = new_b;
    }


    /**
     * sets an alpha value, sets has_alpha too.
     * @param pos Number of example
     * @param alpha New value
     */
    public void set_alpha(int pos, double alpha) {
	alphas[pos] = alpha;
    }



    public void clearAlphas() {
	for (int i = 0; i < alphas.length; i++)
	    alphas[i] = 0.0d;
    }

    // ================================================================================

    public String getId(int index) {
	return ids[index];
    }

    public String toString() {
	return toString(atts.length, false);
    }

    public String toString(boolean onlySV) {
	return toString(atts.length, onlySV);
    }

    public String toString(int numberOfExamples, boolean onlySV) {
	StringBuffer result = new StringBuffer("SVM Example Set (" +
					       (onlySV ? (getNumberOfSupportVectors() + " support vectors") : (train_size + " examples"))+ 
					       "):\nb: " + b + "\n");
	for (int e = 0; e < numberOfExamples; e++) {
	    if (!onlySV || (alphas[e] != 0.0d)) {
		for (int a = 0; a < atts[e].length; a++) {
		    result.append(index[e][a] + ":");
		    result.append(atts[e][a] + " ");
		}
		result.append(", alpha: " + alphas[e]);
		result.append(", y: " + ys[e] + "\n");
	    }
	}
	return result.toString();
    }
};

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?