⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 lbfgs.h

📁 最大熵 等模型使用的 lbfgs 训练源代码。
💻 H
📖 第 1 页 / 共 4 页
字号:
           improper input parameters.           @author Jorge J. More, David J. Thuente: original Fortran version,             as part of Minpack project. Argonne Nat'l Laboratory, June 1983.             Robert Dodier: Java translation, August 1997.         */        static int mcstep(          FloatType& stx,          FloatType& fx,          FloatType& dx,          FloatType& sty,          FloatType& fy,          FloatType& dy,          FloatType& stp,          FloatType fp,          FloatType dp,          bool& brackt,          FloatType stpmin,          FloatType stpmax);    };    template <typename FloatType, typename SizeType>    void mcsrch<FloatType, SizeType>::run(      FloatType const& gtol,      FloatType const& stpmin,      FloatType const& stpmax,      SizeType n,      FloatType* x,      FloatType f,      const FloatType* g,      FloatType* s,      SizeType is0,      FloatType& stp,      FloatType ftol,      FloatType xtol,      SizeType maxfev,      int& info,      SizeType& nfev,      FloatType* wa)    {      if (info != -1) {        infoc = 1;        if (   n == 0            || maxfev == 0            || gtol < FloatType(0)            || xtol < FloatType(0)            || stpmin < FloatType(0)            || stpmax < stpmin) {          throw error_internal_error(__FILE__, __LINE__);        }        if (stp <= FloatType(0) || ftol < FloatType(0)) {          throw error_internal_error(__FILE__, __LINE__);        }        // Compute the initial gradient in the search direction        // and check that s is a descent direction.s is the search direction and gk's<0 (inner product)        dginit = FloatType(0);        for (SizeType j = 0; j < n; j++) {          dginit += g[j] * s[is0+j];//dginit=-g(x)’H(k)g(k)        }        if (dginit >= FloatType(0)) {          throw error_search_direction_not_descent();        }        brackt = false;        stage1 = true;        nfev = 0;        finit = f;        dgtest = ftol*dginit;        width = stpmax - stpmin;        width1 = FloatType(2) * width;        std::copy(x, x+n, wa);        // The variables stx, fx, dgx contain the values of the step,        // function, and directional derivative at the best step.        // The variables sty, fy, dgy contain the value of the step,        // function, and derivative at the other endpoint of        // the interval of uncertainty.        // The variables stp, f, dg contain the values of the step,        // function, and derivative at the current step.        stx = FloatType(0);        fx = finit;        dgx = dginit;        sty = FloatType(0);        fy = finit;        dgy = dginit;      }      for (;;) {        if (info != -1) {          // Set the minimum and maximum steps to correspond          // to the present interval of uncertainty.          if (brackt) {            stmin = std::min(stx, sty);            stmax = std::max(stx, sty);          }          else {            stmin = stx;            stmax = stp + FloatType(4) * (stp - stx);          }          // Force the step to be within the bounds stpmax and stpmin.          stp = std::max(stp, stpmin);          stp = std::min(stp, stpmax);          // If an unusual termination is to occur then let          // stp be the lowest point obtained so far.          if (   (brackt && (stp <= stmin || stp >= stmax))              || nfev >= maxfev - 1 || infoc == 0              || (brackt && stmax - stmin <= xtol * stmax)) {            stp = stx;          }          // Evaluate the function and gradient at stp          // and compute the directional derivative.          // We return to main program to obtain F and G.          for (SizeType j = 0; j < n; j++) {            x[j] = wa[j] + stp * s[is0+j];//X(k+1)=X(k)+stp*(-H(k)g(k))          }          info=-1;          break;        }        info = 0;        nfev++;        FloatType dg(0);        for (SizeType j = 0; j < n; j++) {          dg += g[j] * s[is0+j];        }        FloatType ftest1 = finit + stp*dgtest;        // Test for convergence.        if ((brackt && (stp <= stmin || stp >= stmax)) || infoc == 0) {          throw error_line_search_failed_rounding_errors(            "Rounding errors prevent further progress."            " There may not be a step which satisfies the"            " sufficient decrease and curvature conditions."            " Tolerances may be too small.");        }        if (stp == stpmax && f <= ftest1 && dg <= dgtest) {          throw error_line_search_failed(            "The step is at the upper bound stpmax().");        }        if (stp == stpmin && (f > ftest1 || dg >= dgtest)) {          throw error_line_search_failed(            "The step is at the lower bound stpmin().");        }        if (nfev >= maxfev) {          throw error_line_search_failed(            "Number of function evaluations has reached maxfev().");        }        if (brackt && stmax - stmin <= xtol * stmax) {          throw error_line_search_failed(            "Relative width of the interval of uncertainty"            " is at most xtol().");        }        // Check for termination.        if (f <= ftest1 && abs(dg) <= gtol * (-dginit)) {          info = 1;          break;        }        // In the first stage we seek a step for which the modified        // function has a nonpositive value and nonnegative derivative.        if (   stage1 && f <= ftest1            && dg >= std::min(ftol, gtol) * dginit) {          stage1 = false;        }        // A modified function is used to predict the step only if        // we have not obtained a step for which the modified        // function has a nonpositive function value and nonnegative        // derivative, and if a lower function value has been        // obtained but the decrease is not sufficient.        if (stage1 && f <= fx && f > ftest1) {          // Define the modified function and derivative values.          FloatType fm = f - stp*dgtest;          FloatType fxm = fx - stx*dgtest;          FloatType fym = fy - sty*dgtest;          FloatType dgm = dg - dgtest;          FloatType dgxm = dgx - dgtest;          FloatType dgym = dgy - dgtest;          // Call cstep to update the interval of uncertainty          // and to compute the new step.          infoc = mcstep(stx, fxm, dgxm, sty, fym, dgym, stp, fm, dgm,                         brackt, stmin, stmax);          // Reset the function and gradient values for f.          fx = fxm + stx*dgtest;          fy = fym + sty*dgtest;          dgx = dgxm + dgtest;          dgy = dgym + dgtest;        }        else {          // Call mcstep to update the interval of uncertainty          // and to compute the new step.          infoc = mcstep(stx, fx, dgx, sty, fy, dgy, stp, f, dg,                         brackt, stmin, stmax);        }        // Force a sufficient decrease in the size of the        // interval of uncertainty.        if (brackt) {          if (abs(sty - stx) >= FloatType(0.66) * width1) {            stp = stx + FloatType(0.5) * (sty - stx);          }          width1 = width;          width = abs(sty - stx);        }      }    }    template <typename FloatType, typename SizeType>    int mcsrch<FloatType, SizeType>::mcstep(      FloatType& stx,      FloatType& fx,      FloatType& dx,      FloatType& sty,      FloatType& fy,      FloatType& dy,      FloatType& stp,      FloatType fp,      FloatType dp,      bool& brackt,      FloatType stpmin,      FloatType stpmax)    {      bool bound;      FloatType gamma, p, q, r, s, sgnd, stpc, stpf, stpq, theta;      int info = 0;      if (   (   brackt && (stp <= std::min(stx, sty)              || stp >= std::max(stx, sty)))          || dx * (stp - stx) >= FloatType(0) || stpmax < stpmin) {        return 0;      }      // Determine if the derivatives have opposite sign.      sgnd = dp * (dx / abs(dx));      if (fp > fx) {        // First case. A higher function value.        // The minimum is bracketed. If the cubic step is closer        // to stx than the quadratic step, the cubic step is taken,        // else the average of the cubic and quadratic steps is taken.        info = 1;        bound = true;        theta = FloatType(3) * (fx - fp) / (stp - stx) + dx + dp;        s = max3(abs(theta), abs(dx), abs(dp));        gamma = s * std::sqrt(pow2(theta / s) - (dx / s) * (dp / s));        if (stp < stx) gamma = - gamma;        p = (gamma - dx) + theta;        q = ((gamma - dx) + gamma) + dp;        r = p/q;        stpc = stx + r * (stp - stx);        stpq = stx          + ((dx / ((fx - fp) / (stp - stx) + dx)) / FloatType(2))            * (stp - stx);        if (abs(stpc - stx) < abs(stpq - stx)) {          stpf = stpc;        }        else {          stpf = stpc + (stpq - stpc) / FloatType(2);        }        brackt = true;      }      else if (sgnd < FloatType(0)) {        // Second case. A lower function value and derivatives of        // opposite sign. The minimum is bracketed. If the cubic        // step is closer to stx than the quadratic (secant) step,        // the cubic step is taken, else the quadratic step is taken.        info = 2;        bound = false;        theta = FloatType(3) * (fx - fp) / (stp - stx) + dx + dp;        s = max3(abs(theta), abs(dx), abs(dp));        gamma = s * std::sqrt(pow2(theta / s) - (dx / s) * (dp / s));        if (stp > stx) gamma = - gamma;        p = (gamma - dp) + theta;        q = ((gamma - dp) + gamma) + dx;        r = p/q;        stpc = stp + r * (stx - stp);        stpq = stp + (dp / (dp - dx)) * (stx - stp);        if (abs(stpc - stp) > abs(stpq - stp)) {          stpf = stpc;        }        else {          stpf = stpq;        }        brackt = true;      }      else if (abs(dp) < abs(dx)) {        // Third case. A lower function value, derivatives of the        // same sign, and the magnitude of the derivative decreases.        // The cubic step is only used if the cubic tends to infinity        // in the direction of the step or if the minimum of the cubic        // is beyond stp. Otherwise the cubic step is defined to be        // either stpmin or stpmax. The quadratic (secant) step is also        // computed and if the minimum is bracketed then the the step        // closest to stx is taken, else the step farthest away is taken.        info = 3;        bound = true;        theta = FloatType(3) * (fx - fp) / (stp - stx) + dx + dp;        s = max3(abs(theta), abs(dx), abs(dp));        gamma = s * std::sqrt(          std::max(FloatType(0), pow2(theta / s) - (dx / s) * (dp / s)));        if (stp > stx) gamma = -gamma;        p = (gamma - dp) + theta;        q = (gamma + (dx - dp)) + gamma;        r = p/q;        if (r < FloatType(0) && gamma != FloatType(0)) {          stpc = stp + r * (stx - stp);        }        else if (stp > stx) {          stpc = stpmax;        }        else {          stpc = stpmin;        }        stpq = stp + (dp / (dp - dx)) * (stx - stp);        if (brackt) {          if (abs(stp - stpc) < abs(stp - stpq)) {            stpf = stpc;          }          else {            stpf = stpq;          }        }        else {          if (abs(stp - stpc) > abs(stp - stpq)) {            stpf = stpc;          }          else {            stpf = stpq;          }        }      }      else {        // Fourth case. A lower function value, derivatives of the        // same sign, and the magnitude of the derivative does        // not decrease. If the minimum is not bracketed, the step        // is either stpmin or stpmax, else the cubic step is taken.        info = 4;        bound = false;        if (brackt) {          theta = FloatType(3) * (fp - fy) / (sty - stp) + dy + dp;          s = max3(abs(theta), abs(dy), abs(dp));          gamma = s * std::sqrt(pow2(theta / s) - (dy / s) * (dp / s));          if (stp > sty) gamma = -gamma;          p = (gamma - dp) + theta;          q = ((gamma - dp) + gamma) + dy;          r = p/q;          stpc = stp + r * (sty - stp);          stpf = stpc;        }        else if (stp > stx) {          stpf = stpmax;        }        else {          stpf = stpmin;        }      }      // Update the interval of uncertainty. This update does not      // depend on the new step or the case analysis above.      if (fp > fx) {        sty = stp;        fy = fp;        dy = dp;      }      else {

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -