⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 miboost.java

📁 代码是一个分类器的实现,其中使用了部分weka的源代码。可以将项目导入eclipse运行
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
    extends Optimization {        private double[] weights, errs;    public void setWeights(double[] w){      weights = w;    }    public void setErrs(double[] e){      errs = e;    }    /**      * Evaluate objective function     * @param x the current values of variables     * @return the value of the objective function      * @throws Exception if result is NaN     */    protected double objectiveFunction(double[] x) throws Exception{      double obj=0;      for(int i=0; i<weights.length; i++){        obj += weights[i]*Math.exp(x[0]*(2.0*errs[i]-1.0));        if(Double.isNaN(obj))          throw new Exception("Objective function value is NaN!");      }      return obj;    }    /**      * Evaluate Jacobian vector     * @param x the current values of variables     * @return the gradient vector      * @throws Exception if gradient is NaN     */    protected double[] evaluateGradient(double[] x)  throws Exception{      double[] grad = new double[1];      for(int i=0; i<weights.length; i++){        grad[0] += weights[i]*(2.0*errs[i]-1.0)*Math.exp(x[0]*(2.0*errs[i]-1.0));        if(Double.isNaN(grad[0]))          throw new Exception("Gradient is NaN!");      }      return grad;    }  }  /**   * Returns default capabilities of the classifier.   *   * @return      the capabilities of this classifier   */  public Capabilities getCapabilities() {    Capabilities result = super.getCapabilities();    // attributes    result.enable(Capability.NOMINAL_ATTRIBUTES);    result.enable(Capability.RELATIONAL_ATTRIBUTES);    result.enable(Capability.MISSING_VALUES);    // class    result.disableAllClasses();    result.disableAllClassDependencies();    if (super.getCapabilities().handles(Capability.BINARY_CLASS))      result.enable(Capability.BINARY_CLASS);    result.enable(Capability.MISSING_CLASS_VALUES);        // other    result.enable(Capability.ONLY_MULTIINSTANCE);        return result;  }  /**   * Returns the capabilities of this multi-instance classifier for the   * relational data.   *   * @return            the capabilities of this object   * @see               Capabilities   */  public Capabilities getMultiInstanceCapabilities() {    Capabilities result = super.getCapabilities();        // class    result.disableAllClasses();    result.enable(Capability.NO_CLASS);        return result;  }  /**   * Builds the classifier   *   * @param exps the training data to be used for generating the   * boosted classifier.   * @throws Exception if the classifier could not be built successfully   */  public void buildClassifier(Instances exps) throws Exception {    // can classifier handle the data?    getCapabilities().testWithFail(exps);    // remove instances with missing class    Instances train = new Instances(exps);    train.deleteWithMissingClass();    m_NumClasses = train.numClasses();    m_NumIterations = m_MaxIterations;    if (m_Classifier == null)      throw new Exception("A base classifier has not been specified!");    if(!(m_Classifier instanceof WeightedInstancesHandler))      throw new Exception("Base classifier cannot handle weighted instances!");    m_Models = Classifier.makeCopies(m_Classifier, getMaxIterations());    if(m_Debug)      System.err.println("Base classifier: "+m_Classifier.getClass().getName());    m_Beta = new double[m_NumIterations];    /* modified by Lin Dong. (use MIToSingleInstance filter to convert the MI datasets) */    //Initialize the bags' weights    double N = (double)train.numInstances(), sumNi=0;    for(int i=0; i<N; i++)      sumNi += train.instance(i).relationalValue(1).numInstances();	    for(int i=0; i<N; i++){      train.instance(i).setWeight(sumNi/N);    }    //convert the training dataset into single-instance dataset    m_ConvertToSI.setInputFormat(train);    Instances data = Filter.useFilter( train, m_ConvertToSI);    data.deleteAttributeAt(0); //remove the bagIndex attribute;    // Assume the order of the instances are preserved in the Discretize filter    if(m_DiscretizeBin > 0){      m_Filter = new Discretize();      m_Filter.setInputFormat(new Instances(data, 0));      m_Filter.setBins(m_DiscretizeBin);      data = Filter.useFilter(data, m_Filter);    }    // Main algorithm    int dataIdx;iterations:    for(int m=0; m < m_MaxIterations; m++){      if(m_Debug)        System.err.println("\nIteration "+m);       // Build a model      m_Models[m].buildClassifier(data);      // Prediction of each bag      double[] err=new double[(int)N], weights=new double[(int)N];      boolean perfect = true, tooWrong=true;      dataIdx = 0;      for(int n=0; n<N; n++){        Instance exn = train.instance(n);        // Prediction of each instance and the predicted class distribution        // of the bag		        double nn = (double)exn.relationalValue(1).numInstances();        for(int p=0; p<nn; p++){          Instance testIns = data.instance(dataIdx++);			          if((int)m_Models[m].classifyInstance(testIns)               != (int)exn.classValue()) // Weighted instance-wise 0-1 errors            err[n] ++;		       		               }        weights[n] = exn.weight();        err[n] /= nn;        if(err[n] > 0.5)          perfect = false;        if(err[n] < 0.5)          tooWrong = false;      }      if(perfect || tooWrong){ // No or 100% classification error, cannot find beta        if (m == 0)          m_Beta[m] = 1.0;        else		              m_Beta[m] = 0;		        m_NumIterations = m+1;        if(m_Debug)  System.err.println("No errors");        break iterations;      }      double[] x = new double[1];      x[0] = 0;      double[][] b = new double[2][x.length];      b[0][0] = Double.NaN;      b[1][0] = Double.NaN;      OptEng opt = new OptEng();	      opt.setWeights(weights);      opt.setErrs(err);      //opt.setDebug(m_Debug);      if (m_Debug)        System.out.println("Start searching for c... ");      x = opt.findArgmin(x, b);      while(x==null){        x = opt.getVarbValues();        if (m_Debug)          System.out.println("200 iterations finished, not enough!");        x = opt.findArgmin(x, b);      }	      if (m_Debug)        System.out.println("Finished.");          m_Beta[m] = x[0];      if(m_Debug)        System.err.println("c = "+m_Beta[m]);      // Stop if error too small or error too big and ignore this model      if (Double.isInfinite(m_Beta[m])           || Utils.smOrEq(m_Beta[m], 0)         ) {        if (m == 0)          m_Beta[m] = 1.0;        else		              m_Beta[m] = 0;        m_NumIterations = m+1;        if(m_Debug)          System.err.println("Errors out of range!");        break iterations;         }      // Update weights of data and class label of wfData      dataIdx=0;      double totWeights=0;      for(int r=0; r<N; r++){		        Instance exr = train.instance(r);        exr.setWeight(weights[r]*Math.exp(m_Beta[m]*(2.0*err[r]-1.0)));        totWeights += exr.weight();      }      if(m_Debug)        System.err.println("Total weights = "+totWeights);      for(int r=0; r<N; r++){		        Instance exr = train.instance(r);        double num = (double)exr.relationalValue(1).numInstances();        exr.setWeight(sumNi*exr.weight()/totWeights);        //if(m_Debug)        //    System.err.print("\nExemplar "+r+"="+exr.weight()+": \t");        for(int s=0; s<num; s++){          Instance inss = data.instance(dataIdx);	          inss.setWeight(exr.weight()/num);		             //    if(m_Debug)          //  System.err.print("instance "+s+"="+inss.weight()+          //			 "|ew*iw*sumNi="+data.instance(dataIdx).weight()+"\t");          if(Double.isNaN(inss.weight()))            throw new Exception("instance "+s+" in bag "+r+" has weight NaN!");           dataIdx++;        }        //if(m_Debug)        //    System.err.println();      }	           }  }		  /**   * Computes the distribution for a given exemplar   *   * @param exmp the exemplar for which distribution is computed   * @return the classification   * @throws Exception if the distribution can't be computed successfully   */  public double[] distributionForInstance(Instance exmp)     throws Exception {     double[] rt = new double[m_NumClasses];    Instances insts = new Instances(exmp.dataset(), 0);    insts.add(exmp);    // convert the training dataset into single-instance dataset    insts = Filter.useFilter( insts, m_ConvertToSI);    insts.deleteAttributeAt(0); //remove the bagIndex attribute	    double n = insts.numInstances();    if(m_DiscretizeBin > 0)      insts = Filter.useFilter(insts, m_Filter);    for(int y=0; y<n; y++){      Instance ins = insts.instance(y);	      for(int x=0; x<m_NumIterations; x++){         rt[(int)m_Models[x].classifyInstance(ins)] += m_Beta[x]/n;      }    }    for(int i=0; i<rt.length; i++)      rt[i] = Math.exp(rt[i]);    Utils.normalize(rt);    return rt;  }  /**   * Gets a string describing the classifier.   *   * @return a string describing the classifer built.   */  public String toString() {    if (m_Models == null) {      return "No model built yet!";    }    StringBuffer text = new StringBuffer();    text.append("MIBoost: number of bins in discretization = "+m_DiscretizeBin+"\n");    if (m_NumIterations == 0) {      text.append("No model built yet.\n");    } else if (m_NumIterations == 1) {      text.append("No boosting possible, one classifier used: Weight = "           + Utils.roundDouble(m_Beta[0], 2)+"\n");      text.append("Base classifiers:\n"+m_Models[0].toString());    } else {      text.append("Base classifiers and their weights: \n");      for (int i = 0; i < m_NumIterations ; i++) {        text.append("\n\n"+i+": Weight = " + Utils.roundDouble(m_Beta[i], 2)            +"\nBase classifier:\n"+m_Models[i].toString() );      }    }    text.append("\n\nNumber of performed Iterations: "         + m_NumIterations + "\n");    return text.toString();  }  /**   * Main method for testing this class.   *   * @param argv should contain the command line arguments to the   * scheme (see Evaluation)   */  public static void main(String[] argv) {    runClassifier(new MIBoost(), argv);  }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -