📄 costmatrix.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * CostMatrix.java * Copyright (C) 1999 Intelligenesis Corp. * */package weka.classifiers;import weka.core.Utils;import weka.core.Instance;import weka.core.Instances;import weka.core.Matrix;import java.io.Reader;import java.io.FileReader;import java.io.BufferedReader;import java.io.InputStreamReader;import java.io.StreamTokenizer;import java.util.Random;/** * Class for a misclassification cost matrix. The element in the i'th column * of the j'th row is the cost for (mis)classifying an instance of class j as * having class i. It is valid to have non-zero values down the diagonal * (these are typically negative to indicate some varying degree of "gain" * from making a correct prediction). * * @author Len Trigg (len@intelligenesis.net) * @version $Revision: 1.8 $ */public class CostMatrix extends Matrix { /** The filename extension that should be used for cost files */ public static String FILE_EXTENSION = ".cost"; /** * Creates a cost matrix identical to an existing matrix. * * @param toCopy the matrix to copy. */ public CostMatrix(CostMatrix toCopy) { this(toCopy.size()); for (int i = 0; i < size(); i++) { for (int j = 0; j < size(); j++) { setElement(i, j, toCopy.getElement(i, j)); } } } /** * Creates a default cost matrix for the given number of classes. The * default misclassification cost is 1. * * @param numClasses the number of classes */ public CostMatrix(int numClasses) { super(numClasses, numClasses); } /** * Creates a cost matrix from a cost file. * * @param r a reader from which the cost matrix will be read * @exception Exception if an error occurs */ public CostMatrix(Reader r) throws Exception { super(r); if (numColumns() != numRows()) { throw new Exception("Cost matrix is not square"); } } /** * Creates a cost matrix for the class attribute of the supplied instances, * where the misclassification costs are higher for misclassifying a rare * class as a frequent one. The cost of classifying an instance of class i * as class j is weight * Pj / Pi. (Pi and Pj are laplace estimates) * * @param instances a value of type 'Instances' * @param weight a value of type 'double' * @return a value of type CostMatrix * @exception Exception if no class attribute is assigned, or the class * attribute is not nominal */ public static CostMatrix makeFrequencyDependentMatrix(Instances instances, double weight) throws Exception { if (!instances.classAttribute().isNominal()) { throw new Exception("Class attribute is not nominal!"); } int numClasses = instances.numClasses(); // Collect class probabilities double probs [] = new double [numClasses]; for (int i = 0; i < probs.length; i++) { probs[i]++; } for (int i = 0; i < instances.numInstances(); i++) { Instance current = instances.instance(i); if (!current.classIsMissing()) { probs[(int)current.classValue()]++; } } Utils.normalize(probs); // Create and populate the cost matrix CostMatrix newMatrix = new CostMatrix(numClasses); for (int i = 0; i < numClasses; i++) { for (int j = 0; j < numClasses; j++) { if (i != j) { newMatrix.setElement(i, j, weight * probs[j] / probs[i]); } } } return newMatrix; } /** * Reads misclassification cost matrix from given reader. * Each line has to contain three numbers: the index of the true * class, the index of the incorrectly assigned class, and the * weight, separated by white space characters. Comments can be * appended to the end of a line by using the '%' character. * * @param reader the reader from which the cost matrix is to be read * @exception Exception if the cost matrix does not have the * right format */ public void readOldFormat(Reader reader)throws Exception { initialize(); StreamTokenizer tokenizer = new StreamTokenizer(reader); tokenizer.commentChar('%'); tokenizer.eolIsSignificant(true); int currentToken; while (StreamTokenizer.TT_EOF != (currentToken = tokenizer.nextToken())) { // Skip empty lines if (currentToken == StreamTokenizer.TT_EOL) { continue; } // Get index of first class. if (currentToken != StreamTokenizer.TT_NUMBER) { throw new Exception("Only numbers and comments allowed "+ "in cost file!"); } double firstIndex = tokenizer.nval; if (!Utils.eq((double)(int)firstIndex, firstIndex)) { throw new Exception("First number in line has to be "+ "index of a class!"); } if ((int)firstIndex >= size()) { throw new Exception("Class index out of range!"); } // Get index of second class. if (StreamTokenizer.TT_EOF == (currentToken = tokenizer.nextToken())) { throw new Exception("Premature end of file!"); } if (currentToken == StreamTokenizer.TT_EOL) { throw new Exception("Premature end of line!"); } if (currentToken != StreamTokenizer.TT_NUMBER) { throw new Exception("Only numbers and comments allowed "+ "in cost file!"); } double secondIndex = tokenizer.nval; if (!Utils.eq((double)(int)secondIndex,secondIndex)) { throw new Exception("Second number in line has to be "+ "index of a class!"); } if ((int)secondIndex >= size()) { throw new Exception("Class index out of range!"); } // Get cost factor. if (StreamTokenizer.TT_EOF == (currentToken = tokenizer.nextToken())) { throw new Exception("Premature end of file!"); } if (currentToken == StreamTokenizer.TT_EOL) { throw new Exception("Premature end of line!"); } if (currentToken != StreamTokenizer.TT_NUMBER) { throw new Exception("Only numbers and comments allowed "+ "in cost file!"); } double weight = tokenizer.nval; setElement((int)firstIndex, (int)secondIndex, weight); } } /** * Sets the costs to default values (i.e. 0 down the diagonal, and 1 for * any misclassification). */ public void initialize() { for (int i = 0; i < numRows(); i++) { for (int j = 0; j < numColumns(); j++) { if (i != j) { setElement(i, j, 1); } else { setElement(i, j, 0); } } } } /** * Gets the number of classes. * * @return the number of classes */ public int size() { return numColumns(); } /** * Normalizes the cost matrix so that diagonal elements are zero. The value * of non-zero diagonal elements is subtracted from the row containing the * value. For example: <p> * * <pre><code> * 2 5 * 3 -1 * </code></pre> * * <p> becomes <p> * * <pre><code> * 0 3 * 4 0 * </code></pre><p> * * This normalization will affect total classification cost during * evaluation, but will not affect the decision made by applying minimum * expected cost criteria during prediction. */ public void normalize() { for (int i = 0; i < size(); i++) { double diag = getElement(i, i); for (int j = 0; j < size(); j++) { addElement(i, j, -diag); } } } /** * Changes the dataset to reflect a given set of costs. * Sets the weights of instances according to the misclassification * cost matrix, or does resampling according to the cost matrix (if * a random number generator is provided). Returns a new dataset. * * @param instances the instances to apply cost weights to. * @param random a random number generator * @return the new dataset * @exception Exception if the cost matrix does not have the right * format */ public Instances applyCostMatrix(Instances instances, Random random) throws Exception { if (instances.classIndex() < 0) { throw new Exception("Class index is not set!"); } if (size() != instances.numClasses()) { throw new Exception("Cost matrix and instances have different class" + " size!"); } // If this cost matrix hasn't been normalized, apply a normalized // version instead. for (int i = 0; i < size(); i++) { if (!Utils.eq(m_Elements[i][i], 0)) { CostMatrix cm = new CostMatrix(this); cm.normalize(); return cm.applyCostMatrix(instances, random); } } // Determine the prior weights of all instances in each class double [] weightOfInstancesInClass = new double [size()]; for (int j = 0; j < instances.numInstances(); j++) { Instance current = instances.instance(j); weightOfInstancesInClass[(int)current.classValue()] += current.weight(); } double sumOfWeights = Utils.sum(weightOfInstancesInClass); double [] weightFactor = new double [size()]; double sumOfWeightFactors = 0; for (int i = 0; i < size(); i++) { // Using Kai Ming Ting's formula for deriving weights for // the classes and Breiman's heuristic for multiclass // problems. double sumOfMissClassWeights = 0; for (int j = 0; j < size(); j++) { if (Utils.sm(m_Elements[i][j], 0)) { throw new Exception("Neg. weights in misclassification "+ "cost matrix!"); } sumOfMissClassWeights += m_Elements[i][j]; } weightFactor[i] = sumOfMissClassWeights * sumOfWeights; sumOfWeightFactors += sumOfMissClassWeights * weightOfInstancesInClass[i]; } for (int i = 0; i < size(); i++) { weightFactor[i] /= sumOfWeightFactors; } // Store new weights double [] weightOfInstances = new double[instances.numInstances()]; for (int i = 0; i < instances.numInstances(); i++) { Instance current = instances.instance(i); weightOfInstances[i] = current.weight() * weightFactor[(int)current.classValue()]; } // Change instances weight or do resampling if (random != null) { return instances.resampleWithWeights(random, weightOfInstances); } else { instances = new Instances(instances); for (int i = 0; i < instances.numInstances(); i++) { instances.instance(i).setWeight(weightOfInstances[i]); } return instances; } } /** * Calculates the expected misclassification cost for each possible * class value, given class probability estimates. * * @param probabilities an array containing probability estimates for each * class value. * @return an array containing the expected misclassification cost for each * class. * @exception Exception if the number of probabilities does not match the * number of classes. */ public double [] expectedCosts(double [] probabilities) throws Exception { if (probabilities.length != size()) { throw new Exception("Number of classes in probability estimates does not" + " match size of cost matrix!"); } double [] costs = new double[size()]; for (int i = 0; i < size(); i++) { double expectedCost = 0; for (int j = 0; j < size(); j++) { expectedCost += m_Elements[j][i] * probabilities[j]; } costs[i] = expectedCost; } return costs; } /** * Gets the maximum misclassification cost possible for a given actual * class value * * @param actualClass the index of the actual class value * @return the highest cost possible for misclassifying this class */ public double getMaxCost(int actualClass) { return m_Elements[actualClass][Utils.maxIndex(m_Elements[actualClass])]; } /** * Tests out creation of a frequency dependent cost matrix from the command * line. Either pipe a set of instances into system.in or give the name of * a dataset as an argument. The last column will be treated as the class * attribute and a cost matrix with weight 1000 output. * * @param []args a value of type 'String' */ public static void main(String []args) { try { Reader r = null; if (args.length > 1) { throw (new Exception("Usage: Instances <filename>")); } else if (args.length == 0) { r = new BufferedReader(new InputStreamReader(System.in)); } else { r = new BufferedReader(new FileReader(args[0])); } Instances i = new Instances(r); i.setClassIndex(i.numAttributes() - 1); CostMatrix.makeFrequencyDependentMatrix(i, 1000) .write(new java.io.PrintWriter(System.out)); } catch (Exception ex) { System.err.println(ex); } }} // CostMatrix
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -