📄 costsensitiveclassifier.java
字号:
public String globalInfo() { return "A metaclassifier that makes its base classifier cost-sensitive. " + "Two methods can be used to introduce cost-sensitivity: reweighting " + "training instances according to the total cost assigned to each " + "class; or predicting the class with minimum expected " + "misclassification cost (rather than the most likely class). " + "Performance can often be " + "improved by using a Bagged classifier to improve the probability " + "estimates of the base classifier."; } /** * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String costMatrixSourceTipText() { return "Sets where to get the cost matrix. The two options are" + "to use the supplied explicit cost matrix (the setting of the " + "costMatrix property), or to load a cost matrix from a file when " + "required (this file will be loaded from the directory set by the " + "onDemandDirectory property and will be named relation_name" + CostMatrix.FILE_EXTENSION + ")."; } /** * Gets the source location method of the cost matrix. Will be one of * MATRIX_ON_DEMAND or MATRIX_SUPPLIED. * * @return the cost matrix source. */ public SelectedTag getCostMatrixSource() { return new SelectedTag(m_MatrixSource, TAGS_MATRIX_SOURCE); } /** * Sets the source location of the cost matrix. Values other than * MATRIX_ON_DEMAND or MATRIX_SUPPLIED will be ignored. * * @param newMethod the cost matrix location method. */ public void setCostMatrixSource(SelectedTag newMethod) { if (newMethod.getTags() == TAGS_MATRIX_SOURCE) { m_MatrixSource = newMethod.getSelectedTag().getID(); } } /** * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String onDemandDirectoryTipText() { return "Sets the directory where cost files are loaded from. This option " + "is used when the costMatrixSource is set to \"On Demand\"."; } /** * Returns the directory that will be searched for cost files when * loading on demand. * * @return The cost file search directory. */ public File getOnDemandDirectory() { return m_OnDemandDirectory; } /** * Sets the directory that will be searched for cost files when * loading on demand. * * @param newDir The cost file search directory. */ public void setOnDemandDirectory(File newDir) { if (newDir.isDirectory()) { m_OnDemandDirectory = newDir; } else { m_OnDemandDirectory = new File(newDir.getParent()); } m_MatrixSource = MATRIX_ON_DEMAND; } /** * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String minimizeExpectedCostTipText() { return "Sets whether the minimum expected cost criteria will be used. If " + "this is false, the training data will be reweighted according to the " + "costs assigned to each class. If true, the minimum expected cost " + "criteria will be used."; } /** * Gets the value of MinimizeExpectedCost. * * @return Value of MinimizeExpectedCost. */ public boolean getMinimizeExpectedCost() { return m_MinimizeExpectedCost; } /** * Set the value of MinimizeExpectedCost. * * @param newMinimizeExpectedCost Value to assign to MinimizeExpectedCost. */ public void setMinimizeExpectedCost(boolean newMinimizeExpectedCost) { m_MinimizeExpectedCost = newMinimizeExpectedCost; } /** * Gets the classifier specification string, which contains the class name of * the classifier and any options to the classifier * * @return the classifier string. */ protected String getClassifierSpec() { Classifier c = getClassifier(); if (c instanceof OptionHandler) { return c.getClass().getName() + " " + Utils.joinOptions(((OptionHandler)c).getOptions()); } return c.getClass().getName(); } /** * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String costMatrixTipText() { return "Sets the cost matrix explicitly. This matrix is used if the " + "costMatrixSource property is set to \"Supplied\"."; } /** * Gets the misclassification cost matrix. * * @return the cost matrix */ public CostMatrix getCostMatrix() { return m_CostMatrix; } /** * Sets the misclassification cost matrix. * * @param newCostMatrix the cost matrix */ public void setCostMatrix(CostMatrix newCostMatrix) { m_CostMatrix = newCostMatrix; m_MatrixSource = MATRIX_SUPPLIED; } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); // class result.disableAllClasses(); result.disableAllClassDependencies(); result.enable(Capability.NOMINAL_CLASS); return result; } /** * Builds the model of the base learner. * * @param data the training data * @throws Exception if the classifier could not be built successfully */ public void buildClassifier(Instances data) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(data); // remove instances with missing class data = new Instances(data); data.deleteWithMissingClass(); if (m_Classifier == null) { throw new Exception("No base classifier has been set!"); } if (m_MatrixSource == MATRIX_ON_DEMAND) { String costName = data.relationName() + CostMatrix.FILE_EXTENSION; File costFile = new File(getOnDemandDirectory(), costName); if (!costFile.exists()) { throw new Exception("On-demand cost file doesn't exist: " + costFile); } setCostMatrix(new CostMatrix(new BufferedReader( new FileReader(costFile)))); } else if (m_CostMatrix == null) { // try loading an old format cost file m_CostMatrix = new CostMatrix(data.numClasses()); m_CostMatrix.readOldFormat(new BufferedReader( new FileReader(m_CostFile))); } if (!m_MinimizeExpectedCost) { Random random = null; if (!(m_Classifier instanceof WeightedInstancesHandler)) { random = new Random(m_Seed); } data = m_CostMatrix.applyCostMatrix(data, random); } m_Classifier.buildClassifier(data); } /** * Returns class probabilities. When minimum expected cost approach is chosen, * returns probability one for class with the minimum expected misclassification * cost. Otherwise it returns the probability distribution returned by * the base classifier. * * @param instance the instance to be classified * @return the computed distribution for the given instance * @throws Exception if instance could not be classified * successfully */ public double[] distributionForInstance(Instance instance) throws Exception { if (!m_MinimizeExpectedCost) { return m_Classifier.distributionForInstance(instance); } double [] pred = m_Classifier.distributionForInstance(instance); double [] costs = m_CostMatrix.expectedCosts(pred, instance); /* for (int i = 0; i < pred.length; i++) { System.out.print(pred[i] + " "); } System.out.println(); for (int i = 0; i < costs.length; i++) { System.out.print(costs[i] + " "); } System.out.println("\n"); */ // This is probably not ideal int classIndex = Utils.minIndex(costs); for (int i = 0; i < pred.length; i++) { if (i == classIndex) { pred[i] = 1.0; } else { pred[i] = 0.0; } } return pred; } /** * Returns the type of graph this classifier * represents. * * @return the type of graph this classifier represents */ public int graphType() { if (m_Classifier instanceof Drawable) return ((Drawable)m_Classifier).graphType(); else return Drawable.NOT_DRAWABLE; } /** * Returns graph describing the classifier (if possible). * * @return the graph of the classifier in dotty format * @throws Exception if the classifier cannot be graphed */ public String graph() throws Exception { if (m_Classifier instanceof Drawable) return ((Drawable)m_Classifier).graph(); else throw new Exception("Classifier: " + getClassifierSpec() + " cannot be graphed"); } /** * Output a representation of this classifier * * @return a string representation of the classifier */ public String toString() { if (m_Classifier == null) { return "CostSensitiveClassifier: No model built yet."; } String result = "CostSensitiveClassifier using "; if (m_MinimizeExpectedCost) { result += "minimized expected misclasification cost\n"; } else { result += "reweighted training instances\n"; } result += "\n" + getClassifierSpec() + "\n\nClassifier Model\n" + m_Classifier.toString() + "\n\nCost Matrix\n" + m_CostMatrix.toString(); return result; } /** * Main method for testing this class. * * @param argv should contain the following arguments: * -t training file [-T test file] [-c class index] */ public static void main(String [] argv) { runClassifier(new CostSensitiveClassifier(), argv); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -