📄 unitofwork.py
字号:
# orm/unitofwork.py# Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com## This module is part of SQLAlchemy and is released under# the MIT License: http://www.opensource.org/licenses/mit-license.php"""The internals for the Unit Of Work system.Includes hooks into the attributes package enabling the routing ofchange events to Unit Of Work objects, as well as the flush()mechanism which creates a dependency structure that executes changeoperations.A Unit of Work is essentially a system of maintaining a graph ofin-memory objects and their modified state. Objects are maintained asunique against their primary key identity using an *identity map*pattern. The Unit of Work then maintains lists of objects that arenew, dirty, or deleted and provides the capability to flush all thosechanges at once."""import StringIO, weakreffrom sqlalchemy import util, logging, topological, exceptionsfrom sqlalchemy.orm import attributes, interfacesfrom sqlalchemy.orm import util as mapperutilfrom sqlalchemy.orm.mapper import object_mapper, _state_mapper# Load lazilyobject_session = Noneclass UOWEventHandler(interfaces.AttributeExtension): """An event handler added to all relation attributes which handles session cascade operations. """ def __init__(self, key, class_, cascade=None): self.key = key self.class_ = class_ self.cascade = cascade def append(self, obj, item, initiator): # process "save_update" cascade rules for when an instance is appended to the list of another instance sess = object_session(obj) if sess is not None: if self.cascade is not None and self.cascade.save_update and item not in sess: mapper = object_mapper(obj) prop = mapper.get_property(self.key) ename = prop.mapper.entity_name sess.save_or_update(item, entity_name=ename) def remove(self, obj, item, initiator): # currently no cascade rules for removing an item from a list # (i.e. it stays in the Session) pass def set(self, obj, newvalue, oldvalue, initiator): # process "save_update" cascade rules for when an instance is attached to another instance sess = object_session(obj) if sess is not None: if newvalue is not None and self.cascade is not None and self.cascade.save_update and newvalue not in sess: mapper = object_mapper(obj) prop = mapper.get_property(self.key) ename = prop.mapper.entity_name sess.save_or_update(newvalue, entity_name=ename)def register_attribute(class_, key, *args, **kwargs): """overrides attributes.register_attribute() to add UOW event handlers to new InstrumentedAttributes. """ cascade = kwargs.pop('cascade', None) useobject = kwargs.get('useobject', False) if useobject: # for object-holding attributes, instrument UOWEventHandler # to process per-attribute cascades extension = util.to_list(kwargs.pop('extension', None) or []) extension.insert(0, UOWEventHandler(key, class_, cascade=cascade)) kwargs['extension'] = extension return attributes.register_attribute(class_, key, *args, **kwargs) class UnitOfWork(object): """Main UOW object which stores lists of dirty/new/deleted objects. Provides top-level *flush* functionality as well as the default transaction boundaries involved in a write operation. """ def __init__(self, session): if session.weak_identity_map: self.identity_map = attributes.WeakInstanceDict() else: self.identity_map = attributes.StrongInstanceDict() self.new = {} # InstanceState->object, strong refs object self.deleted = {} # same self.logger = logging.instance_logger(self, echoflag=session.echo_uow) def _remove_deleted(self, state): if '_instance_key' in state.dict: del self.identity_map[state.dict['_instance_key']] self.deleted.pop(state, None) self.new.pop(state, None) def _is_valid(self, state): if '_instance_key' in state.dict: return state.dict['_instance_key'] in self.identity_map else: return state in self.new def _register_clean(self, state): """register the given object as 'clean' (i.e. persistent) within this unit of work, after a save operation has taken place.""" mapper = _state_mapper(state) instance_key = mapper._identity_key_from_state(state) if '_instance_key' not in state.dict: state.dict['_instance_key'] = instance_key elif state.dict['_instance_key'] != instance_key: # primary key switch del self.identity_map[state.dict['_instance_key']] state.dict['_instance_key'] = instance_key if hasattr(state, 'insert_order'): delattr(state, 'insert_order') self.identity_map[state.dict['_instance_key']] = state.obj() state.commit_all() # remove from new last, might be the last strong ref self.new.pop(state, None) def register_new(self, obj): """register the given object as 'new' (i.e. unsaved) within this unit of work.""" if hasattr(obj, '_instance_key'): raise exceptions.InvalidRequestError("Object '%s' already has an identity - it can't be registered as new" % repr(obj)) if obj._state not in self.new: self.new[obj._state] = obj obj._state.insert_order = len(self.new) def register_deleted(self, obj): """register the given persistent object as 'to be deleted' within this unit of work.""" self.deleted[obj._state] = obj def locate_dirty(self): """return a set of all persistent instances within this unit of work which either contain changes or are marked as deleted. """ # a little bit of inlining for speed return util.IdentitySet([x for x in self.identity_map.values() if x._state not in self.deleted and ( x._state.modified or (x.__class__._class_state.has_mutable_scalars and x._state.is_modified()) ) ]) def flush(self, session, objects=None): """create a dependency tree of all pending SQL operations within this unit of work and execute.""" dirty = [x for x in self.identity_map.all_states() if x.modified or (x.class_._class_state.has_mutable_scalars and x.is_modified()) ] if not dirty and not self.deleted and not self.new: return deleted = util.Set(self.deleted) new = util.Set(self.new) dirty = util.Set(dirty).difference(deleted) flush_context = UOWTransaction(self, session) if session.extension is not None: session.extension.before_flush(session, flush_context, objects) # create the set of all objects we want to operate upon if objects: # specific list passed in objset = util.Set([o._state for o in objects]) else: # or just everything objset = util.Set(self.identity_map.all_states()).union(new) # store objects whose fate has been decided processed = util.Set() # put all saves/updates into the flush context. detect top-level orphans and throw them into deleted. for state in new.union(dirty).intersection(objset).difference(deleted): if state in processed: continue flush_context.register_object(state, isdelete=_state_mapper(state)._is_orphan(state.obj())) processed.add(state) # put all remaining deletes into the flush context. for state in deleted.intersection(objset).difference(processed): flush_context.register_object(state, isdelete=True) if len(flush_context.tasks) == 0: return session.create_transaction(autoflush=False) flush_context.transaction = session.transaction try: flush_context.execute() if session.extension is not None: session.extension.after_flush(session, flush_context) session.commit() except: session.rollback() raise flush_context.post_exec() if session.extension is not None: session.extension.after_flush_postexec(session, flush_context) def prune_identity_map(self): """Removes unreferenced instances cached in a strong-referencing identity map. Note that this method is only meaningful if "weak_identity_map" on the parent Session is set to False and therefore this UnitOfWork's identity map is a regular dictionary Removes any object in the identity map that is not referenced in user code or scheduled for a unit of work operation. Returns the number of objects pruned. """ if isinstance(self.identity_map, attributes.WeakInstanceDict): return 0 ref_count = len(self.identity_map) dirty = self.locate_dirty() keepers = weakref.WeakValueDictionary(self.identity_map) self.identity_map.clear() self.identity_map.update(keepers) return ref_count - len(self.identity_map)class UOWTransaction(object): """Handles the details of organizing and executing transaction tasks during a UnitOfWork object's flush() operation. The central operation is to form a graph of nodes represented by the ``UOWTask`` class, which is then traversed by a ``UOWExecutor`` object that issues SQL and instance-synchronizing operations via the related packages. """ def __init__(self, uow, session): self.uow = uow self.session = session self.mapper_flush_opts = session._mapper_flush_opts # stores tuples of mapper/dependent mapper pairs, # representing a partial ordering fed into topological sort self.dependencies = util.Set() # dictionary of mappers to UOWTasks self.tasks = {} # dictionary used by external actors to store arbitrary state # information. self.attributes = {} self.logger = logging.instance_logger(self, echoflag=session.echo_uow) def get_attribute_history(self, state, key, passive=True): hashkey = ("history", state, key) # cache the objects, not the states; the strong reference here # prevents newly loaded objects from being dereferenced during the # flush process if hashkey in self.attributes: (added, unchanged, deleted, cached_passive) = self.attributes[hashkey] # if the cached lookup was "passive" and now we want non-passive, do a non-passive # lookup and re-cache if cached_passive and not passive: (added, unchanged, deleted) = attributes.get_history(state, key, passive=False) self.attributes[hashkey] = (added, unchanged, deleted, passive) else: (added, unchanged, deleted) = attributes.get_history(state, key, passive=passive) self.attributes[hashkey] = (added, unchanged, deleted, passive) if added is None: return (added, unchanged, deleted) else: return ( [getattr(c, '_state', c) for c in added], [getattr(c, '_state', c) for c in unchanged], [getattr(c, '_state', c) for c in deleted], ) def register_object(self, state, isdelete = False, listonly = False, postupdate=False, post_update_cols=None, **kwargs): # if object is not in the overall session, do nothing if not self.uow._is_valid(state): if self._should_log_debug: self.logger.debug("object %s not part of session, not registering for flush" % (mapperutil.state_str(state))) return if self._should_log_debug: self.logger.debug("register object for flush: %s isdelete=%s listonly=%s postupdate=%s" % (mapperutil.state_str(state), isdelete, listonly, postupdate)) mapper = _state_mapper(state) task = self.get_task_by_mapper(mapper) if postupdate: task.append_postupdate(state, post_update_cols) else: task.append(state, listonly, isdelete=isdelete, **kwargs) def set_row_switch(self, state): """mark a deleted object as a 'row switch'.
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -