⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 acoins.cc

📁 基于学习的深度优先搜索算法
💻 CC
📖 第 1 页 / 共 3 页
字号:
#include <iostream>#include <ext/hash_map>#include <list>#include <map>#include <set>#include <assert.h>#include <sys/time.h>#define MAX(s,t)      ((s)>(t)?(s):(t))#define MIN(s,t)      ((s)<(t)?(s):(t))#define UMAX          (UINT_MAX>>1)#define MILLION       1000000#define TIME(s,u)     ((float)(s) + ((float)u)/MILLION)class hash_t;class state_t;class action_t;typedef unsigned char uchar_t;typedef unsigned (hash_t::*qvalue_func_t)( const state_t&, const action_t& ) const;void generateStateSpace( uchar_t, const state_t&, std::set<state_t>& );static unsigned verbose = 0;static unsigned expansions = 0;inline std::ostream& operator<<( std::ostream &os, uchar_t c ){  os << (int)c;  return( os );}inline voidindent( std::ostream &os, int depth ){  for( int i = 0; i < depth; ++i ) os << ' ';}// state and action classesclass state_t{  uchar_t lhs_, ls_, hs_, s_;public:  state_t( uchar_t n ) : lhs_(n), ls_(0), hs_(0), s_(0) { }  state_t( uchar_t lhs, uchar_t ls, uchar_t hs, uchar_t s ) : lhs_(lhs), ls_(ls), hs_(hs), s_(s) { }  uchar_t lhs( void ) const { return( lhs_ ); }  uchar_t ls( void ) const { return( ls_ ); }  uchar_t hs( void ) const { return( hs_ ); }  uchar_t s( void ) const { return( s_ ); }  state_t next( const action_t &a, int obs ) const;  bool terminal( uchar_t n ) const { return( (s_==n) || ((lhs_==0) && (ls_+hs_==1)) ); }  uchar_t terminal_value( void ) const { return( 0 ); }  bool useful( const action_t &a ) const;  bool operator<( const state_t &s ) const { return( (lhs_<s.lhs_) || ((lhs_==s.lhs_)&&(ls_<s.ls_)) || ((lhs_==s.lhs_)&&(ls_==s.ls_)&&(hs_<s.hs_)) || ((lhs_==s.lhs_)&&(ls_==s.ls_)&&(hs_==s.hs_)&&(s_<s.s_)) ); }  bool operator==( const state_t &s ) const { return( (lhs_==s.lhs_) && (ls_==s.ls_) && (hs_==s.hs_) && (s_==s.s_) ); }  void print( std::ostream &os ) const;};class action_t{  state_t l_, r_;public:  explicit action_t() : r_(0), l_(0) { }  action_t( uchar_t lhs1, uchar_t ls1, uchar_t hs1, uchar_t s1, uchar_t lhs2, uchar_t ls2, uchar_t hs2, uchar_t s2 ) : r_(lhs1,ls1,hs1,s1), l_(lhs2,ls2,hs2,s2) { }  state_t l( void ) const { return( l_ ); }  state_t r( void ) const { return( r_ ); }  bool operator<( const action_t &a ) const { return( (l_ < a.l_) || ((l_ == a.l_) && (r_ < a.r_)) ); }  bool operator==( const action_t &a ) const { return( (l_ == a.l_) && (r_ == a.r_) ); }  void print( std::ostream &os ) const;};inline state_tstate_t::next( const action_t &a, int obs ) const{  if( obs < 0 )    {      state_t s( 0, a.r().lhs()+a.r().ls(), a.l().lhs()+a.l().hs(), s_+ls_-a.r().ls()+hs_-a.l().hs()+lhs_-a.l().lhs()-a.r().lhs() );      return( s );    }  else if( obs > 0 )    {      state_t s( 0, a.l().lhs()+a.l().ls(), a.r().lhs()+a.r().hs(), s_+ls_-a.l().ls()+hs_-a.r().hs()+lhs_-a.l().lhs()-a.r().lhs() );      return( s );    }  else    {      state_t s( lhs_-a.l().lhs()-a.r().lhs(), ls_-a.l().ls()-a.r().ls(), hs_-a.l().hs()-a.r().hs(), s_+a.l().ls()+a.r().ls()+a.l().hs()+a.r().hs()+a.l().lhs()+a.r().lhs() );      return( s );    }}inline boolstate_t::useful( const action_t &a ) const{  for( int obs = -1; obs < 2; ++obs )    {       state_t s_obs = next( a, obs );      if( s_obs == *this ) return( false );    }  return( true );}std::ostream& operator<<( std::ostream &os, const state_t &s ){  s.print( os );  return( os );}std::ostream& operator<<( std::ostream &os, const action_t &a ){  a.print( os );  return( os );}voidstate_t::print( std::ostream &os ) const{  os << "(" << lhs_ << "," << ls_ << "," << hs_ << "," << s_ << ")";}voidaction_t::print( std::ostream &os ) const{  os << "PICKUP[" << l_ << ":" << r_ << "]";}// heuristic and hash classesclass heuristic_t{  int type_;  uchar_t n_;  hash_t *hash_;public:  heuristic_t( int type, uchar_t n );  ~heuristic_t();  unsigned value( const state_t &s ) const;  void dump( std::ostream &os ) const;  void compute_heuristic( const state_t &s0, size_t iters, qvalue_func_t qvalue );};namespace hashing = ::__gnu_cxx;namespace __gnu_cxx {  template<> class hash<state_t>  {  public:    size_t operator()( const state_t &s ) const { return( (s.lhs()<<24) | (s.ls()<<16) | (s.hs()<<8) | s.s() ); }  };  class entry_t  {  public:    unsigned lower_;    unsigned upper_;    bool mark_;    entry_t( unsigned lower = 0, unsigned upper = UMAX, bool mark = false ) : lower_(lower), upper_(upper), mark_(mark) { }  };  class hash_t : public hash_map<state_t,entry_t>  {  public:    typedef std::pair<state_t,entry_t*> data_pair;  private:    const heuristic_t *h_;    data_pair push( const state_t &s, unsigned l, unsigned u, bool m ) { std::pair<iterator,bool> p = insert( std::make_pair( s, entry_t(l,u,m) ) ); return( data_pair( (*p.first).first, &(*p.first).second ) ); }    data_pair push( const state_t &s ) { return( push( s, (!h_?0:h_->value(s)), UMAX, false ) ); }    const_iterator lookup( const state_t &s ) const { return( find( s ) ); }    iterator lookup( const state_t &s ) { return( find( s ) ); }  public:    hash_t( const heuristic_t *h = 0 ) : h_(h) { }    data_pair get( const state_t &s ) { iterator di = lookup( s ); return( di == end() ? push( s  ) : data_pair( (*di).first, &(*di).second ) ); }    unsigned value( const state_t &s ) const { const_iterator di = lookup( s ); return( di == end() ? (!h_?0:h_->value(s)) : (*di).second.lower_ ); }    unsigned upper( const state_t &s ) const { const_iterator di = lookup( s ); return( di == end() ? UMAX : (*di).second.upper_ ); }    void update( const state_t &s, unsigned value ) { iterator di = lookup( s ); if( di == end() ) push( s, value, UMAX, false ); else (*di).second.lower_ = value; }    void update_upper( const state_t &s, unsigned value ) { iterator di = lookup( s ); if( di == end() ) push( s, (!h_?0:h_->value(s)), value, false ); else (*di).second.upper_ = value; }    bool solved( const state_t &s ) const { const_iterator di = lookup( s ); return( di == end() ? false : (*di).second.mark_ ); }    void mark( const state_t &s ) { iterator di = lookup( s ); if( di == end() ) push( s, (!h_?0:h_->value(s)), UMAX, true ); else (*di).second.mark_ = true; }    void dump( std::ostream &os ) const { for( const_iterator it = begin(); it != end(); ++it ) os << "state=" << (*it).first << ", lower=" << (*it).second.lower_ << ", upper=" << (*it).second.upper_ << ", mark=" << ((*it).second.mark_?1:0) << std::endl; }  };};class hash_t{  hashing::hash_t table_;  unsigned updates_;  unsigned upper_updates_;public:  hash_t( const heuristic_t *h = 0 ) : table_(h), updates_(0), upper_updates_(0) { }  typedef hashing::hash_t::data_pair data_pair;  data_pair get( const state_t &s ) { return( table_.get( s ) ); }  unsigned updates( void ) const { return( updates_ ); }  unsigned upper_updates( void ) const { return( upper_updates_ ); }  size_t size( void ) const { return( table_.size() ); }  size_t bucket_count( void ) const { return( table_.bucket_count() ); }  unsigned value( const state_t &s ) const { return( table_.value( s ) ); }  unsigned upper( state_t &s ) const { return( table_.upper( s ) ); }  void update( const state_t &s, unsigned value ) { table_.update( s, value ); ++updates_; }  void update_upper( state_t &s, unsigned value ) { table_.update_upper( s, value ); ++upper_updates_; }  bool solved( const state_t &s ) const { return( table_.solved( s ) ); }  void mark( state_t s ) { table_.mark( s ); }  unsigned QValueMax( const state_t &s, const action_t &a ) const;  unsigned QValueAdd( const state_t &s, const action_t &a ) const;  unsigned bestQValue( uchar_t n, const state_t &s, qvalue_func_t qvalue ) const;  std::pair<unsigned,action_t> bestAction( uchar_t n, const state_t &s, qvalue_func_t qvalue ) const;  void dump( std::ostream &os ) const { table_.dump( os ); }};unsignedhash_t::QValueMax( const state_t& s, const action_t &a ) const{  unsigned qv = 0;   for( int obs = -1; obs < 2; ++obs )    {       state_t s_obs = s.next( a, obs );      qv = MAX(qv,value(s_obs));    }  return( qv == UMAX ? UMAX : 1 + qv );}unsignedhash_t::QValueAdd( const state_t& s, const action_t &a ) const{  unsigned qv = 0;   for( int obs = -1; obs < 2; ++obs )    {       state_t s_obs = s.next( a, obs );      qv = qv + value(s_obs);    }  return( qv >= UMAX ? UMAX : 1 + qv );}unsignedhash_t::bestQValue( uchar_t n, const state_t& s, qvalue_func_t qvalue ) const{  unsigned bqv = UMAX;  for( int m = n>>1; m > 0; --m )  for( int lhs1 = m; lhs1 >= 0; --lhs1 )  for( int lhs2 = m; lhs2 >= 0; --lhs2 )    if( lhs1 + lhs2 <= s.lhs() ) {      for( int ls1 = m-lhs1; ls1 >= 0; --ls1 )      for( int ls2 = m-lhs2; ls2 >= 0; --ls2 )        if( ls1 + ls2 <= s.ls() ) {          for( int hs1 = m-lhs1-ls1; hs1 >= 0; --hs1 )          for( int hs2 = m-lhs2-ls2; hs2 >= 0; --hs2 )            if( (hs1+hs2 <= s.hs()) && (m+m-lhs1-lhs2-ls1-ls2-hs1-hs2 <= s.s()) )              {                assert( m - lhs1 - ls1 - hs1 + m - lhs2 - ls2 - hs2 <= s.s() );                action_t a( lhs1, ls1, hs1, m-lhs1-ls1-hs1, lhs2, ls2, hs2, m-lhs2-ls2-hs2 );                if( !s.useful(a) ) continue;                unsigned qv = (this->*qvalue)( s, a );                bqv = MIN(bqv,qv);              }          }      }  return( bqv );}std::pair<unsigned,action_t>hash_t::bestAction( uchar_t n, const state_t &s, qvalue_func_t qvalue ) const{  action_t ba;  unsigned bqv = UMAX;  for( int m = n>>1; m > 0; --m )  for( int lhs1 = m; lhs1 >= 0; --lhs1 )  for( int lhs2 = m; lhs2 >= 0; --lhs2 )    if( lhs1 + lhs2 <= s.lhs() ) {      for( int ls1 = m-lhs1; ls1 >= 0; --ls1 )      for( int ls2 = m-lhs2; ls2 >= 0; --ls2 )        if( ls1 + ls2 <= s.ls() ) {          for( int hs1 = m-lhs1-ls1; hs1 >= 0; --hs1 )          for( int hs2 = m-lhs2-ls2; hs2 >= 0; --hs2 )            if( (hs1+hs2 <= s.hs()) && (m+m-lhs1-lhs2-ls1-ls2-hs1-hs2 <= s.s()) )              {                assert( m - lhs1 - ls1 - hs1 + m - lhs2 - ls2 - hs2 <= s.s() );                action_t a( lhs1, ls1, hs1, m-lhs1-ls1-hs1, lhs2, ls2, hs2, m-lhs2-ls2-hs2 );                if( !s.useful(a) ) continue;                unsigned qv = (this->*qvalue)( s, a );                if( qv < bqv )                  {                    bqv = qv;                    ba = a;                  }              }          }      }  return( std::make_pair( bqv, ba ) );}heuristic_t::heuristic_t( int type, uchar_t n ) : type_(type), n_(n){  hash_ = new hash_t;}heuristic_t::~heuristic_t(){  delete hash_;}inline unsignedheuristic_t::value( const state_t &s ) const{  return( hash_->value( s ) );}voidheuristic_t::dump( std::ostream &os ) const{  hash_->dump( os );}voidheuristic_t::compute_heuristic( const state_t &s0, size_t iters, qvalue_func_t qvalue ){  std::set<state_t> space;  if( iters > 0 ) generateStateSpace( n_, s0, space );  size_t i = 0, ups = iters * space.size();  while( ((type_ == 1) && (i < iters)) || ((type_ == 2) && (hash_->updates() < ups)) )    {      ++i;      for( std::set<state_t>::const_iterator si = space.begin(); si != space.end(); ++si )        if( (type_ == 1) || ((type_ == 2) && (drand48() < 0.5)) )          {            if( !(*si).terminal(n_) )              hash_->update( *si, hash_->bestQValue( n_, *si, qvalue ) );            else              hash_->update( *si, (*si).terminal_value() );            if( (type_ == 2) && (hash_->updates() >= ups) ) goto end;          }    }  end:;  std::cout << "updates=" << hash_->updates() << ", iters=" << i << std::endl;}// graph class (for AO*)class graph_t{public:  class entry_t  {    const state_t s_;    std::set<entry_t*> parents_;    action_t marked_;    bool fringe_;    bool visited_;    size_t index_;  public:    entry_t() : s_(0), fringe_(true), visited_(false), index_(0) { }    entry_t( const state_t s ) : s_(s), fringe_(true), visited_(false), index_(0) { }    const state_t& s( void ) const { return( s_ ); }    bool fringe( void ) const { return( fringe_ ); }    void set_fringe( bool fringe ) { fringe_ = fringe; }    std::set<entry_t*>& parents( void ) { return( parents_ ); }    const std::set<entry_t*>& parents( void ) const { return( parents_ ); }    void add_parent( entry_t *parent ) { parents_.insert( parent ); }    bool visited( void ) const { return( visited_ ); }    void visit( void ) { visited_ = true; }    void unvisit( void ) { visited_ = false; }    action_t marked( void ) const { return( marked_ ); }    void mark( const action_t a ) { marked_ = a; }    void inc_index( void ) { ++index_; }    void dec_index( void ) { assert( index_ > 0 ); --index_; }    void set_index( size_t index ) { index_ = index; }    size_t index( void ) const { return( index_ ); }    bool revise( uchar_t n, hash_t &hash, qvalue_func_t qvalue );    bool operator()( const entry_t *e1, const entry_t *e2 ) const { return( e1->index() < e2->index() ); }  };private:  std::map<state_t,entry_t*> nodes_;  std::set<entry_t*> tips_;  entry_t *root_;public:  graph_t() : root_(0) { }  entry_t* root( void ) { return( root_ ); }  void set_root( entry_t *root ) { root_ = root; }  const std::set<entry_t*>& tips( void ) const { return( tips_ ); }  entry_t& choose_tip( void ) { return( **tips_.begin() ); }  entry_t* add_tip( entry_t *parent, const action_t &a, int obs, const state_t &child );  bool expand( uchar_t n, entry_t &node, std::multiset<entry_t*,entry_t> &S );  void expand_and_update( uchar_t n, entry_t &node, hash_t &hash, qvalue_func_t qvalue );};inline boolgraph_t::entry_t::revise( uchar_t n, hash_t &hash, qvalue_func_t qvalue ){  if( s_.terminal(n) )    {      hash.update( s_, s_.terminal_value() );      hash.mark( s_ );      fringe_ = false;      return( true );    }  else    {      unsigned old_value = hash.value( s_ );      std::pair<unsigned,action_t> p = hash.bestAction( n, s_, qvalue );      hash.update( s_, p.first );      mark( p.second );      bool solved = true;      if( p.first < UMAX )        {          assert( s_.useful( p.second ) );          for( int obs = -1; solved && (obs < 2); ++obs )            {              state_t s_obs = s_.next( p.second, obs );              solved = hash.solved( s_obs );            }          if( solved )            {              hash.mark( s_ );              fringe_ = false;            }        }      else        {          hash.mark( s_ );          fringe_ = false;        }      return( solved || (p.first > old_value) );    }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -