⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 rules.cc

📁 基于学习的深度优先搜索算法
💻 CC
📖 第 1 页 / 共 3 页
字号:
#include <iostream>#include <ext/hash_map>#include <list>#include <map>#include <set>#include <vector>#include <assert.h>#include <limits.h>#include <values.h>#include <sys/time.h>#define MIN(a,b)        ((a)<(b)?(a):(b))#define MAX(a,b)        ((a)>(b)?(a):(b))#define MILLION         1000000#define TIME(s,u)       ((float)(s) + ((float)u)/MILLION)class rule_t;class hash_t;class state_t;class rule_system_t;typedef unsigned atom_t;typedef double (hash_t::*qvalue_func_t)( const state_t&, const rule_t& ) const;static bool verbose = false;static size_t expansions = 0;static const rule_system_t *rule_system = 0;static unsigned short rseed = 0;static size_t num_atoms = 0;static size_t max_rules_per_atom = 0;static size_t max_body_size = 0;inline voidindent( std::ostream &os, size_t depth ){  for( size_t i = 0; i < depth; ++i ) os << ' ';}class rule_body_t : public std::vector<atom_t>{public:  rule_body_t( size_t size ) : std::vector<atom_t>( size ) { }};class rule_t{  atom_t head_;  rule_body_t body_;public:  rule_t( size_t body_size ) : head_(UINT_MAX), body_(body_size) { }  void set_head( atom_t atom ) { head_ = atom; }  atom_t head( void ) const { return( head_ ); }  void insert_body( size_t pos, atom_t atom ) { body_[pos] = atom; }  const rule_body_t& body( void ) const { return( body_ ); }  void print( std::ostream &os ) const;};class rule_list_t : public std::list<const rule_t*> { };class rule_system_t{  size_t num_atoms_;  rule_list_t *rules_;  rule_list_t *inv_rules_;public:  rule_system_t( size_t num_atoms = 0 ) : num_atoms_(num_atoms) { rules_ = new rule_list_t[num_atoms_]; inv_rules_ = new rule_list_t[num_atoms_]; }  ~rule_system_t() { for( size_t i = 0; i < num_atoms_; ++i ) for( rule_list_t::const_iterator ri = rules_[i].begin(); ri != rules_[i].end(); ++ri ) delete *ri; delete[] rules_; delete[] inv_rules_; }  static const rule_system_t& generate_random( size_t num_atoms, size_t max_rules_per_atom, size_t max_body_size );  size_t num_atoms( void ) const { return( num_atoms_ ); }  void insert_rule( atom_t head, const rule_t *rule );  const rule_list_t& rules( atom_t head ) const { return( rules_[head] ); }  const rule_list_t& inv_rules( atom_t p ) const { return( inv_rules_[p] ); }  void print( std::ostream &os, size_t depth ) const;};voidrule_system_t::insert_rule( atom_t head, const rule_t *rule ){  rules_[head].push_back( rule );  const rule_body_t &body = rule->body();  for( rule_body_t::const_iterator bi = body.begin(); bi != body.end(); ++bi )    inv_rules_[*bi].push_back( rule );}const rule_system_t&rule_system_t::generate_random( size_t num_atoms, size_t max_rules_per_atom, size_t max_body_size ){  rule_system_t *rs = new rule_system_t( num_atoms );  for( atom_t head = 0; head < num_atoms; ++head )    {      std::set<std::set<atom_t> > bodies;      size_t nrules = lrand48() % max_rules_per_atom;      for( size_t r = 0; r < 1+nrules; ++r )        {          std::set<atom_t> body;          size_t goal = 1 + (lrand48() % max_body_size);          goal = MIN(num_atoms-head-1,goal);          while( body.size() < goal )            {              for( atom_t atom = 1+head; (atom < num_atoms) && (body.size() < goal); ++atom )                if( !(lrand48()%10) ) body.insert( atom );            }          bodies.insert( body );        }      for( std::set<std::set<atom_t> >::const_iterator bi = bodies.begin(); bi != bodies.end(); ++bi )        {          rule_t *rule = new rule_t( (*bi).size() );          rule->set_head( head );          size_t pos = 0;          for( std::set<atom_t>::const_iterator ai = (*bi).begin(); ai != (*bi).end(); ++ai )            rule->insert_body( pos++, *ai );          rs->insert_rule( head, rule );        }    }  return( *rs );}std::ostream& operator<<( std::ostream &os, const rule_t &r ){  r.print( os );  return( os );}std::ostream& operator<<( std::ostream &os, const rule_system_t &rs ){  rs.print( os, 0 );  return( os );}voidrule_t::print( std::ostream &os ) const{  os << head_ << " <-- ";  for( rule_body_t::const_iterator bi = body_.begin(); bi != body_.end(); ++bi )    os << *bi << " ";}voidrule_system_t::print( std::ostream &os, size_t depth ) const{  for( atom_t head = 0; head < num_atoms_; ++head )    for( rule_list_t::const_iterator ri = rules_[head].begin(); ri != rules_[head].end(); ++ri )      {        indent( os, depth );        os << **ri << std::endl;      }}class state_t{  atom_t atom_;public:  state_t( atom_t atom = UINT_MAX ) : atom_(atom) { }    atom_t atom( void ) const { return( atom_ ); }  void print( std::ostream &os, size_t depth ) const { indent( os, depth ); os << "{ " << atom_ << " }"; }  bool terminal( void ) const { const rule_list_t &rules = rule_system->rules( atom_ ); for( rule_list_t::const_iterator ri = rules.begin(); ri != rules.end(); ++ri ) if( (*ri)->body().empty() ) return( true ); return( false ); }  double terminal_value( void ) const { return( 1 ); }  bool operator==( const state_t &s ) const { return( atom_ == s.atom_ ); }  bool operator<( const state_t &s ) const { return( atom_ < s.atom_ ); }};std::ostream& operator<<( std::ostream &os, const state_t &s ){  s.print( os, 0 );  return( os );}class heuristic_t{  int type_;  hash_t *hash_;  const rule_system_t &rule_system_;public:  heuristic_t( int type, const rule_system_t &rs );  ~heuristic_t();  double value( const state_t &s ) const;  void compute_heuristic( size_t iters, qvalue_func_t qvalue );  void dump( std::ostream &os ) const;};namespace hashing = ::__gnu_cxx;namespace __gnu_cxx {  template<> class hash<const state_t>  {  public:    size_t operator()( const state_t &s ) const { return( s.atom() ); }  };  class entry_t  {  public:    double lower_;    double upper_;    bool mark_;    entry_t( double lower = 0, double upper = DBL_MAX, bool mark = false ) : lower_(lower), upper_(upper), mark_(mark) { }  };  class hash_t : public hash_map<const state_t,entry_t>  {  public:    typedef std::pair<state_t,entry_t*> data_pair;  private:    const heuristic_t *h_;    data_pair push( const state_t &s, double l, double u, bool m ) { std::pair<iterator,bool> p = insert( std::make_pair( s, entry_t(l,u,m) ) ); return( data_pair( (*p.first).first, &(*p.first).second ) ); }    data_pair push( const state_t &s ) { return( push( s, (!h_?0:h_->value(s)), DBL_MAX, false ) ); }    const_iterator lookup( const state_t &s ) const { return( find( s ) ); }    iterator lookup( const state_t &s ) { return( find( s ) ); }  public:    hash_t( const heuristic_t *h = 0 ) : h_(h) { }    data_pair get( const state_t &s ) { iterator di = lookup( s ); return( di == end() ? push( s  ) : data_pair( (*di).first, &(*di).second ) ); }    double value( const state_t &s ) const { const_iterator di = lookup( s ); return( di == end() ? (!h_?0:h_->value(s)) : (*di).second.lower_ ); }    double upper( const state_t &s ) const { const_iterator di = lookup( s ); return( di == end() ? DBL_MAX : (*di).second.upper_ ); }    void update( const state_t &s, double value ) { iterator di = lookup( s ); if( di == end() ) push( s, value, DBL_MAX, false ); else (*di).second.lower_ = value; }    void update_upper( const state_t &s, double value ) { iterator di = lookup( s ); if( di == end() ) push( s, (!h_?0:h_->value(s)), value, false ); else (*di).second.upper_ = value; }    bool solved( const state_t &s ) const { const_iterator di = lookup( s ); return( di == end() ? false : (*di).second.mark_ ); }    void mark( const state_t &s ) { iterator di = lookup( s ); if( di == end() ) push( s, (!h_?0:h_->value(s)), DBL_MAX, true ); else (*di).second.mark_ = true; }    void dump( std::ostream &os ) const { for( const_iterator it = begin(); it != end(); ++it ) os << "state=" << (*it).first << ", lower=" << (*it).second.lower_ << ", upper=" << (*it).second.upper_ << ", mark=" << ((*it).second.mark_?1:0) << std::endl; }  };};class hash_t{  hashing::hash_t table_;  unsigned updates_;  unsigned upper_updates_;public:  hash_t( const heuristic_t *h = 0 ) : table_(h), updates_(0), upper_updates_(0) { }  typedef hashing::hash_t::data_pair data_pair;  data_pair get( const state_t &s ) { return( table_.get( s ) ); }  unsigned updates( void ) const { return( updates_ ); }  unsigned upper_updates( void ) const { return( upper_updates_ ); }  size_t size( void ) const { return( table_.size() ); }  size_t bucket_count( void ) const { return( table_.bucket_count() ); }  double value( const state_t &s ) const { return( table_.value( s ) ); }  double upper( const state_t &s ) const { return( table_.upper( s ) ); }  void update( const state_t &s, double value ) { table_.update( s, value ); ++updates_; }  void update_upper( const state_t &s, double value ) { table_.update_upper( s, value ); ++upper_updates_; }  bool solved( const state_t &s ) const { return( table_.solved( s ) ); }  void mark( state_t s ) { table_.mark( s ); }  double QValueMax( const state_t &s, const rule_t &r ) const;  double QValueAdd( const state_t &s, const rule_t &r ) const;  double bestQValue( const state_t &s, qvalue_func_t qvalue ) const;  std::pair<double,const rule_t*> bestAction( const state_t &s, qvalue_func_t qvalue ) const;  void dump( std::ostream &os ) const { table_.dump( os ); }};class graph_t{public:  class entry_t  {    const state_t s_;    std::set<entry_t*> parents_;    const rule_t *marked_;    bool fringe_;    bool visited_;    size_t index_;  public:    entry_t() : marked_(0), fringe_(true), visited_(false), index_(0) { }    entry_t( const state_t s ) : s_(s), marked_(0), fringe_(true), visited_(false), index_(0) { }    const state_t& s( void ) const { return( s_ ); }    bool fringe( void ) const { return( fringe_ ); }    void set_fringe( bool fringe ) { fringe_ = fringe; }    std::set<entry_t*>& parents( void ) { return( parents_ ); }    const std::set<entry_t*>& parents( void ) const { return( parents_ ); }    void add_parent( entry_t *parent ) { parents_.insert( parent ); }    bool visited( void ) const { return( visited_ ); }    void visit( void ) { visited_ = true; }    void unvisit( void ) { visited_ = false; }    const rule_t* marked( void ) const { return( marked_ ); }    void mark( const rule_t *r ) { marked_ = r; }    void inc_index( void ) { ++index_; }    void dec_index( void ) { assert( index_ > 0 ); --index_; }    void set_index( size_t index ) { index_ = index; }    size_t index( void ) const { return( index_ ); }    bool revise( hash_t &hash, qvalue_func_t qvalue );    bool operator()( const entry_t *e1, const entry_t *e2 ) const { return( e1->index() < e2->index() ); }  };private:  std::map<state_t,entry_t*> nodes_;  std::set<entry_t*> tips_;  entry_t *root_;public:  graph_t() : root_(0) { }  entry_t* root( void ) { return( root_ ); }  void set_root( entry_t *root ) { root_ = root; }  const std::set<entry_t*>& tips( void ) const { return( tips_ ); }  entry_t& choose_tip( void ) { return( **tips_.begin() ); }  entry_t* add_tip( entry_t *parent, const rule_t *rule, unsigned target, const state_t &child );  void expand( entry_t &n, std::multiset<entry_t*,entry_t> &S );  void expand_and_update( entry_t &n, hash_t &hash, qvalue_func_t qvalue );};doublehash_t::QValueMax( const state_t &s, const rule_t &r ) const{  double qv = 0;   for( rule_body_t::const_iterator bi = r.body().begin(); bi != r.body().end(); ++bi )    {       double v = value( state_t( *bi ) );      qv = MAX(qv,v);    }  return( 1 + qv );}doublehash_t::QValueAdd( const state_t &s, const rule_t &r ) const{  double qv = 0;   for( rule_body_t::const_iterator bi = r.body().begin(); bi != r.body().end(); ++bi )    qv += value( state_t( *bi ) );  return( 1 + qv );}doublehash_t::bestQValue( const state_t &s, qvalue_func_t qvalue ) const{  double bqv = DBL_MAX;  const rule_list_t &rules = rule_system->rules( s.atom() );  for( rule_list_t::const_iterator ri = rules.begin(); ri != rules.end(); ++ri )    {      double qv = (this->*qvalue)( s, **ri );      bqv = MIN(bqv,qv);    }  return( bqv );}inline std::pair<double,const rule_t*>hash_t::bestAction( const state_t &s, qvalue_func_t qvalue ) const{  const rule_t *best_rule = 0;  double bqv = DBL_MAX;  const rule_list_t &rules = rule_system->rules( s.atom() );  for( rule_list_t::const_iterator ri = rules.begin(); ri != rules.end(); ++ri )    {      double qv = (this->*qvalue)( s, **ri );      if( qv < bqv )        {          bqv = qv;          best_rule = *ri;        }    }  return( std::make_pair( bqv, best_rule ) );}heuristic_t::heuristic_t( int type, const rule_system_t &rs )  : type_(type), rule_system_(rs){  hash_ = new hash_t;}heuristic_t::~heuristic_t(){  delete hash_;}inline doubleheuristic_t::value( const state_t &s ) const{  return( hash_->value( s ) );}voidheuristic_t::dump( std::ostream &os ) const{  hash_->dump( os );}voidheuristic_t::compute_heuristic( size_t iters, qvalue_func_t qvalue ){  size_t i = 0, ups = iters * rule_system_.num_atoms();  while( ((type_ == 1) && (i < iters)) || ((type_ == 2) && (hash_->updates() < ups)) )    {      ++i;      for( atom_t p = 0; p < rule_system_.num_atoms(); ++p )        if( (type_ == 1) || ((type_ == 2) && (drand48() < 0.5)) )          {            state_t s( p );            if( !s.terminal() )              hash_->update( s, hash_->bestQValue( s, qvalue ) );            else              hash_->update( s, s.terminal_value() );            if( (type_ == 2) && (hash_->updates() >= ups) ) goto end;          }    }  end:;  std::cerr << "updates=" << hash_->updates() << ", iters=" << i << std::endl;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -