📄 mts.cc
字号:
#include <iostream>#include <list>#include <map>#include <set>#include <vector>#include <assert.h>#include <limits.h>#include <values.h>#include <sys/types.h>#include <sys/time.h>#define UP 0#define DOWN 1#define RIGHT 2#define LEFT 3#define NOP 4#define BOXW 10#define MAX(i,j) ((i)>(j)?(i):(j))#define MIN(i,j) ((i)<(j)?(i):(j))#define IDX(n,i,j) ((i)*(n)+(j))#define VALID(n,i,j,k) (maze[IDX(n,i,j)]&(1<<k))#define MOVE(n,i,j) (maze[IDX(n,i,j)])#define MILLION 1000000#define TIME(s,u) ((float)(s) + ((float)u)/MILLION)static char *op[5] = { "UP", "DOWN", "RIGHT", "LEFT", "NOP" };static unsigned char *maze = 0;static unsigned expansions = 0;static bool verbose = false;static unsigned short rseed = 0;class hash_t;class state_t;class heuristic_t;typedef double (hash_t::*qvalue_func_t)( const state_t&, int ) const;inline voidindent( std::ostream &os, size_t depth ){ for( size_t i = 0; i < depth; ++i ) os << " ";}class state_t{ unsigned ax_ : 8; unsigned ay_ : 8; unsigned bx_ : 8; unsigned by_ : 8;public: state_t( int ax, int ay, int bx, int by ) : ax_(ax), ay_(ay), bx_(bx), by_(by) { } unsigned ax( void ) const { return( ax_ ); } unsigned ay( void ) const { return( ay_ ); } unsigned bx( void ) const { return( bx_ ); } unsigned by( void ) const { return( by_ ); } unsigned index( int n ) const { return( (bx_*n*n*n) + (by_*n*n) + (ax_*n) + ay_ ); } bool terminal( void ) const { int dx = (int)ax_-(int)bx_; int dy = (int)ay_-(int)by_; dx = (dx<0?-dx:dx); dy = (dy<0?-dy:dy); return( dx + dy <= 1 ); } bool terminal_value( void ) const { return( 0 ); } state_t a_move( int a ) const { state_t s( *this ); switch( a ) { case UP: --s.ay_; break; case DOWN: ++s.ay_; break; case LEFT: --s.ax_; break; case RIGHT: ++s.ax_; break; } return( s ); } state_t b_move( int b ) const { state_t s( *this ); switch( b ) { case UP: --s.by_; break; case DOWN: ++s.by_; break; case LEFT: --s.bx_; break; case RIGHT: ++s.bx_; break; } return( s ); } bool operator==( const state_t &s ) const { return( (ax_ == s.ax()) && (ay_ == s.ay()) && (bx_ == s.bx()) && (by_ == s.by()) ); } bool operator!=( const state_t &s ) const { return( !((*this) == s) ); } bool operator<( const state_t &s ) const { return( (ax_ < s.ax()) || ((ax_ == s.ax()) && (ay_ < s.ay())) || ((ax_ == s.ax()) && (ay_ == s.ay()) && (bx_ < s.bx())) || ((ax_ == s.ax()) && (ay_ == s.ay()) && (bx_ == s.bx()) && (by_ < s.by())) ); } void print( std::ostream &os ) const { os << "(" << ax_ << ":" << ay_ << ":" << bx_ << ":" << by_ << ")"; }};std::ostream& operator<<( std::ostream &os, const state_t &s ){ s.print( os ); return( os );}class hash_t{public: class entry_t { public: double lower_; double upper_; bool mark_; entry_t( double l = 0, double u = DBL_MAX, bool m = false ) : lower_(l), upper_(u), mark_(m) { } };private: size_t n_; size_t updates_; size_t upper_updates_; const heuristic_t *h_; entry_t *table_;public: hash_t( size_t n, const heuristic_t *h = 0 ); ~hash_t(); typedef std::pair<state_t,entry_t*> data_pair; data_pair get( const state_t &s ) { return( data_pair( s, &table_[s.index(n_)] ) ); } size_t updates( void ) { return( updates_ ); } size_t upper_updates( void ) { return( upper_updates_ ); } void update( const state_t &s, double value ) { table_[s.index(n_)].lower_ = value; ++updates_; } void update_upper( const state_t &s, double value ) { table_[s.index(n_)].upper_ = value; ++upper_updates_; } double value( const state_t &s ) const { return( table_[s.index(n_)].lower_ ); } double upper( const state_t &s ) const { return( table_[s.index(n_)].upper_ ); } void mark( const state_t &s ) { table_[s.index(n_)].mark_ = true; } bool solved( const state_t &s ) const { return( table_[s.index(n_)].mark_ ); } double QValueMax( const state_t &s_a, int a ) const; double QValueAdd( const state_t &s_a, int a ) const; double bestQValue( const state_t &s, qvalue_func_t qvalue ) const; std::pair<double,int> bestAction( const state_t &s, qvalue_func_t qvalue ) const; void dump( std::ostream &os ) const;};class heuristic_t{ int type_; size_t n_; hash_t *hash_;public: heuristic_t( int type, size_t n ) : type_(type), n_(n) { hash_ = new hash_t( n_ ); } ~heuristic_t() { delete hash_; } double value( const state_t &s ) const { return( hash_->value( s ) ); } void dump( std::ostream &os ) const { hash_->dump( os ); } void compute_heuristic( const state_t &s0, size_t iters, qvalue_func_t qvalue );};voidprintMaze( int n, int d, const unsigned char *maze ){ for( int j = 0; j < n; ++j ) { indent( std::cout, d ); std::cout << "**"; for( int i = 0; i < n; ++i ) std::cout << (!VALID(n,i,j,UP)?"****":" **"); std::cout << std::endl; indent( std::cout, d ); std::cout << "**"; for( int i = 0; i < n; ++i ) std::cout << " " << (!VALID(n,i,j,RIGHT)?"**":" "); std::cout << std::endl; } indent( std::cout, d ); for( int i = 0; i < n; ++i ) std::cout << "****"; std::cout << "**" << std::endl;}voiddfs( int n, int x, int y, std::set<std::pair<int,int> > &visited, std::set<std::pair<int,int> > &nowalls ){ std::pair<int,int> p(x,y); if( visited.find( p ) != visited.end() ) return; else { int i, j; int map[4], d[4]; int dir = lrand48() % 24; d[0] = dir / 6; d[1] = (dir % 6) / 2; d[2] = (dir / 12); map[0] = map[1] = map[2] = map[3] = 0; map[d[0]] = 1; for( i = j = 0; i < 4; ++i ) if( map[i] == 0 ) { if( j == d[1] ) break; ++j; } d[1] = i; map[i] = 1; for( i = j = 0; i < 4; ++i ) if( map[i] == 0 ) { if( j == d[2] ) break; ++j; } d[2] = i; map[i] = 1; for( i = 0; i < 4; ++i ) if( map[i] == 0 ) break; d[3] = i; visited.insert( p ); for( i = 0; i < 4; ++i ) { int nx = x, ny = y; if( (ny > 0) && (d[i] == UP) ) { --ny; if( visited.find( std::make_pair(nx,ny) ) == visited.end() ) nowalls.insert( std::make_pair( 2*nx+1, 2*y ) ); } if( (ny < n-1) && (d[i] == DOWN) ) { ++ny; if( visited.find( std::make_pair(nx,ny) ) == visited.end() ) nowalls.insert( std::make_pair( 2*nx+1, 2*ny ) ); } if( (nx > 0) && (d[i] == LEFT) ) { --nx; if( visited.find( std::make_pair(nx,ny) ) == visited.end() ) nowalls.insert( std::make_pair( 2*x, 2*ny+1 ) ); } if( (nx < n-1) && (d[i] == RIGHT) ) { ++nx; if( visited.find( std::make_pair(nx,ny) ) == visited.end() ) nowalls.insert( std::make_pair( 2*nx, 2*ny+1 ) ); } dfs( n, nx, ny, visited, nowalls ); } }}voidmoveMap( int n, unsigned char* &maze, std::set<std::pair<int,int> > &nowalls ){ maze = (unsigned char*)calloc( n*n, sizeof(unsigned char) ); for( int j = 0; j < n; ++j ) for( int i = 0; i < n; ++i ) { int moves = 0; if( (j > 0) && (nowalls.find( std::make_pair((i<<1)+1,j<<1) ) != nowalls.end()) ) moves |= (1<<UP); if( (j < n-1) && (nowalls.find( std::make_pair((i<<1)+1,(j+1)<<1) ) != nowalls.end()) ) moves |= (1<<DOWN); if( (i > 0) && (nowalls.find( std::make_pair(i<<1,(j<<1)+1) ) != nowalls.end()) ) moves |= (1<<LEFT); if( (i < n-1) && (nowalls.find( std::make_pair((i+1)<<1,(j<<1)+1) ) != nowalls.end()) ) moves |= (1<<RIGHT); maze[IDX(n,i,j)] = moves; //std::cout << "(" << i << "," << j << ") --> " << moves << " = " << MOVE(n,i,j) << std::endl; }}hash_t::hash_t( size_t n, const heuristic_t *h ) : n_(n), updates_(0), upper_updates_(0), h_(h){ table_ = new entry_t[n_*n_*n_*n_]; if( h != 0 ) { for( int ax = 0; ax < n; ++ax ) for( int ay = 0; ay < n; ++ay ) for( int bx = 0; bx < n; ++bx ) for( int by = 0; by < n; ++by ) { state_t s( ax, ay, bx, by ); table_[s.index(n_)].lower_ = h_->value( s ); } }}hash_t::~hash_t(){ delete[] table_;}doublehash_t::QValueMax( const state_t &s_a, int a ) const{ double qv = 0; for( int b = 0; b < 4; ++b ) if( VALID(n_,s_a.bx(),s_a.by(),b) ) { state_t s_ab = s_a.b_move( b ); double val= value( s_ab ); qv = MAX(qv,val); } return( 1 + qv );}doublehash_t::QValueAdd( const state_t &s_a, int a ) const{ double qv = 0; for( int b = 0; b < 4; ++b ) if( VALID(n_,s_a.bx(),s_a.by(),b) ) { state_t s_ab = s_a.b_move( b ); qv += value( s_ab ); } return( 1 + qv );}inline doublehash_t::bestQValue( const state_t &s, qvalue_func_t qvalue ) const{ assert( !s.terminal() ); double bqv = DBL_MAX; for( int a = 0; a < 4; ++a ) if( VALID(n_,s.ax(),s.ay(),a) ) { state_t s_a = s.a_move( a ); double qv = (this->*qvalue)( s_a, a ); bqv = MIN(bqv,qv); } return( bqv );}inline std::pair<double,int>hash_t::bestAction( const state_t &s, qvalue_func_t qvalue ) const{ int ba = 0; double bqv = DBL_MAX; assert( !s.terminal() ); for( int a = 0; a < 4; ++a ) if( VALID(n_,s.ax(),s.ay(),a) ) { state_t s_a = s.a_move( a ); double qv = (this->*qvalue)( s_a, a ); if( qv < bqv ) { bqv = qv; ba = a; } } return( std::make_pair(bqv,ba) );}voidhash_t::dump( std::ostream &os ) const{ for( int ax = 0; ax < n_; ++ax ) for( int ay = 0; ay < n_; ++ay ) for( int bx = 0; bx < n_; ++bx ) for( int by = 0; by < n_; ++by ) { state_t s( ax, ay, bx, by ); os << "state=" << s << ", lower=" << value(s) << ", upper=" << upper(s) << ", mark=" << (solved(s)?1:0) << std::endl;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -