⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 acoins.cc

📁 基于学习的深度优先搜索算法
💻 CC
📖 第 1 页 / 共 3 页
字号:
voidprintStats( std::ostream &os, uchar_t n, const state_t s0 ){  std::set<state_t> space;  generateStateSpace( n, s0, space );  float max_OR_branch = 0;  float avg_OR_branch = 0;  for( std::set<state_t>::const_iterator si = space.begin(); si != space.end(); ++si )    if( !(*si).terminal(n) )      {        unsigned branch = 0;        for( int m = n>>1; m > 0; --m )        for( int lhs1 = m; lhs1 >= 0; --lhs1 )        for( int lhs2 = m; lhs2 >= 0; --lhs2 )          if( lhs1 + lhs2 <= (*si).lhs() ) {            for( int ls1 = m-lhs1; ls1 >= 0; --ls1 )            for( int ls2 = m-lhs2; ls2 >= 0; --ls2 )              if( ls1 + ls2 <= (*si).ls() ) {                for( int hs1 = m-lhs1-ls1; hs1 >= 0; --hs1 )                for( int hs2 = m-lhs2-ls2; hs2 >= 0; --hs2 )                  if( (hs1+hs2 <= (*si).hs()) && (m+m-lhs1-lhs2-ls1-ls2-hs1-hs2 <= (*si).s()) )                    {                      assert( m - lhs1 - ls1 - hs1 + m - lhs2 - ls2 - hs2 <= (*si).s() );                      action_t a( lhs1, ls1, hs1, m-lhs1-ls1-hs1, lhs2, ls2, hs2, m-lhs2-ls2-hs2 );                      if( !(*si).useful(a) ) continue;                      ++branch;                    }              }          }        max_OR_branch = MAX(max_OR_branch,branch);        avg_OR_branch += branch;    }  avg_OR_branch /= space.size();  os << space.size() << " " << max_OR_branch << " " << avg_OR_branch << std::endl;}boolchecksolved( uchar_t n, hash_t::data_pair e, hash_t &hash, qvalue_func_t qvalue ){  static std::set<hashing::entry_t*> aux;  static std::list<hash_t::data_pair> queue;  static std::list<hash_t::data_pair> stack;  assert( queue.empty() && stack.empty() );  aux.clear();  bool rv = true;  aux.insert( e.second );  queue.push_front( e );  while( !queue.empty() )    {      e = queue.front();      queue.pop_front();      if( ((e.second != 0) && e.second->mark_) || e.first.terminal(n) )	{	  if( e.second == 0 )	    {	      hash.update( e.first, e.first.terminal_value() );	      hash.mark( e.first );	    }	  else if( !e.second->mark_ )	    e.second->mark_ = true;	  continue;	}      assert( !e.first.terminal(n) );      stack.push_front( e );      std::pair<unsigned,action_t> p = hash.bestAction( n, e.first, qvalue );      if( (e.second == 0) || (e.second->lower_ != p.first) ) { rv = false; continue; }      for( int obs = -1; obs < 2; ++obs )        {          state_t s_obs = e.first.next( p.second, obs );          hash_t::data_pair tmp = hash.get( s_obs );          if( aux.find( tmp.second ) == aux.end() )            {              aux.insert( tmp.second );              queue.push_front( tmp );            }        }    }  while( !stack.empty() )    {      e = stack.front();      stack.pop_front();      assert( (e.second != 0) || !rv );      if( rv )	e.second->mark_ = true;      else	{	  assert( !e.first.terminal(n) );	  unsigned bqv = hash.bestQValue( n, e.first, qvalue );	  assert( (e.second == 0) || (bqv >= e.second->lower_) );	  if( e.second != 0 )	    e.second->lower_ = bqv;	  else	    hash.update( e.first, bqv );	}    }  return( rv );}voidlabeled_lrta_trial( uchar_t n, hash_t &hash, const state_t s0, qvalue_func_t qvalue ){  std::list<state_t> queue;  std::vector<size_t> b_moves;  state_t s = s0;  queue.push_front( s );  while( !s.terminal(n) && !hash.solved( s ) )    {      ++expansions;      std::pair<unsigned,action_t> p = hash.bestAction( n, s, qvalue );      hash.update( s, p.first );      int obs = (lrand48() % 3) - 1;      state_t s_obs = s.next( p.second, obs );      s = s_obs;      queue.push_front( s );    }  // mark terminal nodes as solved  if( s.terminal(n) || (hash.value( s ) == UMAX) )    {      if( s.terminal(n) ) hash.update( s, s.terminal_value() );      hash.mark( s );    }  // labeling  while( !queue.empty() )    {      state_t s = queue.front();      queue.pop_front();      if( !checksolved( n, hash.get(s), hash, qvalue ) ) break;    }}size_tlabeled_lrta( uchar_t n, hash_t &hash, const state_t s0, qvalue_func_t qvalue ){  expansions = 0;  size_t iterations = 0;  while( !hash.solved( s0 ) )    {      labeled_lrta_trial( n, hash, s0, qvalue );      ++iterations;      if( verbose ) std::cout << "V(s0)=" << hash.value( s0 ) << std::endl;    }  return( iterations );}size_tpolicySize( uchar_t n, const state_t s0, hash_t &hash, qvalue_func_t qvalue ){  std::set<state_t> policy, aux;  std::list<state_t> stack;  aux.insert( s0 );  stack.push_back( s0 );  while( !stack.empty() )    {      state_t s = stack.front();      stack.pop_front();      policy.insert( s );      if( s.terminal(n) ) continue;      std::pair<unsigned,action_t> p = hash.bestAction( n, s, qvalue );      for( int obs = -1; obs < 2; ++obs )        {          state_t s_obs = s.next( p.second, obs );          if( aux.find( s_obs ) == aux.end() )            {              stack.push_back( s_obs );              aux.insert( s_obs );            }        }    }  return( policy.size() );}// dump problem for CFC/REVvoiddumpCFC( uchar_t n, const state_t &s0, const heuristic_t *h, std::ostream &os, int format ){  std::map<state_t,unsigned> onodes;  std::map<unsigned,state_t> ionodes;  std::map<std::pair<unsigned,action_t>,unsigned> anodes;  std::set<std::pair<unsigned,std::pair<unsigned,int> > > edges;  std::set<state_t> space;  generateStateSpace( n, s0, space );  unsigned k = 0;  ionodes.insert( std::make_pair( k, s0 ) );  onodes.insert( std::make_pair( s0, k++ ) );  for( std::set<state_t>::iterator si = space.begin(); si != space.end(); ++si )    {      if( onodes.find( (*si) ) == onodes.end() )        {          ionodes.insert( std::make_pair( k, (*si) ) );          onodes.insert( std::make_pair( (*si), k++ ) );        }      if( (*si).terminal(n) ) continue;      unsigned s_idx = onodes[(*si)];      for( int m = n>>1; m > 0; --m )      for( int lhs1 = m; lhs1 >= 0; --lhs1 )      for( int lhs2 = m; lhs2 >= 0; --lhs2 )        if( lhs1 + lhs2 <= (*si).lhs() ) {          for( int ls1 = m-lhs1; ls1 >= 0; --ls1 )          for( int ls2 = m-lhs2; ls2 >= 0; --ls2 )            if( ls1 + ls2 <= (*si).ls() ) {              for( int hs1 = m-lhs1-ls1; hs1 >= 0; --hs1 )              for( int hs2 = m-lhs2-ls2; hs2 >= 0; --hs2 )                if( (hs1+hs2 <= (*si).hs()) && (m+m-lhs1-lhs2-ls1-ls2-hs1-hs2 <= (*si).s()) )                  {                    assert( m - lhs1 - ls1 - hs1 + m - lhs2 - ls2 - hs2 <= (*si).s() );                    action_t a( lhs1, ls1, hs1, m-lhs1-ls1-hs1, lhs2, ls2, hs2, m-lhs2-ls2-hs2 );                    if( !(*si).useful(a) ) continue;                    std::pair<unsigned,action_t> s_a( s_idx, a );                    if( anodes.find( s_a ) == anodes.end() )                      anodes.insert( std::make_pair( s_a, k++ ) );                    assert( anodes.find( s_a ) != anodes.end() );                    unsigned s_a_idx = anodes[s_a];                    edges.insert( std::make_pair( s_idx, std::make_pair( s_a_idx, 1 ) ) );                    for( int obs = -1; obs < 2; ++obs )                      {                        state_t s_obs = (*si).next( a, obs );                        if( onodes.find( s_obs ) == onodes.end() )                          {                            ionodes.insert( std::make_pair( k, s_obs ) );                            onodes.insert( std::make_pair( s_obs, k++ ) );                          }                        assert( onodes.find( s_obs ) != onodes.end() );                        edges.insert( std::make_pair(s_a_idx,std::make_pair(onodes[s_obs],0)) );                      }                  }              }          }    }  assert( space.size() == onodes.size() );  assert( k == onodes.size() + anodes.size() );  if( format == 1 )    os << "Number_of_nodes " << onodes.size() + anodes.size() << std::endl;  else    {      os << "comment coins " << n << std::endl;      os << "aograph " << onodes.size() << " "         << anodes.size() << " "         << edges.size() << std::endl;    }  // output nodes  for( size_t k = 0; k < onodes.size() + anodes.size(); ++k )    {      std::map<unsigned,state_t>::const_iterator it = ionodes.find( k );      if( it != ionodes.end() )        {          double hvalue = ( !h ? 0 : h->value( (*it).second ) );          int terminal = ( (*it).second.terminal(n) ? 0 : 1 );          if( format == 1 )            {              os << "h(" << k << ") " << hvalue << std::endl                 << "AND(0)/OR(1) 1" << std::endl                 << "SOLVED(0)/NON_TERMINAL(1) " << terminal << std::endl;            }          else            {              os << "or " << k << " " << terminal << " " << hvalue << std::endl;            }        }      else        {          if( format == 1 )            {              os << "h(" << k << ") 0" << std::endl                 << "AND(0)/OR(1) 0" << std::endl                 << "SOLVED(0)/NON_TERMINAL(1) 1" << std::endl;            }          else            {              os << "and " << k << " 1 0" << std::endl;            }        }    }  // output edges  if( format == 1 ) os << "arcarcarcarcarcarcarcarcarcarcarcarc" << std::endl;  for( std::set<std::pair<unsigned,std::pair<unsigned,int> > >::const_iterator it = edges.begin(); it != edges.end(); )    if( format == 1 )      {        os << "vtx1 " << (*it).first << std::endl           << "vtx2 " << (*it).second.first << std::endl           << "weight " << (*it).second.second << std::endl           << "another? " << (++it == edges.end() ? 0 : 1) << std::endl;      }    else      {        os << "edge " << (*it).first << " "           << (*it).second.first << " "           << (*it).second.second << std::endl;        ++it;      }}voiddiffTime( unsigned long& secs, unsigned long& usecs, struct timeval& t1, struct timeval& t2 ){  if( t1.tv_sec == t2.tv_sec )    {      secs = 0;      usecs = t2.tv_usec - t1.tv_usec;    }  else    {      secs = (t2.tv_sec - t1.tv_sec) - 1;      usecs = (MILLION - t1.tv_usec) + t2.tv_usec;      if( usecs > MILLION )	{	  ++secs;	  usecs = usecs % MILLION;	}    }}intmain( int argc, char **argv ){  unsigned long secs, usecs;  struct timeval startTime, elapsedTime;  qvalue_func_t qvalue = &hash_t::QValueMax;  int algorithm = 0;  bool dump = false;  int format = 1;  bool output = false;  bool stats = false;  int htype = 0;  size_t hiter = 0;  // read arguments  ++argv;  --argc;  if( argc == 0 ) exit( -1 );  while( **argv == '-' )    {      switch( argv[0][1] )        {        case 'A':          algorithm = atoi( argv[1] );          ++argv;          --argc;          break;        case 'a':          qvalue = &hash_t::QValueAdd;          break;        case 'd':          dump = true;          break;        case 'f':          format = 2;          break;        case 'h':          htype = atoi( argv[1] );          hiter = atoi( argv[2] );          argv += 2;          argc -= 2;          break;        case 'o':          output = true;          break;        case 'S':          stats = true;          break;        case 'v':          verbose = true;          break;        default:          std::cerr << "undefined option: " << argv[0] << std::endl;          exit( -1 );          break;        }      ++argv;      --argc;    }  int coins = atoi( argv[0] );  state_t s0( coins );  gettimeofday( &startTime, NULL );  heuristic_t heuristic( htype, coins );  heuristic.compute_heuristic( s0, hiter, qvalue );  gettimeofday( &elapsedTime, NULL );  diffTime( secs, usecs, startTime, elapsedTime );  float htime = TIME(secs,usecs);  if( output ) heuristic.dump( std::cout );  if( stats )    {      printStats( std::cout, coins, s0 );    }  else if( dump )    {      dumpCFC( coins, s0, &heuristic, std::cerr, format );      return( 0 );    }  size_t (*atable[5])( uchar_t, hash_t&, const state_t, qvalue_func_t );  atable[0] = valueIteration;  atable[1] = ldfsDriver;  atable[2] = ldfsBoundDriver;  atable[3] = aostar;  atable[4] = labeled_lrta;  unsigned last_solution = UINT_MAX;  for( int i = 0; i < 5; ++i )    if( ((algorithm>>i) % 2 == 1) && (atable[i] != 0) )      {        hash_t *hash = new hash_t( &heuristic );        gettimeofday( &startTime, NULL );        size_t iterations = (*atable[i])( coins, *hash, s0, qvalue );        gettimeofday( &elapsedTime, NULL );        diffTime( secs, usecs, startTime, elapsedTime );        std::cout << hash->value(s0) << " "                  << iterations << " "                  << hash->updates() << " "                  << expansions << " "                  << htime << " "                  << TIME(secs,usecs) << std::endl;        if( (last_solution < UINT_MAX) && (last_solution != hash->value(s0)) )          std::cout << "***** DISCREPANCY FOUND" << std::endl;        last_solution = hash->value( s0 );        if( stats ) std::cerr << "policy_size=" << policySize( coins, s0, *hash, qvalue ) << std::endl;        delete hash;      }  return( 0 );}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -