📄 testing.py
字号:
item = (statement, entry) except KeyError: assert False, "Testing for one of the following queries: %s, received '%s'" % (repr([k for k in item.keys()]), statement) (query, params) = item if callable(params): params = params(ctx) if params is not None and not isinstance(params, list): params = [params] parameters = ctx.compiled_parameters query = self.convert_statement(query) testdata.unittest.assert_(statement == query and (params is None or params == parameters), "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters))) testdata.sql_count += 1 self.ctx.post_execution() def convert_statement(self, query): paramstyle = self.ctx.dialect.paramstyle if paramstyle == 'named': pass elif paramstyle =='pyformat': query = re.sub(r':([\w_]+)', r"%(\1)s", query) else: # positional params repl = None if paramstyle=='qmark': repl = "?" elif paramstyle=='format': repl = r"%s" elif paramstyle=='numeric': repl = None query = re.sub(r':([\w_]+)', repl, query) return queryclass TestBase(unittest.TestCase): # A sequence of dialect names to exclude from the test class. __unsupported_on__ = () # If present, test class is only runnable for the *single* specified # dialect. If you need multiple, use __unsupported_on__ and invert. __only_on__ = None # A sequence of no-arg callables. If any are True, the entire testcase is # skipped. __skip_if__ = None def __init__(self, *args, **params): unittest.TestCase.__init__(self, *args, **params) def setUpAll(self): pass def tearDownAll(self): pass def shortDescription(self): """overridden to not return docstrings""" return None if not hasattr(unittest.TestCase, 'assertTrue'): assertTrue = unittest.TestCase.failUnless if not hasattr(unittest.TestCase, 'assertFalse'): assertFalse = unittest.TestCase.failIfclass AssertsCompiledSQL(object): def assert_compile(self, clause, result, params=None, checkparams=None, dialect=None): if dialect is None: dialect = getattr(self, '__dialect__', None) if params is None: keys = None else: keys = params.keys() c = clause.compile(column_keys=keys, dialect=dialect) print "\nSQL String:\n" + str(c) + repr(c.params) cc = re.sub(r'\n', '', str(c)) self.assertEquals(cc, result) if checkparams is not None: self.assertEquals(c.construct_params(params), checkparams)class ComparesTables(object): def assert_tables_equal(self, table, reflected_table): global sqltypes, schema if sqltypes is None: import sqlalchemy.types as sqltypes if schema is None: import sqlalchemy.schema as schema base_mro = sqltypes.TypeEngine.__mro__ assert len(table.c) == len(reflected_table.c) for c, reflected_c in zip(table.c, reflected_table.c): self.assertEquals(c.name, reflected_c.name) assert reflected_c is reflected_table.c[c.name] self.assertEquals(c.primary_key, reflected_c.primary_key) self.assertEquals(c.nullable, reflected_c.nullable) assert len( set(type(reflected_c.type).__mro__).difference(base_mro).intersection( set(type(c.type).__mro__).difference(base_mro) ) ) > 0, "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type) if isinstance(c.type, sqltypes.String): self.assertEquals(c.type.length, reflected_c.type.length) self.assertEquals(set([f.column.name for f in c.foreign_keys]), set([f.column.name for f in reflected_c.foreign_keys])) if c.default: assert isinstance(reflected_c.default, schema.PassiveDefault) elif not c.primary_key or not against('postgres'): assert reflected_c.default is None assert len(table.primary_key) == len(reflected_table.primary_key) for c in table.primary_key: assert reflected_table.primary_key.columns[c.name] class AssertsExecutionResults(object): def assert_result(self, result, class_, *objects): result = list(result) print repr(result) self.assert_list(result, class_, objects) def assert_list(self, result, class_, list): self.assert_(len(result) == len(list), "result list is not the same size as test list, " + "for class " + class_.__name__) for i in range(0, len(list)): self.assert_row(class_, result[i], list[i]) def assert_row(self, class_, rowobj, desc): self.assert_(rowobj.__class__ is class_, "item class is not " + repr(class_)) for key, value in desc.iteritems(): if isinstance(value, tuple): if isinstance(value[1], list): self.assert_list(getattr(rowobj, key), value[0], value[1]) else: self.assert_row(value[0], getattr(rowobj, key), value[1]) else: self.assert_(getattr(rowobj, key) == value, "attribute %s value %s does not match %s" % ( key, getattr(rowobj, key), value)) def assert_unordered_result(self, result, cls, *expected): """As assert_result, but the order of objects is not considered. The algorithm is very expensive but not a big deal for the small numbers of rows that the test suite manipulates. """ global util if util is None: from sqlalchemy import util class frozendict(dict): def __hash__(self): return id(self) found = util.IdentitySet(result) expected = set([frozendict(e) for e in expected]) for wrong in itertools.ifilterfalse(lambda o: type(o) == cls, found): self.fail('Unexpected type "%s", expected "%s"' % ( type(wrong).__name__, cls.__name__)) if len(found) != len(expected): self.fail('Unexpected object count "%s", expected "%s"' % ( len(found), len(expected))) NOVALUE = object() def _compare_item(obj, spec): for key, value in spec.iteritems(): if isinstance(value, tuple): try: self.assert_unordered_result( getattr(obj, key), value[0], *value[1]) except AssertionError: return False else: if getattr(obj, key, NOVALUE) != value: return False return True for expected_item in expected: for found_item in found: if _compare_item(found_item, expected_item): found.remove(found_item) break else: self.fail( "Expected %s instance with attributes %s not found." % ( cls.__name__, repr(expected_item))) return True def assert_sql(self, db, callable_, list, with_sequences=None): global testdata testdata = TestData() if with_sequences is not None and config.db.name in ('firebird', 'oracle', 'postgres'): testdata.set_assert_list(self, with_sequences) else: testdata.set_assert_list(self, list) try: callable_() finally: testdata.set_assert_list(None, None) def assert_sql_count(self, db, callable_, count): global testdata testdata = TestData() callable_() self.assert_(testdata.sql_count == count, "desired statement count %d does not match %d" % ( count, testdata.sql_count)) def capture_sql(self, db, callable_): global testdata testdata = TestData() buffer = StringIO() testdata.buffer = buffer try: callable_() return buffer.getvalue() finally: testdata.buffer = None_otest_metadata = Noneclass ORMTest(TestBase, AssertsExecutionResults): keep_mappers = False keep_data = False metadata = None def setUpAll(self): global MetaData, _otest_metadata if MetaData is None: from sqlalchemy import MetaData if self.metadata is None: _otest_metadata = MetaData(config.db) else: _otest_metadata = self.metadata if self.metadata.bind is None: _otest_metadata.bind = config.db self.define_tables(_otest_metadata) _otest_metadata.create_all() self.insert_data() def define_tables(self, _otest_metadata): raise NotImplementedError() def insert_data(self): pass def get_metadata(self): return _otest_metadata def tearDownAll(self): global clear_mappers if clear_mappers is None: from sqlalchemy.orm import clear_mappers clear_mappers() _otest_metadata.drop_all() def tearDown(self): global Session if Session is None: from sqlalchemy.orm.session import Session Session.close_all() global clear_mappers if clear_mappers is None: from sqlalchemy.orm import clear_mappers if not self.keep_mappers: clear_mappers() if not self.keep_data: for t in _otest_metadata.table_iterator(reverse=True): try: t.delete().execute().close() except Exception, e: print "EXCEPTION DELETING...", eclass TTestSuite(unittest.TestSuite): """A TestSuite with once per TestCase setUpAll() and tearDownAll()""" def __init__(self, tests=()): if len(tests) > 0 and isinstance(tests[0], TestBase): self._initTest = tests[0] else: self._initTest = None unittest.TestSuite.__init__(self, tests) def do_run(self, result): # nice job unittest ! you switched __call__ and run() between py2.3 # and 2.4 thereby making straight subclassing impossible ! for test in self._tests: if result.shouldStop: break test(result) return result def run(self, result): return self(result) def __call__(self, result): init = getattr(self, '_initTest', None) if init is not None: if (hasattr(init, '__unsupported_on__') and config.db.name in init.__unsupported_on__): print "'%s' unsupported on DB implementation '%s'" % ( init.__class__.__name__, config.db.name) return True if (getattr(init, '__only_on__', None) not in (None,config.db.name)): print "'%s' unsupported on DB implementation '%s'" % ( init.__class__.__name__, config.db.name) return True if (getattr(init, '__skip_if__', False)): for c in getattr(init, '__skip_if__'): if c(): print "'%s' skipped by %s" % ( init.__class__.__name__, c.__name__) return True for rule in getattr(init, '__excluded_on__', ()): if _is_excluded(*rule): print "'%s' unsupported on DB %s version %s" % ( init.__class__.__name__, config.db.name, _server_version()) return True try: resetwarnings() init.setUpAll() except: # skip tests if global setup fails ex = self.__exc_info() for test in self._tests: result.addError(test, ex) return False try: resetwarnings() return self.do_run(result) finally: try: resetwarnings() if init is not None: init.tearDownAll() except: result.addError(init, self.__exc_info()) pass def __exc_info(self): """Return a version of sys.exc_info() with the traceback frame minimised; usually the top level of the traceback frame is not needed. ripped off out of unittest module since its double __ """ exctype, excvalue, tb = sys.exc_info() if sys.platform[:4] == 'java': ## tracebacks look different in Jython return (exctype, excvalue, tb) return (exctype, excvalue, tb)# monkeypatchunittest.TestLoader.suiteClass = TTestSuiteclass DevNullWriter(object): def write(self, msg): pass def flush(self): passdef runTests(suite): verbose = config.options.verbose quiet = config.options.quiet orig_stdout = sys.stdout try: if not verbose or quiet: sys.stdout = DevNullWriter() runner = unittest.TextTestRunner(verbosity = quiet and 1 or 2) return runner.run(suite) finally: if not verbose or quiet: sys.stdout = orig_stdoutdef main(suite=None): if not suite: if sys.argv[1:]: suite =unittest.TestLoader().loadTestsFromNames( sys.argv[1:], __import__('__main__')) else: suite = unittest.TestLoader().loadTestsFromModule( __import__('__main__')) result = runTests(suite) sys.exit(not result.wasSuccessful())
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -