📄 multiclassclassifier.java
字号:
} else { // use error correcting code style methods
Code code = null;
switch (m_Method) {
case METHOD_ERROR_EXHAUSTIVE:
code = new ExhaustiveCode(numClassifiers);
break;
case METHOD_ERROR_RANDOM:
code = new RandomCode(numClassifiers,
(int)(numClassifiers * m_RandomWidthFactor),
insts);
break;
case METHOD_1_AGAINST_ALL:
code = new StandardCode(numClassifiers);
break;
default:
throw new Exception("Unrecognized correction code type");
}
numClassifiers = code.size();
m_Classifiers = Classifier.makeCopies(m_Classifier, numClassifiers);
m_ClassFilters = new MakeIndicator[numClassifiers];
AttributeStats classStats = insts.attributeStats(insts.classIndex());
for (int i = 0; i < m_Classifiers.length; i++) {
m_ClassFilters[i] = new MakeIndicator();
MakeIndicator classFilter = (MakeIndicator) m_ClassFilters[i];
classFilter.setAttributeIndex("" + (insts.classIndex() + 1));
classFilter.setValueIndices(code.getIndices(i));
classFilter.setNumeric(false);
classFilter.setInputFormat(insts);
newInsts = Filter.useFilter(insts, m_ClassFilters[i]);
m_Classifiers[i].buildClassifier(newInsts);
}
}
m_ClassAttribute = insts.classAttribute();
}
/**
* Returns the individual predictions of the base classifiers
* for an instance. Used by StackedMultiClassClassifier.
* Returns the probability for the second "class" predicted
* by each base classifier.
*
* @exception Exception if the predictions can't be computed successfully
*/
public double[] individualPredictions(Instance inst) throws Exception {
double[] result = null;
if (m_Classifiers.length == 1) {
result = new double[1];
result[0] = m_Classifiers[0].distributionForInstance(inst)[1];
} else {
result = new double[m_ClassFilters.length];
for(int i = 0; i < m_ClassFilters.length; i++) {
if (m_Classifiers[i] != null) {
if (m_Method == METHOD_1_AGAINST_1) {
Instance tempInst = (Instance)inst.copy();
tempInst.setDataset(m_TwoClassDataset);
result[i] = m_Classifiers[i].distributionForInstance(tempInst)[1];
} else {
m_ClassFilters[i].input(inst);
m_ClassFilters[i].batchFinished();
result[i] = m_Classifiers[i].
distributionForInstance(m_ClassFilters[i].output())[1];
}
}
}
}
return result;
}
/**
* Returns the distribution for an instance.
*
* @exception Exception if the distribution can't be computed successfully
*/
public double[] distributionForInstance(Instance inst) throws Exception {
if (m_Classifiers.length == 1) {
return m_Classifiers[0].distributionForInstance(inst);
}
double[] probs = new double[inst.numClasses()];
if (m_Method == METHOD_1_AGAINST_1) {
for(int i = 0; i < m_ClassFilters.length; i++) {
if (m_Classifiers[i] != null) {
Instance tempInst = (Instance)inst.copy();
tempInst.setDataset(m_TwoClassDataset);
double [] current = m_Classifiers[i].distributionForInstance(tempInst);
Range range = new Range(((RemoveWithValues)m_ClassFilters[i])
.getNominalIndices());
range.setUpper(m_ClassAttribute.numValues());
int[] pair = range.getSelection();
if (current[0] > current[1]) probs[pair[0]] += 1.0;
else probs[pair[1]] += 1.0;
}
}
} else {
// error correcting style methods
for(int i = 0; i < m_ClassFilters.length; i++) {
m_ClassFilters[i].input(inst);
m_ClassFilters[i].batchFinished();
double [] current = m_Classifiers[i].
distributionForInstance(m_ClassFilters[i].output());
for (int j = 0; j < m_ClassAttribute.numValues(); j++) {
if (((MakeIndicator)m_ClassFilters[i]).getValueRange().isInRange(j)) {
probs[j] += current[1];
} else {
probs[j] += current[0];
}
}
}
}
if (Utils.gr(Utils.sum(probs), 0)) {
Utils.normalize(probs);
return probs;
} else {
return m_ZeroR.distributionForInstance(inst);
}
}
/**
* Prints the classifiers.
*/
public String toString() {
if (m_Classifiers == null) {
return "MultiClassClassifier: No model built yet.";
}
StringBuffer text = new StringBuffer();
text.append("MultiClassClassifier\n\n");
for (int i = 0; i < m_Classifiers.length; i++) {
text.append("Classifier ").append(i + 1);
if (m_Classifiers[i] != null) {
if ((m_ClassFilters != null) && (m_ClassFilters[i] != null)) {
if (m_ClassFilters[i] instanceof RemoveWithValues) {
Range range = new Range(((RemoveWithValues)m_ClassFilters[i])
.getNominalIndices());
range.setUpper(m_ClassAttribute.numValues());
int[] pair = range.getSelection();
text.append(", " + (pair[0]+1) + " vs " + (pair[1]+1));
} else if (m_ClassFilters[i] instanceof MakeIndicator) {
text.append(", using indicator values: ");
text.append(((MakeIndicator)m_ClassFilters[i]).getValueRange());
}
}
text.append('\n');
text.append(m_Classifiers[i].toString() + "\n\n");
} else {
text.append(" Skipped (no training examples)\n");
}
}
return text.toString();
}
/**
* Returns an enumeration describing the available options
*
* @return an enumeration of all the available options
*/
public Enumeration listOptions() {
Vector vec = new Vector(3);
Object c;
vec.addElement(new Option(
"\tSets the method to use. Valid values are 0 (1-against-all),\n"
+"\t1 (random codes), 2 (exhaustive code), and 3 (1-against-1). (default 0)\n",
"M", 1, "-M <num>"));
vec.addElement(new Option(
"\tSets the multiplier when using random codes. (default 2.0)",
"R", 1, "-R <num>"));
Enumeration enu = super.listOptions();
while (enu.hasMoreElements()) {
vec.addElement(enu.nextElement());
}
return vec.elements();
}
/**
* Parses a given list of options. Valid options are:<p>
*
* -M num <br>
* Sets the method to use. Valid values are 0 (1-against-all),
* 1 (random codes), 2 (exhaustive code), and 3 (1-against-1). (default 0) <p>
*
* -R num <br>
* Sets the multiplier when using random codes. (default 2.0)<p>
*
* -W classname <br>
* Specify the full class name of a learner as the basis for
* the multiclassclassifier (required).<p>
*
* -S seed <br>
* Random number seed (default 1).<p>
*
* @param options the list of options as an array of strings
* @exception Exception if an option is not supported
*/
public void setOptions(String[] options) throws Exception {
String errorString = Utils.getOption('M', options);
if (errorString.length() != 0) {
setMethod(new SelectedTag(Integer.parseInt(errorString),
TAGS_METHOD));
} else {
setMethod(new SelectedTag(METHOD_1_AGAINST_ALL, TAGS_METHOD));
}
String rfactorString = Utils.getOption('R', options);
if (rfactorString.length() != 0) {
setRandomWidthFactor((new Double(rfactorString)).doubleValue());
} else {
setRandomWidthFactor(2.0);
}
super.setOptions(options);
}
/**
* Gets the current settings of the Classifier.
*
* @return an array of strings suitable for passing to setOptions
*/
public String [] getOptions() {
String [] superOptions = super.getOptions();
String [] options = new String [superOptions.length + 4];
int current = 0;
options[current++] = "-M";
options[current++] = "" + m_Method;
options[current++] = "-R";
options[current++] = "" + m_RandomWidthFactor;
System.arraycopy(superOptions, 0, options, current,
superOptions.length);
current += superOptions.length;
while (current < options.length) {
options[current++] = "";
}
return options;
}
/**
* @return a description of the classifier suitable for
* displaying in the explorer/experimenter gui
*/
public String globalInfo() {
return "A metaclassifier for handling multi-class datasets with 2-class "
+ "classifiers. This classifier is also capable of "
+ "applying error correcting output codes for increased accuracy.";
}
/**
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String randomWidthFactorTipText() {
return "Sets the width multiplier when using random codes. The number "
+ "of codes generated will be thus number multiplied by the number of "
+ "classes.";
}
/**
* Gets the multiplier when generating random codes. Will generate
* numClasses * m_RandomWidthFactor codes.
*
* @return the width multiplier
*/
public double getRandomWidthFactor() {
return m_RandomWidthFactor;
}
/**
* Sets the multiplier when generating random codes. Will generate
* numClasses * m_RandomWidthFactor codes.
*
* @param newRandomWidthFactor the new width multiplier
*/
public void setRandomWidthFactor(double newRandomWidthFactor) {
m_RandomWidthFactor = newRandomWidthFactor;
}
/**
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String methodTipText() {
return "Sets the method to use for transforming the multi-class problem into "
+ "several 2-class ones.";
}
/**
* Gets the method used. Will be one of METHOD_1_AGAINST_ALL,
* METHOD_ERROR_RANDOM, METHOD_ERROR_EXHAUSTIVE, or METHOD_1_AGAINST_1.
*
* @return the current method.
*/
public SelectedTag getMethod() {
return new SelectedTag(m_Method, TAGS_METHOD);
}
/**
* Sets the method used. Will be one of METHOD_1_AGAINST_ALL,
* METHOD_ERROR_RANDOM, METHOD_ERROR_EXHAUSTIVE, or METHOD_1_AGAINST_1.
*
* @param newMethod the new method.
*/
public void setMethod(SelectedTag newMethod) {
if (newMethod.getTags() == TAGS_METHOD) {
m_Method = newMethod.getSelectedTag().getID();
}
}
/**
* Main method for testing this class.
*
* @param argv the options
*/
public static void main(String [] argv) {
Classifier scheme;
try {
scheme = new MultiClassClassifier();
System.out.println(Evaluation.evaluateModel(scheme, argv));
} catch (Exception e) {
System.err.println(e.getMessage());
e.printStackTrace();
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -