📄 mapper.py
字号:
extlist = util.OrderedSet() extension = self.extension if extension is not None: for ext_obj in util.to_list(extension): # local MapperExtensions have already instrumented the class extlist.add(ext_obj) if self.inherits is not None: for ext in self.inherits.extension: if ext not in extlist: extlist.add(ext) ext.instrument_class(self, self.class_) else: for ext in global_extensions: if isinstance(ext, type): ext = ext() if ext not in extlist: extlist.add(ext) ext.instrument_class(self, self.class_) self.extension = ExtensionCarrier() for ext in extlist: self.extension.append(ext) def _compile_inheritance(self): """Determine if this Mapper inherits from another mapper, and if so calculates the mapped_table for this Mapper taking the inherited mapper into account. For joined table inheritance, creates a ``SyncRule`` that will synchronize column values between the joined tables. also initializes polymorphic variables used in polymorphic loads. """ if self.inherits is not None: if isinstance(self.inherits, type): self.inherits = class_mapper(self.inherits, compile=False) else: self.inherits = self.inherits if not issubclass(self.class_, self.inherits.class_): raise exceptions.ArgumentError("Class '%s' does not inherit from '%s'" % (self.class_.__name__, self.inherits.class_.__name__)) if self.non_primary != self.inherits.non_primary: np = not self.non_primary and "primary" or "non-primary" raise exceptions.ArgumentError("Inheritance of %s mapper for class '%s' is only allowed from a %s mapper" % (np, self.class_.__name__, np)) # inherit_condition is optional. if self.local_table is None: self.local_table = self.inherits.local_table self.single = True if not self.local_table is self.inherits.local_table: if self.concrete: self._synchronizer= None self.mapped_table = self.local_table else: if self.inherit_condition is None: # figure out inherit condition from our table to the immediate table # of the inherited mapper, not its full table which could pull in other # stuff we dont want (allows test/inheritance.InheritTest4 to pass) self.inherit_condition = sql.join(self.inherits.local_table, self.local_table).onclause self.mapped_table = sql.join(self.inherits.mapped_table, self.local_table, self.inherit_condition) # generate sync rules. similarly to creating the on clause, specify a # stricter set of tables to create "sync rules" by,based on the immediate # inherited table, rather than all inherited tables self._synchronizer = sync.ClauseSynchronizer(self, self, sync.ONETOMANY) if self.inherit_foreign_keys: fks = util.Set(self.inherit_foreign_keys) else: fks = None self._synchronizer.compile(self.mapped_table.onclause, foreign_keys=fks) else: self._synchronizer = None self.mapped_table = self.local_table if self.polymorphic_identity is not None: self.inherits.polymorphic_map[self.polymorphic_identity] = self if self.polymorphic_on is None: for mapper in self.iterate_to_root(): # try to set up polymorphic on using correesponding_column(); else leave # as None if mapper.polymorphic_on: self.polymorphic_on = self.mapped_table.corresponding_column(mapper.polymorphic_on) break else: # TODO: this exception not covered raise exceptions.ArgumentError("Mapper '%s' specifies a polymorphic_identity of '%s', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument" % (str(self), self.polymorphic_identity)) if self.polymorphic_identity is not None and not self.concrete: self._identity_class = self.inherits._identity_class else: self._identity_class = self.class_ if self.version_id_col is None: self.version_id_col = self.inherits.version_id_col for mapper in self.iterate_to_root(): if hasattr(mapper, '_genned_equivalent_columns'): del mapper._genned_equivalent_columns if self.order_by is False: self.order_by = self.inherits.order_by self.polymorphic_map = self.inherits.polymorphic_map self.batch = self.inherits.batch self.inherits._inheriting_mappers.add(self) self.base_mapper = self.inherits.base_mapper self._all_tables = self.inherits._all_tables else: self._all_tables = util.Set() self.base_mapper = self self._synchronizer = None self.mapped_table = self.local_table if self.polymorphic_identity is not None: if self.polymorphic_on is None: raise exceptions.ArgumentError("Mapper '%s' specifies a polymorphic_identity of '%s', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument" % (str(self), self.polymorphic_identity)) self.polymorphic_map[self.polymorphic_identity] = self self._identity_class = self.class_ if self.mapped_table is None: raise exceptions.ArgumentError("Mapper '%s' does not have a mapped_table specified. (Are you using the return value of table.create()? It no longer has a return value.)" % str(self)) def _compile_tables(self): # summary of the various Selectable units: # mapped_table - the Selectable that represents a join of the underlying Tables to be saved (or just the Table) # local_table - the Selectable that was passed to this Mapper's constructor, if any # select_table - the Selectable that will be used during queries. if this is specified # as a constructor keyword argument, it takes precendence over mapped_table, otherwise its mapped_table # this is either select_table if it was given explicitly, or in the case of a mapper that inherits # its local_table # tables - a collection of underlying Table objects pulled from mapped_table if self.select_table is None: self.select_table = self.mapped_table # locate all tables contained within the "table" passed in, which # may be a join or other construct self.tables = sqlutil.find_tables(self.mapped_table) if not self.tables: raise exceptions.InvalidRequestError("Could not find any Table objects in mapped table '%s'" % str(self.mapped_table)) def _compile_pks(self): self._pks_by_table = {} self._cols_by_table = {} all_cols = util.Set(chain(*[c2 for c2 in [col.proxy_set for col in [c for c in self._columntoproperty]]])) pk_cols = util.Set([c for c in all_cols if c.primary_key]) # identify primary key columns which are also mapped by this mapper. for t in util.Set(self.tables + [self.mapped_table]): self._all_tables.add(t) if t.primary_key and pk_cols.issuperset(t.primary_key): # ordering is important since it determines the ordering of mapper.primary_key (and therefore query.get()) self._pks_by_table[t] = util.OrderedSet(t.primary_key).intersection(pk_cols) self._cols_by_table[t] = util.OrderedSet(t.c).intersection(all_cols) # if explicit PK argument sent, add those columns to the primary key mappings if self.primary_key_argument: for k in self.primary_key_argument: if k.table not in self._pks_by_table: self._pks_by_table[k.table] = util.OrderedSet() self._pks_by_table[k.table].add(k) if self.mapped_table not in self._pks_by_table or len(self._pks_by_table[self.mapped_table]) == 0: raise exceptions.ArgumentError("Mapper %s could not assemble any primary key columns for mapped table '%s'" % (self, self.mapped_table.description)) if self.inherits is not None and not self.concrete and not self.primary_key_argument: # if inheriting, the "primary key" for this mapper is that of the inheriting (unless concrete or explicit) self.primary_key = self.inherits.primary_key self._get_clause = self.inherits._get_clause else: # determine primary key from argument or mapped_table pks - reduce to the minimal set of columns if self.primary_key_argument: primary_key = sqlutil.reduce_columns([self.mapped_table.corresponding_column(c) for c in self.primary_key_argument]) else: primary_key = sqlutil.reduce_columns(self._pks_by_table[self.mapped_table]) if len(primary_key) == 0: raise exceptions.ArgumentError("Mapper %s could not assemble any primary key columns for mapped table '%s'" % (self, self.mapped_table.description)) self.primary_key = primary_key self.__log("Identified primary key columns: " + str(primary_key)) # create a "get clause" based on the primary key. this is used # by query.get() and many-to-one lazyloads to load this item # by primary key. _get_clause = sql.and_() _get_params = {} for primary_key in self.primary_key: bind = sql.bindparam(None, type_=primary_key.type) _get_params[primary_key] = bind _get_clause.clauses.append(primary_key == bind) self._get_clause = (_get_clause, _get_params) def __get_equivalent_columns(self): """Create a map of all *equivalent* columns, based on the determination of column pairs that are equated to one another either by an established foreign key relationship or by a joined-table inheritance join. This is used to determine the minimal set of primary key columns for the mapper, as well as when relating columns to those of a polymorphic selectable (i.e. a UNION of several mapped tables), as that selectable usually only contains one column in its columns clause out of a group of several which are equated to each other. The resulting structure is a dictionary of columns mapped to lists of equivalent columns, i.e. { tablea.col1: set([tableb.col1, tablec.col1]), tablea.col2: set([tabled.col2]) } """ result = {} def visit_binary(binary): if binary.operator == operators.eq: if binary.left in result: result[binary.left].add(binary.right) else: result[binary.left] = util.Set([binary.right]) if binary.right in result: result[binary.right].add(binary.left) else: result[binary.right] = util.Set([binary.left]) for mapper in self.base_mapper.polymorphic_iterator(): if mapper.inherit_condition is not None: visitors.traverse(mapper.inherit_condition, visit_binary=visit_binary) # TODO: matching of cols to foreign keys might better be generalized # into general column translation (i.e. corresponding_column) # recursively descend into the foreign key collection of the given column # and assemble each FK-related col as an "equivalent" for the given column def equivs(col, recursive, equiv): if col in recursive: return recursive.add(col) for fk in col.foreign_keys: if fk.column not in result: result[fk.column] = util.Set() result[fk.column].add(equiv) equivs(fk.column, recursive, col) for column in (self.primary_key_argument or self._pks_by_table[self.mapped_table]): for col in column.proxy_set: if not col.foreign_keys: if col not in result: result[col] = util.Set() result[col].add(col) else: equivs(col, util.Set(), col) return result def _equivalent_columns(self): if hasattr(self, '_genned_equivalent_columns'): return self._genned_equivalent_columns else: self._genned_equivalent_columns = self.__get_equivalent_columns() return self._genned_equivalent_columns _equivalent_columns = property(_equivalent_columns) class _CompileOnAttr(PropComparator): """placeholder class attribute which fires mapper compilation on access""" def __init__(self, class_, key): self.class_ = class_ self.key = key self.existing_prop = getattr(class_, key, None) def __getattribute__(self, key): cls = object.__getattribute__(self, 'class_') clskey = object.__getattribute__(self, 'key') if key.startswith('__'): return object.__getattribute__(self, key)
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -