📄 multiclassclassifier.java
字号:
int[] pair = new int[2]; pair[0] = i; pair[1] = j; pairs.addElement(pair); } } numClassifiers = pairs.size(); m_Classifiers = Classifier.makeCopies(m_Classifier, numClassifiers); m_ClassFilters = new Filter[numClassifiers]; // generate the classifiers for (int i=0; i<numClassifiers; i++) { RemoveWithValues classFilter = new RemoveWithValues(); classFilter.setAttributeIndex("" + (insts.classIndex() + 1)); classFilter.setModifyHeader(true); classFilter.setInvertSelection(true); classFilter.setNominalIndicesArr((int[])pairs.elementAt(i)); Instances tempInstances = new Instances(insts, 0); tempInstances.setClassIndex(-1); classFilter.setInputFormat(tempInstances); newInsts = Filter.useFilter(insts, classFilter); if (newInsts.numInstances() > 0) { newInsts.setClassIndex(insts.classIndex()); m_Classifiers[i].buildClassifier(newInsts); m_ClassFilters[i] = classFilter; } else { m_Classifiers[i] = null; m_ClassFilters[i] = null; } } // construct a two-class header version of the dataset m_TwoClassDataset = new Instances(insts, 0); int classIndex = m_TwoClassDataset.classIndex(); m_TwoClassDataset.setClassIndex(-1); m_TwoClassDataset.deleteAttributeAt(classIndex); FastVector classLabels = new FastVector(); classLabels.addElement("class0"); classLabels.addElement("class1"); m_TwoClassDataset.insertAttributeAt(new Attribute("class", classLabels), classIndex); m_TwoClassDataset.setClassIndex(classIndex); } else { // use error correcting code style methods Code code = null; switch (m_Method) { case METHOD_ERROR_EXHAUSTIVE: code = new ExhaustiveCode(numClassifiers); break; case METHOD_ERROR_RANDOM: code = new RandomCode(numClassifiers, (int)(numClassifiers * m_RandomWidthFactor), insts); break; case METHOD_1_AGAINST_ALL: code = new StandardCode(numClassifiers); break; default: throw new Exception("Unrecognized correction code type"); } numClassifiers = code.size(); m_Classifiers = Classifier.makeCopies(m_Classifier, numClassifiers); m_ClassFilters = new MakeIndicator[numClassifiers]; for (int i = 0; i < m_Classifiers.length; i++) { m_ClassFilters[i] = new MakeIndicator(); MakeIndicator classFilter = (MakeIndicator) m_ClassFilters[i]; classFilter.setAttributeIndex("" + (insts.classIndex() + 1)); classFilter.setValueIndices(code.getIndices(i)); classFilter.setNumeric(false); classFilter.setInputFormat(insts); newInsts = Filter.useFilter(insts, m_ClassFilters[i]); m_Classifiers[i].buildClassifier(newInsts); } } m_ClassAttribute = insts.classAttribute(); } /** * Returns the individual predictions of the base classifiers * for an instance. Used by StackedMultiClassClassifier. * Returns the probability for the second "class" predicted * by each base classifier. * * @param inst the instance to get the prediction for * @return the individual predictions * @throws Exception if the predictions can't be computed successfully */ public double[] individualPredictions(Instance inst) throws Exception { double[] result = null; if (m_Classifiers.length == 1) { result = new double[1]; result[0] = m_Classifiers[0].distributionForInstance(inst)[1]; } else { result = new double[m_ClassFilters.length]; for(int i = 0; i < m_ClassFilters.length; i++) { if (m_Classifiers[i] != null) { if (m_Method == METHOD_1_AGAINST_1) { Instance tempInst = (Instance)inst.copy(); tempInst.setDataset(m_TwoClassDataset); result[i] = m_Classifiers[i].distributionForInstance(tempInst)[1]; } else { m_ClassFilters[i].input(inst); m_ClassFilters[i].batchFinished(); result[i] = m_Classifiers[i]. distributionForInstance(m_ClassFilters[i].output())[1]; } } } } return result; } /** * Returns the distribution for an instance. * * @param inst the instance to get the distribution for * @return the distribution * @throws Exception if the distribution can't be computed successfully */ public double[] distributionForInstance(Instance inst) throws Exception { if (m_Classifiers.length == 1) { return m_Classifiers[0].distributionForInstance(inst); } double[] probs = new double[inst.numClasses()]; if (m_Method == METHOD_1_AGAINST_1) { for(int i = 0; i < m_ClassFilters.length; i++) { if (m_Classifiers[i] != null) { Instance tempInst = (Instance)inst.copy(); tempInst.setDataset(m_TwoClassDataset); double [] current = m_Classifiers[i].distributionForInstance(tempInst); Range range = new Range(((RemoveWithValues)m_ClassFilters[i]) .getNominalIndices()); range.setUpper(m_ClassAttribute.numValues()); int[] pair = range.getSelection(); if (current[0] > current[1]) probs[pair[0]] += 1.0; else probs[pair[1]] += 1.0; } } } else { // error correcting style methods for(int i = 0; i < m_ClassFilters.length; i++) { m_ClassFilters[i].input(inst); m_ClassFilters[i].batchFinished(); double [] current = m_Classifiers[i]. distributionForInstance(m_ClassFilters[i].output()); for (int j = 0; j < m_ClassAttribute.numValues(); j++) { if (((MakeIndicator)m_ClassFilters[i]).getValueRange().isInRange(j)) { probs[j] += current[1]; } else { probs[j] += current[0]; } } } } if (Utils.gr(Utils.sum(probs), 0)) { Utils.normalize(probs); return probs; } else { return m_ZeroR.distributionForInstance(inst); } } /** * Prints the classifiers. * * @return a string representation of the classifier */ public String toString() { if (m_Classifiers == null) { return "MultiClassClassifier: No model built yet."; } StringBuffer text = new StringBuffer(); text.append("MultiClassClassifier\n\n"); for (int i = 0; i < m_Classifiers.length; i++) { text.append("Classifier ").append(i + 1); if (m_Classifiers[i] != null) { if ((m_ClassFilters != null) && (m_ClassFilters[i] != null)) { if (m_ClassFilters[i] instanceof RemoveWithValues) { Range range = new Range(((RemoveWithValues)m_ClassFilters[i]) .getNominalIndices()); range.setUpper(m_ClassAttribute.numValues()); int[] pair = range.getSelection(); text.append(", " + (pair[0]+1) + " vs " + (pair[1]+1)); } else if (m_ClassFilters[i] instanceof MakeIndicator) { text.append(", using indicator values: "); text.append(((MakeIndicator)m_ClassFilters[i]).getValueRange()); } } text.append('\n'); text.append(m_Classifiers[i].toString() + "\n\n"); } else { text.append(" Skipped (no training examples)\n"); } } return text.toString(); } /** * Returns an enumeration describing the available options * * @return an enumeration of all the available options */ public Enumeration listOptions() { Vector vec = new Vector(3); vec.addElement(new Option( "\tSets the method to use. Valid values are 0 (1-against-all),\n" +"\t1 (random codes), 2 (exhaustive code), and 3 (1-against-1). (default 0)\n", "M", 1, "-M <num>")); vec.addElement(new Option( "\tSets the multiplier when using random codes. (default 2.0)", "R", 1, "-R <num>")); Enumeration enu = super.listOptions(); while (enu.hasMoreElements()) { vec.addElement(enu.nextElement()); } return vec.elements(); } /** * Parses a given list of options. <p/> * <!-- options-start --> * Valid options are: <p/> * * <pre> -M <num> * Sets the method to use. Valid values are 0 (1-against-all), * 1 (random codes), 2 (exhaustive code), and 3 (1-against-1). (default 0) * </pre> * * <pre> -R <num> * Sets the multiplier when using random codes. (default 2.0)</pre> * * <pre> -S <num> * Random number seed. * (default 1)</pre> * * <pre> -D * If set, classifier is run in debug mode and * may output additional info to the console</pre> * * <pre> -W * Full name of base classifier. * (default: weka.classifiers.functions.Logistic)</pre> * * <pre> * Options specific to classifier weka.classifiers.functions.Logistic: * </pre> * * <pre> -D * Turn on debugging output.</pre> * * <pre> -R <ridge> * Set the ridge in the log-likelihood.</pre> * * <pre> -M <number> * Set the maximum number of iterations (default -1, until convergence).</pre> * <!-- options-end --> * * @param options the list of options as an array of strings * @throws Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { String errorString = Utils.getOption('M', options); if (errorString.length() != 0) { setMethod(new SelectedTag(Integer.parseInt(errorString), TAGS_METHOD)); } else { setMethod(new SelectedTag(METHOD_1_AGAINST_ALL, TAGS_METHOD)); } String rfactorString = Utils.getOption('R', options); if (rfactorString.length() != 0) { setRandomWidthFactor((new Double(rfactorString)).doubleValue()); } else { setRandomWidthFactor(2.0); } super.setOptions(options); } /** * Gets the current settings of the Classifier. * * @return an array of strings suitable for passing to setOptions */ public String [] getOptions() { String [] superOptions = super.getOptions(); String [] options = new String [superOptions.length + 4]; int current = 0; options[current++] = "-M"; options[current++] = "" + m_Method; options[current++] = "-R"; options[current++] = "" + m_RandomWidthFactor; System.arraycopy(superOptions, 0, options, current, superOptions.length); current += superOptions.length; while (current < options.length) { options[current++] = ""; } return options; } /** * @return a description of the classifier suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "A metaclassifier for handling multi-class datasets with 2-class " + "classifiers. This classifier is also capable of " + "applying error correcting output codes for increased accuracy."; } /** * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String randomWidthFactorTipText() { return "Sets the width multiplier when using random codes. The number " + "of codes generated will be thus number multiplied by the number of " + "classes."; } /** * Gets the multiplier when generating random codes. Will generate * numClasses * m_RandomWidthFactor codes. * * @return the width multiplier */ public double getRandomWidthFactor() { return m_RandomWidthFactor; } /** * Sets the multiplier when generating random codes. Will generate * numClasses * m_RandomWidthFactor codes. * * @param newRandomWidthFactor the new width multiplier */ public void setRandomWidthFactor(double newRandomWidthFactor) { m_RandomWidthFactor = newRandomWidthFactor; } /** * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String methodTipText() { return "Sets the method to use for transforming the multi-class problem into " + "several 2-class ones."; } /** * Gets the method used. Will be one of METHOD_1_AGAINST_ALL, * METHOD_ERROR_RANDOM, METHOD_ERROR_EXHAUSTIVE, or METHOD_1_AGAINST_1. * * @return the current method. */ public SelectedTag getMethod() { return new SelectedTag(m_Method, TAGS_METHOD); } /** * Sets the method used. Will be one of METHOD_1_AGAINST_ALL, * METHOD_ERROR_RANDOM, METHOD_ERROR_EXHAUSTIVE, or METHOD_1_AGAINST_1. * * @param newMethod the new method. */ public void setMethod(SelectedTag newMethod) { if (newMethod.getTags() == TAGS_METHOD) { m_Method = newMethod.getSelectedTag().getID(); } } /** * Main method for testing this class. * * @param argv the options */ public static void main(String [] argv) { runClassifier(new MultiClassClassifier(), argv); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -