📄 milr.java
字号:
int nI = m_Data[i][0].length; // numInstances in ith bag double denom=0.0; double[] numrt = new double[x.length]; for(int j=0; j<nI; j++){ // Compute exp(b0+b1*Xi1j+...)/[1+exp(b0+b1*Xi1j+...)] double exp=0.0; for(int k=m_Data[i].length-1; k>=0; k--) exp += m_Data[i][k][j]*x[k+1]; exp += x[0]; exp = Math.exp(exp); if(m_Classes[i]==1) denom += exp/(1.0+exp); else denom += 1.0/(1.0+exp); // Instance-wise update of dNLL/dBk for(int p=0; p<x.length; p++){ // pth variable double m = 1.0; if(p>0) m=m_Data[i][p-1][j]; numrt[p] += m*exp/((1.0+exp)*(1.0+exp)); } } // Bag-wise update of dNLL/dBk for(int q=0; q<grad.length; q++){ if(m_Classes[i]==1) grad[q] -= numrt[q]/denom; else grad[q] += numrt[q]/denom; } } break; case ALGORITHMTYPE_GEOMETRIC: for(int i=0; i<m_Classes.length; i++){ // ith bag int nI = m_Data[i][0].length; // numInstances in ith bag double bag = 0; double[] sumX = new double[x.length]; for(int j=0; j<nI; j++){ // Compute exp(b0+b1*Xi1j+...)/[1+exp(b0+b1*Xi1j+...)] double exp=0.0; for(int k=m_Data[i].length-1; k>=0; k--) exp += m_Data[i][k][j]*x[k+1]; exp += x[0]; if(m_Classes[i]==1){ bag -= exp/(double)nI; for(int q=0; q<grad.length; q++){ double m = 1.0; if(q>0) m=m_Data[i][q-1][j]; sumX[q] -= m/(double)nI; } } else{ bag += exp/(double)nI; for(int q=0; q<grad.length; q++){ double m = 1.0; if(q>0) m=m_Data[i][q-1][j]; sumX[q] += m/(double)nI; } } } for(int p=0; p<x.length; p++) grad[p] += Math.exp(bag)*sumX[p]/(1.0+Math.exp(bag)); } break; } // ridge: note that intercepts NOT included for(int r=1; r<x.length; r++){ grad[r] += 2.0*m_Ridge*x[r]; } return grad; } } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); // attributes result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.RELATIONAL_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); // class result.enable(Capability.BINARY_CLASS); result.enable(Capability.MISSING_CLASS_VALUES); // other result.enable(Capability.ONLY_MULTIINSTANCE); return result; } /** * Returns the capabilities of this multi-instance classifier for the * relational data. * * @return the capabilities of this object * @see Capabilities */ public Capabilities getMultiInstanceCapabilities() { Capabilities result = super.getCapabilities(); // attributes result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.NUMERIC_ATTRIBUTES); result.enable(Capability.DATE_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); // class result.disableAllClasses(); result.enable(Capability.NO_CLASS); return result; } /** * Builds the classifier * * @param train the training data to be used for generating the * boosted classifier. * @throws Exception if the classifier could not be built successfully */ public void buildClassifier(Instances train) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(train); // remove instances with missing class train = new Instances(train); train.deleteWithMissingClass(); m_NumClasses = train.numClasses(); int nR = train.attribute(1).relation().numAttributes(); int nC = train.numInstances(); m_Data = new double [nC][nR][]; // Data values m_Classes = new int [nC]; // Class values m_Attributes = train.attribute(1).relation(); xMean = new double [nR]; // Mean of mean xSD = new double [nR]; // Mode of stddev double sY1=0, sY0=0, totIns=0; // Number of classes int[] missingbags = new int[nR]; if (m_Debug) { System.out.println("Extracting data..."); } for(int h=0; h<m_Data.length; h++){ Instance current = train.instance(h); m_Classes[h] = (int)current.classValue(); // Class value starts from 0 Instances currInsts = current.relationalValue(1); int nI = currInsts.numInstances(); totIns += (double)nI; for (int i = 0; i < nR; i++) { // initialize m_data[][][] m_Data[h][i] = new double[nI]; double avg=0, std=0, num=0; for (int k=0; k<nI; k++){ if(!currInsts.instance(k).isMissing(i)){ m_Data[h][i][k] = currInsts.instance(k).value(i); avg += m_Data[h][i][k]; std += m_Data[h][i][k]*m_Data[h][i][k]; num++; } else m_Data[h][i][k] = Double.NaN; } if(num > 0){ xMean[i] += avg/num; xSD[i] += std/num; } else missingbags[i]++; } // Class count if (m_Classes[h] == 1) sY1++; else sY0++; } for (int j = 0; j < nR; j++) { xMean[j] = xMean[j]/(double)(nC-missingbags[j]); xSD[j] = Math.sqrt(Math.abs(xSD[j]/((double)(nC-missingbags[j])-1.0) -xMean[j]*xMean[j]*(double)(nC-missingbags[j])/ ((double)(nC-missingbags[j])-1.0))); } if (m_Debug) { // Output stats about input data System.out.println("Descriptives..."); System.out.println(sY0 + " bags have class 0 and " + sY1 + " bags have class 1"); System.out.println("\n Variable Avg SD "); for (int j = 0; j < nR; j++) System.out.println(Utils.doubleToString(j,8,4) + Utils.doubleToString(xMean[j], 10, 4) + Utils.doubleToString(xSD[j], 10,4)); } // Normalise input data and remove ignored attributes for (int i = 0; i < nC; i++) { for (int j = 0; j < nR; j++) { for(int k=0; k < m_Data[i][j].length; k++){ if(xSD[j] != 0){ if(!Double.isNaN(m_Data[i][j][k])) m_Data[i][j][k] = (m_Data[i][j][k] - xMean[j]) / xSD[j]; else m_Data[i][j][k] = 0; } } } } if (m_Debug) { System.out.println("\nIteration History..." ); } double x[] = new double[nR + 1]; x[0] = Math.log((sY1+1.0) / (sY0+1.0)); double[][] b = new double[2][x.length]; b[0][0] = Double.NaN; b[1][0] = Double.NaN; for (int q=1; q < x.length;q++){ x[q] = 0.0; b[0][q] = Double.NaN; b[1][q] = Double.NaN; } OptEng opt = new OptEng(m_AlgorithmType); opt.setDebug(m_Debug); m_Par = opt.findArgmin(x, b); while(m_Par==null){ m_Par = opt.getVarbValues(); if (m_Debug) System.out.println("200 iterations finished, not enough!"); m_Par = opt.findArgmin(m_Par, b); } if (m_Debug) System.out.println(" -------------<Converged>--------------"); // feature selection use if (m_AlgorithmType == ALGORITHMTYPE_ARITHMETIC) { double[] fs = new double[nR]; for(int k=1; k<nR+1; k++) fs[k-1] = Math.abs(m_Par[k]); int[] idx = Utils.sort(fs); double max = fs[idx[idx.length-1]]; for(int k=idx.length-1; k>=0; k--) System.out.println(m_Attributes.attribute(idx[k]).name()+"\t"+(fs[idx[k]]*100/max)); } // Convert coefficients back to non-normalized attribute units for(int j = 1; j < nR+1; j++) { if (xSD[j-1] != 0) { m_Par[j] /= xSD[j-1]; m_Par[0] -= m_Par[j] * xMean[j-1]; } } } /** * Computes the distribution for a given exemplar * * @param exmp the exemplar for which distribution is computed * @return the distribution * @throws Exception if the distribution can't be computed successfully */ public double[] distributionForInstance(Instance exmp) throws Exception { // Extract the data Instances ins = exmp.relationalValue(1); int nI = ins.numInstances(), nA = ins.numAttributes(); double[][] dat = new double [nI][nA+1]; for(int j=0; j<nI; j++){ dat[j][0]=1.0; int idx=1; for(int k=0; k<nA; k++){ if(!ins.instance(j).isMissing(k)) dat[j][idx] = ins.instance(j).value(k); else dat[j][idx] = xMean[idx-1]; idx++; } } // Compute the probability of the bag double [] distribution = new double[2]; switch (m_AlgorithmType) { case ALGORITHMTYPE_DEFAULT: distribution[0]=0.0; // Log-Prob. for class 0 for(int i=0; i<nI; i++){ double exp = 0.0; for(int r=0; r<m_Par.length; r++) exp += m_Par[r]*dat[i][r]; exp = Math.exp(exp); // Prob. updated for one instance distribution[0] -= Math.log(1.0+exp); } // Prob. for class 0 distribution[0] = Math.exp(distribution[0]); // Prob. for class 1 distribution[1] = 1.0 - distribution[0]; break; case ALGORITHMTYPE_ARITHMETIC: distribution[0]=0.0; // Prob. for class 0 for(int i=0; i<nI; i++){ double exp = 0.0; for(int r=0; r<m_Par.length; r++) exp += m_Par[r]*dat[i][r]; exp = Math.exp(exp); // Prob. updated for one instance distribution[0] += 1.0/(1.0+exp); } // Prob. for class 0 distribution[0] /= (double)nI; // Prob. for class 1 distribution[1] = 1.0 - distribution[0]; break; case ALGORITHMTYPE_GEOMETRIC: for(int i=0; i<nI; i++){ double exp = 0.0; for(int r=0; r<m_Par.length; r++) exp += m_Par[r]*dat[i][r]; distribution[1] += exp/(double)nI; } // Prob. for class 1 distribution[1] = 1.0/(1.0+Math.exp(-distribution[1])); // Prob. for class 0 distribution[0] = 1-distribution[1]; break; } return distribution; } /** * Gets a string describing the classifier. * * @return a string describing the classifer built. */ public String toString() { String result = "Modified Logistic Regression"; if (m_Par == null) { return result + ": No model built yet."; } result += "\nMean type: " + getAlgorithmType().getSelectedTag().getReadable() + "\n"; result += "\nCoefficients...\n" + "Variable Coeff.\n"; for (int j = 1, idx=0; j < m_Par.length; j++, idx++) { result += m_Attributes.attribute(idx).name(); result += " "+Utils.doubleToString(m_Par[j], 12, 4); result += "\n"; } result += "Intercept:"; result += " "+Utils.doubleToString(m_Par[0], 10, 4); result += "\n"; result += "\nOdds Ratios...\n" + "Variable O.R.\n"; for (int j = 1, idx=0; j < m_Par.length; j++, idx++) { result += " " + m_Attributes.attribute(idx).name(); double ORc = Math.exp(m_Par[j]); result += " " + ((ORc > 1e10) ? "" + ORc : Utils.doubleToString(ORc, 12, 4)); } result += "\n"; return result; } /** * Main method for testing this class. * * @param argv should contain the command line arguments to the * scheme (see Evaluation) */ public static void main(String[] argv) { runClassifier(new MILR(), argv); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -