⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 mdd.java

📁 代码是一个分类器的实现,其中使用了部分weka的源代码。可以将项目导入eclipse运行
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
    protected double[] evaluateGradient(double[] x){      double[] grad = new double[x.length];      for(int i=0; i<m_Classes.length; i++){ // ith bag        int nI = m_Data[i][0].length; // numInstances in ith bag         double denom=0.0;        double[] numrt = new double[x.length];        for(int j=0; j<nI; j++){          double exp=0.0;          for(int k=0; k<m_Data[i].length; k++)            exp += (m_Data[i][k][j]-x[k*2])*(m_Data[i][k][j]-x[k*2])/              (x[k*2+1]*x[k*2+1]);			          exp = Math.exp(-exp);          if(m_Classes[i]==1)            denom += exp;          else            denom += (1.0-exp);		             // Instance-wise update          for(int p=0; p<m_Data[i].length; p++){  // pth variable            numrt[2*p] += exp*2.0*(x[2*p]-m_Data[i][p][j])/              (x[2*p+1]*x[2*p+1]);            numrt[2*p+1] +=               exp*(x[2*p]-m_Data[i][p][j])*(x[2*p]-m_Data[i][p][j])/              (x[2*p+1]*x[2*p+1]*x[2*p+1]);          }			        }        if(denom <= m_Zero){          denom = m_Zero;        }        // Bag-wise update         for(int q=0; q<m_Data[i].length; q++){          if(m_Classes[i]==1){            grad[2*q] += numrt[2*q]/denom;            grad[2*q+1] -= numrt[2*q+1]/denom;          }else{            grad[2*q] -= numrt[2*q]/denom;            grad[2*q+1] += numrt[2*q+1]/denom;          }        }      }      return grad;    }  }  /**   * Returns default capabilities of the classifier.   *   * @return      the capabilities of this classifier   */  public Capabilities getCapabilities() {    Capabilities result = super.getCapabilities();    // attributes    result.enable(Capability.NOMINAL_ATTRIBUTES);    result.enable(Capability.RELATIONAL_ATTRIBUTES);    result.enable(Capability.MISSING_VALUES);    // class    result.enable(Capability.BINARY_CLASS);    result.enable(Capability.MISSING_CLASS_VALUES);        // other    result.enable(Capability.ONLY_MULTIINSTANCE);        return result;  }  /**   * Returns the capabilities of this multi-instance classifier for the   * relational data.   *   * @return            the capabilities of this object   * @see               Capabilities   */  public Capabilities getMultiInstanceCapabilities() {    Capabilities result = super.getCapabilities();        // attributes    result.enable(Capability.NOMINAL_ATTRIBUTES);    result.enable(Capability.NUMERIC_ATTRIBUTES);    result.enable(Capability.DATE_ATTRIBUTES);    result.enable(Capability.MISSING_VALUES);    // class    result.disableAllClasses();    result.enable(Capability.NO_CLASS);        return result;  }  /**   * Builds the classifier   *   * @param train the training data to be used for generating the   * boosted classifier.   * @throws Exception if the classifier could not be built successfully   */  public void buildClassifier(Instances train) throws Exception {    // can classifier handle the data?    getCapabilities().testWithFail(train);    // remove instances with missing class    train = new Instances(train);    train.deleteWithMissingClass();        m_ClassIndex = train.classIndex();    m_NumClasses = train.numClasses();    int nR = train.attribute(1).relation().numAttributes();    int nC = train.numInstances();    int [] bagSize=new int [nC];    Instances datasets= new Instances(train.attribute(1).relation(),0);    m_Data  = new double [nC][nR][];              // Data values    m_Classes  = new int [nC];                    // Class values    m_Attributes = datasets.stringFreeStructure();		    double sY1=0, sY0=0;                          // Number of classes    if (m_Debug) {      System.out.println("Extracting data...");    }    FastVector maxSzIdx=new FastVector();    int maxSz=0;    for(int h=0; h<nC; h++){      Instance current = train.instance(h);      m_Classes[h] = (int)current.classValue();  // Class value starts from 0      Instances currInsts = current.relationalValue(1);      int nI = currInsts.numInstances();      bagSize[h]=nI;      for (int i=0; i<nI;i++){        Instance inst=currInsts.instance(i);        datasets.add(inst);      }      if(m_Classes[h]==1){        if(nI>maxSz){          maxSz=nI;          maxSzIdx=new FastVector(1);          maxSzIdx.addElement(new Integer(h));        }        else if(nI == maxSz)          maxSzIdx.addElement(new Integer(h));      }    }    /* filter the training data */    if (m_filterType == FILTER_STANDARDIZE)        m_Filter = new Standardize();    else if (m_filterType == FILTER_NORMALIZE)      m_Filter = new Normalize();    else       m_Filter = null;     if (m_Filter!=null) {      m_Filter.setInputFormat(datasets);      datasets = Filter.useFilter(datasets, m_Filter); 	    }    m_Missing.setInputFormat(datasets);    datasets = Filter.useFilter(datasets, m_Missing);    int instIndex=0;    int start=0;	    for(int h=0; h<nC; h++)  {	      for (int i = 0; i < datasets.numAttributes(); i++) {        // initialize m_data[][][]        m_Data[h][i] = new double[bagSize[h]];        instIndex=start;        for (int k=0; k<bagSize[h]; k++){          m_Data[h][i][k]=datasets.instance(instIndex).value(i);          instIndex ++;        }      }      start=instIndex;      // Class count	      if (m_Classes[h] == 1)        sY1++;		      else        sY0++;    }    if (m_Debug) {      System.out.println("\nIteration History..." );    }    double[] x = new double[nR*2], tmp = new double[x.length];    double[][] b = new double[2][x.length];     OptEng opt;    double nll, bestnll = Double.MAX_VALUE;    for (int t=0; t<x.length; t++){      b[0][t] = Double.NaN;      b[1][t] = Double.NaN;     }    // Largest positive exemplar    for(int s=0; s<maxSzIdx.size(); s++){      int exIdx = ((Integer)maxSzIdx.elementAt(s)).intValue();      for(int p=0; p<m_Data[exIdx][0].length; p++){        for (int q=0; q < nR;q++){          x[2*q] = m_Data[exIdx][q][p];  // pick one instance          x[2*q+1] = 1.0;        }		        opt = new OptEng();	        tmp = opt.findArgmin(x, b);        while(tmp==null){          tmp = opt.getVarbValues();          if (m_Debug)            System.out.println("200 iterations finished, not enough!");          tmp = opt.findArgmin(tmp, b);        }        nll = opt.getMinFunction();        if(nll < bestnll){          bestnll = nll;          m_Par = tmp;          if (m_Debug)            System.out.println("!!!!!!!!!!!!!!!!Smaller NLL found: "+nll);        }        if (m_Debug)          System.out.println(exIdx+":  -------------<Converged>--------------");      }    }    }		  /**   * Computes the distribution for a given exemplar   *   * @param exmp the exemplar for which distribution is computed   * @return the distribution   * @throws Exception if the distribution can't be computed successfully   */  public double[] distributionForInstance(Instance exmp)     throws Exception {    // Extract the data    Instances ins = exmp.relationalValue(1);    if(m_Filter!=null)      ins = Filter.useFilter(ins, m_Filter);    ins = Filter.useFilter(ins, m_Missing);    int nI = ins.numInstances(), nA = ins.numAttributes();    double[][] dat = new double [nI][nA];    for(int j=0; j<nI; j++){      for(int k=0; k<nA; k++){         dat[j][k] = ins.instance(j).value(k);      }    }    // Compute the probability of the bag    double [] distribution = new double[2];    distribution[1]=0.0;  // Prob. for class 1    for(int i=0; i<nI; i++){      double exp = 0.0;      for(int r=0; r<nA; r++)        exp += (m_Par[r*2]-dat[i][r])*(m_Par[r*2]-dat[i][r])/          ((m_Par[r*2+1])*(m_Par[r*2+1]));      exp = Math.exp(-exp);      // Prob. updated for one instance      distribution[1] += exp/(double)nI;      distribution[0] += (1.0-exp)/(double)nI;    }    return distribution;  }  /**   * Gets a string describing the classifier.   *   * @return a string describing the classifer built.   */  public String toString() {    String result = "Modified Logistic Regression";    if (m_Par == null) {      return result + ": No model built yet.";    }    result += "\nCoefficients...\n"      + "Variable      Coeff.\n";    for (int j = 0, idx=0; j < m_Par.length/2; j++, idx++) {      result += m_Attributes.attribute(idx).name();      result += " "+Utils.doubleToString(m_Par[j*2], 12, 4);       result += " "+Utils.doubleToString(m_Par[j*2+1], 12, 4)+"\n";    }    return result;  }  /**   * Main method for testing this class.   *   * @param argv should contain the command line arguments to the   * scheme (see Evaluation)   */  public static void main(String[] argv) {    runClassifier(new MDD(), argv);  }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -