⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 costmatrix.java

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 JAVA
字号:
/* *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* *    CostMatrix.java *    Copyright (C) 2001 Richard Kirkby * */package weka.classifiers;import weka.core.Matrix;import weka.core.Instances;import weka.core.Utils;import java.io.Reader;import java.io.StreamTokenizer;import java.util.Random;/** * Class for storing and manipulating a misclassification cost matrix. * The element at position i,j in the matrix is the penalty for classifying * an instance of class j as class i. * * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) * @version $Revision: 1.1.1.1 $ */public class CostMatrix extends Matrix {  /** The deafult file extension for cost matrix files */  public static String FILE_EXTENSION = ".cost";  /**   * Creates a cost matrix that is a copy of another.   *   * @param toCopy the matrix to copy.   */  public CostMatrix(CostMatrix toCopy) {        super(toCopy.size(), toCopy.size());    for (int x=0; x<toCopy.size(); x++)       for (int y=0; y<toCopy.size(); y++) 	setElement(x, y, toCopy.getElement(x, y));   }  /**   * Creates a default cost matrix of a particular size. All values will be 0.   *   * @param numOfClasses the number of classes that the cost matrix holds.   */    public CostMatrix(int numOfClasses) {        super(numOfClasses, numOfClasses);  }  /**   * Creates a cost matrix from a reader.   *   * @param reader the reader to get the values from.   * @exception Exception if the matrix is invalid.   */    public CostMatrix(Reader reader) throws Exception {    super(reader);    // make sure that the matrix is square    if (numRows() != numColumns())      throw new Exception("Trying to create a non-square cost matrix");  }  /**   * Sets the cost of all correct classifications to 0, and all   * misclassifications to 1.   *   */   public void initialize() {    for (int i = 0; i < size(); i++) {      for (int j = 0; j < size(); j++) {	setElement(i, j, i == j ? 0.0 : 1.0);      }    }  }  /**   * Gets the size of the matrix.   *   * @return the size.   */  public int size() {    return numColumns();  }  /**   * Applies the cost matrix to a set of instances. If a random number generator is    * supplied the instances will be resampled, otherwise they will be rewighted.    * Adapted from code once sitting in Instances.java   *   * @param data the instances to reweight.   * @param random a random number generator for resampling, if null then instances are   * rewighted.   * @return a new dataset reflecting the cost of misclassification.   * @exception Exception if the data has no class or the matrix in inappropriate.   */  public Instances applyCostMatrix(Instances data, Random random) throws Exception {    double sumOfWeightFactors = 0, sumOfMissClassWeights,      sumOfWeights;    double [] weightOfInstancesInClass, weightFactor, weightOfInstances;    Instances newData;    if (data.classIndex() < 0) {      throw new Exception("Class index is not set!");    }     if (size() != data.numClasses()) {       throw new Exception("Misclassification cost matrix has "+			  "wrong format!");    }    weightFactor = new double[data.numClasses()];    weightOfInstancesInClass = new double[data.numClasses()];    for (int j = 0; j < data.numInstances(); j++) {      weightOfInstancesInClass[(int)data.instance(j).classValue()] += 	data.instance(j).weight();    }    sumOfWeights = Utils.sum(weightOfInstancesInClass);    // normalize the matrix if not already    for (int i=0; i<size(); i++)      if (!Utils.eq(getElement(i, i),0)) {	CostMatrix normMatrix = new CostMatrix(this);	normMatrix.normalize();	return normMatrix.applyCostMatrix(data, random);      }        for (int i = 0; i < data.numClasses(); i++) {      // Using Kai Ming Ting's formula for deriving weights for       // the classes and Breiman's heuristic for multiclass       // problems.      sumOfMissClassWeights = 0;      for (int j = 0; j < data.numClasses(); j++) {	if (Utils.sm(getElement(i,j),0)) {	  throw new Exception("Neg. weights in misclassification "+			      "cost matrix!"); 	}	sumOfMissClassWeights += getElement(i,j);      }      weightFactor[i] = sumOfMissClassWeights * sumOfWeights;      sumOfWeightFactors += sumOfMissClassWeights * 	weightOfInstancesInClass[i];    }    for (int i = 0; i < data.numClasses(); i++) {      weightFactor[i] /= sumOfWeightFactors;    }        // Store new weights    weightOfInstances = new double[data.numInstances()];    for (int i = 0; i < data.numInstances(); i++) {      weightOfInstances[i] = data.instance(i).weight()*	weightFactor[(int)data.instance(i).classValue()];    }     // Change instances weight or do resampling    if (random != null) {      return data.resampleWithWeights(random, weightOfInstances);    } else {       Instances instances = new Instances(data);      for (int i = 0; i < data.numInstances(); i++) {	instances.instance(i).setWeight(weightOfInstances[i]);      }      return instances;    }  }  /**   * Calculates the expected misclassification cost for each possible class value,   * given class probability estimates.    *   * @param classProbs the class probability estimates.   * @return the expected costs.   * @exception Exception if the wrong number of class probabilities is supplied.   */  public double[] expectedCosts(double[] classProbs) throws Exception {    if (classProbs.length != size())      throw new Exception("Length of probability estimates don't match cost matrix");    double[] costs = new double[size()];    for (int x=0; x<size(); x++)      for (int y=0; y<size(); y++) 	costs[x] += classProbs[y] * getElement(x, y);    return costs;  }  /**   * Gets the maximum cost for a particular class value.   *   * @param classVal the class value.   * @return the maximum cost.   */  public double getMaxCost(int classVal) {    double maxCost = Double.NEGATIVE_INFINITY;    for (int i=0; i<size(); i++) {      double cost = getElement(classVal, i);      if (cost > maxCost) maxCost = cost;    }    return maxCost;  }  /**   * Normalizes the matrix so that the diagonal contains zeros.   *   */  public void normalize() {    for (int y=0; y<size(); y++) {      double diag = getElement(y, y);      for (int x=0; x<size(); x++)	setElement(x, y, getElement(x, y) - diag);    }  }  /**   * Loads a cost matrix in the old format from a reader. Adapted from code once sitting    * in Instances.java   *   * @param reader the reader to get the values from.   * @exception Exception if the matrix cannot be read correctly.   */    public void readOldFormat(Reader reader) throws Exception {    StreamTokenizer tokenizer;    int currentToken;    double firstIndex, secondIndex, weight;    tokenizer = new StreamTokenizer(reader);    initialize();    tokenizer.commentChar('%');    tokenizer.eolIsSignificant(true);    while (StreamTokenizer.TT_EOF != 	   (currentToken = tokenizer.nextToken())) {      // Skip empty lines       if (currentToken == StreamTokenizer.TT_EOL) {	continue;      }      // Get index of first class.      if (currentToken != StreamTokenizer.TT_NUMBER) {	throw new Exception("Only numbers and comments allowed "+			    "in cost file!");      }      firstIndex = tokenizer.nval;      if (!Utils.eq((double)(int)firstIndex,firstIndex)) {	throw new Exception("First number in line has to be "+			    "index of a class!");      }      if ((int)firstIndex >= size()) {	throw new Exception("Class index out of range!");      }      // Get index of second class.      if (StreamTokenizer.TT_EOF == 	  (currentToken = tokenizer.nextToken())) {	throw new Exception("Premature end of file!");      }      if (currentToken == StreamTokenizer.TT_EOL) {	throw new Exception("Premature end of line!");      }      if (currentToken != StreamTokenizer.TT_NUMBER) {	throw new Exception("Only numbers and comments allowed "+			    "in cost file!");      }      secondIndex = tokenizer.nval;      if (!Utils.eq((double)(int)secondIndex,secondIndex)) {	throw new Exception("Second number in line has to be "+			    "index of a class!");      }      if ((int)secondIndex >= size()) {	throw new Exception("Class index out of range!");      }      if ((int)secondIndex == (int)firstIndex) {	throw new Exception("Diagonal of cost matrix non-zero!");      }      // Get cost factor.      if (StreamTokenizer.TT_EOF == 	  (currentToken = tokenizer.nextToken())) {	throw new Exception("Premature end of file!");      }      if (currentToken == StreamTokenizer.TT_EOL) {	throw new Exception("Premature end of line!");      }      if (currentToken != StreamTokenizer.TT_NUMBER) {	throw new Exception("Only numbers and comments allowed "+			    "in cost file!");      }      weight = tokenizer.nval;      if (!Utils.gr(weight,0)) {	throw new Exception("Only positive weights allowed!");      }      setElement((int)firstIndex, (int)secondIndex, weight);    }  }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -