1# orm/unitofwork.py
2# Copyright (C) 2005-2026 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: ignore-errors
8
9
10"""The internals for the unit of work system.
11
12The session's flush() process passes objects to a contextual object
13here, which assembles flush tasks based on mappers and their properties,
14organizes them in order of dependency, and executes.
15
16"""
17
18from __future__ import annotations
19
20from typing import Any
21from typing import Dict
22from typing import Optional
23from typing import Set
24from typing import TYPE_CHECKING
25
26from . import attributes
27from . import exc as orm_exc
28from . import util as orm_util
29from .. import event
30from .. import util
31from ..util import topological
32
33if TYPE_CHECKING:
34 from .dependency import _DependencyProcessor
35 from .interfaces import MapperProperty
36 from .mapper import Mapper
37 from .session import Session
38 from .session import SessionTransaction
39 from .state import InstanceState
40
41
42def _track_cascade_events(descriptor, prop):
43 """Establish event listeners on object attributes which handle
44 cascade-on-set/append.
45
46 """
47 key = prop.key
48
49 def append(state, item, initiator, **kw):
50 # process "save_update" cascade rules for when
51 # an instance is appended to the list of another instance
52
53 if item is None:
54 return
55
56 sess = state.session
57 if sess:
58 if sess._warn_on_events:
59 sess._flush_warning("collection append")
60
61 prop = state.manager.mapper._props[key]
62 item_state = attributes.instance_state(item)
63
64 if (
65 prop._cascade.save_update
66 and (key == initiator.key)
67 and not sess._contains_state(item_state)
68 ):
69 sess._save_or_update_state(item_state)
70 return item
71
72 def remove(state, item, initiator, **kw):
73 if item is None:
74 return
75
76 sess = state.session
77
78 prop = state.manager.mapper._props[key]
79
80 if sess and sess._warn_on_events:
81 sess._flush_warning(
82 "collection remove"
83 if prop.uselist
84 else "related attribute delete"
85 )
86
87 if (
88 item is not None
89 and item is not attributes.NEVER_SET
90 and item is not attributes.PASSIVE_NO_RESULT
91 and prop._cascade.delete_orphan
92 ):
93 # expunge pending orphans
94 item_state = attributes.instance_state(item)
95
96 if prop.mapper._is_orphan(item_state):
97 if sess and item_state in sess._new:
98 sess.expunge(item)
99 else:
100 # the related item may or may not itself be in a
101 # Session, however the parent for which we are catching
102 # the event is not in a session, so memoize this on the
103 # item
104 item_state._orphaned_outside_of_session = True
105
106 def set_(state, newvalue, oldvalue, initiator, **kw):
107 # process "save_update" cascade rules for when an instance
108 # is attached to another instance
109 if oldvalue is newvalue:
110 return newvalue
111
112 sess = state.session
113 if sess:
114 if sess._warn_on_events:
115 sess._flush_warning("related attribute set")
116
117 prop = state.manager.mapper._props[key]
118 if newvalue is not None:
119 newvalue_state = attributes.instance_state(newvalue)
120 if (
121 prop._cascade.save_update
122 and (key == initiator.key)
123 and not sess._contains_state(newvalue_state)
124 ):
125 sess._save_or_update_state(newvalue_state)
126
127 if (
128 oldvalue is not None
129 and oldvalue is not attributes.NEVER_SET
130 and oldvalue is not attributes.PASSIVE_NO_RESULT
131 and prop._cascade.delete_orphan
132 ):
133 # possible to reach here with attributes.NEVER_SET ?
134 oldvalue_state = attributes.instance_state(oldvalue)
135
136 if oldvalue_state in sess._new and prop.mapper._is_orphan(
137 oldvalue_state
138 ):
139 sess.expunge(oldvalue)
140 return newvalue
141
142 event.listen(
143 descriptor, "append_wo_mutation", append, raw=True, include_key=True
144 )
145 event.listen(
146 descriptor, "append", append, raw=True, retval=True, include_key=True
147 )
148 event.listen(
149 descriptor, "remove", remove, raw=True, retval=True, include_key=True
150 )
151 event.listen(
152 descriptor, "set", set_, raw=True, retval=True, include_key=True
153 )
154
155
156class UOWTransaction:
157 """Manages the internal state of a unit of work flush operation."""
158
159 session: Session
160 transaction: SessionTransaction
161 attributes: Dict[str, Any]
162 deps: util.defaultdict[Mapper[Any], Set[_DependencyProcessor]]
163 mappers: util.defaultdict[Mapper[Any], Set[InstanceState[Any]]]
164
165 def __init__(self, session: Session):
166 self.session = session
167
168 # dictionary used by external actors to
169 # store arbitrary state information.
170 self.attributes = {}
171
172 # dictionary of mappers to sets of
173 # DependencyProcessors, which are also
174 # set to be part of the sorted flush actions,
175 # which have that mapper as a parent.
176 self.deps = util.defaultdict(set)
177
178 # dictionary of mappers to sets of InstanceState
179 # items pending for flush which have that mapper
180 # as a parent.
181 self.mappers = util.defaultdict(set)
182
183 # a dictionary of Preprocess objects, which gather
184 # additional states impacted by the flush
185 # and determine if a flush action is needed
186 self.presort_actions = {}
187
188 # dictionary of PostSortRec objects, each
189 # one issues work during the flush within
190 # a certain ordering.
191 self.postsort_actions = {}
192
193 # a set of 2-tuples, each containing two
194 # PostSortRec objects where the second
195 # is dependent on the first being executed
196 # first
197 self.dependencies = set()
198
199 # dictionary of InstanceState-> (isdelete, listonly)
200 # tuples, indicating if this state is to be deleted
201 # or insert/updated, or just refreshed
202 self.states = {}
203
204 # tracks InstanceStates which will be receiving
205 # a "post update" call. Keys are mappers,
206 # values are a set of states and a set of the
207 # columns which should be included in the update.
208 self.post_update_states = util.defaultdict(lambda: (set(), set()))
209
210 @property
211 def has_work(self):
212 return bool(self.states)
213
214 def was_already_deleted(self, state):
215 """Return ``True`` if the given state is expired and was deleted
216 previously.
217 """
218 if state.expired:
219 try:
220 state._load_expired(state, attributes.PASSIVE_OFF)
221 except orm_exc.ObjectDeletedError:
222 self.session._remove_newly_deleted([state])
223 return True
224 return False
225
226 def is_deleted(self, state):
227 """Return ``True`` if the given state is marked as deleted
228 within this uowtransaction."""
229
230 return state in self.states and self.states[state][0]
231
232 def memo(self, key, callable_):
233 if key in self.attributes:
234 return self.attributes[key]
235 else:
236 self.attributes[key] = ret = callable_()
237 return ret
238
239 def remove_state_actions(self, state):
240 """Remove pending actions for a state from the uowtransaction."""
241
242 isdelete = self.states[state][0]
243
244 self.states[state] = (isdelete, True)
245
246 def get_attribute_history(
247 self, state, key, passive=attributes.PASSIVE_NO_INITIALIZE
248 ):
249 """Facade to attributes.get_state_history(), including
250 caching of results."""
251
252 hashkey = ("history", state, key)
253
254 # cache the objects, not the states; the strong reference here
255 # prevents newly loaded objects from being dereferenced during the
256 # flush process
257
258 if hashkey in self.attributes:
259 history, state_history, cached_passive = self.attributes[hashkey]
260 # if the cached lookup was "passive" and now
261 # we want non-passive, do a non-passive lookup and re-cache
262
263 if (
264 not cached_passive & attributes.SQL_OK
265 and passive & attributes.SQL_OK
266 ):
267 impl = state.manager[key].impl
268 history = impl.get_history(
269 state,
270 state.dict,
271 attributes.PASSIVE_OFF
272 | attributes.LOAD_AGAINST_COMMITTED
273 | attributes.NO_RAISE,
274 )
275 if history and impl.uses_objects:
276 state_history = history.as_state()
277 else:
278 state_history = history
279 self.attributes[hashkey] = (history, state_history, passive)
280 else:
281 impl = state.manager[key].impl
282 # TODO: store the history as (state, object) tuples
283 # so we don't have to keep converting here
284 history = impl.get_history(
285 state,
286 state.dict,
287 passive
288 | attributes.LOAD_AGAINST_COMMITTED
289 | attributes.NO_RAISE,
290 )
291 if history and impl.uses_objects:
292 state_history = history.as_state()
293 else:
294 state_history = history
295 self.attributes[hashkey] = (history, state_history, passive)
296
297 return state_history
298
299 def has_dep(self, processor):
300 return (processor, True) in self.presort_actions
301
302 def register_preprocessor(self, processor, fromparent):
303 key = (processor, fromparent)
304 if key not in self.presort_actions:
305 self.presort_actions[key] = _Preprocess(processor, fromparent)
306
307 def register_object(
308 self,
309 state: InstanceState[Any],
310 isdelete: bool = False,
311 listonly: bool = False,
312 cancel_delete: bool = False,
313 operation: Optional[str] = None,
314 prop: Optional[MapperProperty] = None,
315 ) -> bool:
316 if not self.session._contains_state(state):
317 # this condition is normal when objects are registered
318 # as part of a relationship cascade operation. it should
319 # not occur for the top-level register from Session.flush().
320 if not state.deleted and operation is not None:
321 util.warn(
322 "Object of type %s not in session, %s operation "
323 "along '%s' will not proceed"
324 % (orm_util.state_class_str(state), operation, prop)
325 )
326 return False
327
328 if state not in self.states:
329 mapper = state.manager.mapper
330
331 if mapper not in self.mappers:
332 self._per_mapper_flush_actions(mapper)
333
334 self.mappers[mapper].add(state)
335 self.states[state] = (isdelete, listonly)
336 else:
337 if not listonly and (isdelete or cancel_delete):
338 self.states[state] = (isdelete, False)
339 return True
340
341 def register_post_update(self, state, post_update_cols):
342 mapper = state.manager.mapper.base_mapper
343 states, cols = self.post_update_states[mapper]
344 states.add(state)
345 cols.update(post_update_cols)
346
347 def _per_mapper_flush_actions(self, mapper):
348 saves = _SaveUpdateAll(self, mapper.base_mapper)
349 deletes = _DeleteAll(self, mapper.base_mapper)
350 self.dependencies.add((saves, deletes))
351
352 for dep in mapper._dependency_processors:
353 dep.per_property_preprocessors(self)
354
355 for prop in mapper.relationships:
356 if prop.viewonly:
357 continue
358 dep = prop._dependency_processor
359 dep.per_property_preprocessors(self)
360
361 @util.memoized_property
362 def _mapper_for_dep(self):
363 """return a dynamic mapping of (Mapper, DependencyProcessor) to
364 True or False, indicating if the DependencyProcessor operates
365 on objects of that Mapper.
366
367 The result is stored in the dictionary persistently once
368 calculated.
369
370 """
371 return util.PopulateDict(
372 lambda tup: tup[0]._props.get(tup[1].key) is tup[1].prop
373 )
374
375 def filter_states_for_dep(self, dep, states):
376 """Filter the given list of InstanceStates to those relevant to the
377 given DependencyProcessor.
378
379 """
380 mapper_for_dep = self._mapper_for_dep
381 return [s for s in states if mapper_for_dep[(s.manager.mapper, dep)]]
382
383 def states_for_mapper_hierarchy(self, mapper, isdelete, listonly):
384 checktup = (isdelete, listonly)
385 for mapper in mapper.base_mapper.self_and_descendants:
386 for state in self.mappers[mapper]:
387 if self.states[state] == checktup:
388 yield state
389
390 def _generate_actions(self):
391 """Generate the full, unsorted collection of PostSortRecs as
392 well as dependency pairs for this UOWTransaction.
393
394 """
395 # execute presort_actions, until all states
396 # have been processed. a presort_action might
397 # add new states to the uow.
398 while True:
399 ret = False
400 for action in list(self.presort_actions.values()):
401 if action.execute(self):
402 ret = True
403 if not ret:
404 break
405
406 # see if the graph of mapper dependencies has cycles.
407 self.cycles = cycles = topological.find_cycles(
408 self.dependencies, list(self.postsort_actions.values())
409 )
410
411 if cycles:
412 # if yes, break the per-mapper actions into
413 # per-state actions
414 convert = {
415 rec: set(rec.per_state_flush_actions(self)) for rec in cycles
416 }
417
418 # rewrite the existing dependencies to point to
419 # the per-state actions for those per-mapper actions
420 # that were broken up.
421 for edge in list(self.dependencies):
422 if (
423 None in edge
424 or edge[0].disabled
425 or edge[1].disabled
426 or cycles.issuperset(edge)
427 ):
428 self.dependencies.remove(edge)
429 elif edge[0] in cycles:
430 self.dependencies.remove(edge)
431 for dep in convert[edge[0]]:
432 self.dependencies.add((dep, edge[1]))
433 elif edge[1] in cycles:
434 self.dependencies.remove(edge)
435 for dep in convert[edge[1]]:
436 self.dependencies.add((edge[0], dep))
437
438 return {
439 a for a in self.postsort_actions.values() if not a.disabled
440 }.difference(cycles)
441
442 def execute(self) -> None:
443 postsort_actions = self._generate_actions()
444
445 postsort_actions = sorted(
446 postsort_actions,
447 key=lambda item: item.sort_key,
448 )
449 # sort = topological.sort(self.dependencies, postsort_actions)
450 # print "--------------"
451 # print "\ndependencies:", self.dependencies
452 # print "\ncycles:", self.cycles
453 # print "\nsort:", list(sort)
454 # print "\nCOUNT OF POSTSORT ACTIONS", len(postsort_actions)
455
456 # execute
457 if self.cycles:
458 for subset in topological.sort_as_subsets(
459 self.dependencies, postsort_actions
460 ):
461 set_ = set(subset)
462 while set_:
463 n = set_.pop()
464 n.execute_aggregate(self, set_)
465 else:
466 for rec in topological.sort(self.dependencies, postsort_actions):
467 rec.execute(self)
468
469 def finalize_flush_changes(self) -> None:
470 """Mark processed objects as clean / deleted after a successful
471 flush().
472
473 This method is called within the flush() method after the
474 execute() method has succeeded and the transaction has been committed.
475
476 """
477 if not self.states:
478 return
479
480 states = set(self.states)
481 isdel = {
482 s for (s, (isdelete, listonly)) in self.states.items() if isdelete
483 }
484 other = states.difference(isdel)
485 if isdel:
486 self.session._remove_newly_deleted(isdel)
487 if other:
488 self.session._register_persistent(other)
489
490
491class _IterateMappersMixin:
492 __slots__ = ()
493
494 def _mappers(self, uow):
495 if self.fromparent:
496 return iter(
497 m
498 for m in self.dependency_processor.parent.self_and_descendants
499 if uow._mapper_for_dep[(m, self.dependency_processor)]
500 )
501 else:
502 return self.dependency_processor.mapper.self_and_descendants
503
504
505class _Preprocess(_IterateMappersMixin):
506 __slots__ = (
507 "dependency_processor",
508 "fromparent",
509 "processed",
510 "setup_flush_actions",
511 )
512
513 def __init__(self, dependency_processor, fromparent):
514 self.dependency_processor = dependency_processor
515 self.fromparent = fromparent
516 self.processed = set()
517 self.setup_flush_actions = False
518
519 def execute(self, uow):
520 delete_states = set()
521 save_states = set()
522
523 for mapper in self._mappers(uow):
524 for state in uow.mappers[mapper].difference(self.processed):
525 isdelete, listonly = uow.states[state]
526 if not listonly:
527 if isdelete:
528 delete_states.add(state)
529 else:
530 save_states.add(state)
531
532 if delete_states:
533 self.dependency_processor.presort_deletes(uow, delete_states)
534 self.processed.update(delete_states)
535 if save_states:
536 self.dependency_processor.presort_saves(uow, save_states)
537 self.processed.update(save_states)
538
539 if delete_states or save_states:
540 if not self.setup_flush_actions and (
541 self.dependency_processor.prop_has_changes(
542 uow, delete_states, True
543 )
544 or self.dependency_processor.prop_has_changes(
545 uow, save_states, False
546 )
547 ):
548 self.dependency_processor.per_property_flush_actions(uow)
549 self.setup_flush_actions = True
550 return True
551 else:
552 return False
553
554
555class _PostSortRec:
556 __slots__ = ("disabled",)
557
558 def __new__(cls, uow, *args):
559 key = (cls,) + args
560 if key in uow.postsort_actions:
561 return uow.postsort_actions[key]
562 else:
563 uow.postsort_actions[key] = ret = object.__new__(cls)
564 ret.disabled = False
565 return ret
566
567 def execute_aggregate(self, uow, recs):
568 self.execute(uow)
569
570
571class _ProcessAll(_IterateMappersMixin, _PostSortRec):
572 __slots__ = "dependency_processor", "isdelete", "fromparent", "sort_key"
573
574 def __init__(self, uow, dependency_processor, isdelete, fromparent):
575 self.dependency_processor = dependency_processor
576 self.sort_key = (
577 "ProcessAll",
578 self.dependency_processor.sort_key,
579 isdelete,
580 )
581 self.isdelete = isdelete
582 self.fromparent = fromparent
583 uow.deps[dependency_processor.parent.base_mapper].add(
584 dependency_processor
585 )
586
587 def execute(self, uow):
588 states = self._elements(uow)
589 if self.isdelete:
590 self.dependency_processor.process_deletes(uow, states)
591 else:
592 self.dependency_processor.process_saves(uow, states)
593
594 def per_state_flush_actions(self, uow):
595 # this is handled by SaveUpdateAll and DeleteAll,
596 # since a ProcessAll should unconditionally be pulled
597 # into per-state if either the parent/child mappers
598 # are part of a cycle
599 return iter([])
600
601 def __repr__(self):
602 return "%s(%s, isdelete=%s)" % (
603 self.__class__.__name__,
604 self.dependency_processor,
605 self.isdelete,
606 )
607
608 def _elements(self, uow):
609 for mapper in self._mappers(uow):
610 for state in uow.mappers[mapper]:
611 isdelete, listonly = uow.states[state]
612 if isdelete == self.isdelete and not listonly:
613 yield state
614
615
616class _PostUpdateAll(_PostSortRec):
617 __slots__ = "mapper", "isdelete", "sort_key"
618
619 def __init__(self, uow, mapper, isdelete):
620 self.mapper = mapper
621 self.isdelete = isdelete
622 self.sort_key = ("PostUpdateAll", mapper._sort_key, isdelete)
623
624 @util.preload_module("sqlalchemy.orm.persistence")
625 def execute(self, uow):
626 persistence = util.preloaded.orm_persistence
627 states, cols = uow.post_update_states[self.mapper]
628 states = [s for s in states if uow.states[s][0] == self.isdelete]
629
630 persistence._post_update(self.mapper, states, uow, cols)
631
632
633class _SaveUpdateAll(_PostSortRec):
634 __slots__ = ("mapper", "sort_key")
635
636 def __init__(self, uow, mapper):
637 self.mapper = mapper
638 self.sort_key = ("SaveUpdateAll", mapper._sort_key)
639 assert mapper is mapper.base_mapper
640
641 @util.preload_module("sqlalchemy.orm.persistence")
642 def execute(self, uow):
643 util.preloaded.orm_persistence._save_obj(
644 self.mapper,
645 uow.states_for_mapper_hierarchy(self.mapper, False, False),
646 uow,
647 )
648
649 def per_state_flush_actions(self, uow):
650 states = list(
651 uow.states_for_mapper_hierarchy(self.mapper, False, False)
652 )
653 base_mapper = self.mapper.base_mapper
654 delete_all = _DeleteAll(uow, base_mapper)
655 for state in states:
656 # keep saves before deletes -
657 # this ensures 'row switch' operations work
658 action = _SaveUpdateState(uow, state)
659 uow.dependencies.add((action, delete_all))
660 yield action
661
662 for dep in uow.deps[self.mapper]:
663 states_for_prop = uow.filter_states_for_dep(dep, states)
664 dep.per_state_flush_actions(uow, states_for_prop, False)
665
666 def __repr__(self):
667 return "%s(%s)" % (self.__class__.__name__, self.mapper)
668
669
670class _DeleteAll(_PostSortRec):
671 __slots__ = ("mapper", "sort_key")
672
673 def __init__(self, uow, mapper):
674 self.mapper = mapper
675 self.sort_key = ("DeleteAll", mapper._sort_key)
676 assert mapper is mapper.base_mapper
677
678 @util.preload_module("sqlalchemy.orm.persistence")
679 def execute(self, uow):
680 util.preloaded.orm_persistence._delete_obj(
681 self.mapper,
682 uow.states_for_mapper_hierarchy(self.mapper, True, False),
683 uow,
684 )
685
686 def per_state_flush_actions(self, uow):
687 states = list(
688 uow.states_for_mapper_hierarchy(self.mapper, True, False)
689 )
690 base_mapper = self.mapper.base_mapper
691 save_all = _SaveUpdateAll(uow, base_mapper)
692 for state in states:
693 # keep saves before deletes -
694 # this ensures 'row switch' operations work
695 action = _DeleteState(uow, state)
696 uow.dependencies.add((save_all, action))
697 yield action
698
699 for dep in uow.deps[self.mapper]:
700 states_for_prop = uow.filter_states_for_dep(dep, states)
701 dep.per_state_flush_actions(uow, states_for_prop, True)
702
703 def __repr__(self):
704 return "%s(%s)" % (self.__class__.__name__, self.mapper)
705
706
707class _ProcessState(_PostSortRec):
708 __slots__ = "dependency_processor", "isdelete", "state", "sort_key"
709
710 def __init__(self, uow, dependency_processor, isdelete, state):
711 self.dependency_processor = dependency_processor
712 self.sort_key = ("ProcessState", dependency_processor.sort_key)
713 self.isdelete = isdelete
714 self.state = state
715
716 def execute_aggregate(self, uow, recs):
717 cls_ = self.__class__
718 dependency_processor = self.dependency_processor
719 isdelete = self.isdelete
720 our_recs = [
721 r
722 for r in recs
723 if r.__class__ is cls_
724 and r.dependency_processor is dependency_processor
725 and r.isdelete is isdelete
726 ]
727 recs.difference_update(our_recs)
728 states = [self.state] + [r.state for r in our_recs]
729 if isdelete:
730 dependency_processor.process_deletes(uow, states)
731 else:
732 dependency_processor.process_saves(uow, states)
733
734 def __repr__(self):
735 return "%s(%s, %s, delete=%s)" % (
736 self.__class__.__name__,
737 self.dependency_processor,
738 orm_util.state_str(self.state),
739 self.isdelete,
740 )
741
742
743class _SaveUpdateState(_PostSortRec):
744 __slots__ = "state", "mapper", "sort_key"
745
746 def __init__(self, uow, state):
747 self.state = state
748 self.mapper = state.mapper.base_mapper
749 self.sort_key = ("ProcessState", self.mapper._sort_key)
750
751 @util.preload_module("sqlalchemy.orm.persistence")
752 def execute_aggregate(self, uow, recs):
753 persistence = util.preloaded.orm_persistence
754 cls_ = self.__class__
755 mapper = self.mapper
756 our_recs = [
757 r for r in recs if r.__class__ is cls_ and r.mapper is mapper
758 ]
759 recs.difference_update(our_recs)
760 persistence._save_obj(
761 mapper, [self.state] + [r.state for r in our_recs], uow
762 )
763
764 def __repr__(self):
765 return "%s(%s)" % (
766 self.__class__.__name__,
767 orm_util.state_str(self.state),
768 )
769
770
771class _DeleteState(_PostSortRec):
772 __slots__ = "state", "mapper", "sort_key"
773
774 def __init__(self, uow, state):
775 self.state = state
776 self.mapper = state.mapper.base_mapper
777 self.sort_key = ("DeleteState", self.mapper._sort_key)
778
779 @util.preload_module("sqlalchemy.orm.persistence")
780 def execute_aggregate(self, uow, recs):
781 persistence = util.preloaded.orm_persistence
782 cls_ = self.__class__
783 mapper = self.mapper
784 our_recs = [
785 r for r in recs if r.__class__ is cls_ and r.mapper is mapper
786 ]
787 recs.difference_update(our_recs)
788 states = [self.state] + [r.state for r in our_recs]
789 persistence._delete_obj(
790 mapper, [s for s in states if uow.states[s][0]], uow
791 )
792
793 def __repr__(self):
794 return "%s(%s)" % (
795 self.__class__.__name__,
796 orm_util.state_str(self.state),
797 )