⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 rules.cc

📁 基于学习的深度优先搜索算法
💻 CC
📖 第 1 页 / 共 3 页
字号:
  aux.clear();  bool rv = true;  aux.insert( e.second );  queue.push_front( e );  while( !queue.empty() )    {      e = queue.front();      queue.pop_front();      if( ((e.second != 0) && e.second->mark_) || e.first.terminal() )	{	  if( e.second == 0 )	    {	      hash.update( e.first, e.first.terminal_value() );	      hash.mark( e.first );	    }	  else if( !e.second->mark_ )	    e.second->mark_ = true;	  continue;	}      assert( !e.first.terminal() );      stack.push_front( e );      std::pair<double,const rule_t*> p = hash.bestAction( e.first, qvalue );      if( (e.second == 0) || (e.second->lower_ != p.first) ) { rv = false; continue; }      const rule_body_t &body = p.second->body();      for( rule_body_t::const_iterator bi = body.begin(); bi != body.end(); ++bi )        {          state_t s_next = state_t( *bi );          hash_t::data_pair tmp = hash.get( s_next );          if( aux.find( tmp.second ) == aux.end() )            {              aux.insert( tmp.second );              queue.push_front( tmp );            }        }    }  while( !stack.empty() )    {      e = stack.front();      stack.pop_front();      assert( (e.second != 0) || !rv );      if( rv )	e.second->mark_ = true;      else	{	  assert( !e.first.terminal() );	  double bqv = hash.bestQValue( e.first, qvalue );	  assert( (e.second == 0) || (bqv >= e.second->lower_) );	  if( e.second != 0 )	    e.second->lower_ = bqv;	  else	    hash.update( e.first, bqv );	}    }  return( rv );}voidlabeled_lrta_trial( hash_t &hash, const state_t s0, qvalue_func_t qvalue ){  std::list<state_t> queue;  std::vector<size_t> b_moves;  state_t s = s0;  queue.push_front( s );  while( !s.terminal() && !hash.solved( s ) )    {      ++expansions;      std::pair<double,const rule_t*> p = hash.bestAction( s, qvalue );      assert( p.first != DBL_MAX );      hash.update( s, p.first );      const rule_body_t &body = p.second->body();      size_t prop = lrand48() % body.size();      s = state_t( body[prop] );      queue.push_front( s );    }  // mark terminal nodes as solved  if( s.terminal() || (hash.value( s ) == DBL_MAX) )    {      if( s.terminal() ) hash.update( s, s.terminal_value() );      hash.mark( s );    }  // labeling  while( !queue.empty() )    {      state_t s = queue.front();      queue.pop_front();      if( !checksolved( hash.get(s), hash, qvalue ) ) break;    }}size_tlabeled_lrta( hash_t &hash, const state_t s0, qvalue_func_t qvalue ){  expansions = 0;  size_t iterations = 0;  while( !hash.solved( s0 ) )    {      labeled_lrta_trial( hash, s0, qvalue );      ++iterations;      if( verbose ) std::cout << "V(s0)=" << hash.value( s0 ) << std::endl;    }  return( iterations );}size_tpolicySize( const state_t s0, hash_t &hash, qvalue_func_t qvalue ){  std::set<state_t> policy, aux;  std::list<state_t> stack;  aux.insert( s0 );  stack.push_back( s0 );  while( !stack.empty() )    {      state_t s = stack.front();      stack.pop_front();      policy.insert( s );      if( s.terminal() ) continue;      std::pair<double,const rule_t*> p = hash.bestAction( s, qvalue );      const rule_body_t &body = p.second->body();       for( rule_body_t::const_iterator bi = body.begin(); bi != body.end(); ++bi )        {          state_t s_next( *bi );          if( aux.find( s_next ) == aux.end() )            {              stack.push_back( s_next );              aux.insert( s_next );            }        }    }  return( policy.size() );}voiddumpCFC( const state_t s0, const heuristic_t *h, std::ostream &os, int format ){  std::map<state_t,unsigned> onodes;  std::map<unsigned,state_t> ionodes;  std::map<std::pair<unsigned,const rule_t*>,unsigned> anodes;  std::set<std::pair<unsigned,std::pair<unsigned,int> > > edges;  unsigned k = 0;  ionodes.insert( std::make_pair( k, s0 ) );  onodes.insert( std::make_pair( s0, k++ ) );  for( atom_t p = 0; p < rule_system->num_atoms(); ++p )    {      state_t s( p );      if( onodes.find( s ) == onodes.end() )        {          ionodes.insert( std::make_pair( k, s ) );          onodes.insert( std::make_pair( s, k++ ) );        }      if( s.terminal() ) continue;      unsigned s_idx = onodes[s];      const rule_list_t &rules = rule_system->rules( s.atom() );      for( rule_list_t::const_iterator ri = rules.begin(); ri != rules.end(); ++ri )        {           std::pair<unsigned,const rule_t*> s_a( s_idx, *ri );           if( anodes.find( s_a ) == anodes.end() ) anodes.insert( std::make_pair( s_a, k++ ) );           assert( anodes.find( s_a ) != anodes.end() );           unsigned s_a_idx = anodes[s_a];           edges.insert( std::make_pair( s_idx, std::make_pair( s_a_idx, 1 ) ) );           const rule_body_t &body = (*ri)->body();           if( !body.empty() )             for( rule_body_t::const_iterator bi = body.begin(); bi != body.end(); ++bi )               {                 state_t s_n( *bi );                 if( onodes.find( s_n ) == onodes.end() )                   {                     ionodes.insert( std::make_pair( k, s_n ) );                     onodes.insert( std::make_pair( s_n, k++ ) );                   }                 assert( onodes.find( s_n ) != onodes.end() );                 edges.insert( std::make_pair(s_a_idx,std::make_pair(onodes[s_n],0)) );               }        }    }  assert( onodes.size() == rule_system->num_atoms() );  assert( k == onodes.size() + anodes.size() );  if( format == 1 )    os << "Number_of_nodes " << onodes.size() + anodes.size() << std::endl;  else    {      os << "comment atoms" << num_atoms         << " rules " << max_rules_per_atom         << " body " << max_body_size         << " random seed " << rseed << std::endl;      os << "aograph " << onodes.size() << " "         << anodes.size() << " "         << edges.size() << std::endl;    }  // output nodes  for( size_t k = 0; k < onodes.size() + anodes.size(); ++k )    {      std::map<unsigned,state_t>::const_iterator it = ionodes.find( k );      if( it != ionodes.end() )        {          double hvalue = ( !h ? 0 : h->value( (*it).second ) );          int terminal = ( (*it).second.terminal() ? 0 : 1 );          if( format == 1 )            {              os << "h(" << k << ") " << hvalue << std::endl                 << "AND(0)/OR(1) 1" << std::endl                 << "SOLVED(0)/NON_TERMINAL(1) " << terminal << std::endl;            }          else            {              os << "or " << k << " " << terminal << " " << hvalue << std::endl;            }        }      else        {          if( format == 1 )            {              os << "h(" << k << ") 0" << std::endl                 << "AND(0)/OR(1) 0" << std::endl                 << "SOLVED(0)/NON_TERMINAL(1) 1" << std::endl;            }          else            {              os << "and " << k << " 1 0" << std::endl;            }        }    }  // output edges  if( format == 1 ) os << "arcarcarcarcarcarcarcarcarcarcarcarc" << std::endl;  for( std::set<std::pair<unsigned,std::pair<unsigned,int> > >::const_iterator it = edges.begin(); it != edges.end(); )    if( format == 1 )      {        os << "vtx1 " << (*it).first << std::endl           << "vtx2 " << (*it).second.first << std::endl           << "weight " << (*it).second.second << std::endl           << "another? " << (++it == edges.end() ? 0 : 1) << std::endl;      }    else      {        os << "edge " << (*it).first << " "           << (*it).second.first << " "           << (*it).second.second << std::endl;        ++it;      }}voiddiffTime( unsigned long& secs, unsigned long& usecs, struct timeval& t1, struct timeval& t2 ){  if( t1.tv_sec == t2.tv_sec )    {      secs = 0;      usecs = t2.tv_usec - t1.tv_usec;    }  else    {      secs = (t2.tv_sec - t1.tv_sec) - 1;      usecs = (MILLION - t1.tv_usec) + t2.tv_usec;      if( usecs > MILLION )	{	  ++secs;	  usecs = usecs % MILLION;	}    }}intmain( int argc, char **argv ){  unsigned short seed[3];  unsigned long secs, usecs;  struct timeval startTime, elapsedTime;  qvalue_func_t qvalue = &hash_t::QValueMax;  int algorithm = 0;  bool dump = false;  int format = 1;  bool output = false;  bool stats = false;  int htype = 0;  size_t hiter = 0;  // read arguments  ++argv;  --argc;  if( argc == 0 ) exit( -1 );  while( **argv == '-' )    {      switch( argv[0][1] )        {        case 'a':          qvalue = &hash_t::QValueAdd;          break;        case 'A':          algorithm = atoi( argv[1] );          ++argv;          --argc;          break;        case 'd':          dump = true;          break;        case 'f':          format = 2;          break;        case 'h':          htype = atoi( argv[1] );          hiter = atoi( argv[2] );          argv += 2;          argc -= 2;          break;        case 'o':          output = true;          break;        case 's':          rseed = atoi( argv[1] );          ++argv;          --argc;          break;        case 'S':          stats = true;          break;        case 'v':          verbose = true;          break;        default:          std::cerr << "undefined option: " << argv[0] << std::endl;          exit( -1 );          break;        }      ++argv;      --argc;    }  num_atoms = atoi( argv[0] );  max_rules_per_atom = atoi( argv[1] );  max_body_size = atoi( argv[2] );  atom_t theorem = atoi( argv[3] );  seed[0] = seed[1] = seed[2] = rseed;  seed48( seed );  //std::cout << "seed = " << rseed << std::endl;  //std::cout << "building rule system ... "; std::cout.flush();  gettimeofday( &startTime, NULL );  rule_system = &rule_system_t::generate_random( num_atoms, max_rules_per_atom, max_body_size );  gettimeofday( &elapsedTime, NULL );  diffTime( secs, usecs, startTime, elapsedTime );  //std::cout << "done in " << TIME(secs,usecs) << std::endl;  if( output )    {      std::cout << "Rule System:" << std::endl                << "------------" << std::endl                << *rule_system << std::endl;    }  state_t s0( theorem );  gettimeofday( &startTime, NULL );  heuristic_t heuristic( htype, *rule_system );  heuristic.compute_heuristic( hiter, qvalue );  gettimeofday( &elapsedTime, NULL );  diffTime( secs, usecs, startTime, elapsedTime );  float htime = TIME(secs,usecs);  if( output ) heuristic.dump( std::cout );  if( dump )    {      dumpCFC( s0, &heuristic, std::cerr, format );      return( 0 );    }  size_t (*atable[5])( hash_t &, const state_t, qvalue_func_t );  atable[0] = valueIteration;  atable[1] = ldfsDriver;  atable[2] = ldfsBoundDriver;  atable[3] = aostar;  atable[4] = labeled_lrta;  int udiff;  double last_solution = DBL_MAX;  for( int i = 0; i < 5; ++i )    if( ((algorithm>>i) % 2 == 1) && (atable[i] != 0) )      {        hash_t *hash = new hash_t( &heuristic );        gettimeofday( &startTime, NULL );        size_t iterations = (*atable[i])( *hash, s0, qvalue );        gettimeofday( &elapsedTime, NULL );        diffTime( secs, usecs, startTime, elapsedTime );        std::cout << hash->value(s0) << " "                  << iterations << " "                  << hash->updates() << " "                  << expansions << " "                  << htime << " "                  << TIME(secs,usecs) << std::endl;        if( (last_solution < UINT_MAX) && (last_solution != hash->value(s0)) )          std::cout << "***** DISCREPANCY FOUND" << std::endl;        last_solution = hash->value( s0 );        if( stats ) std::cerr << "policy_size=" << policySize( s0, *hash, qvalue ) << std::endl;        delete hash;      }  return( 0 );}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -