📄 lbfgs.cpp
字号:
infoc = 1; /* CHECK THE INPUT PARAMETERS FOR ERRORS. */ if (*n <= 0 || *stp <= 0) { return 0; } /* COMPUTE THE INITIAL GRADIENT IN THE SEARCH DIRECTION */ /* AND CHECK THAT S IS A DESCENT DIRECTION. */ dginit = 0; i__1 = *n; for (j = 1; j <= i__1; ++j) { dginit += g[j] * s[j]; /* L10: */ } if (dginit >= 0) { return 0; } /* INITIALIZE LOCAL VARIABLES. */ brackt = 0; stage1 = 1; *nfev = 0; finit = *f; dgtest = FTOL * dginit; width = STPMAX - STPMIN; width1 = width / P5; i__1 = *n; for (j = 1; j <= i__1; ++j) { wa[j] = x[j]; /* L20: */ } /* THE VARIABLES STX, FX, DGX CONTAIN THE VALUES OF THE STEP, */ /* FUNCTION, AND DIRECTIONAL DERIVATIVE AT THE BEST STEP. */ /* THE VARIABLES STY, FY, DGY CONTAIN THE VALUE OF THE STEP, */ /* FUNCTION, AND DERIVATIVE AT THE OTHER ENDPOINT OF */ /* THE INTERVAL OF UNCERTAINTY. */ /* THE VARIABLES STP, F, DG CONTAIN THE VALUES OF THE STEP, */ /* FUNCTION, AND DERIVATIVE AT THE CURRENT STEP. */ stx = 0; fx = finit; dgx = dginit; sty = 0; fy = finit; dgy = dginit; /* START OF ITERATION. */ L30: /* SET THE MINIMUM AND MAXIMUM STEPS TO CORRESPOND */ /* TO THE PRESENT INTERVAL OF UNCERTAINTY. */ if (brackt) { stmin = min(stx,sty); stmax = max(stx,sty); } else { stmin = stx; stmax = *stp + XTRAPF *(*stp - stx); } /* FORCE THE STEP TO BE WITHIN THE BOUNDS STPMAX AND STPMIN. */ *stp = max(*stp,STPMIN); *stp = min(*stp,STPMAX); /* IF AN UNUSUAL TERMINATION IS TO OCCUR THEN LET */ /* STP BE THE LOWEST POINT OBTAINED SO FAR. */ if ((brackt && ((*stp <= stmin || *stp >= stmax) || *nfev >= MAXFEV - 1 || infoc == 0)) ||(brackt && (stmax - stmin <= XTOL * stmax))) { *stp = stx; } /* EVALUATE THE FUNCTION AND GRADIENT AT STP */ /* AND COMPUTE THE DIRECTIONAL DERIVATIVE. */ /* We return to main program to obtain F and G. */ i__1 = *n; if(orthant) { for (j = 1; j <= i__1; ++j) { if(g[j]==0 && x[j]==0) continue; //stop at 0 if x[j] cross 0 x[j] = wa[j] + *stp * s[j]; if(wa[j]*x[j]<0) x[j]=0; } }else{ for (j = 1; j <= i__1; ++j) { x[j] = wa[j] + *stp * s[j]; /* L40: */ } } *info = -1; return 0; L45: *info = 0; ++(*nfev); dg = 0; i__1 = *n; for (j = 1; j <= i__1; ++j) { dg += g[j] * s[j]; /* L50: */ } ftest1 = finit + *stp * dgtest; /* TEST FOR CONVERGENCE. */ if (brackt && ((*stp <= stmin || *stp >= stmax) || infoc == 0)) { *info = 6; } if (*stp == STPMAX && *f <= ftest1 && dg <= dgtest) { *info = 5; } if (*stp == STPMIN && (*f > ftest1 || dg >= dgtest)) { *info = 4; } if (*nfev >= MAXFEV) { *info = 3; } if (brackt && stmax - stmin <= XTOL * stmax) { *info = 2; } if (*f <= ftest1 && fabs(dg) <= GTOL * (-dginit)) { *info = 1; } /* CHECK FOR TERMINATION. */ if (*info != 0) { return 0; } /* IN THE FIRST STAGE WE SEEK A STEP FOR WHICH THE MODIFIED */ /* FUNCTION HAS A NONPOSITIVE VALUE AND NONNEGATIVE DERIVATIVE. */ if (stage1 && *f <= ftest1 && dg >= min(FTOL,GTOL) * dginit) { stage1 = 0; } /* A MODIFIED FUNCTION IS USED TO PREDICT THE STEP ONLY IF */ /* WE HAVE NOT OBTAINED A STEP FOR WHICH THE MODIFIED */ /* FUNCTION HAS A NONPOSITIVE FUNCTION VALUE AND NONNEGATIVE */ /* DERIVATIVE, AND IF A LOWER FUNCTION VALUE HAS BEEN */ /* OBTAINED BUT THE DECREASE IS NOT SUFFICIENT. */ if (stage1 && *f <= fx && *f > ftest1) { /* DEFINE THE MODIFIED FUNCTION AND DERIVATIVE VALUES. */ fm = *f - *stp * dgtest; fxm = fx - stx * dgtest; fym = fy - sty * dgtest; dgm = dg - dgtest; dgxm = dgx - dgtest; dgym = dgy - dgtest; /* CALL CSTEP TO UPDATE THE INTERVAL OF UNCERTAINTY */ /* AND TO COMPUTE THE NEW STEP. */ mcstep_(&stx, &fxm, &dgxm, &sty, &fym, &dgym, stp, &fm, &dgm, &brackt, &stmin, &stmax, &infoc); /* RESET THE FUNCTION AND GRADIENT VALUES FOR F. */ fx = fxm + stx * dgtest; fy = fym + sty * dgtest; dgx = dgxm + dgtest; dgy = dgym + dgtest; } else { /* CALL MCSTEP TO UPDATE THE INTERVAL OF UNCERTAINTY */ /* AND TO COMPUTE THE NEW STEP. */ mcstep_(&stx, &fx, &dgx, &sty, &fy, &dgy, stp, f, &dg, &brackt, & stmin, &stmax, &infoc); } /* FORCE A SUFFICIENT DECREASE IN THE SIZE OF THE */ /* INTERVAL OF UNCERTAINTY. */ if (brackt) { if ((d__1 = sty - stx, fabs(d__1)) >= P66 * width1) { *stp = stx + P5 *(sty - stx); } width1 = width; width =(d__1 = sty - stx, fabs(d__1)); } /* END OF ITERATION. */ goto L30; /* LAST LINE OF SUBROUTINE MCSRCH. */} /* mcsrch_ */int LBFGS::optimize(double *x, double *f, double *g, int orthant, double *w0, double *w1){ int i,j,k; int index; double ys,yy,h_0; double *w2,*w3;//4 work spaces for memory saving prior if(prior==1) { w2=x; w3=g; }else if(prior==0){ w1=&diag[0]; } double gnorm; if(iflag==0)//first iteration { if(prior==1) { s[0]=w0; q=w1; } for(i=0;i<n;i++) s[0][i]=-g[i]; gnorm=sqrt(inner_product(n,g,g)); stp1=1/gnorm; }else{ if(prior==1) { load(w0,"__w0"); load(w1,"__w1"); s[iter%m]=w0; } goto L172; } while(1) {//main loop //w0:null, w1:q, w2:null, w3:g //for first iteration //w0:s,w1:q,w2:x,w3:g info=0; bound=iter>m?m:iter; if(iter==0) goto L165; index=(iter+m-1)%m; if(prior==1) { s[index]=w0; y[index]=w2; load(s[index],"__s",index); load(y[index],"__y",index); } //w0:s, w1:q, w2:y, w3:g ys=inner_product(n,y[index],s[index]); yy=inner_product(n,y[index],y[index]); h_0=ys/yy; rho[index]=1/ys; for(i=0;i<n;i++) q[i]=-g[i]; index=iter%m; for(i=0;i<bound;i++) { index=(index+m-1)%m; if(prior==1 && i>0) { s[index]=w0; y[index]=w2; load(s[index],"__s",index); load(y[index],"__y",index); } double sq=inner_product(n,q,s[index]); alpha[index]=sq*rho[index]; daxpy(n,q,-alpha[index],y[index]); } for(i=0;i<n;i++) q[i]*=h_0; for(i=0;i<bound;i++) { if(prior==1 && i>0) { s[index]=w0; y[index]=w2; load(s[index],"__s",index); load(y[index],"__y",index); } double yr=inner_product(n,y[index],q); double beta=yr*rho[index]; beta=alpha[index]-beta; daxpy(n,q,beta,s[index]); index=(index+1)%m; } if(prior==1) s[iter%m]=w0; //w0:s, w1:q, w2:y, w3:g for(i=0;i<n;i++) s[iter%m][i]=q[i]; if(prior==1) { x=w2; load(x,"__x"); }L165://w0:s, w1:q, w2:x, w3:g nfev = 0; stp = 1; if (iter == 0) stp = stp1; for(i=0;i<n;i++) q[i] = g[i]; if(prior==1) save(q,"__q");L172://w0:s, w1:temp, w2:x, w3:g mcsrch_(&n, x, f, g, s[iter%m], &stp, &info, &nfev, w1, orthant); if (info == -1) { if(prior==1) { save(w0,"__w0"); save(w1,"__w1"); } iflag = 1; return 1; } if (info != 1) return -1;//error, see iflag gnorm=sqrt(inner_product(n,g,g)); double xnorm=sqrt(inner_product(n,x,x)); if(prior==1){ q=w1; load(q,"__q"); save(x,"__x"); y[iter%m]=w2; } //w0:s, w1:q, w2:y, w3:g for(i=0;i<n;i++) { s[iter%m][i]*=stp; y[iter%m][i]=g[i]-q[i]; } if(xnorm<1) xnorm=1; if (gnorm / xnorm <= ETA) { finish = true; } if(finish) { iflag=0; return 0; } if(prior==1) { save(y[iter%m],"__y",iter%m); save(s[iter%m],"__s",iter%m); } //w0:null, w1:q, w2:null, w3:g iter++; }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -