⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 pulse.cpp

📁 C++编写的机器学习算法 Lemga is a C++ package which consists of classes for several learning models and gener
💻 CPP
字号:
/** @file *  $Id: pulse.cpp 2186 2005-06-29 03:40:11Z ling $ */#include <assert.h>#include <cmath>#include <vector>#include <map>#include <algorithm>#include "pulse.h"REGISTER_CREATOR(lemga::Pulse);template <typename II>bool serialize (std::ostream& os, const II& b, const II& e, bool l = true) {    if (l) if (!(os << (e - b) << '\n')) return false;    for (II i = b; i != e; ++i)        if (!(os << *i << ' ')) return false;    if (b != e) os << '\n';    return true;}template <typename II>bool unserialize (std::istream& is, const II& b, const II& e) {    for (II i = b; i != e; ++i)        if (!(is >> *i)) return false;    return true;}namespace lemga {bool Pulse::serialize (std::ostream& os, ver_list& vl) const {    SERIALIZE_PARENT(LearnModel, os, vl, 1);    if (!(os << idx << ' ' << th.size() << ' ' << (dir? '-':'+') << '\n'))        return false;    return ::serialize(os, th.begin(), th.end(), false);}bool Pulse::unserialize (std::istream& is, ver_list& vl, const id_t& d) {    if (d != id() && d != empty_id) return false;    UNSERIALIZE_PARENT(LearnModel, is, vl, 1, v);    assert(v > 0);    UINT nl;    if (!(is >> idx >> nl)) return false;    char c;    if (!(is >> c) || (c != '-' && c != '+')) return false;    dir = (c == '-');    th.resize(nl);    return ::unserialize(is, th.begin(), th.end());}void Pulse::set_threshold (const std::vector<REAL>& t) {    assert(t.size() <= max_l);#ifndef NDEBUG    // assert t is sorted (std::is_sorted is an SGL extension)    for (UINT i = 1; i < t.size(); ++i)        assert(t[i-1] <= t[i]);#endif    th = t;}REAL Pulse::train () {    const UINT N = n_samples;    assert(ptd != 0 && ptw != 0 && ptd->size() == N);    assert(n_input() > 0 && n_output() == 1);    std::vector<REAL> yw(N);    for (UINT i = 0; i < N; ++i)        yw[i] = ptd->y(i)[0] * (*ptw)[i];    REAL minerr = 2;         // a number large enough (> 1)    std::vector<UINT> thi;   // threshold index    std::vector<REAL> xb(N); // backup for sorted x    // no reallocation within loops    std::vector<REAL> x(N), ysum(N);    for (UINT d = 0; d < _n_in; ++d) {        // extract the dimension d info, collapse data with same x        std::map<REAL,REAL> xy;        for (UINT i = 0; i < N; ++i)            xy[ptd->x(i)[d]] += yw[i];        REAL sum = 0;        int last_sign = 0;  // 1: pos, 2: neg, 3: zero        std::vector<REAL>::iterator px = x.begin(), py = ysum.begin();        for (std::map<REAL,REAL>::const_iterator                 p = xy.begin(); p != xy.end(); ++p) {            static REAL last_x;            const int cur_sign = (p->second > 0)? 1:((p->second < 0)? 2:3);            if (last_sign != cur_sign && last_sign != 0) {                /** @note we can also save the threshold info. as lower                 *  and upper @a x values, which is a bit faster (less                 *  operations) but needs more memory space. */                *px = last_x + p->first; *py = sum * 2;                ++px; ++py;            }            last_sign = cur_sign;            last_x = p->first;            sum += p->second;            assert(-1.01 < sum && sum < 1.01);        }        *py = sum * 2;        const UINT n = py - ysum.begin();        std::vector<REAL> e0(n+1, 0); // error of pulses ending with -1        std::vector<REAL> e1(n+1, 0); // error of pulses ending with +1        std::vector<std::vector<UINT> > t0(n+1), t1(n+1); // transitions idx        // dynamic programming: compute err for level 1--max_l        // e0 and e1 at the begining of loop l are        //    e0[i] = 2*best_e_{i,l} - 1 - sum(w*y),        //    e1[i] = 2*best_e_{i,l} - 1 + sum(w*y).        // where best_e_{i,l} is the lowest error if l transitions        // happens before or at position i.        for (UINT l = 0; l < max_l; ++l) {            // swap e0 & e1, t0 & t1 to get rid of the sign change            e0.swap(e1); t0.swap(t1);            // compute errors for level (l+1)            std::vector<REAL>::iterator p0 = e0.begin(), p1 = e1.begin();            std::vector<REAL>::iterator ps = ysum.begin();            for (UINT i = 0; i <= n; ++p0, ++p1, ++ps, ++i) {                *p0 -= *ps; *p1 += *ps;            }            assert(p0 == e0.end());            std::vector<std::vector<UINT> >::iterator                pt0 = t0.begin(), pt1 = t1.begin();            REAL bst0 = 3, bst1 = 3;  // a number large enough (> 2)            p0 = e0.begin(); p1 = e1.begin();            for (UINT i = 0; i <= n; ++p0, ++p1, ++pt0, ++pt1, ++i) {                static std::vector<UINT> tb0, tb1;  // always the best                assert(-2.01 < *p0 && *p0 < 2.01);                assert(-2.01 < *p1 && *p1 < 2.01);                if (*p0 < bst0) {                    bst0 = *p0; tb0.swap(*pt0); // => tb0 = *pt0;                    if (i < n) tb0.push_back(i);                }                *p0 = bst0; *pt0 = tb0;                if (*p1 < bst1) {                    bst1 = *p1; tb1.swap(*pt1); // => tb1 = *pt1;                    if (i < n) tb1.push_back(i);                }                *p1 = bst1; *pt1 = tb1;            }            assert(p0 == e0.end());        }        e0[n] += ysum[n] / 2;        e1[n] -= ysum[n] / 2;        if (e0[n] <= e1[n] && e0[n] < minerr) {            minerr = e0[n]; idx = d; dir = !(max_l & 1);            thi.swap(t0[n]); xb.swap(x);        } else if (e1[n] < minerr) {            minerr = e1[n]; idx = d; dir = (max_l & 1);            thi.swap(t1[n]); xb.swap(x);        }    }    th.clear();    for (UINT i = 0; i < thi.size(); /* empty */) {        UINT ind = thi[i]; ++i;        if (i < thi.size() && ind == thi[i])            ++i;        else            th.push_back(xb[ind] / 2);    }    return (1 + minerr) / 2;}Output Pulse::operator() (const Input& x) const {    assert(idx < n_input() && x.size() == n_input());    if (th.empty())        return Output(1, dir? -1 : 1);    const UINT i =        std::lower_bound(th.begin(), th.end(), x[idx]) - th.begin();    return Output(1, ((i & 1) ^ dir)? -1 : 1);}} // namespace lemga

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -