📄 rules.cc
字号:
#include <iostream>#include <ext/hash_map>#include <list>#include <map>#include <set>#include <vector>#include <assert.h>#include <limits.h>#include <values.h>#include <sys/time.h>#define MIN(a,b) ((a)<(b)?(a):(b))#define MAX(a,b) ((a)>(b)?(a):(b))#define MILLION 1000000#define TIME(s,u) ((float)(s) + ((float)u)/MILLION)class rule_t;class hash_t;class state_t;class rule_system_t;typedef unsigned atom_t;typedef double (hash_t::*qvalue_func_t)( const state_t&, const rule_t& ) const;static bool verbose = false;static size_t expansions = 0;static const rule_system_t *rule_system = 0;static unsigned short rseed = 0;static size_t num_atoms = 0;static size_t max_rules_per_atom = 0;static size_t max_body_size = 0;inline voidindent( std::ostream &os, size_t depth ){ for( size_t i = 0; i < depth; ++i ) os << ' ';}class rule_body_t : public std::vector<atom_t>{public: rule_body_t( size_t size ) : std::vector<atom_t>( size ) { }};class rule_t{ atom_t head_; rule_body_t body_;public: rule_t( size_t body_size ) : head_(UINT_MAX), body_(body_size) { } void set_head( atom_t atom ) { head_ = atom; } atom_t head( void ) const { return( head_ ); } void insert_body( size_t pos, atom_t atom ) { body_[pos] = atom; } const rule_body_t& body( void ) const { return( body_ ); } void print( std::ostream &os ) const;};class rule_list_t : public std::list<const rule_t*> { };class rule_system_t{ size_t num_atoms_; rule_list_t *rules_; rule_list_t *inv_rules_;public: rule_system_t( size_t num_atoms = 0 ) : num_atoms_(num_atoms) { rules_ = new rule_list_t[num_atoms_]; inv_rules_ = new rule_list_t[num_atoms_]; } ~rule_system_t() { for( size_t i = 0; i < num_atoms_; ++i ) for( rule_list_t::const_iterator ri = rules_[i].begin(); ri != rules_[i].end(); ++ri ) delete *ri; delete[] rules_; delete[] inv_rules_; } static const rule_system_t& generate_random( size_t num_atoms, size_t max_rules_per_atom, size_t max_body_size ); size_t num_atoms( void ) const { return( num_atoms_ ); } void insert_rule( atom_t head, const rule_t *rule ); const rule_list_t& rules( atom_t head ) const { return( rules_[head] ); } const rule_list_t& inv_rules( atom_t p ) const { return( inv_rules_[p] ); } void print( std::ostream &os, size_t depth ) const;};voidrule_system_t::insert_rule( atom_t head, const rule_t *rule ){ rules_[head].push_back( rule ); const rule_body_t &body = rule->body(); for( rule_body_t::const_iterator bi = body.begin(); bi != body.end(); ++bi ) inv_rules_[*bi].push_back( rule );}const rule_system_t&rule_system_t::generate_random( size_t num_atoms, size_t max_rules_per_atom, size_t max_body_size ){ rule_system_t *rs = new rule_system_t( num_atoms ); for( atom_t head = 0; head < num_atoms; ++head ) { std::set<std::set<atom_t> > bodies; size_t nrules = lrand48() % max_rules_per_atom; for( size_t r = 0; r < 1+nrules; ++r ) { std::set<atom_t> body; size_t goal = 1 + (lrand48() % max_body_size); goal = MIN(num_atoms-head-1,goal); while( body.size() < goal ) { for( atom_t atom = 1+head; (atom < num_atoms) && (body.size() < goal); ++atom ) if( !(lrand48()%10) ) body.insert( atom ); } bodies.insert( body ); } for( std::set<std::set<atom_t> >::const_iterator bi = bodies.begin(); bi != bodies.end(); ++bi ) { rule_t *rule = new rule_t( (*bi).size() ); rule->set_head( head ); size_t pos = 0; for( std::set<atom_t>::const_iterator ai = (*bi).begin(); ai != (*bi).end(); ++ai ) rule->insert_body( pos++, *ai ); rs->insert_rule( head, rule ); } } return( *rs );}std::ostream& operator<<( std::ostream &os, const rule_t &r ){ r.print( os ); return( os );}std::ostream& operator<<( std::ostream &os, const rule_system_t &rs ){ rs.print( os, 0 ); return( os );}voidrule_t::print( std::ostream &os ) const{ os << head_ << " <-- "; for( rule_body_t::const_iterator bi = body_.begin(); bi != body_.end(); ++bi ) os << *bi << " ";}voidrule_system_t::print( std::ostream &os, size_t depth ) const{ for( atom_t head = 0; head < num_atoms_; ++head ) for( rule_list_t::const_iterator ri = rules_[head].begin(); ri != rules_[head].end(); ++ri ) { indent( os, depth ); os << **ri << std::endl; }}class state_t{ atom_t atom_;public: state_t( atom_t atom = UINT_MAX ) : atom_(atom) { } atom_t atom( void ) const { return( atom_ ); } void print( std::ostream &os, size_t depth ) const { indent( os, depth ); os << "{ " << atom_ << " }"; } bool terminal( void ) const { const rule_list_t &rules = rule_system->rules( atom_ ); for( rule_list_t::const_iterator ri = rules.begin(); ri != rules.end(); ++ri ) if( (*ri)->body().empty() ) return( true ); return( false ); } double terminal_value( void ) const { return( 1 ); } bool operator==( const state_t &s ) const { return( atom_ == s.atom_ ); } bool operator<( const state_t &s ) const { return( atom_ < s.atom_ ); }};std::ostream& operator<<( std::ostream &os, const state_t &s ){ s.print( os, 0 ); return( os );}class heuristic_t{ int type_; hash_t *hash_; const rule_system_t &rule_system_;public: heuristic_t( int type, const rule_system_t &rs ); ~heuristic_t(); double value( const state_t &s ) const; void compute_heuristic( size_t iters, qvalue_func_t qvalue ); void dump( std::ostream &os ) const;};namespace hashing = ::__gnu_cxx;namespace __gnu_cxx { template<> class hash<const state_t> { public: size_t operator()( const state_t &s ) const { return( s.atom() ); } }; class entry_t { public: double lower_; double upper_; bool mark_; entry_t( double lower = 0, double upper = DBL_MAX, bool mark = false ) : lower_(lower), upper_(upper), mark_(mark) { } }; class hash_t : public hash_map<const state_t,entry_t> { public: typedef std::pair<state_t,entry_t*> data_pair; private: const heuristic_t *h_; data_pair push( const state_t &s, double l, double u, bool m ) { std::pair<iterator,bool> p = insert( std::make_pair( s, entry_t(l,u,m) ) ); return( data_pair( (*p.first).first, &(*p.first).second ) ); } data_pair push( const state_t &s ) { return( push( s, (!h_?0:h_->value(s)), DBL_MAX, false ) ); } const_iterator lookup( const state_t &s ) const { return( find( s ) ); } iterator lookup( const state_t &s ) { return( find( s ) ); } public: hash_t( const heuristic_t *h = 0 ) : h_(h) { } data_pair get( const state_t &s ) { iterator di = lookup( s ); return( di == end() ? push( s ) : data_pair( (*di).first, &(*di).second ) ); } double value( const state_t &s ) const { const_iterator di = lookup( s ); return( di == end() ? (!h_?0:h_->value(s)) : (*di).second.lower_ ); } double upper( const state_t &s ) const { const_iterator di = lookup( s ); return( di == end() ? DBL_MAX : (*di).second.upper_ ); } void update( const state_t &s, double value ) { iterator di = lookup( s ); if( di == end() ) push( s, value, DBL_MAX, false ); else (*di).second.lower_ = value; } void update_upper( const state_t &s, double value ) { iterator di = lookup( s ); if( di == end() ) push( s, (!h_?0:h_->value(s)), value, false ); else (*di).second.upper_ = value; } bool solved( const state_t &s ) const { const_iterator di = lookup( s ); return( di == end() ? false : (*di).second.mark_ ); } void mark( const state_t &s ) { iterator di = lookup( s ); if( di == end() ) push( s, (!h_?0:h_->value(s)), DBL_MAX, true ); else (*di).second.mark_ = true; } void dump( std::ostream &os ) const { for( const_iterator it = begin(); it != end(); ++it ) os << "state=" << (*it).first << ", lower=" << (*it).second.lower_ << ", upper=" << (*it).second.upper_ << ", mark=" << ((*it).second.mark_?1:0) << std::endl; } };};class hash_t{ hashing::hash_t table_; unsigned updates_; unsigned upper_updates_;public: hash_t( const heuristic_t *h = 0 ) : table_(h), updates_(0), upper_updates_(0) { } typedef hashing::hash_t::data_pair data_pair; data_pair get( const state_t &s ) { return( table_.get( s ) ); } unsigned updates( void ) const { return( updates_ ); } unsigned upper_updates( void ) const { return( upper_updates_ ); } size_t size( void ) const { return( table_.size() ); } size_t bucket_count( void ) const { return( table_.bucket_count() ); } double value( const state_t &s ) const { return( table_.value( s ) ); } double upper( const state_t &s ) const { return( table_.upper( s ) ); } void update( const state_t &s, double value ) { table_.update( s, value ); ++updates_; } void update_upper( const state_t &s, double value ) { table_.update_upper( s, value ); ++upper_updates_; } bool solved( const state_t &s ) const { return( table_.solved( s ) ); } void mark( state_t s ) { table_.mark( s ); } double QValueMax( const state_t &s, const rule_t &r ) const; double QValueAdd( const state_t &s, const rule_t &r ) const; double bestQValue( const state_t &s, qvalue_func_t qvalue ) const; std::pair<double,const rule_t*> bestAction( const state_t &s, qvalue_func_t qvalue ) const; void dump( std::ostream &os ) const { table_.dump( os ); }};class graph_t{public: class entry_t { const state_t s_; std::set<entry_t*> parents_; const rule_t *marked_; bool fringe_; bool visited_; size_t index_; public: entry_t() : marked_(0), fringe_(true), visited_(false), index_(0) { } entry_t( const state_t s ) : s_(s), marked_(0), fringe_(true), visited_(false), index_(0) { } const state_t& s( void ) const { return( s_ ); } bool fringe( void ) const { return( fringe_ ); } void set_fringe( bool fringe ) { fringe_ = fringe; } std::set<entry_t*>& parents( void ) { return( parents_ ); } const std::set<entry_t*>& parents( void ) const { return( parents_ ); } void add_parent( entry_t *parent ) { parents_.insert( parent ); } bool visited( void ) const { return( visited_ ); } void visit( void ) { visited_ = true; } void unvisit( void ) { visited_ = false; } const rule_t* marked( void ) const { return( marked_ ); } void mark( const rule_t *r ) { marked_ = r; } void inc_index( void ) { ++index_; } void dec_index( void ) { assert( index_ > 0 ); --index_; } void set_index( size_t index ) { index_ = index; } size_t index( void ) const { return( index_ ); } bool revise( hash_t &hash, qvalue_func_t qvalue ); bool operator()( const entry_t *e1, const entry_t *e2 ) const { return( e1->index() < e2->index() ); } };private: std::map<state_t,entry_t*> nodes_; std::set<entry_t*> tips_; entry_t *root_;public: graph_t() : root_(0) { } entry_t* root( void ) { return( root_ ); } void set_root( entry_t *root ) { root_ = root; } const std::set<entry_t*>& tips( void ) const { return( tips_ ); } entry_t& choose_tip( void ) { return( **tips_.begin() ); } entry_t* add_tip( entry_t *parent, const rule_t *rule, unsigned target, const state_t &child ); void expand( entry_t &n, std::multiset<entry_t*,entry_t> &S ); void expand_and_update( entry_t &n, hash_t &hash, qvalue_func_t qvalue );};doublehash_t::QValueMax( const state_t &s, const rule_t &r ) const{ double qv = 0; for( rule_body_t::const_iterator bi = r.body().begin(); bi != r.body().end(); ++bi ) { double v = value( state_t( *bi ) ); qv = MAX(qv,v); } return( 1 + qv );}doublehash_t::QValueAdd( const state_t &s, const rule_t &r ) const{ double qv = 0; for( rule_body_t::const_iterator bi = r.body().begin(); bi != r.body().end(); ++bi ) qv += value( state_t( *bi ) ); return( 1 + qv );}doublehash_t::bestQValue( const state_t &s, qvalue_func_t qvalue ) const{ double bqv = DBL_MAX; const rule_list_t &rules = rule_system->rules( s.atom() ); for( rule_list_t::const_iterator ri = rules.begin(); ri != rules.end(); ++ri ) { double qv = (this->*qvalue)( s, **ri ); bqv = MIN(bqv,qv); } return( bqv );}inline std::pair<double,const rule_t*>hash_t::bestAction( const state_t &s, qvalue_func_t qvalue ) const{ const rule_t *best_rule = 0; double bqv = DBL_MAX; const rule_list_t &rules = rule_system->rules( s.atom() ); for( rule_list_t::const_iterator ri = rules.begin(); ri != rules.end(); ++ri ) { double qv = (this->*qvalue)( s, **ri ); if( qv < bqv ) { bqv = qv; best_rule = *ri; } } return( std::make_pair( bqv, best_rule ) );}heuristic_t::heuristic_t( int type, const rule_system_t &rs ) : type_(type), rule_system_(rs){ hash_ = new hash_t;}heuristic_t::~heuristic_t(){ delete hash_;}inline doubleheuristic_t::value( const state_t &s ) const{ return( hash_->value( s ) );}voidheuristic_t::dump( std::ostream &os ) const{ hash_->dump( os );}voidheuristic_t::compute_heuristic( size_t iters, qvalue_func_t qvalue ){ size_t i = 0, ups = iters * rule_system_.num_atoms(); while( ((type_ == 1) && (i < iters)) || ((type_ == 2) && (hash_->updates() < ups)) ) { ++i; for( atom_t p = 0; p < rule_system_.num_atoms(); ++p ) if( (type_ == 1) || ((type_ == 2) && (drand48() < 0.5)) ) { state_t s( p ); if( !s.terminal() ) hash_->update( s, hash_->bestQValue( s, qvalue ) ); else hash_->update( s, s.terminal_value() ); if( (type_ == 2) && (hash_->updates() >= ups) ) goto end; } } end:; std::cerr << "updates=" << hash_->updates() << ", iters=" << i << std::endl;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -