📄 mdd.java
字号:
protected double[] evaluateGradient(double[] x){ double[] grad = new double[x.length]; for(int i=0; i<m_Classes.length; i++){ // ith bag int nI = m_Data[i][0].length; // numInstances in ith bag double denom=0.0; double[] numrt = new double[x.length]; for(int j=0; j<nI; j++){ double exp=0.0; for(int k=0; k<m_Data[i].length; k++) exp += (m_Data[i][k][j]-x[k*2])*(m_Data[i][k][j]-x[k*2])/ (x[k*2+1]*x[k*2+1]); exp = Math.exp(-exp); if(m_Classes[i]==1) denom += exp; else denom += (1.0-exp); // Instance-wise update for(int p=0; p<m_Data[i].length; p++){ // pth variable numrt[2*p] += exp*2.0*(x[2*p]-m_Data[i][p][j])/ (x[2*p+1]*x[2*p+1]); numrt[2*p+1] += exp*(x[2*p]-m_Data[i][p][j])*(x[2*p]-m_Data[i][p][j])/ (x[2*p+1]*x[2*p+1]*x[2*p+1]); } } if(denom <= m_Zero){ denom = m_Zero; } // Bag-wise update for(int q=0; q<m_Data[i].length; q++){ if(m_Classes[i]==1){ grad[2*q] += numrt[2*q]/denom; grad[2*q+1] -= numrt[2*q+1]/denom; }else{ grad[2*q] -= numrt[2*q]/denom; grad[2*q+1] += numrt[2*q+1]/denom; } } } return grad; } } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); // attributes result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.RELATIONAL_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); // class result.enable(Capability.BINARY_CLASS); result.enable(Capability.MISSING_CLASS_VALUES); // other result.enable(Capability.ONLY_MULTIINSTANCE); return result; } /** * Returns the capabilities of this multi-instance classifier for the * relational data. * * @return the capabilities of this object * @see Capabilities */ public Capabilities getMultiInstanceCapabilities() { Capabilities result = super.getCapabilities(); // attributes result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.NUMERIC_ATTRIBUTES); result.enable(Capability.DATE_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); // class result.disableAllClasses(); result.enable(Capability.NO_CLASS); return result; } /** * Builds the classifier * * @param train the training data to be used for generating the * boosted classifier. * @throws Exception if the classifier could not be built successfully */ public void buildClassifier(Instances train) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(train); // remove instances with missing class train = new Instances(train); train.deleteWithMissingClass(); m_ClassIndex = train.classIndex(); m_NumClasses = train.numClasses(); int nR = train.attribute(1).relation().numAttributes(); int nC = train.numInstances(); int [] bagSize=new int [nC]; Instances datasets= new Instances(train.attribute(1).relation(),0); m_Data = new double [nC][nR][]; // Data values m_Classes = new int [nC]; // Class values m_Attributes = datasets.stringFreeStructure(); double sY1=0, sY0=0; // Number of classes if (m_Debug) { System.out.println("Extracting data..."); } FastVector maxSzIdx=new FastVector(); int maxSz=0; for(int h=0; h<nC; h++){ Instance current = train.instance(h); m_Classes[h] = (int)current.classValue(); // Class value starts from 0 Instances currInsts = current.relationalValue(1); int nI = currInsts.numInstances(); bagSize[h]=nI; for (int i=0; i<nI;i++){ Instance inst=currInsts.instance(i); datasets.add(inst); } if(m_Classes[h]==1){ if(nI>maxSz){ maxSz=nI; maxSzIdx=new FastVector(1); maxSzIdx.addElement(new Integer(h)); } else if(nI == maxSz) maxSzIdx.addElement(new Integer(h)); } } /* filter the training data */ if (m_filterType == FILTER_STANDARDIZE) m_Filter = new Standardize(); else if (m_filterType == FILTER_NORMALIZE) m_Filter = new Normalize(); else m_Filter = null; if (m_Filter!=null) { m_Filter.setInputFormat(datasets); datasets = Filter.useFilter(datasets, m_Filter); } m_Missing.setInputFormat(datasets); datasets = Filter.useFilter(datasets, m_Missing); int instIndex=0; int start=0; for(int h=0; h<nC; h++) { for (int i = 0; i < datasets.numAttributes(); i++) { // initialize m_data[][][] m_Data[h][i] = new double[bagSize[h]]; instIndex=start; for (int k=0; k<bagSize[h]; k++){ m_Data[h][i][k]=datasets.instance(instIndex).value(i); instIndex ++; } } start=instIndex; // Class count if (m_Classes[h] == 1) sY1++; else sY0++; } if (m_Debug) { System.out.println("\nIteration History..." ); } double[] x = new double[nR*2], tmp = new double[x.length]; double[][] b = new double[2][x.length]; OptEng opt; double nll, bestnll = Double.MAX_VALUE; for (int t=0; t<x.length; t++){ b[0][t] = Double.NaN; b[1][t] = Double.NaN; } // Largest positive exemplar for(int s=0; s<maxSzIdx.size(); s++){ int exIdx = ((Integer)maxSzIdx.elementAt(s)).intValue(); for(int p=0; p<m_Data[exIdx][0].length; p++){ for (int q=0; q < nR;q++){ x[2*q] = m_Data[exIdx][q][p]; // pick one instance x[2*q+1] = 1.0; } opt = new OptEng(); tmp = opt.findArgmin(x, b); while(tmp==null){ tmp = opt.getVarbValues(); if (m_Debug) System.out.println("200 iterations finished, not enough!"); tmp = opt.findArgmin(tmp, b); } nll = opt.getMinFunction(); if(nll < bestnll){ bestnll = nll; m_Par = tmp; if (m_Debug) System.out.println("!!!!!!!!!!!!!!!!Smaller NLL found: "+nll); } if (m_Debug) System.out.println(exIdx+": -------------<Converged>--------------"); } } } /** * Computes the distribution for a given exemplar * * @param exmp the exemplar for which distribution is computed * @return the distribution * @throws Exception if the distribution can't be computed successfully */ public double[] distributionForInstance(Instance exmp) throws Exception { // Extract the data Instances ins = exmp.relationalValue(1); if(m_Filter!=null) ins = Filter.useFilter(ins, m_Filter); ins = Filter.useFilter(ins, m_Missing); int nI = ins.numInstances(), nA = ins.numAttributes(); double[][] dat = new double [nI][nA]; for(int j=0; j<nI; j++){ for(int k=0; k<nA; k++){ dat[j][k] = ins.instance(j).value(k); } } // Compute the probability of the bag double [] distribution = new double[2]; distribution[1]=0.0; // Prob. for class 1 for(int i=0; i<nI; i++){ double exp = 0.0; for(int r=0; r<nA; r++) exp += (m_Par[r*2]-dat[i][r])*(m_Par[r*2]-dat[i][r])/ ((m_Par[r*2+1])*(m_Par[r*2+1])); exp = Math.exp(-exp); // Prob. updated for one instance distribution[1] += exp/(double)nI; distribution[0] += (1.0-exp)/(double)nI; } return distribution; } /** * Gets a string describing the classifier. * * @return a string describing the classifer built. */ public String toString() { String result = "Modified Logistic Regression"; if (m_Par == null) { return result + ": No model built yet."; } result += "\nCoefficients...\n" + "Variable Coeff.\n"; for (int j = 0, idx=0; j < m_Par.length/2; j++, idx++) { result += m_Attributes.attribute(idx).name(); result += " "+Utils.doubleToString(m_Par[j*2], 12, 4); result += " "+Utils.doubleToString(m_Par[j*2+1], 12, 4)+"\n"; } return result; } /** * Main method for testing this class. * * @param argv should contain the command line arguments to the * scheme (see Evaluation) */ public static void main(String[] argv) { runClassifier(new MDD(), argv); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -