📄 olm.java
字号:
* Sets if the instances are to be sorted prior to building the rule bases. * * @param sort if <code> true </code> the instances will be sorted */ public void setSort(boolean sort) { m_sort = sort; } /** * Returns if the instances are sorted prior to building the rule bases. * * @return <code> true </code> if instances are sorted prior to building * the rule bases, <code> false </code> otherwise. */ public boolean getSort() { return m_sort; } /** * Returns the tip text for this property. * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String seedTipText() { return "Sets the seed that is used to randomize the instances prior " + "to building the rule bases"; } /** * Return the number of examples in the minimal rule base. * The minimal rule base is the one that corresponds to the * rule base of Ben-David. * * @return the number of examples in the minimal rule base */ public int getSizeRuleBaseMin() { return m_baseMin.numInstances(); } /** * Return the number of examples in the maximal rule base. * The maximal rule base is built using an algorithm * dual to that for building the minimal rule base. * * @return the number of examples in the maximal rule base */ public int getSizeRuleBaseMax() { return m_baseMax.numInstances(); } /** * Classifies a given instance according to the current settings * of the classifier. * * @param instance the instance to be classified * @return a <code> double </code> that represents the classification, * this could either be the internal value of a label, when rounding is * on, or a real number. */ public double classifyInstance(Instance instance) { double classValueMin = -1; double classValueMax = -1; double classValue; if (m_etype == ET_MIN || m_etype == ET_BOTH) { classValueMin = classifyInstanceMin(instance); } if (m_etype == ET_MAX || m_etype == ET_BOTH) { classValueMax = classifyInstanceMax(instance); } switch (m_etype) { case ET_MIN: classValue = classValueMin; break; case ET_MAX: classValue = classValueMax; break; case ET_BOTH: classValue = (classValueMin + classValueMax) / 2; break; default: throw new IllegalStateException("Illegal mode type!"); } // round if necessary and return return (m_ctype == CT_ROUNDED ? Utils.round(classValue) : classValue); } /** * Classify <code> instance </code> using the minimal rule base. * Rounding is never performed, this is the responsability * of <code> classifyInstance </code>. * * @param instance the instance to be classified * @return the classification according to the minimal rule base */ private double classifyInstanceMin(Instance instance) { double classValue = -1; if (m_baseMin == null) { throw new IllegalStateException ("Classifier has not yet been built"); } Iterator it = new EnumerationIterator(m_baseMin.enumerateInstances()); while (it.hasNext()) { Instance r = (Instance) it.next(); // we assume that rules are ordered in decreasing class value order // so that the first one that we encounter is immediately the // one with the biggest class value if (InstancesUtil.smallerOrEqual(r, instance)) { classValue = r.classValue(); break; } } // there is no smaller rule in the database if (classValue == -1) { if (m_dtype != DT_NONE) { Instance[] nn = nearestRules(instance, m_baseMin); classValue = 0; // XXX for the moment we only use the mean to extract a // classValue; other possibilities might be included later for (int i = 0; i < nn.length; i++) { classValue += nn[i].classValue(); } classValue /= nn.length; } else { classValue = 0; // minimal class value } } return classValue; // no rounding! } /** * Classify <code> instance </code> using the maximal rule base. * Rounding is never performed, this is the responsability * of <code> classifyInstance </code>. * * @param instance the instance to be classified * @return the classification according to the maximal rule base */ private double classifyInstanceMax(Instance instance) { double classValue = -1; if (m_baseMax == null) { throw new IllegalStateException ("Classifier has not yet been built"); } Iterator it = new EnumerationIterator(m_baseMax.enumerateInstances()); while (it.hasNext()) { Instance r = (Instance) it.next(); // we assume that rules are ordered in increasing class value order // so that the first bigger one we encounter will be the one // with the smallest label if (InstancesUtil.smallerOrEqual(instance, r)) { classValue = r.classValue(); break; } } // there is no bigger rule in the database if (classValue == -1) { if (m_dtype != DT_NONE) { // XXX see remark in classifyInstanceMin Instance[] nn = nearestRules(instance, m_baseMax); classValue = 0; for (int i = 0; i < nn.length; i++) { classValue += nn[i].classValue(); } classValue /= nn.length; } else { classValue = m_numClasses - 1; // maximal label } } return classValue; } /** * Find the instances in <code> base </code> that are, * according to the current distance type, closest * to <code> instance </code>. * * @param instance the instance around which one looks * @param base the instances to choose from * @return an array of <code> Instance </code> which contains * the instances closest to <code> instance </code> */ private Instance[] nearestRules(Instance instance, Instances base) { double min = Double.POSITIVE_INFINITY; double dist = 0; double[] instanceDouble = InstancesUtil.toDataDouble(instance); ArrayList nn = new ArrayList(); Iterator it = new EnumerationIterator(base.enumerateInstances()); while(it.hasNext()) { Instance r = (Instance) it.next(); double[] rDouble = InstancesUtil.toDataDouble(r); switch (m_dtype) { case DT_EUCLID: dist = euclidDistance(instanceDouble, rDouble); break; case DT_HAMMING: dist = hammingDistance(instanceDouble, rDouble); break; default: throw new IllegalArgumentException("distance type is not valid"); } if (dist < min) { min = dist; nn.clear(); nn.add(r); } else if (dist == min) { nn.add(r); } } nn.trimToSize(); return (Instance[]) nn.toArray(new Instance[0]); } /** * Build the OLM classifier, meaning that the rule bases * are built. * * @param instances the instances to use for building the rule base * @throws Exception if <code> instances </code> cannot be handled by * the classifier. */ public void buildClassifier(Instances instances) throws Exception { getCapabilities().testWithFail(instances); // copy the dataset m_train = new Instances(instances); m_numClasses = m_train.numClasses(); // new dataset in which examples with missing class value are removed m_train.deleteWithMissingClass(); // build the Map for the estimatedDistributions m_estimatedDistributions = new HashMap(m_train.numInstances() / 2); // cycle through all instances Iterator it = new EnumerationIterator(m_train.enumerateInstances()); while (it.hasNext() == true) { Instance instance = (Instance) it.next(); Coordinates c = new Coordinates(instance); // get DiscreteEstimator from the map DiscreteEstimator df = (DiscreteEstimator) m_estimatedDistributions.get(c); // if no DiscreteEstimator is present in the map, create one if (df == null) { df = new DiscreteEstimator(instances.numClasses(), 0); } df.addValue(instance.classValue(), instance.weight()); // update m_estimatedDistributions.put(c, df); // put back in map } // Create the attributes for m_baseMin and m_baseMax. // These are identical to those of m_train, except that the // class is set to 'numeric' // The class attribute is moved to the back FastVector newAtts = new FastVector(m_train.numAttributes()); Attribute classAttribute = null; for (int i = 0; i < m_train.numAttributes(); i++) { Attribute att = m_train.attribute(i); if (i != m_train.classIndex()) { newAtts.addElement(att.copy()); } else { classAttribute = new Attribute(att.name()); //numeric attribute } } newAtts.addElement(classAttribute); // original training instances are replaced by an empty set // of instances m_train = new Instances(m_train.relationName(), newAtts, m_estimatedDistributions.size()); m_train.setClassIndex(m_train.numAttributes() - 1); // We cycle through the map of estimatedDistributions and // create one Instance for each entry in the map, with // a class value that is calculated from the distribution of // the class values it = m_estimatedDistributions.keySet().iterator(); while(it.hasNext()) { // XXX attValues must be here, otherwise things go wrong double[] attValues = new double[m_train.numAttributes()]; Coordinates cc = (Coordinates) it.next(); DiscreteEstimator df = (DiscreteEstimator) m_estimatedDistributions.get(cc); cc.getValues(attValues); switch(m_atype) { case AT_MEAN: attValues[attValues.length - 1] = (new DiscreteDistribution(df)).mean(); break; case AT_MEDIAN: attValues[attValues.length - 1] = (new DiscreteDistribution(df)).median(); break; case AT_MAXPROB: attValues[attValues.length - 1] = (new DiscreteDistribution(df)).modes()[0]; break; default: throw new IllegalStateException("Not a valid averaging type"); } // add the instance, we give it weight one m_train.add(new Instance(1, attValues)); } if (m_Debug == true) { System.out.println("The dataset after phase 1 :"); System.out.println(m_train.toString()); } /* Shuffle to training instances, to prevent the order dictated * by the map to play an important role in how the rule bases * are built. */ m_train.randomize(new Random(getSeed())); if (m_sort == false) { // sort the instances only in increasing class order m_train.sort(m_train.classIndex()); } else { // sort instances completely Comparator[] cc = new Comparator[m_train.numAttributes()]; // sort the class, increasing cc[0] = new InstancesComparator(m_train.classIndex()); // sort the attributes, decreasing for (int i = 1; i < cc.length; i++) { cc[i] = new InstancesComparator(i - 1, true); } // copy instances into an array Instance[] tmp = new Instance[m_train.numInstances()]; for (int i = 0; i < tmp.length; i++) { tmp[i] = m_train.instance(i); } MultiDimensionalSort.multiDimensionalSort(tmp, cc); // copy sorted array back into Instances m_train.delete(); for (int i = 0; i < tmp.length; i++) { m_train.add(tmp[i]); } } // phase 2: building the rule bases themselves m_baseMin = new Instances(m_train, m_estimatedDistributions.size() / 4); phaseTwoMin(); m_baseMax = new Instances(m_train, m_estimatedDistributions.size() / 4); phaseTwoMax(); } /** * This implements the second phase of the OLM algorithm. * We build the rule base m_baseMin, according to the conflict * resolution mechanism described in the thesis. */ private void phaseTwoMin() { // loop through instances backwards, this is biggest class labels first for (int i = m_train.numInstances() - 1; i >=0; i--) { Instance e = m_train.instance(i); // if the example is redundant with m_base, we discard it if (isRedundant(e) == false) { // how many examples are redundant if we would add e int[] redundancies = makesRedundant(e); if (redundancies[0] == 1 && causesReversedPreference(e) == false) { // there is one example made redundant be e, and // adding e doesn't cause reversed preferences // so we replace the indicated rule by e m_baseMin.delete(redundancies[1]); m_baseMin.add(e); continue; } if (redundancies[0] == 0) { // adding e causes no redundancies, what about // reversed preferences ? int[] revPref = reversedPreferences(e); if (revPref[0] == 1) { // there would be one reversed preference, we // prefer the example e because it has a lower label m_baseMin.delete(revPref[1]); m_baseMin.add(e); continue; } if (revPref[0] == 0) { // this means: e causes no redundancies and no // reversed preferences. We can simply add it. m_baseMin.add(e); } } } } } /** * This implements the second phase of the OLM algorithm. * We build the rule base m_baseMax . */ private void phaseTwoMax() { // loop through instances, smallest class labels first for (int i = 0; i < m_train.numInstances(); i++) { Instance e = m_train.instance(i); // if the example is redundant with m_base, we discard it if (isRedundantMax(e) == false) { // how many examples are redundant if we would add e int[] redundancies = makesRedundantMax(e); if (redundancies[0] == 1 && causesReversedPreferenceMax(e) == false) { // there is one example made redundant be e, and // adding e doesn't cause reversed preferences // so we replace the indicated rule by e m_baseMax.delete(redundancies[1]); m_baseMax.add(e); continue; } if (redundancies[0] == 0) { // adding e causes no redundancies, what about // reversed preferences ? int[] revPref = reversedPreferencesMax(e); if (revPref[0] == 1) { // there would be one reversed preference, we // prefer the example e because it has a lower label m_baseMax.delete(revPref[1]); m_baseMax.add(e); continue; } if (revPref[0] == 0) { // this means: e causes no redundancies and no // reversed preferences. We can simply add it. m_baseMax.add(e); } } } } } /** * Returns a string description of the classifier. In debug * mode, the rule bases are added to the string representation * as well. This means that the description can become rather * lengthy. * * @return a <code> String </code> describing the classifier. */ public String toString() { StringBuffer sb = new StringBuffer(); sb.append("OLM\n===\n\n"); if (m_etype == ET_MIN || m_etype == ET_BOTH) { if (m_baseMin != null) { sb.append("Number of examples in the minimal rule base = " + m_baseMin.numInstances() + "\n"); } else { sb.append("minimal rule base not yet created"); }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -