📄 diagnosis.cc
字号:
#include <iostream>#include <ext/hash_map>#include <list>#include <map>#include <set>#include <limits.h>#include <math.h>#include <assert.h>#include <sys/time.h>#define MAX(s,t) ((s)>(t)?(s):(t))#define MIN(s,t) ((s)<(t)?(s):(t))#define UMAX (UINT_MAX>>1)#define MILLION 1000000#define TIME(s,u) ((float)(s) + ((float)u)/MILLION)class hash_t;class state_t;class test_matrix_t;typedef unsigned (hash_t::*qvalue_func_t)( const state_t&, size_t ) const;void generateStateSpace( const state_t, std::set<state_t>& );static bool verbose = false;static size_t expansions = 0;static const test_matrix_t *test_matrix = 0;static unsigned short rseed = 0;inline voidindent( std::ostream &os, size_t depth ){ for( size_t i = 0; i < depth; ++i ) os << ' ';}class test_matrix_t{ size_t m_, n_; char *matrix_;public: test_matrix_t( size_t m, size_t n ) : m_(m), n_(n) { matrix_ = new char[m_*n_]; memset( matrix_, 0, m_ * n_ * sizeof(char) ); } ~test_matrix_t() { delete[] matrix_; } size_t test( size_t t, size_t s ) const { return( matrix_[s*n_+t] ); } void set( size_t s, size_t t, char value ) { matrix_[s*n_+t] = value; } bool solvable( void ) const; void print( std::ostream &os, size_t depth ) const; static const test_matrix_t* generate_random( size_t m, size_t n );};class state_t{ unsigned states_[4]; unsigned tests_[4]; static size_t max_states_; static size_t max_tests_;public: state_t() { states_[0] = states_[1] = states_[2] = states_[3] = 0; tests_[0] = tests_[1] = tests_[2] = tests_[3] = 0; } static void set_max( size_t maxs, size_t maxt ) { max_states_ = maxs; max_tests_ = maxt; } static size_t max_states( void ) { return( max_states_ ); } static size_t max_tests( void ) { return( max_tests_ ); } size_t hash( void ) const { return( states_[0] ^ states_[1] ^ states_[2] ^ states_[3] ^ tests_[0] ^ tests_[1] ^ tests_[2] ^ tests_[3] ); } void insert_state( size_t s ) { states_[s/32] |= (1<<(s%32)); } bool available_state( size_t s ) const { return( states_[s/32] & (1<<(s%32)) ); } void clear_test( size_t t ) { tests_[t/32] |= (1<<(t%32)); } void clear_all_states( void ) { states_[0] = states_[1] = states_[2] = states_[3] = 0; } bool available_test( size_t t ) const { return( !(tests_[t/32] & (1<<(t%32))) ); } size_t num_states( void ) const { size_t n = 0; for( size_t i = 0; i < max_states_; ++i ) n += available_state(i); return( n ); } void print( std::ostream &os, size_t depth ) const; bool terminal( void ) const { return( num_states() == 1 ); } unsigned terminal_value( void ) const { return( 0 ); } state_t next_state( size_t test, size_t obs, const test_matrix_t &matrix ) const; bool operator<( const state_t &s ) const { return( (states_[0] < s.states_[0]) || ((states_[0] == s.states_[0]) && (states_[1] < s.states_[1])) || ((states_[0] == s.states_[0]) && (states_[1] == s.states_[1]) && (states_[2] < s.states_[2])) || ((states_[0] == s.states_[0]) && (states_[1] == s.states_[1]) && (states_[2] == s.states_[2]) && (states_[3] < s.states_[3])) || ((states_[0] == s.states_[0]) && (states_[1] == s.states_[1]) && (states_[2] == s.states_[2]) && (states_[3] == s.states_[3]) && (tests_[0] < s.tests_[0])) || ((states_[0] == s.states_[0]) && (states_[1] == s.states_[1]) && (states_[2] == s.states_[2]) && (states_[3] == s.states_[3]) && (tests_[0] == s.tests_[0]) && (tests_[1] < s.tests_[1])) || ((states_[0] == s.states_[0]) && (states_[1] == s.states_[1]) && (states_[2] == s.states_[2]) && (states_[3] == s.states_[3]) && (tests_[0] == s.tests_[0]) && (tests_[1] == s.tests_[1]) && (tests_[2] < s.tests_[2])) || ((states_[0] == s.states_[0]) && (states_[1] == s.states_[1]) && (states_[2] == s.states_[2]) && (states_[3] == s.states_[3]) && (tests_[0] == s.tests_[0]) && (tests_[1] == s.tests_[1]) && (tests_[2] == s.tests_[2]) && (tests_[3] < s.tests_[3])) ); } bool operator==( const state_t &s ) const { return( (states_[0] == s.states_[0]) && (states_[1] == s.states_[1]) && (states_[2] == s.states_[2]) && (states_[3] == s.states_[3]) && (tests_[0] == s.tests_[0]) && (tests_[1] == s.tests_[1]) && (tests_[2] == s.tests_[2]) && (tests_[3] == s.tests_[3]) ); } class const_iterator { size_t pos_; const state_t &s_; public: const_iterator( size_t pos, const state_t &s ) : pos_(pos), s_(s) { } void advance( void ) { while( !s_.available_state(pos_) && (pos_ < max_states_) ) ++pos_; } const_iterator operator++( void ) { ++pos_; advance(); return( *this ); } size_t operator*( void ) const { return( pos_ ); } bool operator==( const const_iterator &it ) const { return( pos_ == it.pos_ ); } bool operator!=( const const_iterator &it ) const { return( !(*this == it) ); } }; const_iterator begin( void ) const { const_iterator it(0,*this); it.advance(); return( it ); } const_iterator end( void ) const { const_iterator it(max_states_,*this); return( it ); }};size_t state_t::max_states_ = 0;size_t state_t::max_tests_ = 0;voidstate_t::print( std::ostream &os, size_t depth ) const{ indent( os, depth ); os << "{ "; for( size_t s = 0; s < state_t::max_states(); ++s ) if( available_state( s ) ) os << s << ' '; os << "} : { "; for( size_t t = 0; t < state_t::max_tests(); ++t ) if( available_test( t ) ) os << t << ' '; os << '}';}std::ostream& operator<<( std::ostream &os, const state_t &s ){ s.print( os, 0 ); return( os );}class heuristic_t{ int type_; hash_t *hash_; const test_matrix_t &test_matrix_;public: heuristic_t( int type, const test_matrix_t &tm ); ~heuristic_t(); unsigned value( const state_t &s ) const; void dump( std::ostream &os ) const; void compute_heuristic( const state_t &s0, size_t iters, qvalue_func_t qvalue );};namespace hashing = ::__gnu_cxx;namespace __gnu_cxx { template<> class hash<const state_t> { public: unsigned operator()( const state_t &s ) const { return( s.hash() ); } }; class entry_t { public: unsigned lower_; unsigned upper_; bool mark_; entry_t( unsigned lower = 0, unsigned upper = UMAX, bool mark = false ) : lower_(lower), upper_(upper), mark_(mark) { } }; class hash_t : public hash_map<const state_t,entry_t> { public: typedef std::pair<state_t,entry_t*> data_pair; private: const heuristic_t *h_; data_pair push( const state_t &s, unsigned l, unsigned u, bool m ) { std::pair<iterator,bool> p = insert( std::make_pair( s, entry_t(l,u,m) ) ); return( data_pair( (*p.first).first, &(*p.first).second ) ); } data_pair push( const state_t &s ) { return( push( s, (!h_?0:h_->value(s)), UMAX, false ) ); } const_iterator lookup( const state_t &s ) const { return( find( s ) ); } iterator lookup( const state_t &s ) { return( find( s ) ); } public: hash_t( const heuristic_t *h = 0 ) : h_(h) { } data_pair get( const state_t &s ) { iterator di = lookup( s ); return( di == end() ? push( s ) : data_pair( (*di).first, &(*di).second ) ); } unsigned value( const state_t &s ) const { const_iterator di = lookup( s ); return( di == end() ? (!h_?0:h_->value(s)) : (*di).second.lower_ ); } unsigned upper( const state_t &s ) const { const_iterator di = lookup( s ); return( di == end() ? UMAX : (*di).second.upper_ ); } void update( const state_t &s, unsigned value ) { iterator di = lookup( s ); if( di == end() ) push( s, value, UMAX, false ); else (*di).second.lower_ = value; } void update_upper( const state_t &s, unsigned value ) { iterator di = lookup( s ); if( di == end() ) push( s, (!h_?0:h_->value(s)), value, false ); else (*di).second.upper_ = value; } bool solved( const state_t &s ) const { const_iterator di = lookup( s ); return( di == end() ? false : (*di).second.mark_ ); } void mark( const state_t &s ) { iterator di = lookup( s ); if( di == end() ) push( s, (!h_?0:h_->value(s)), UMAX, true ); else (*di).second.mark_ = true; } void dump( std::ostream &os ) const { for( const_iterator it = begin(); it != end(); ++it ) os << "state=" << (*it).first << ", lower=" << (*it).second.lower_ << ", upper=" << (*it).second.upper_ << ", mark=" << ((*it).second.mark_?1:0) << std::endl; } };};class hash_t{ hashing::hash_t table_; unsigned updates_; unsigned upper_updates_;public: hash_t( const heuristic_t *h = 0) : table_(h), updates_(0), upper_updates_(0) { } typedef hashing::hash_t::data_pair data_pair; data_pair get( const state_t &s ) { return( table_.get( s ) ); } unsigned updates( void ) const { return( updates_ ); } unsigned upper_updates( void ) const { return( upper_updates_ ); } size_t size( void ) const { return( table_.size() ); } size_t bucket_count( void ) const { return( table_.bucket_count() ); } unsigned value( const state_t &s ) const { return( table_.value( s ) ); } unsigned upper( const state_t &s ) const { return( table_.upper( s ) ); } void update( const state_t &s, unsigned value ) { table_.update( s, value ); ++updates_; } void update_upper( const state_t &s, unsigned value ) { table_.update_upper( s, value ); ++upper_updates_; } bool solved( const state_t &s ) const { return( table_.solved( s ) ); } void mark( state_t s ) { table_.mark( s ); } unsigned QValueMax( const state_t &s, size_t t ) const; unsigned QValueAdd( const state_t &s, size_t t ) const; unsigned bestQValue( const state_t &s, qvalue_func_t qvalue ) const; std::pair<unsigned,size_t> bestAction( const state_t &s, qvalue_func_t qvalue ) const; void dump( std::ostream &os ) const { table_.dump( os ); }};class graph_t{public: class entry_t { const state_t s_; std::set<entry_t*> parents_; size_t marked_; bool fringe_; bool visited_; size_t index_; public: entry_t() : marked_(UINT_MAX), fringe_(true), visited_(false), index_(0) { } entry_t( const state_t s ) : s_(s), marked_(UINT_MAX), fringe_(true), visited_(false), index_(0) { } const state_t& s( void ) const { return( s_ ); } bool fringe( void ) const { return( fringe_ ); } void set_fringe( bool fringe ) { fringe_ = fringe; } std::set<entry_t*>& parents( void ) { return( parents_ ); } const std::set<entry_t*>& parents( void ) const { return( parents_ ); } void add_parent( entry_t *parent ) { parents_.insert( parent ); } bool visited( void ) const { return( visited_ ); } void visit( void ) { visited_ = true; } void unvisit( void ) { visited_ = false; } size_t marked( void ) const { return( marked_ ); } void mark( size_t test ) { marked_ = test; } void inc_index( void ) { ++index_; } void dec_index( void ) { assert( index_ > 0 ); --index_; } void set_index( size_t index ) { index_ = index; } size_t index( void ) const { return( index_ ); } bool revise( hash_t &hash, qvalue_func_t qvalue ); bool operator()( const entry_t *e1, const entry_t *e2 ) const { return( e1->index() < e2->index() ); } };private: std::map<state_t,entry_t*> nodes_; std::set<entry_t*> tips_; entry_t *root_;public: graph_t() : root_(0) { } entry_t* root( void ) { return( root_ ); } void set_root( entry_t *root ) { root_ = root; } const std::set<entry_t*>& tips( void ) const { return( tips_ ); } entry_t& choose_tip( void ) { return( **tips_.begin() ); } entry_t* add_tip( entry_t *parent, size_t test, size_t obs, const state_t &child ); bool expand( entry_t &n, std::multiset<entry_t*,entry_t> &S ); void expand_and_update( entry_t &n, hash_t &hash, qvalue_func_t qvalue );};std::ostream& operator<<( std::ostream &os, const test_matrix_t &tm ){ tm.print( os, 0 ); return( os );}inline state_tstate_t::next_state( size_t test, size_t obs, const test_matrix_t &matrix ) const{ state_t s( *this ); s.clear_all_states(); for( const_iterator it = begin(); it != end(); ++it ) if( matrix.test( test, *it ) == obs ) s.insert_state( *it ); s.clear_test( test ); return( s );}const test_matrix_t*test_matrix_t::generate_random( size_t m, size_t n ){ test_matrix_t *tm = new test_matrix_t( m, n ); for( size_t s = 1; s < m; ++s ) for( size_t t = 0; t < n; ++t ) tm->set( s, t, (char)(lrand48()%2) ); return( tm );}booltest_matrix_t::solvable( void ) const{ size_t t = 0; for( t = 0; t < n_; ++t ) if( test(t,0) ) return( false ); for( size_t s1 = 0; s1 < m_; ++s1 ) for( size_t s2 = s1+1; s2 < m_; ++s2 ) { for( t = 0; t < n_; ++t ) if( (test(t,s1) && !test(t,s2)) || (!test(t,s1) && test(t,s2)) ) break; if( t == n_ ) return( false ); } return( true );}voidtest_matrix_t::print( std::ostream &os, size_t depth ) const{ for( size_t s = 0; s < m_; ++s ) { indent( os, depth ); for( size_t t = 0; t < n_; ++t ) os << (test(t,s)?'1':'0'); os << std::endl; }}unsignedhash_t::QValueMax( const state_t &s, size_t t ) const{ unsigned qv = 0; for( size_t obs = 0; obs < 2; ++obs ) { state_t s_obs = s.next_state( t, obs, *test_matrix ); unsigned v = value( s_obs ); qv = MAX(qv,v); } return( qv == UMAX ? UMAX : 1 + qv );}unsignedhash_t::QValueAdd( const state_t &s, size_t t ) const{ unsigned qv = 0; for( size_t obs = 0; obs < 2; ++obs ) { state_t s_obs = s.next_state( t, obs, *test_matrix ); qv = ( qv + value( s_obs ) < qv ? UMAX : qv + value( s_obs ) ); } return( qv >= UMAX ? UMAX : 1 + qv );}inline unsignedhash_t::bestQValue( const state_t &s, qvalue_func_t qvalue ) const{ unsigned bqv = UMAX; for( size_t test = 0; test < state_t::max_tests(); ++test ) if( s.available_test( test ) ) { unsigned qv = (this->*qvalue)( s, test ); bqv = MIN(bqv,qv); } return( bqv );}inline std::pair<unsigned,size_t>hash_t::bestAction( const state_t &s, qvalue_func_t qvalue ) const{ size_t best_test = UINT_MAX; unsigned bqv = UMAX; for( size_t test = 0; test < state_t::max_tests(); ++test ) if( s.available_test( test ) ) { unsigned qv = (this->*qvalue)( s, test ); if( qv < bqv ) { bqv = qv; best_test = test; } } return( std::make_pair( bqv, best_test ) );}heuristic_t::heuristic_t( int type, const test_matrix_t &tm ) : type_(type), test_matrix_(tm){ hash_ = new hash_t;}heuristic_t::~heuristic_t(){ delete hash_;}inline unsignedheuristic_t::value( const state_t &s ) const{ return( hash_->value( s ) );}voidheuristic_t::dump( std::ostream &os ) const{ hash_->dump( os );}voidheuristic_t::compute_heuristic( const state_t &s0, size_t iters, qvalue_func_t qvalue ){ std::set<state_t> space; if( iters > 0 ) generateStateSpace( s0, space ); size_t i = 0, ups = iters * space.size(); while( ((type_ == 1) && (i < iters)) || ((type_ == 2) && (hash_->updates() < ups)) ) { ++i; for( std::set<state_t>::const_iterator si = space.begin(); si != space.end(); ++si ) if( (type_ == 1) || ((type_ == 2) && (drand48() < 0.5)) ) { if( !(*si).terminal() ) hash_->update( *si, hash_->bestQValue( *si, qvalue ) ); else hash_->update( *si, (*si).terminal_value() ); if( (type_ == 2) && (hash_->updates() >= ups) ) goto end; }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -