⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 costmatrix.java

📁 一个数据挖掘系统的源码
💻 JAVA
字号:

/**
 *   
 *   AgentAcademy - an open source Data Mining framework for
 *   training intelligent agents
 *
 *   Copyright (C)   2001-2003 AA Consortium.
 *
 *   This library is open source software; you can redistribute it 
 *   and/or modify it under the terms of the GNU Lesser General 
 *   Public License as published by the Free Software Foundation;   
 *   either version 2.0 of the License, or (at your option) any later 
 *   version.
 *
 *   This library is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the    
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU Lesser General Public
 *   License along with this library; if not, write to the Free 
 *   Software Foundation, Inc., 59 Temple Place, Suite 330, Boston, 
 *   MA  02111-1307 USA
 * 
 */

package org.agentacademy.modules.dataminer.classifiers.evaluation;

/**
 * <p>Title: The Data Miner prototype</p>
 * <p>Description: A prototype for the DataMiner (DM), the Agent Academy (AA) module responsible for performing data mining on the contents of the Agent Use Repository (AUR). The extracted knowledge is to be sent back to the AUR in the form of a PMML document.</p>
 * <p>Copyright: Copyright (c) 2002</p>
 * <p>Company: CERTH</p>
 * @author asymeon
 * @version 0.3
 */

import org.agentacademy.modules.dataminer.core.*;
import java.io.*;
import java.util.Random;

/**
 * Class for storing and manipulating a misclassification cost matrix.
 * The element at position i,j in the matrix is the penalty for classifying
 * an instance of class j as class i.
 *
 */
public class CostMatrix extends Matrix {

  /** The deafult file extension for cost matrix files */
  public static String FILE_EXTENSION = ".cost";

  /**
   * Creates a cost matrix that is a copy of another.
   *
   * @param toCopy the matrix to copy.
   */
  public CostMatrix(CostMatrix toCopy) {

    super(toCopy.size(), toCopy.size());

    for (int x=0; x<toCopy.size(); x++)
      for (int y=0; y<toCopy.size(); y++)
	setElement(x, y, toCopy.getElement(x, y));
  }

  /**
   * Creates a default cost matrix of a particular size. All values will be 0.
   *
   * @param numOfClasses the number of classes that the cost matrix holds.
   */
  public CostMatrix(int numOfClasses) {

    super(numOfClasses, numOfClasses);
  }

  /**
   * Creates a cost matrix from a reader.
   *
   * @param reader the reader to get the values from.
   * @exception Exception if the matrix is invalid.
   */
  public CostMatrix(Reader reader) throws Exception {

    super(reader);

    // make sure that the matrix is square
    if (numRows() != numColumns())
      throw new Exception("Trying to create a non-square cost matrix");
  }

  /**
   * Sets the cost of all correct classifications to 0, and all
   * misclassifications to 1.
   *
   */
  public void initialize() {

    for (int i = 0; i < size(); i++) {
      for (int j = 0; j < size(); j++) {
	setElement(i, j, i == j ? 0.0 : 1.0);
      }
    }
  }

  /**
   * Gets the size of the matrix.
   *
   * @return the size.
   */
  public int size() {

    return numColumns();
  }

  /**
   * Applies the cost matrix to a set of instances. If a random number generator is
   * supplied the instances will be resampled, otherwise they will be rewighted.
   * Adapted from code once sitting in Instances.java
   *
   * @param data the instances to reweight.
   * @param random a random number generator for resampling, if null then instances are
   * rewighted.
   * @return a new dataset reflecting the cost of misclassification.
   * @exception Exception if the data has no class or the matrix in inappropriate.
   */
  public Instances applyCostMatrix(Instances data, Random random) throws Exception {

    double sumOfWeightFactors = 0, sumOfMissClassWeights,
      sumOfWeights;
    double [] weightOfInstancesInClass, weightFactor, weightOfInstances;
    Instances newData;

    if (data.classIndex() < 0) {
      throw new Exception("Class index is not set!");
    }

    if (size() != data.numClasses()) {
      throw new Exception("Misclassification cost matrix has "+
			  "wrong format!");
    }

    weightFactor = new double[data.numClasses()];
    weightOfInstancesInClass = new double[data.numClasses()];
    for (int j = 0; j < data.numInstances(); j++) {
      weightOfInstancesInClass[(int)data.instance(j).classValue()] +=
	data.instance(j).weight();
    }
    sumOfWeights = Utils.sum(weightOfInstancesInClass);

    // normalize the matrix if not already
    for (int i=0; i<size(); i++)
      if (!Utils.eq(getElement(i, i),0)) {
	CostMatrix normMatrix = new CostMatrix(this);
	normMatrix.normalize();
	return normMatrix.applyCostMatrix(data, random);
      }

    for (int i = 0; i < data.numClasses(); i++) {

      // Using Kai Ming Ting's formula for deriving weights for
      // the classes and Breiman's heuristic for multiclass
      // problems.
      sumOfMissClassWeights = 0;
      for (int j = 0; j < data.numClasses(); j++) {
	if (Utils.sm(getElement(i,j),0)) {
	  throw new Exception("Neg. weights in misclassification "+
			      "cost matrix!");
	}
	sumOfMissClassWeights += getElement(i,j);
      }
      weightFactor[i] = sumOfMissClassWeights * sumOfWeights;
      sumOfWeightFactors += sumOfMissClassWeights *
	weightOfInstancesInClass[i];
    }
    for (int i = 0; i < data.numClasses(); i++) {
      weightFactor[i] /= sumOfWeightFactors;
    }

    // Store new weights
    weightOfInstances = new double[data.numInstances()];
    for (int i = 0; i < data.numInstances(); i++) {
      weightOfInstances[i] = data.instance(i).weight()*
	weightFactor[(int)data.instance(i).classValue()];
    }

    // Change instances weight or do resampling
    if (random != null) {
      return data.resampleWithWeights(random, weightOfInstances);
    } else {
      Instances instances = new Instances(data);
      for (int i = 0; i < data.numInstances(); i++) {
	instances.instance(i).setWeight(weightOfInstances[i]);
      }
      return instances;
    }
  }

  /**
   * Calculates the expected misclassification cost for each possible class value,
   * given class probability estimates.
   *
   * @param classProbs the class probability estimates.
   * @return the expected costs.
   * @exception Exception if the wrong number of class probabilities is supplied.
   */
  public double[] expectedCosts(double[] classProbs) throws Exception {

    if (classProbs.length != size())
      throw new Exception("Length of probability estimates don't match cost matrix");

    double[] costs = new double[size()];

    for (int x=0; x<size(); x++)
      for (int y=0; y<size(); y++)
	costs[x] += classProbs[y] * getElement(x, y);

    return costs;
  }

  /**
   * Gets the maximum cost for a particular class value.
   *
   * @param classVal the class value.
   * @return the maximum cost.
   */
  public double getMaxCost(int classVal) {

    double maxCost = Double.NEGATIVE_INFINITY;

    for (int i=0; i<size(); i++) {
      double cost = getElement(classVal, i);
      if (cost > maxCost) maxCost = cost;
    }

    return maxCost;
  }

  /**
   * Normalizes the matrix so that the diagonal contains zeros.
   *
   */
  public void normalize() {

    for (int y=0; y<size(); y++) {
      double diag = getElement(y, y);
      for (int x=0; x<size(); x++)
	setElement(x, y, getElement(x, y) - diag);
    }
  }

  /**
   * Loads a cost matrix in the old format from a reader. Adapted from code once sitting
   * in Instances.java
   *
   * @param reader the reader to get the values from.
   * @exception Exception if the matrix cannot be read correctly.
   */
  public void readOldFormat(Reader reader) throws Exception {

    StreamTokenizer tokenizer;
    int currentToken;
    double firstIndex, secondIndex, weight;

    tokenizer = new StreamTokenizer(reader);

    initialize();

    tokenizer.commentChar('%');
    tokenizer.eolIsSignificant(true);
    while (StreamTokenizer.TT_EOF !=
	   (currentToken = tokenizer.nextToken())) {

      // Skip empty lines
      if (currentToken == StreamTokenizer.TT_EOL) {
	continue;
      }

      // Get index of first class.
      if (currentToken != StreamTokenizer.TT_NUMBER) {
	throw new Exception("Only numbers and comments allowed "+
			    "in cost file!");
      }
      firstIndex = tokenizer.nval;
      if (!Utils.eq((double)(int)firstIndex,firstIndex)) {
	throw new Exception("First number in line has to be "+
			    "index of a class!");
      }
      if ((int)firstIndex >= size()) {
	throw new Exception("Class index out of range!");
      }

      // Get index of second class.
      if (StreamTokenizer.TT_EOF ==
	  (currentToken = tokenizer.nextToken())) {
	throw new Exception("Premature end of file!");
      }
      if (currentToken == StreamTokenizer.TT_EOL) {
	throw new Exception("Premature end of line!");
      }
      if (currentToken != StreamTokenizer.TT_NUMBER) {
	throw new Exception("Only numbers and comments allowed "+
			    "in cost file!");
      }
      secondIndex = tokenizer.nval;
      if (!Utils.eq((double)(int)secondIndex,secondIndex)) {
	throw new Exception("Second number in line has to be "+
			    "index of a class!");
      }
      if ((int)secondIndex >= size()) {
	throw new Exception("Class index out of range!");
      }
      if ((int)secondIndex == (int)firstIndex) {
	throw new Exception("Diagonal of cost matrix non-zero!");
      }

      // Get cost factor.
      if (StreamTokenizer.TT_EOF ==
	  (currentToken = tokenizer.nextToken())) {
	throw new Exception("Premature end of file!");
      }
      if (currentToken == StreamTokenizer.TT_EOL) {
	throw new Exception("Premature end of line!");
      }
      if (currentToken != StreamTokenizer.TT_NUMBER) {
	throw new Exception("Only numbers and comments allowed "+
			    "in cost file!");
      }
      weight = tokenizer.nval;
      if (!Utils.gr(weight,0)) {
	throw new Exception("Only positive weights allowed!");
      }
      setElement((int)firstIndex, (int)secondIndex, weight);
    }
  }
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -