📄 acoins.cc
字号:
#include <iostream>#include <ext/hash_map>#include <list>#include <map>#include <set>#include <assert.h>#include <sys/time.h>#define MAX(s,t) ((s)>(t)?(s):(t))#define MIN(s,t) ((s)<(t)?(s):(t))#define UMAX (UINT_MAX>>1)#define MILLION 1000000#define TIME(s,u) ((float)(s) + ((float)u)/MILLION)class hash_t;class state_t;class action_t;typedef unsigned char uchar_t;typedef unsigned (hash_t::*qvalue_func_t)( const state_t&, const action_t& ) const;void generateStateSpace( uchar_t, const state_t&, std::set<state_t>& );static unsigned verbose = 0;static unsigned expansions = 0;inline std::ostream& operator<<( std::ostream &os, uchar_t c ){ os << (int)c; return( os );}inline voidindent( std::ostream &os, int depth ){ for( int i = 0; i < depth; ++i ) os << ' ';}// state and action classesclass state_t{ uchar_t lhs_, ls_, hs_, s_;public: state_t( uchar_t n ) : lhs_(n), ls_(0), hs_(0), s_(0) { } state_t( uchar_t lhs, uchar_t ls, uchar_t hs, uchar_t s ) : lhs_(lhs), ls_(ls), hs_(hs), s_(s) { } uchar_t lhs( void ) const { return( lhs_ ); } uchar_t ls( void ) const { return( ls_ ); } uchar_t hs( void ) const { return( hs_ ); } uchar_t s( void ) const { return( s_ ); } state_t next( const action_t &a, int obs ) const; bool terminal( uchar_t n ) const { return( (s_==n) || ((lhs_==0) && (ls_+hs_==1)) ); } uchar_t terminal_value( void ) const { return( 0 ); } bool useful( const action_t &a ) const; bool operator<( const state_t &s ) const { return( (lhs_<s.lhs_) || ((lhs_==s.lhs_)&&(ls_<s.ls_)) || ((lhs_==s.lhs_)&&(ls_==s.ls_)&&(hs_<s.hs_)) || ((lhs_==s.lhs_)&&(ls_==s.ls_)&&(hs_==s.hs_)&&(s_<s.s_)) ); } bool operator==( const state_t &s ) const { return( (lhs_==s.lhs_) && (ls_==s.ls_) && (hs_==s.hs_) && (s_==s.s_) ); } void print( std::ostream &os ) const;};class action_t{ state_t l_, r_;public: explicit action_t() : r_(0), l_(0) { } action_t( uchar_t lhs1, uchar_t ls1, uchar_t hs1, uchar_t s1, uchar_t lhs2, uchar_t ls2, uchar_t hs2, uchar_t s2 ) : r_(lhs1,ls1,hs1,s1), l_(lhs2,ls2,hs2,s2) { } state_t l( void ) const { return( l_ ); } state_t r( void ) const { return( r_ ); } bool operator<( const action_t &a ) const { return( (l_ < a.l_) || ((l_ == a.l_) && (r_ < a.r_)) ); } bool operator==( const action_t &a ) const { return( (l_ == a.l_) && (r_ == a.r_) ); } void print( std::ostream &os ) const;};inline state_tstate_t::next( const action_t &a, int obs ) const{ if( obs < 0 ) { state_t s( 0, a.r().lhs()+a.r().ls(), a.l().lhs()+a.l().hs(), s_+ls_-a.r().ls()+hs_-a.l().hs()+lhs_-a.l().lhs()-a.r().lhs() ); return( s ); } else if( obs > 0 ) { state_t s( 0, a.l().lhs()+a.l().ls(), a.r().lhs()+a.r().hs(), s_+ls_-a.l().ls()+hs_-a.r().hs()+lhs_-a.l().lhs()-a.r().lhs() ); return( s ); } else { state_t s( lhs_-a.l().lhs()-a.r().lhs(), ls_-a.l().ls()-a.r().ls(), hs_-a.l().hs()-a.r().hs(), s_+a.l().ls()+a.r().ls()+a.l().hs()+a.r().hs()+a.l().lhs()+a.r().lhs() ); return( s ); }}inline boolstate_t::useful( const action_t &a ) const{ for( int obs = -1; obs < 2; ++obs ) { state_t s_obs = next( a, obs ); if( s_obs == *this ) return( false ); } return( true );}std::ostream& operator<<( std::ostream &os, const state_t &s ){ s.print( os ); return( os );}std::ostream& operator<<( std::ostream &os, const action_t &a ){ a.print( os ); return( os );}voidstate_t::print( std::ostream &os ) const{ os << "(" << lhs_ << "," << ls_ << "," << hs_ << "," << s_ << ")";}voidaction_t::print( std::ostream &os ) const{ os << "PICKUP[" << l_ << ":" << r_ << "]";}// heuristic and hash classesclass heuristic_t{ int type_; uchar_t n_; hash_t *hash_;public: heuristic_t( int type, uchar_t n ); ~heuristic_t(); unsigned value( const state_t &s ) const; void dump( std::ostream &os ) const; void compute_heuristic( const state_t &s0, size_t iters, qvalue_func_t qvalue );};namespace hashing = ::__gnu_cxx;namespace __gnu_cxx { template<> class hash<state_t> { public: size_t operator()( const state_t &s ) const { return( (s.lhs()<<24) | (s.ls()<<16) | (s.hs()<<8) | s.s() ); } }; class entry_t { public: unsigned lower_; unsigned upper_; bool mark_; entry_t( unsigned lower = 0, unsigned upper = UMAX, bool mark = false ) : lower_(lower), upper_(upper), mark_(mark) { } }; class hash_t : public hash_map<state_t,entry_t> { public: typedef std::pair<state_t,entry_t*> data_pair; private: const heuristic_t *h_; data_pair push( const state_t &s, unsigned l, unsigned u, bool m ) { std::pair<iterator,bool> p = insert( std::make_pair( s, entry_t(l,u,m) ) ); return( data_pair( (*p.first).first, &(*p.first).second ) ); } data_pair push( const state_t &s ) { return( push( s, (!h_?0:h_->value(s)), UMAX, false ) ); } const_iterator lookup( const state_t &s ) const { return( find( s ) ); } iterator lookup( const state_t &s ) { return( find( s ) ); } public: hash_t( const heuristic_t *h = 0 ) : h_(h) { } data_pair get( const state_t &s ) { iterator di = lookup( s ); return( di == end() ? push( s ) : data_pair( (*di).first, &(*di).second ) ); } unsigned value( const state_t &s ) const { const_iterator di = lookup( s ); return( di == end() ? (!h_?0:h_->value(s)) : (*di).second.lower_ ); } unsigned upper( const state_t &s ) const { const_iterator di = lookup( s ); return( di == end() ? UMAX : (*di).second.upper_ ); } void update( const state_t &s, unsigned value ) { iterator di = lookup( s ); if( di == end() ) push( s, value, UMAX, false ); else (*di).second.lower_ = value; } void update_upper( const state_t &s, unsigned value ) { iterator di = lookup( s ); if( di == end() ) push( s, (!h_?0:h_->value(s)), value, false ); else (*di).second.upper_ = value; } bool solved( const state_t &s ) const { const_iterator di = lookup( s ); return( di == end() ? false : (*di).second.mark_ ); } void mark( const state_t &s ) { iterator di = lookup( s ); if( di == end() ) push( s, (!h_?0:h_->value(s)), UMAX, true ); else (*di).second.mark_ = true; } void dump( std::ostream &os ) const { for( const_iterator it = begin(); it != end(); ++it ) os << "state=" << (*it).first << ", lower=" << (*it).second.lower_ << ", upper=" << (*it).second.upper_ << ", mark=" << ((*it).second.mark_?1:0) << std::endl; } };};class hash_t{ hashing::hash_t table_; unsigned updates_; unsigned upper_updates_;public: hash_t( const heuristic_t *h = 0 ) : table_(h), updates_(0), upper_updates_(0) { } typedef hashing::hash_t::data_pair data_pair; data_pair get( const state_t &s ) { return( table_.get( s ) ); } unsigned updates( void ) const { return( updates_ ); } unsigned upper_updates( void ) const { return( upper_updates_ ); } size_t size( void ) const { return( table_.size() ); } size_t bucket_count( void ) const { return( table_.bucket_count() ); } unsigned value( const state_t &s ) const { return( table_.value( s ) ); } unsigned upper( state_t &s ) const { return( table_.upper( s ) ); } void update( const state_t &s, unsigned value ) { table_.update( s, value ); ++updates_; } void update_upper( state_t &s, unsigned value ) { table_.update_upper( s, value ); ++upper_updates_; } bool solved( const state_t &s ) const { return( table_.solved( s ) ); } void mark( state_t s ) { table_.mark( s ); } unsigned QValueMax( const state_t &s, const action_t &a ) const; unsigned QValueAdd( const state_t &s, const action_t &a ) const; unsigned bestQValue( uchar_t n, const state_t &s, qvalue_func_t qvalue ) const; std::pair<unsigned,action_t> bestAction( uchar_t n, const state_t &s, qvalue_func_t qvalue ) const; void dump( std::ostream &os ) const { table_.dump( os ); }};unsignedhash_t::QValueMax( const state_t& s, const action_t &a ) const{ unsigned qv = 0; for( int obs = -1; obs < 2; ++obs ) { state_t s_obs = s.next( a, obs ); qv = MAX(qv,value(s_obs)); } return( qv == UMAX ? UMAX : 1 + qv );}unsignedhash_t::QValueAdd( const state_t& s, const action_t &a ) const{ unsigned qv = 0; for( int obs = -1; obs < 2; ++obs ) { state_t s_obs = s.next( a, obs ); qv = qv + value(s_obs); } return( qv >= UMAX ? UMAX : 1 + qv );}unsignedhash_t::bestQValue( uchar_t n, const state_t& s, qvalue_func_t qvalue ) const{ unsigned bqv = UMAX; for( int m = n>>1; m > 0; --m ) for( int lhs1 = m; lhs1 >= 0; --lhs1 ) for( int lhs2 = m; lhs2 >= 0; --lhs2 ) if( lhs1 + lhs2 <= s.lhs() ) { for( int ls1 = m-lhs1; ls1 >= 0; --ls1 ) for( int ls2 = m-lhs2; ls2 >= 0; --ls2 ) if( ls1 + ls2 <= s.ls() ) { for( int hs1 = m-lhs1-ls1; hs1 >= 0; --hs1 ) for( int hs2 = m-lhs2-ls2; hs2 >= 0; --hs2 ) if( (hs1+hs2 <= s.hs()) && (m+m-lhs1-lhs2-ls1-ls2-hs1-hs2 <= s.s()) ) { assert( m - lhs1 - ls1 - hs1 + m - lhs2 - ls2 - hs2 <= s.s() ); action_t a( lhs1, ls1, hs1, m-lhs1-ls1-hs1, lhs2, ls2, hs2, m-lhs2-ls2-hs2 ); if( !s.useful(a) ) continue; unsigned qv = (this->*qvalue)( s, a ); bqv = MIN(bqv,qv); } } } return( bqv );}std::pair<unsigned,action_t>hash_t::bestAction( uchar_t n, const state_t &s, qvalue_func_t qvalue ) const{ action_t ba; unsigned bqv = UMAX; for( int m = n>>1; m > 0; --m ) for( int lhs1 = m; lhs1 >= 0; --lhs1 ) for( int lhs2 = m; lhs2 >= 0; --lhs2 ) if( lhs1 + lhs2 <= s.lhs() ) { for( int ls1 = m-lhs1; ls1 >= 0; --ls1 ) for( int ls2 = m-lhs2; ls2 >= 0; --ls2 ) if( ls1 + ls2 <= s.ls() ) { for( int hs1 = m-lhs1-ls1; hs1 >= 0; --hs1 ) for( int hs2 = m-lhs2-ls2; hs2 >= 0; --hs2 ) if( (hs1+hs2 <= s.hs()) && (m+m-lhs1-lhs2-ls1-ls2-hs1-hs2 <= s.s()) ) { assert( m - lhs1 - ls1 - hs1 + m - lhs2 - ls2 - hs2 <= s.s() ); action_t a( lhs1, ls1, hs1, m-lhs1-ls1-hs1, lhs2, ls2, hs2, m-lhs2-ls2-hs2 ); if( !s.useful(a) ) continue; unsigned qv = (this->*qvalue)( s, a ); if( qv < bqv ) { bqv = qv; ba = a; } } } } return( std::make_pair( bqv, ba ) );}heuristic_t::heuristic_t( int type, uchar_t n ) : type_(type), n_(n){ hash_ = new hash_t;}heuristic_t::~heuristic_t(){ delete hash_;}inline unsignedheuristic_t::value( const state_t &s ) const{ return( hash_->value( s ) );}voidheuristic_t::dump( std::ostream &os ) const{ hash_->dump( os );}voidheuristic_t::compute_heuristic( const state_t &s0, size_t iters, qvalue_func_t qvalue ){ std::set<state_t> space; if( iters > 0 ) generateStateSpace( n_, s0, space ); size_t i = 0, ups = iters * space.size(); while( ((type_ == 1) && (i < iters)) || ((type_ == 2) && (hash_->updates() < ups)) ) { ++i; for( std::set<state_t>::const_iterator si = space.begin(); si != space.end(); ++si ) if( (type_ == 1) || ((type_ == 2) && (drand48() < 0.5)) ) { if( !(*si).terminal(n_) ) hash_->update( *si, hash_->bestQValue( n_, *si, qvalue ) ); else hash_->update( *si, (*si).terminal_value() ); if( (type_ == 2) && (hash_->updates() >= ups) ) goto end; } } end:; std::cout << "updates=" << hash_->updates() << ", iters=" << i << std::endl;}// graph class (for AO*)class graph_t{public: class entry_t { const state_t s_; std::set<entry_t*> parents_; action_t marked_; bool fringe_; bool visited_; size_t index_; public: entry_t() : s_(0), fringe_(true), visited_(false), index_(0) { } entry_t( const state_t s ) : s_(s), fringe_(true), visited_(false), index_(0) { } const state_t& s( void ) const { return( s_ ); } bool fringe( void ) const { return( fringe_ ); } void set_fringe( bool fringe ) { fringe_ = fringe; } std::set<entry_t*>& parents( void ) { return( parents_ ); } const std::set<entry_t*>& parents( void ) const { return( parents_ ); } void add_parent( entry_t *parent ) { parents_.insert( parent ); } bool visited( void ) const { return( visited_ ); } void visit( void ) { visited_ = true; } void unvisit( void ) { visited_ = false; } action_t marked( void ) const { return( marked_ ); } void mark( const action_t a ) { marked_ = a; } void inc_index( void ) { ++index_; } void dec_index( void ) { assert( index_ > 0 ); --index_; } void set_index( size_t index ) { index_ = index; } size_t index( void ) const { return( index_ ); } bool revise( uchar_t n, hash_t &hash, qvalue_func_t qvalue ); bool operator()( const entry_t *e1, const entry_t *e2 ) const { return( e1->index() < e2->index() ); } };private: std::map<state_t,entry_t*> nodes_; std::set<entry_t*> tips_; entry_t *root_;public: graph_t() : root_(0) { } entry_t* root( void ) { return( root_ ); } void set_root( entry_t *root ) { root_ = root; } const std::set<entry_t*>& tips( void ) const { return( tips_ ); } entry_t& choose_tip( void ) { return( **tips_.begin() ); } entry_t* add_tip( entry_t *parent, const action_t &a, int obs, const state_t &child ); bool expand( uchar_t n, entry_t &node, std::multiset<entry_t*,entry_t> &S ); void expand_and_update( uchar_t n, entry_t &node, hash_t &hash, qvalue_func_t qvalue );};inline boolgraph_t::entry_t::revise( uchar_t n, hash_t &hash, qvalue_func_t qvalue ){ if( s_.terminal(n) ) { hash.update( s_, s_.terminal_value() ); hash.mark( s_ ); fringe_ = false; return( true ); } else { unsigned old_value = hash.value( s_ ); std::pair<unsigned,action_t> p = hash.bestAction( n, s_, qvalue ); hash.update( s_, p.first ); mark( p.second ); bool solved = true; if( p.first < UMAX ) { assert( s_.useful( p.second ) ); for( int obs = -1; solved && (obs < 2); ++obs ) { state_t s_obs = s_.next( p.second, obs ); solved = hash.solved( s_obs ); } if( solved ) { hash.mark( s_ ); fringe_ = false; } } else { hash.mark( s_ ); fringe_ = false; } return( solved || (p.first > old_value) ); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -