📄 misvm.java
字号:
* * @param value the kernel */ public void setKernel(Kernel value) { m_kernel = value; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String filterTypeTipText() { return "The filter type for transforming the training data."; } /** * Sets how the training data will be transformed. Should be one of * FILTER_NORMALIZE, FILTER_STANDARDIZE, FILTER_NONE. * * @param newType the new filtering mode */ public void setFilterType(SelectedTag newType) { if (newType.getTags() == TAGS_FILTER) { m_filterType = newType.getSelectedTag().getID(); } } /** * Gets how the training data will be transformed. Will be one of * FILTER_NORMALIZE, FILTER_STANDARDIZE, FILTER_NONE. * * @return the filtering mode */ public SelectedTag getFilterType() { return new SelectedTag(m_filterType, TAGS_FILTER); } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String cTipText() { return "The value for C."; } /** * Get the value of C. * * @return Value of C. */ public double getC() { return m_C; } /** * Set the value of C. * * @param v Value to assign to C. */ public void setC(double v) { m_C = v; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String maxIterationsTipText() { return "The maximum number of iterations to perform."; } /** * Gets the maximum number of iterations. * * @return the maximum number of iterations. */ public int getMaxIterations() { return m_MaxIterations; } /** * Sets the maximum number of iterations. * * @param value the maximum number of iterations. */ public void setMaxIterations(int value) { if (value < 1) System.out.println( "At least 1 iteration is necessary (provided: " + value + ")!"); else m_MaxIterations = value; } /** * adapted version of SMO */ private class SVM extends SMO { /** for serialization */ static final long serialVersionUID = -8325638229658828931L; /** * Constructor */ protected SVM (){ super(); } /** * Computes SVM output for given instance. * * @param index the instance for which output is to be computed * @param inst the instance * @return the output of the SVM for the given instance * @throws Exception in case of an error */ protected double output(int index, Instance inst) throws Exception { double output = 0; output = m_classifiers[0][1].SVMOutput(index, inst); return output; } } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); // attributes result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.RELATIONAL_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); // class result.disableAllClasses(); result.disableAllClassDependencies(); result.enable(Capability.BINARY_CLASS); result.enable(Capability.MISSING_CLASS_VALUES); // other result.enable(Capability.ONLY_MULTIINSTANCE); return result; } /** * Returns the capabilities of this multi-instance classifier for the * relational data. * * @return the capabilities of this object * @see Capabilities */ public Capabilities getMultiInstanceCapabilities() { SVM classifier; Capabilities result; classifier = null; result = null; try { classifier = new SVM(); classifier.setKernel(Kernel.makeCopy(getKernel())); result = classifier.getCapabilities(); result.setOwner(this); } catch (Exception e) { e.printStackTrace(); } // class result.disableAllClasses(); result.enable(Capability.NO_CLASS); return result; } /** * Builds the classifier * * @param train the training data to be used for generating the * boosted classifier. * @throws Exception if the classifier could not be built successfully */ public void buildClassifier(Instances train) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(train); // remove instances with missing class train = new Instances(train); train.deleteWithMissingClass(); int numBags = train.numInstances(); //number of bags int []bagSize= new int [numBags]; int classes [] = new int [numBags]; Vector instLabels = new Vector(); //store the class label assigned to each single instance Vector pre_instLabels=new Vector(); for(int h=0; h<numBags; h++) {//h_th bag classes[h] = (int) train.instance(h).classValue(); bagSize[h]=train.instance(h).relationalValue(1).numInstances(); for (int i=0; i<bagSize[h];i++) instLabels.addElement(new Double(classes[h])); } // convert the training dataset into single-instance dataset m_ConvertToProp.setWeightMethod( new SelectedTag( MultiInstanceToPropositional.WEIGHTMETHOD_1, MultiInstanceToPropositional.TAGS_WEIGHTMETHOD)); m_ConvertToProp.setInputFormat(train); train = Filter.useFilter( train, m_ConvertToProp); train.deleteAttributeAt(0); //remove the bagIndex attribute; if (m_filterType == FILTER_STANDARDIZE) m_Filter = new Standardize(); else if (m_filterType == FILTER_NORMALIZE) m_Filter = new Normalize(); else m_Filter = null; if (m_Filter!=null) { m_Filter.setInputFormat(train); train = Filter.useFilter(train, m_Filter); } if (m_Debug) { System.out.println("\nIteration History..." ); } if (getDebug()) System.out.println("\nstart building model ..."); int index; double sum, max_output; Vector max_index = new Vector(); Instance inst=null; int loopNum=0; do { loopNum++; index=-1; if (m_Debug) System.out.println("=====================loop: "+loopNum); //store the previous label information pre_instLabels=(Vector)instLabels.clone(); // set the proper SMO options in order to build a SVM model m_SVM = new SVM(); m_SVM.setC(getC()); m_SVM.setKernel(Kernel.makeCopy(getKernel())); // SVM model do not normalize / standardize the input dataset as the the dataset has already been processed m_SVM.setFilterType(new SelectedTag(FILTER_NONE, TAGS_FILTER)); m_SVM.buildClassifier(train); for(int h=0; h<numBags; h++) {//h_th bag if (classes[h]==1) { //positive bag if (m_Debug) System.out.println("--------------- "+h+" ----------------"); sum=0; //compute outputs f=(w,x)+b for all instance in positive bags for (int i=0; i<bagSize[h]; i++){ index ++; inst=train.instance(index); double output =m_SVM.output(-1, inst); //System.out.println(output); if (output<=0){ if (inst.classValue()==1.0) { train.instance(index).setClassValue(0.0); instLabels.set(index, new Double(0.0)); if (m_Debug) System.out.println( index+ "- changed to 0"); } } else { if (inst.classValue()==0.0) { train.instance(index).setClassValue(1.0); instLabels.set(index, new Double(1.0)); if (m_Debug) System.out.println(index+ "+ changed to 1"); } } sum += train.instance(index).classValue(); } /* if class value of all instances in a positive bag are changed to 0.0, find the instance with max SVMOutput value and assign the class value 1.0 to it. */ if (sum==0){ //find the instance with max SVMOutput value max_output=-Double.MAX_VALUE; max_index.clear(); for (int j=index-bagSize[h]+1; j<index+1; j++){ inst=train.instance(j); double output = m_SVM.output(-1, inst); if(max_output<output) { max_output=output; max_index.clear(); max_index.add(new Integer(j)); } else if(max_output==output) max_index.add(new Integer(j)); } //assign the class value 1.0 to the instances with max SVMOutput for (int vecIndex=0; vecIndex<max_index.size(); vecIndex ++) { Integer i =(Integer)max_index.get(vecIndex); train.instance(i.intValue()).setClassValue(1.0); instLabels.set(i.intValue(), new Double(1.0)); if (m_Debug) System.out.println("##change to 1 ###outpput: "+max_output+" max_index: "+i+ " bag: "+h); } } }else //negative bags index += bagSize[h]; } }while(!instLabels.equals(pre_instLabels) && loopNum < m_MaxIterations); if (getDebug()) System.out.println("finish building model."); } /** * Computes the distribution for a given exemplar * * @param exmp the exemplar for which distribution is computed * @return the distribution * @throws Exception if the distribution can't be computed successfully */ public double[] distributionForInstance(Instance exmp) throws Exception { double sum=0; double classValue; double[] distribution = new double[2]; Instances testData = new Instances(exmp.dataset(), 0); testData.add(exmp); // convert the training dataset into single-instance dataset testData = Filter.useFilter(testData, m_ConvertToProp); testData.deleteAttributeAt(0); //remove the bagIndex attribute if (m_Filter != null) testData = Filter.useFilter(testData, m_Filter); for(int j = 0; j < testData.numInstances(); j++){ Instance inst = testData.instance(j); double output = m_SVM.output(-1, inst); if (output <= 0) classValue = 0.0; else classValue = 1.0; sum += classValue; } if (sum == 0) distribution[0] = 1.0; else distribution[0] = 0.0; distribution [1] = 1.0 - distribution[0]; return distribution; } /** * Main method for testing this class. * * @param argv should contain the command line arguments to the * scheme (see Evaluation) */ public static void main(String[] argv) { runClassifier(new MISVM(), argv); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -