📄 miwrapper.java
字号:
} String weightString = Utils.getOption('A', options); if (weightString.length() != 0) { setWeightMethod( new SelectedTag( Integer.parseInt(weightString), MultiInstanceToPropositional.TAGS_WEIGHTMETHOD)); } else { setWeightMethod( new SelectedTag( MultiInstanceToPropositional.WEIGHTMETHOD_INVERSE2, MultiInstanceToPropositional.TAGS_WEIGHTMETHOD)); } super.setOptions(options); } /** * Gets the current settings of the Classifier. * * @return an array of strings suitable for passing to setOptions */ public String[] getOptions() { Vector result; String[] options; int i; result = new Vector(); options = super.getOptions(); for (i = 0; i < options.length; i++) result.add(options[i]); if (getDebug()) result.add("-D"); result.add("-P"); result.add("" + m_Method); result.add("-A"); result.add("" + m_WeightMethod); return (String[]) result.toArray(new String[result.size()]); } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String weightMethodTipText() { return "The method used for weighting the instances."; } /** * The new method for weighting the instances. * * @param method the new method */ public void setWeightMethod(SelectedTag method){ if (method.getTags() == MultiInstanceToPropositional.TAGS_WEIGHTMETHOD) m_WeightMethod = method.getSelectedTag().getID(); } /** * Returns the current weighting method for instances. * * @return the current weighting method */ public SelectedTag getWeightMethod(){ return new SelectedTag( m_WeightMethod, MultiInstanceToPropositional.TAGS_WEIGHTMETHOD); } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String methodTipText() { return "The method used for testing."; } /** * Set the method used in testing. * * @param method the index of method to use. */ public void setMethod(SelectedTag method) { if (method.getTags() == TAGS_TESTMETHOD) m_Method = method.getSelectedTag().getID(); } /** * Get the method used in testing. * * @return the index of method used in testing. */ public SelectedTag getMethod() { return new SelectedTag(m_Method, TAGS_TESTMETHOD); } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); // class result.disableAllClasses(); result.disableAllClassDependencies(); if (super.getCapabilities().handles(Capability.NOMINAL_CLASS)) result.enable(Capability.NOMINAL_CLASS); if (super.getCapabilities().handles(Capability.BINARY_CLASS)) result.enable(Capability.BINARY_CLASS); result.enable(Capability.RELATIONAL_ATTRIBUTES); result.enable(Capability.MISSING_CLASS_VALUES); // other result.enable(Capability.ONLY_MULTIINSTANCE); return result; } /** * Returns the capabilities of this multi-instance classifier for the * relational data. * * @return the capabilities of this object * @see Capabilities */ public Capabilities getMultiInstanceCapabilities() { Capabilities result = super.getCapabilities(); // class result.disableAllClasses(); result.enable(Capability.NO_CLASS); return result; } /** * Builds the classifier * * @param data the training data to be used for generating the * boosted classifier. * @throws Exception if the classifier could not be built successfully */ public void buildClassifier(Instances data) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(data); // remove instances with missing class Instances train = new Instances(data); train.deleteWithMissingClass(); if (m_Classifier == null) { throw new Exception("A base classifier has not been specified!"); } if (getDebug()) System.out.println("Start training ..."); m_NumClasses = train.numClasses(); //convert the training dataset into single-instance dataset m_ConvertToProp.setWeightMethod(getWeightMethod()); m_ConvertToProp.setInputFormat(train); train = Filter.useFilter(train, m_ConvertToProp); train.deleteAttributeAt(0); // remove the bag index attribute m_Classifier.buildClassifier(train); } /** * Computes the distribution for a given exemplar * * @param exmp the exemplar for which distribution is computed * @return the distribution * @throws Exception if the distribution can't be computed successfully */ public double[] distributionForInstance(Instance exmp) throws Exception { Instances testData = new Instances (exmp.dataset(),0); testData.add(exmp); // convert the training dataset into single-instance dataset m_ConvertToProp.setWeightMethod( new SelectedTag( MultiInstanceToPropositional.WEIGHTMETHOD_ORIGINAL, MultiInstanceToPropositional.TAGS_WEIGHTMETHOD)); testData = Filter.useFilter(testData, m_ConvertToProp); testData.deleteAttributeAt(0); //remove the bag index attribute // Compute the log-probability of the bag double [] distribution = new double[m_NumClasses]; double nI = (double)testData.numInstances(); double [] maxPr = new double [m_NumClasses]; for(int i=0; i<nI; i++){ double[] dist = m_Classifier.distributionForInstance(testData.instance(i)); for(int j=0; j<m_NumClasses; j++){ switch(m_Method){ case TESTMETHOD_ARITHMETIC: distribution[j] += dist[j]/nI; break; case TESTMETHOD_GEOMETRIC: // Avoid 0/1 probability if(dist[j]<0.001) dist[j] = 0.001; else if(dist[j]>0.999) dist[j] = 0.999; distribution[j] += Math.log(dist[j])/nI; break; case TESTMETHOD_MAXPROB: if (dist[j]>maxPr[j]) maxPr[j] = dist[j]; break; } } } if(m_Method == TESTMETHOD_GEOMETRIC) for(int j=0; j<m_NumClasses; j++) distribution[j] = Math.exp(distribution[j]); if(m_Method == TESTMETHOD_MAXPROB){ // for positive bag distribution[1] = maxPr[1]; distribution[0] = 1 - distribution[1]; } if (Utils.eq(Utils.sum(distribution), 0)) { for (int i = 0; i < distribution.length; i++) distribution[i] = 1.0 / (double) distribution.length; } else { Utils.normalize(distribution); } return distribution; } /** * Gets a string describing the classifier. * * @return a string describing the classifer built. */ public String toString() { return "MIWrapper with base classifier: \n"+m_Classifier.toString(); } /** * Main method for testing this class. * * @param argv should contain the command line arguments to the * scheme (see Evaluation) */ public static void main(String[] argv) { try { System.out.println(Evaluation.evaluateModel(new MIWrapper(), argv)); } catch (Exception e) { e.printStackTrace(); System.err.println(e.getMessage()); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -