📄 testing.py
字号:
"""TestCase and TestSuite artifacts and testing decorators."""# monkeypatches unittest.TestLoader.suiteClass at import timeimport itertools, os, operator, re, sys, unittest, warningsfrom cStringIO import StringIOimport testlib.config as configfrom testlib.compat import *sql, sqltypes, schema, MetaData, clear_mappers, Session, util = None, None, None, None, None, None, Nonesa_exceptions = None__all__ = ('TestBase', 'AssertsExecutionResults', 'ComparesTables', 'ORMTest', 'AssertsCompiledSQL')_ops = { '<': operator.lt, '>': operator.gt, '==': operator.eq, '!=': operator.ne, '<=': operator.le, '>=': operator.ge, 'in': operator.contains, 'between': lambda val, pair: val >= pair[0] and val <= pair[1], }# sugar ('testing.db'); set here by config() at runtimedb = Nonedef fails_if(callable_): """Mark a test as expected to fail if callable_ returns True. If the callable returns false, the test is run and reported as normal. However if the callable returns true, the test is expected to fail and the unit test logic is inverted: if the test fails, a success is reported. If the test succeeds, a failure is reported. """ docstring = getattr(callable_, '__doc__', None) or callable_.__name__ description = docstring.split('\n')[0] def decorate(fn): fn_name = fn.__name__ def maybe(*args, **kw): if not callable_(): return fn(*args, **kw) else: try: fn(*args, **kw) except Exception, ex: print ("'%s' failed as expected (condition: %s): %s " % ( fn_name, description, str(ex))) return True else: raise AssertionError( "Unexpected success for '%s' (condition: %s)" % (fn_name, description)) return _function_named(maybe, fn_name) return decoratedef future(fn): """Mark a test as expected to unconditionally fail. Takes no arguments, omit parens when using as a decorator. """ fn_name = fn.__name__ def decorated(*args, **kw): try: fn(*args, **kw) except Exception, ex: print ("Future test '%s' failed as expected: %s " % ( fn_name, str(ex))) return True else: raise AssertionError( "Unexpected success for future test '%s'" % fn_name) return _function_named(decorated, fn_name)def fails_on(*dbs): """Mark a test as expected to fail on one or more database implementations. Unlike ``unsupported``, tests marked as ``fails_on`` will be run for the named databases. The test is expected to fail and the unit test logic is inverted: if the test fails, a success is reported. If the test succeeds, a failure is reported. """ def decorate(fn): fn_name = fn.__name__ def maybe(*args, **kw): if config.db.name not in dbs: return fn(*args, **kw) else: try: fn(*args, **kw) except Exception, ex: print ("'%s' failed as expected on DB implementation " "'%s': %s" % ( fn_name, config.db.name, str(ex))) return True else: raise AssertionError( "Unexpected success for '%s' on DB implementation '%s'" % (fn_name, config.db.name)) return _function_named(maybe, fn_name) return decoratedef fails_on_everything_except(*dbs): """Mark a test as expected to fail on most database implementations. Like ``fails_on``, except failure is the expected outcome on all databases except those listed. """ def decorate(fn): fn_name = fn.__name__ def maybe(*args, **kw): if config.db.name in dbs: return fn(*args, **kw) else: try: fn(*args, **kw) except Exception, ex: print ("'%s' failed as expected on DB implementation " "'%s': %s" % ( fn_name, config.db.name, str(ex))) return True else: raise AssertionError( "Unexpected success for '%s' on DB implementation '%s'" % (fn_name, config.db.name)) return _function_named(maybe, fn_name) return decoratedef unsupported(*dbs): """Mark a test as unsupported by one or more database implementations. 'unsupported' tests will be skipped unconditionally. Useful for feature tests that cause deadlocks or other fatal problems. """ def decorate(fn): fn_name = fn.__name__ def maybe(*args, **kw): if config.db.name in dbs: print "'%s' unsupported on DB implementation '%s'" % ( fn_name, config.db.name) return True else: return fn(*args, **kw) return _function_named(maybe, fn_name) return decoratedef exclude(db, op, spec): """Mark a test as unsupported by specific database server versions. Stackable, both with other excludes and other decorators. Examples:: # Not supported by mydb versions less than 1, 0 @exclude('mydb', '<', (1,0)) # Other operators work too @exclude('bigdb', '==', (9,0,9)) @exclude('yikesdb', 'in', ((0, 3, 'alpha2'), (0, 3, 'alpha3'))) """ def decorate(fn): fn_name = fn.__name__ def maybe(*args, **kw): if _is_excluded(db, op, spec): print "'%s' unsupported on DB %s version '%s'" % ( fn_name, config.db.name, _server_version()) return True else: return fn(*args, **kw) return _function_named(maybe, fn_name) return decoratedef _is_excluded(db, op, spec): """Return True if the configured db matches an exclusion specification. db: A dialect name op: An operator or stringified operator, such as '==' spec: A value that will be compared to the dialect's server_version_info using the supplied operator. Examples:: # Not supported by mydb versions less than 1, 0 _is_excluded('mydb', '<', (1,0)) # Other operators work too _is_excluded('bigdb', '==', (9,0,9)) _is_excluded('yikesdb', 'in', ((0, 3, 'alpha2'), (0, 3, 'alpha3'))) """ if config.db.name != db: return False version = _server_version() oper = hasattr(op, '__call__') and op or _ops[op] return oper(version, spec)def _server_version(bind=None): """Return a server_version_info tuple.""" if bind is None: bind = config.db return bind.dialect.server_version_info(bind.contextual_connect())def emits_warning(*messages): """Mark a test as emitting a warning. With no arguments, squelches all SAWarning failures. Or pass one or more strings; these will be matched to the root of the warning description by warnings.filterwarnings(). """ # TODO: it would be nice to assert that a named warning was # emitted. should work with some monkeypatching of warnings, # and may work on non-CPython if they keep to the spirit of # warnings.showwarning's docstring. # - update: jython looks ok, it uses cpython's module def decorate(fn): def safe(*args, **kw): global sa_exceptions if sa_exceptions is None: import sqlalchemy.exceptions as sa_exceptions if not messages: filters = [dict(action='ignore', category=sa_exceptions.SAWarning)] else: filters = [dict(action='ignore', message=message, category=sa_exceptions.SAWarning) for message in messages ] for f in filters: warnings.filterwarnings(**f) try: return fn(*args, **kw) finally: resetwarnings() return _function_named(safe, fn.__name__) return decoratedef uses_deprecated(*messages): """Mark a test as immune from fatal deprecation warnings. With no arguments, squelches all SADeprecationWarning failures. Or pass one or more strings; these will be matched to the root of the warning description by warnings.filterwarnings(). As a special case, you may pass a function name prefixed with // and it will be re-written as needed to match the standard warning verbiage emitted by the sqlalchemy.util.deprecated decorator. """ def decorate(fn): def safe(*args, **kw): global sa_exceptions if sa_exceptions is None: import sqlalchemy.exceptions as sa_exceptions if not messages: filters = [dict(action='ignore', category=sa_exceptions.SADeprecationWarning)] else: filters = [dict(action='ignore', message=message, category=sa_exceptions.SADeprecationWarning) for message in [ (m.startswith('//') and ('Call to deprecated function ' + m[2:]) or m) for m in messages] ] for f in filters: warnings.filterwarnings(**f) try: return fn(*args, **kw) finally: resetwarnings() return _function_named(safe, fn.__name__) return decoratedef resetwarnings(): """Reset warning behavior to testing defaults.""" global sa_exceptions if sa_exceptions is None: import sqlalchemy.exceptions as sa_exceptions warnings.resetwarnings() warnings.filterwarnings('error', category=sa_exceptions.SADeprecationWarning) warnings.filterwarnings('error', category=sa_exceptions.SAWarning) if sys.version_info < (2, 4): warnings.filterwarnings('ignore', category=FutureWarning)def against(*queries): """Boolean predicate, compares to testing database configuration. Given one or more dialect names, returns True if one is the configured database engine. Also supports comparison to database version when provided with one or more 3-tuples of dialect name, operator, and version specification:: testing.against('mysql', 'postgres') testing.against(('mysql', '>=', (5, 0, 0)) """ for query in queries: if isinstance(query, basestring): if config.db.name == query: return True else: name, op, spec = query if config.db.name != name: continue have = config.db.dialect.server_version_info( config.db.contextual_connect()) oper = hasattr(op, '__call__') and op or _ops[op] if oper(have, spec): return True return Falsedef rowset(results): """Converts the results of sql execution into a plain set of column tuples. Useful for asserting the results of an unordered query. """ return set([tuple(row) for row in results])class TestData(object): """Tracks SQL expressions as they are executed via an instrumented ExecutionContext.""" def __init__(self): self.set_assert_list(None, None) self.sql_count = 0 self.buffer = None def set_assert_list(self, unittest, list): self.unittest = unittest self.assert_list = list if list is not None: self.assert_list.reverse()testdata = TestData()class ExecutionContextWrapper(object): """instruments the ExecutionContext created by the Engine so that SQL expressions can be tracked.""" def __init__(self, ctx): global sql if sql is None: from sqlalchemy import sql self.__dict__['ctx'] = ctx def __getattr__(self, key): return getattr(self.ctx, key) def __setattr__(self, key, value): setattr(self.ctx, key, value) def post_execution(self): ctx = self.ctx statement = unicode(ctx.compiled) statement = re.sub(r'\n', '', ctx.statement) if config.db.name == 'mssql' and statement.endswith('; select scope_identity()'): statement = statement[:-25] if testdata.buffer is not None: testdata.buffer.write(statement + "\n") if testdata.assert_list is not None: assert len(testdata.assert_list), "Received query but no more assertions: %s" % statement item = testdata.assert_list[-1] if not isinstance(item, dict): item = testdata.assert_list.pop() else: # asserting a dictionary of statements->parameters # this is to specify query assertions where the queries can be in # multiple orderings if '_converted' not in item: for key in item.keys(): ckey = self.convert_statement(key) item[ckey] = item[key] if ckey != key: del item[key] item['_converted'] = True try: entry = item.pop(statement) if len(item) == 1: testdata.assert_list.pop()
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -