⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 mts.cc

📁 基于学习的深度优先搜索算法
💻 CC
📖 第 1 页 / 共 3 页
字号:
          }}voidheuristic_t::compute_heuristic( const state_t &s0, size_t iters, qvalue_func_t qvalue ){  size_t i = 0, ups = iters * n_ * n_ * n_ * n_;  while( ((type_ == 1) && (i < iters)) || ((type_ == 2) && (hash_->updates() < ups)) )    {      ++i;      for( int ax = 0; ax < n_; ++ax )        for( int ay = 0; ay < n_; ++ay )          for( int bx = 0; bx < n_; ++bx )            for( int by = 0; by < n_; ++by )              if( (type_ == 1) || ((type_ == 2) && (drand48() < 0.5)) )                {                  state_t s( ax, ay, bx, by );                  if( !s.terminal() )                    {                      double bqv = hash_->bestQValue( s, qvalue );                      assert( bqv >= hash_->value( s ) );                      hash_->update( s, bqv );                    }                  else                    hash_->update( s, s.terminal_value() );                  if( (type_ == 2) && (hash_->updates() >= ups) ) goto end;                }    }  end:;  if( verbose ) std::cerr << "updates=" << hash_->updates() << ", iters=" << i << std::endl;}size_tvalueIteration( int n, hash_t &hash, const state_t &s0, qvalue_func_t qvalue ){  size_t iterations = 0;  double res = 1;  while( res > 0 )    {      res = 0;      for( int ax = 0; ax < n; ++ax )        for( int ay = 0; ay < n; ++ay )          for( int bx = 0; bx < n; ++bx )            for( int by = 0; by < n; ++by )              {                state_t s( ax, ay, bx, by );                if( !s.terminal() )                  {                    double value = hash.value( s );                    double bqv = hash.bestQValue( s, qvalue );                    res = MAX(res,bqv-value);                    hash.update( s, bqv );                    assert( bqv >= value );                  }              }      ++iterations;      if( verbose ) std::cout << "residual=" << res << ", V(s0)=" << hash.value( s0 ) << std::endl;    }  return( iterations );}boolldfsBound( int depth, int n, hash_t &hash, const state_t &s, double bound, qvalue_func_t qvalue ){  if( s.terminal() || (hash.upper( s ) <= bound) )    {      if( s.terminal() )        {          hash.update( s, s.terminal_value() );          hash.update_upper( s, s.terminal_value() );        }      return( true );    }  ++expansions;  int a;  bool flag = false;  assert( hash.value( s ) <= bound );  for( a = 0; a < 4; ++a )    if( VALID(n,s.ax(),s.ay(),a) )      {        state_t s_a = s.a_move( a );        double qv = (hash.*qvalue)( s_a, a );        if( qv > bound ) continue;        flag = true;        for( int b = 0; b < 4; ++b )        if( VALID(n,s_a.bx(),s_a.by(),b) )          {            double nbound = bound - 1;            state_t s_ab = s_a.b_move( b );            if( qvalue == &hash_t::QValueAdd )              {                for( int t = 0; t < 4; ++t )                  if( t != b ) nbound -= hash.value( s_a.b_move( t ) );              }            flag = ldfsBound( 1+depth, n, hash, s_ab, nbound, qvalue );            if( flag )              {                double qv = (hash.*qvalue)( s_a, a );                flag = (qv <= bound);              }            if( !flag ) break;          }        if( flag ) break;      }  if( flag )    {      hash.update_upper( s, bound );      if( verbose )        {          indent( std::cout, depth );          std::cout << "upper=" << bound << std::endl;        }    }  else    {      double bqv = hash.bestQValue( s, qvalue );      //assert( bqv > hash.value( s ) );      hash.update( s, bqv );      if( verbose )        {          indent( std::cout, depth );          std::cout << s << ", update=" << bqv << std::endl;        }    }  return( flag );}size_tldfsBoundDriver( int n, hash_t &hash, const state_t &s, qvalue_func_t qvalue ){  expansions = 0;  size_t iterations = 0;  while( hash.value( s ) < hash.upper( s ) )    {      ldfsBound( 0, n, hash, s, hash.value( s ), qvalue );      ++iterations;      if( verbose ) std::cout << "V(s0)=" << hash.value( s ) << std::endl;    }  return( iterations );}boolldfs( int n, hash_t &hash, const state_t &s, qvalue_func_t qvalue ){  if( s.terminal() || hash.solved( s ) )    {      if( s.terminal() )        {          hash.update( s, s.terminal_value() );          hash.update_upper( s, s.terminal_value() );        }      hash.mark( s );      return( true );    }  ++expansions;  int a;  bool flag = false;  double bound = hash.value( s );  for( a = 0; a < 4; ++a )    if( VALID(n,s.ax(),s.ay(),a) )      {        state_t s_a = s.a_move( a );        double qv = (hash.*qvalue)( s_a, a );        if( qv > bound ) continue;        flag = true;        for( int b = 0; b < 4; ++b )          if( VALID(n,s_a.bx(),s_a.by(),b) )            {              flag = ldfs( n, hash, s_a.b_move( b ), qvalue );              if( flag )                {                  double qv = (hash.*qvalue)( s_a, a );                  flag = (qv <= bound);                }              if( !flag ) break;            }        if( flag ) break;      }  if( flag )    {      hash.mark( s );    }  else    {      double bqv = hash.bestQValue( s, qvalue );      assert( (qvalue != &hash_t::QValueAdd) || (bqv > hash.value( s )) );      hash.update( s, bqv );    }  return( flag );}size_tldfsDriver( int n, hash_t &hash, const state_t &s, qvalue_func_t qvalue ){  expansions = 0;  size_t iterations = 0;  while( !hash.solved( s ) )    {      ldfs( n, hash, s, qvalue );      ++iterations;      if( verbose ) std::cout << "V(s0)=" << hash.value( s ) << std::endl;    }  return( iterations );}boolchecksolved( int n, hash_t::data_pair e, hash_t &hash, qvalue_func_t qvalue ){  static std::set<hash_t::entry_t*> aux;  static std::list<hash_t::data_pair> queue;  static std::list<hash_t::data_pair> stack;  assert( queue.empty() && stack.empty() );  aux.clear();  bool rv = true;  aux.insert( e.second );  queue.push_front( e );  while( !queue.empty() )    {      e = queue.front();      queue.pop_front();      if( ((e.second != 0) && e.second->mark_) || e.first.terminal() )	{	  if( e.second == 0 )	    {	      hash.update( e.first, e.first.terminal_value() );	      hash.mark( e.first );	    }	  else if( !e.second->mark_ )	    e.second->mark_ = true;	  continue;	}      assert( !e.first.terminal() );      stack.push_front( e );      std::pair<double,int> p = hash.bestAction( e.first, qvalue );      if( (e.second == 0) || (e.second->lower_ != p.first) ) { rv = false; continue; }      state_t s_a = e.first.a_move( p.second );      for( int b = 0; b < 4; ++b )        if( VALID(n,s_a.bx(),s_a.by(),b) )          {            state_t s_ab = s_a.b_move( b );	    hash_t::data_pair tmp = hash.get( s_ab );	    if( aux.find( tmp.second ) == aux.end() )	      {	        aux.insert( tmp.second );	        queue.push_front( tmp );	      }	}    }  while( !stack.empty() )    {      e = stack.front();      stack.pop_front();      assert( (e.second != 0) || !rv );      if( rv )	e.second->mark_ = true;      else	{	  assert( !e.first.terminal() );	  double bqv = hash.bestQValue( e.first, qvalue );	  assert( (e.second == 0) || (bqv >= e.second->lower_) );	  if( e.second != 0 )	    e.second->lower_ = bqv;	  else	    hash.update( e.first, bqv );	}    }  return( rv );}voidlabeled_lrta_trial( int n, hash_t &hash, const state_t &s0, qvalue_func_t qvalue ){  std::list<state_t> queue;  std::vector<size_t> b_moves;  state_t s = s0;  queue.push_front( s );  //std::cout << "begin trial" << std::endl;  while( !s.terminal() && !hash.solved( s ) )    {      //std::cout << "  state = " << s << std::endl;      ++expansions;      std::pair<double,int> p = hash.bestAction( s, qvalue );      assert( p.first != DBL_MAX );      hash.update( s, p.first );      state_t s_a = s.a_move( p.second );      //std::cout << "  a-action = " << p.second << std::endl;      //std::cout << "  s_a = " << s_a << std::endl;      b_moves.clear();      for( int b = 0; b < 4; ++b ) if( VALID(n,s_a.bx(),s_a.by(),b) ) b_moves.push_back( b );      assert( b_moves.size() > 0 );      int b = b_moves[lrand48() % b_moves.size()];      //std::cout << "  b-action = " << b << std::endl;      s = s_a.b_move( b );      //std::cout << "  s_ab = " << s << " [" << s_a.b_move(b) << "]" << std::endl;      queue.push_front( s );    }  //std::cout << "end trial" << std::endl;  // mark terminal nodes as solved  if( s.terminal() )    {      hash.update( s, s.terminal_value() );      hash.mark( s );    }  // labeling  while( !queue.empty() )    {      state_t s = queue.front();      queue.pop_front();      if( !checksolved( n, hash.get(s), hash, qvalue ) ) break;    }}size_tlabeled_lrta( int n, hash_t &hash, const state_t &s0, qvalue_func_t qvalue ){  expansions = 0;  size_t iterations = 0;  while( !hash.solved( s0 ) )    {      labeled_lrta_trial( n, hash, s0, qvalue );      ++iterations;      if( verbose ) std::cout << "V(s0)=" << hash.value( s0 ) << std::endl;    }  return( iterations );}size_tpolicySize( int n, const state_t s0, hash_t &hash, qvalue_func_t qvalue ){  std::set<state_t> policy, aux;  std::list<state_t> stack;  aux.insert( s0 );  stack.push_back( s0 );  while( !stack.empty() )    {      state_t s = stack.front();      stack.pop_front();      policy.insert( s );      if( s.terminal() ) continue;      std::pair<double,int> p = hash.bestAction( s, qvalue );      state_t s_a = s.a_move( p.second );      for( int b = 0; b < 4; ++b )        if( VALID(n,s_a.bx(),s_a.by(),b) )        {          state_t s_b = s_a.b_move( b );          if( aux.find( s_b ) == aux.end() )            {              stack.push_back( s_b );              aux.insert( s_b );            }        }    }  return( policy.size() );}voiddumpCFC( int n, state_t s0, const heuristic_t *h, std::ostream &os, int format ){

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -