📄 tld.java
字号:
setNumRuns(1); super.setOptions(options); } /** * Gets the current settings of the Classifier. * * @return an array of strings suitable for passing to setOptions */ public String[] getOptions() { Vector result; String[] options; int i; result = new Vector(); options = super.getOptions(); for (i = 0; i < options.length; i++) result.add(options[i]); if (getDebug()) result.add("-D"); if (getUsingCutOff()) result.add("-C"); result.add("-R"); result.add("" + getNumRuns()); return (String[]) result.toArray(new String[result.size()]); } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String numRunsTipText() { return "The number of runs to perform."; } /** * Sets the number of runs to perform. * * @param numRuns the number of runs to perform */ public void setNumRuns(int numRuns) { m_Run = numRuns; } /** * Returns the number of runs to perform. * * @return the number of runs to perform */ public int getNumRuns() { return m_Run; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String usingCutOffTipText() { return "Whether to use an empirical cutoff."; } /** * Sets whether to use an empirical cutoff. * * @param cutOff whether to use an empirical cutoff */ public void setUsingCutOff (boolean cutOff) { m_UseEmpiricalCutOff = cutOff; } /** * Returns whether an empirical cutoff is used * * @return true if an empirical cutoff is used */ public boolean getUsingCutOff() { return m_UseEmpiricalCutOff; } /** * Main method for testing. * * @param args the options for the classifier */ public static void main(String[] args) { runClassifier(new TLD(), args); }}class TLD_Optm extends Optimization{ private double[] num; private double[] sSq; private double[] xBar; public void setNum(double[] n) {num = n;} public void setSSquare(double[] s){sSq = s;} public void setXBar(double[] x){xBar = x;} /** * Compute Ln[Gamma(b+0.5)] - Ln[Gamma(b)] * * @param b the value in the above formula * @return the result */ public static double diffLnGamma(double b){ double[] coef= {76.18009172947146, -86.50532032941677, 24.01409824083091, -1.231739572450155, 0.1208650973866179e-2, -0.5395239384953e-5}; double rt = -0.5; rt += (b+1.0)*Math.log(b+6.0) - (b+0.5)*Math.log(b+5.5); double series1=1.000000000190015, series2=1.000000000190015; for(int i=0; i<6; i++){ series1 += coef[i]/(b+1.5+(double)i); series2 += coef[i]/(b+1.0+(double)i); } rt += Math.log(series1*b)-Math.log(series2*(b+0.5)); return rt; } /** * Compute dLn[Gamma(x+0.5)]/dx - dLn[Gamma(x)]/dx * * @param x the value in the above formula * @return the result */ protected double diffFstDervLnGamma(double x){ double rt=0, series=1.0;// Just make it >0 for(int i=0;series>=m_Zero*1e-3;i++){ series = 0.5/((x+(double)i)*(x+(double)i+0.5)); rt += series; } return rt; } /** * Compute {Ln[Gamma(x+0.5)]}'' - {Ln[Gamma(x)]}'' * * @param x the value in the above formula * @return the result */ protected double diffSndDervLnGamma(double x){ double rt=0, series=1.0;// Just make it >0 for(int i=0;series>=m_Zero*1e-3;i++){ series = (x+(double)i+0.25)/ ((x+(double)i)*(x+(double)i)*(x+(double)i+0.5)*(x+(double)i+0.5)); rt -= series; } return rt; } /** * Implement this procedure to evaluate objective * function to be minimized */ protected double objectiveFunction(double[] x){ int numExs = num.length; double NLL = 0; // Negative Log-Likelihood double a=x[0], b=x[1], w=x[2], m=x[3]; for(int j=0; j < numExs; j++){ if(Double.isNaN(xBar[j])) continue; // All missing values NLL += 0.5*(b+num[j])* Math.log((1.0+num[j]*w)*(a+sSq[j]) + num[j]*(xBar[j]-m)*(xBar[j]-m)); if(Double.isNaN(NLL) && m_Debug){ System.err.println("???????????1: "+a+" "+b+" "+w+" "+m +"|x-: "+xBar[j] + "|n: "+num[j] + "|S^2: "+sSq[j]); System.exit(1); } // Doesn't affect optimization //NLL += 0.5*num[j]*Math.log(Math.PI); NLL -= 0.5*(b+num[j]-1.0)*Math.log(1.0+num[j]*w); if(Double.isNaN(NLL) && m_Debug){ System.err.println("???????????2: "+a+" "+b+" "+w+" "+m +"|x-: "+xBar[j] + "|n: "+num[j] + "|S^2: "+sSq[j]); System.exit(1); } int halfNum = ((int)num[j])/2; for(int z=1; z<=halfNum; z++) NLL -= Math.log(0.5*b+0.5*num[j]-(double)z); if(0.5*num[j] > halfNum) // num[j] is odd NLL -= diffLnGamma(0.5*b); if(Double.isNaN(NLL) && m_Debug){ System.err.println("???????????3: "+a+" "+b+" "+w+" "+m +"|x-: "+xBar[j] + "|n: "+num[j] + "|S^2: "+sSq[j]); System.exit(1); } NLL -= 0.5*Math.log(a)*b; if(Double.isNaN(NLL) && m_Debug){ System.err.println("???????????4:"+a+" "+b+" "+w+" "+m); System.exit(1); } } if(m_Debug) System.err.println("?????????????5: "+NLL); if(Double.isNaN(NLL)) System.exit(1); return NLL; } /** * Subclass should implement this procedure to evaluate gradient * of the objective function */ protected double[] evaluateGradient(double[] x){ double[] g = new double[x.length]; int numExs = num.length; double a=x[0],b=x[1],w=x[2],m=x[3]; double da=0.0, db=0.0, dw=0.0, dm=0.0; for(int j=0; j < numExs; j++){ if(Double.isNaN(xBar[j])) continue; // All missing values double denorm = (1.0+num[j]*w)*(a+sSq[j]) + num[j]*(xBar[j]-m)*(xBar[j]-m); da += 0.5*(b+num[j])*(1.0+num[j]*w)/denorm-0.5*b/a; db += 0.5*Math.log(denorm) - 0.5*Math.log(1.0+num[j]*w) - 0.5*Math.log(a); int halfNum = ((int)num[j])/2; for(int z=1; z<=halfNum; z++) db -= 1.0/(b+num[j]-2.0*(double)z); if(num[j]/2.0 > halfNum) // num[j] is odd db -= 0.5*diffFstDervLnGamma(0.5*b); dw += 0.5*(b+num[j])*(a+sSq[j])*num[j]/denorm - 0.5*(b+num[j]-1.0)*num[j]/(1.0+num[j]*w); dm += num[j]*(b+num[j])*(m-xBar[j])/denorm; } g[0] = da; g[1] = db; g[2] = dw; g[3] = dm; return g; } /** * Subclass should implement this procedure to evaluate second-order * gradient of the objective function */ protected double[] evaluateHessian(double[] x, int index){ double[] h = new double[x.length]; // # of exemplars, # of dimensions // which dimension and which variable for 'index' int numExs = num.length; double a,b,w,m; // Take the 2nd-order derivative switch(index){ case 0: // a a=x[0];b=x[1];w=x[2];m=x[3]; for(int j=0; j < numExs; j++){ if(Double.isNaN(xBar[j])) continue; //All missing values double denorm = (1.0+num[j]*w)*(a+sSq[j]) + num[j]*(xBar[j]-m)*(xBar[j]-m); h[0] += 0.5*b/(a*a) - 0.5*(b+num[j])*(1.0+num[j]*w)*(1.0+num[j]*w) /(denorm*denorm); h[1] += 0.5*(1.0+num[j]*w)/denorm - 0.5/a; h[2] += 0.5*num[j]*num[j]*(b+num[j])* (xBar[j]-m)*(xBar[j]-m)/(denorm*denorm); h[3] -= num[j]*(b+num[j])*(m-xBar[j]) *(1.0+num[j]*w)/(denorm*denorm); } break; case 1: // b a=x[0];b=x[1];w=x[2];m=x[3]; for(int j=0; j < numExs; j++){ if(Double.isNaN(xBar[j])) continue; //All missing values double denorm = (1.0+num[j]*w)*(a+sSq[j]) + num[j]*(xBar[j]-m)*(xBar[j]-m); h[0] += 0.5*(1.0+num[j]*w)/denorm - 0.5/a; int halfNum = ((int)num[j])/2; for(int z=1; z<=halfNum; z++) h[1] += 1.0/((b+num[j]-2.0*(double)z)*(b+num[j]-2.0*(double)z)); if(num[j]/2.0 > halfNum) // num[j] is odd h[1] -= 0.25*diffSndDervLnGamma(0.5*b); h[2] += 0.5*(a+sSq[j])*num[j]/denorm - 0.5*num[j]/(1.0+num[j]*w); h[3] += num[j]*(m-xBar[j])/denorm; } break; case 2: // w a=x[0];b=x[1];w=x[2];m=x[3]; for(int j=0; j < numExs; j++){ if(Double.isNaN(xBar[j])) continue; //All missing values double denorm = (1.0+num[j]*w)*(a+sSq[j]) + num[j]*(xBar[j]-m)*(xBar[j]-m); h[0] += 0.5*num[j]*num[j]*(b+num[j])* (xBar[j]-m)*(xBar[j]-m)/(denorm*denorm); h[1] += 0.5*(a+sSq[j])*num[j]/denorm - 0.5*num[j]/(1.0+num[j]*w); h[2] += 0.5*(b+num[j]-1.0)*num[j]*num[j]/ ((1.0+num[j]*w)*(1.0+num[j]*w)) - 0.5*(b+num[j])*(a+sSq[j])*(a+sSq[j])* num[j]*num[j]/(denorm*denorm); h[3] -= num[j]*num[j]*(b+num[j])* (m-xBar[j])*(a+sSq[j])/(denorm*denorm); } break; case 3: // m a=x[0];b=x[1];w=x[2];m=x[3]; for(int j=0; j < numExs; j++){ if(Double.isNaN(xBar[j])) continue; //All missing values double denorm = (1.0+num[j]*w)*(a+sSq[j]) + num[j]*(xBar[j]-m)*(xBar[j]-m); h[0] -= num[j]*(b+num[j])*(m-xBar[j]) *(1.0+num[j]*w)/(denorm*denorm); h[1] += num[j]*(m-xBar[j])/denorm; h[2] -= num[j]*num[j]*(b+num[j])* (m-xBar[j])*(a+sSq[j])/(denorm*denorm); h[3] += num[j]*(b+num[j])* ((1.0+num[j]*w)*(a+sSq[j])- num[j]*(m-xBar[j])*(m-xBar[j])) /(denorm*denorm); } } return h; }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -