1# orm/unitofwork.py
2# Copyright (C) 2005-2026 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: ignore-errors
8
9
10"""The internals for the unit of work system.
11
12The session's flush() process passes objects to a contextual object
13here, which assembles flush tasks based on mappers and their properties,
14organizes them in order of dependency, and executes.
15
16"""
17
18from __future__ import annotations
19
20from typing import Any
21from typing import Dict
22from typing import Optional
23from typing import Set
24from typing import TYPE_CHECKING
25
26from . import attributes
27from . import exc as orm_exc
28from . import util as orm_util
29from .. import event
30from .. import util
31from ..util import topological
32
33if TYPE_CHECKING:
34 from .dependency import DependencyProcessor
35 from .interfaces import MapperProperty
36 from .mapper import Mapper
37 from .session import Session
38 from .session import SessionTransaction
39 from .state import InstanceState
40
41
42def track_cascade_events(descriptor, prop):
43 """Establish event listeners on object attributes which handle
44 cascade-on-set/append.
45
46 """
47 key = prop.key
48
49 def append(state, item, initiator, **kw):
50 # process "save_update" cascade rules for when
51 # an instance is appended to the list of another instance
52
53 if item is None:
54 return
55
56 sess = state.session
57 if sess:
58 if sess._warn_on_events:
59 sess._flush_warning("collection append")
60
61 prop = state.manager.mapper._props[key]
62 item_state = attributes.instance_state(item)
63
64 if (
65 prop._cascade.save_update
66 and (key == initiator.key)
67 and not sess._contains_state(item_state)
68 ):
69 sess._save_or_update_state(item_state)
70 return item
71
72 def remove(state, item, initiator, **kw):
73 if item is None:
74 return
75
76 sess = state.session
77
78 prop = state.manager.mapper._props[key]
79
80 if sess and sess._warn_on_events:
81 sess._flush_warning(
82 "collection remove"
83 if prop.uselist
84 else "related attribute delete"
85 )
86
87 if (
88 item is not None
89 and item is not attributes.NEVER_SET
90 and item is not attributes.PASSIVE_NO_RESULT
91 and prop._cascade.delete_orphan
92 ):
93 # expunge pending orphans
94 item_state = attributes.instance_state(item)
95
96 if prop.mapper._is_orphan(item_state):
97 if sess and item_state in sess._new:
98 sess.expunge(item)
99 else:
100 # the related item may or may not itself be in a
101 # Session, however the parent for which we are catching
102 # the event is not in a session, so memoize this on the
103 # item
104 item_state._orphaned_outside_of_session = True
105
106 def set_(state, newvalue, oldvalue, initiator, **kw):
107 # process "save_update" cascade rules for when an instance
108 # is attached to another instance
109 if oldvalue is newvalue:
110 return newvalue
111
112 sess = state.session
113 if sess:
114 if sess._warn_on_events:
115 sess._flush_warning("related attribute set")
116
117 prop = state.manager.mapper._props[key]
118 if newvalue is not None:
119 newvalue_state = attributes.instance_state(newvalue)
120 if (
121 prop._cascade.save_update
122 and (key == initiator.key)
123 and not sess._contains_state(newvalue_state)
124 ):
125 sess._save_or_update_state(newvalue_state)
126
127 if (
128 oldvalue is not None
129 and oldvalue is not attributes.NEVER_SET
130 and oldvalue is not attributes.PASSIVE_NO_RESULT
131 and prop._cascade.delete_orphan
132 ):
133 # possible to reach here with attributes.NEVER_SET ?
134 oldvalue_state = attributes.instance_state(oldvalue)
135
136 if oldvalue_state in sess._new and prop.mapper._is_orphan(
137 oldvalue_state
138 ):
139 sess.expunge(oldvalue)
140 return newvalue
141
142 event.listen(
143 descriptor, "append_wo_mutation", append, raw=True, include_key=True
144 )
145 event.listen(
146 descriptor, "append", append, raw=True, retval=True, include_key=True
147 )
148 event.listen(
149 descriptor, "remove", remove, raw=True, retval=True, include_key=True
150 )
151 event.listen(
152 descriptor, "set", set_, raw=True, retval=True, include_key=True
153 )
154
155
156class UOWTransaction:
157 session: Session
158 transaction: SessionTransaction
159 attributes: Dict[str, Any]
160 deps: util.defaultdict[Mapper[Any], Set[DependencyProcessor]]
161 mappers: util.defaultdict[Mapper[Any], Set[InstanceState[Any]]]
162
163 def __init__(self, session: Session):
164 self.session = session
165
166 # dictionary used by external actors to
167 # store arbitrary state information.
168 self.attributes = {}
169
170 # dictionary of mappers to sets of
171 # DependencyProcessors, which are also
172 # set to be part of the sorted flush actions,
173 # which have that mapper as a parent.
174 self.deps = util.defaultdict(set)
175
176 # dictionary of mappers to sets of InstanceState
177 # items pending for flush which have that mapper
178 # as a parent.
179 self.mappers = util.defaultdict(set)
180
181 # a dictionary of Preprocess objects, which gather
182 # additional states impacted by the flush
183 # and determine if a flush action is needed
184 self.presort_actions = {}
185
186 # dictionary of PostSortRec objects, each
187 # one issues work during the flush within
188 # a certain ordering.
189 self.postsort_actions = {}
190
191 # a set of 2-tuples, each containing two
192 # PostSortRec objects where the second
193 # is dependent on the first being executed
194 # first
195 self.dependencies = set()
196
197 # dictionary of InstanceState-> (isdelete, listonly)
198 # tuples, indicating if this state is to be deleted
199 # or insert/updated, or just refreshed
200 self.states = {}
201
202 # tracks InstanceStates which will be receiving
203 # a "post update" call. Keys are mappers,
204 # values are a set of states and a set of the
205 # columns which should be included in the update.
206 self.post_update_states = util.defaultdict(lambda: (set(), set()))
207
208 @property
209 def has_work(self):
210 return bool(self.states)
211
212 def was_already_deleted(self, state):
213 """Return ``True`` if the given state is expired and was deleted
214 previously.
215 """
216 if state.expired:
217 try:
218 state._load_expired(state, attributes.PASSIVE_OFF)
219 except orm_exc.ObjectDeletedError:
220 self.session._remove_newly_deleted([state])
221 return True
222 return False
223
224 def is_deleted(self, state):
225 """Return ``True`` if the given state is marked as deleted
226 within this uowtransaction."""
227
228 return state in self.states and self.states[state][0]
229
230 def memo(self, key, callable_):
231 if key in self.attributes:
232 return self.attributes[key]
233 else:
234 self.attributes[key] = ret = callable_()
235 return ret
236
237 def remove_state_actions(self, state):
238 """Remove pending actions for a state from the uowtransaction."""
239
240 isdelete = self.states[state][0]
241
242 self.states[state] = (isdelete, True)
243
244 def get_attribute_history(
245 self, state, key, passive=attributes.PASSIVE_NO_INITIALIZE
246 ):
247 """Facade to attributes.get_state_history(), including
248 caching of results."""
249
250 hashkey = ("history", state, key)
251
252 # cache the objects, not the states; the strong reference here
253 # prevents newly loaded objects from being dereferenced during the
254 # flush process
255
256 if hashkey in self.attributes:
257 history, state_history, cached_passive = self.attributes[hashkey]
258 # if the cached lookup was "passive" and now
259 # we want non-passive, do a non-passive lookup and re-cache
260
261 if (
262 not cached_passive & attributes.SQL_OK
263 and passive & attributes.SQL_OK
264 ):
265 impl = state.manager[key].impl
266 history = impl.get_history(
267 state,
268 state.dict,
269 attributes.PASSIVE_OFF
270 | attributes.LOAD_AGAINST_COMMITTED
271 | attributes.NO_RAISE,
272 )
273 if history and impl.uses_objects:
274 state_history = history.as_state()
275 else:
276 state_history = history
277 self.attributes[hashkey] = (history, state_history, passive)
278 else:
279 impl = state.manager[key].impl
280 # TODO: store the history as (state, object) tuples
281 # so we don't have to keep converting here
282 history = impl.get_history(
283 state,
284 state.dict,
285 passive
286 | attributes.LOAD_AGAINST_COMMITTED
287 | attributes.NO_RAISE,
288 )
289 if history and impl.uses_objects:
290 state_history = history.as_state()
291 else:
292 state_history = history
293 self.attributes[hashkey] = (history, state_history, passive)
294
295 return state_history
296
297 def has_dep(self, processor):
298 return (processor, True) in self.presort_actions
299
300 def register_preprocessor(self, processor, fromparent):
301 key = (processor, fromparent)
302 if key not in self.presort_actions:
303 self.presort_actions[key] = Preprocess(processor, fromparent)
304
305 def register_object(
306 self,
307 state: InstanceState[Any],
308 isdelete: bool = False,
309 listonly: bool = False,
310 cancel_delete: bool = False,
311 operation: Optional[str] = None,
312 prop: Optional[MapperProperty] = None,
313 ) -> bool:
314 if not self.session._contains_state(state):
315 # this condition is normal when objects are registered
316 # as part of a relationship cascade operation. it should
317 # not occur for the top-level register from Session.flush().
318 if not state.deleted and operation is not None:
319 util.warn(
320 "Object of type %s not in session, %s operation "
321 "along '%s' will not proceed"
322 % (orm_util.state_class_str(state), operation, prop)
323 )
324 return False
325
326 if state not in self.states:
327 mapper = state.manager.mapper
328
329 if mapper not in self.mappers:
330 self._per_mapper_flush_actions(mapper)
331
332 self.mappers[mapper].add(state)
333 self.states[state] = (isdelete, listonly)
334 else:
335 if not listonly and (isdelete or cancel_delete):
336 self.states[state] = (isdelete, False)
337 return True
338
339 def register_post_update(self, state, post_update_cols):
340 mapper = state.manager.mapper.base_mapper
341 states, cols = self.post_update_states[mapper]
342 states.add(state)
343 cols.update(post_update_cols)
344
345 def _per_mapper_flush_actions(self, mapper):
346 saves = SaveUpdateAll(self, mapper.base_mapper)
347 deletes = DeleteAll(self, mapper.base_mapper)
348 self.dependencies.add((saves, deletes))
349
350 for dep in mapper._dependency_processors:
351 dep.per_property_preprocessors(self)
352
353 for prop in mapper.relationships:
354 if prop.viewonly:
355 continue
356 dep = prop._dependency_processor
357 dep.per_property_preprocessors(self)
358
359 @util.memoized_property
360 def _mapper_for_dep(self):
361 """return a dynamic mapping of (Mapper, DependencyProcessor) to
362 True or False, indicating if the DependencyProcessor operates
363 on objects of that Mapper.
364
365 The result is stored in the dictionary persistently once
366 calculated.
367
368 """
369 return util.PopulateDict(
370 lambda tup: tup[0]._props.get(tup[1].key) is tup[1].prop
371 )
372
373 def filter_states_for_dep(self, dep, states):
374 """Filter the given list of InstanceStates to those relevant to the
375 given DependencyProcessor.
376
377 """
378 mapper_for_dep = self._mapper_for_dep
379 return [s for s in states if mapper_for_dep[(s.manager.mapper, dep)]]
380
381 def states_for_mapper_hierarchy(self, mapper, isdelete, listonly):
382 checktup = (isdelete, listonly)
383 for mapper in mapper.base_mapper.self_and_descendants:
384 for state in self.mappers[mapper]:
385 if self.states[state] == checktup:
386 yield state
387
388 def _generate_actions(self):
389 """Generate the full, unsorted collection of PostSortRecs as
390 well as dependency pairs for this UOWTransaction.
391
392 """
393 # execute presort_actions, until all states
394 # have been processed. a presort_action might
395 # add new states to the uow.
396 while True:
397 ret = False
398 for action in list(self.presort_actions.values()):
399 if action.execute(self):
400 ret = True
401 if not ret:
402 break
403
404 # see if the graph of mapper dependencies has cycles.
405 self.cycles = cycles = topological.find_cycles(
406 self.dependencies, list(self.postsort_actions.values())
407 )
408
409 if cycles:
410 # if yes, break the per-mapper actions into
411 # per-state actions
412 convert = {
413 rec: set(rec.per_state_flush_actions(self)) for rec in cycles
414 }
415
416 # rewrite the existing dependencies to point to
417 # the per-state actions for those per-mapper actions
418 # that were broken up.
419 for edge in list(self.dependencies):
420 if (
421 None in edge
422 or edge[0].disabled
423 or edge[1].disabled
424 or cycles.issuperset(edge)
425 ):
426 self.dependencies.remove(edge)
427 elif edge[0] in cycles:
428 self.dependencies.remove(edge)
429 for dep in convert[edge[0]]:
430 self.dependencies.add((dep, edge[1]))
431 elif edge[1] in cycles:
432 self.dependencies.remove(edge)
433 for dep in convert[edge[1]]:
434 self.dependencies.add((edge[0], dep))
435
436 return {
437 a for a in self.postsort_actions.values() if not a.disabled
438 }.difference(cycles)
439
440 def execute(self) -> None:
441 postsort_actions = self._generate_actions()
442
443 postsort_actions = sorted(
444 postsort_actions,
445 key=lambda item: item.sort_key,
446 )
447 # sort = topological.sort(self.dependencies, postsort_actions)
448 # print "--------------"
449 # print "\ndependencies:", self.dependencies
450 # print "\ncycles:", self.cycles
451 # print "\nsort:", list(sort)
452 # print "\nCOUNT OF POSTSORT ACTIONS", len(postsort_actions)
453
454 # execute
455 if self.cycles:
456 for subset in topological.sort_as_subsets(
457 self.dependencies, postsort_actions
458 ):
459 set_ = set(subset)
460 while set_:
461 n = set_.pop()
462 n.execute_aggregate(self, set_)
463 else:
464 for rec in topological.sort(self.dependencies, postsort_actions):
465 rec.execute(self)
466
467 def finalize_flush_changes(self) -> None:
468 """Mark processed objects as clean / deleted after a successful
469 flush().
470
471 This method is called within the flush() method after the
472 execute() method has succeeded and the transaction has been committed.
473
474 """
475 if not self.states:
476 return
477
478 states = set(self.states)
479 isdel = {
480 s for (s, (isdelete, listonly)) in self.states.items() if isdelete
481 }
482 other = states.difference(isdel)
483 if isdel:
484 self.session._remove_newly_deleted(isdel)
485 if other:
486 self.session._register_persistent(other)
487
488
489class IterateMappersMixin:
490 __slots__ = ()
491
492 def _mappers(self, uow):
493 if self.fromparent:
494 return iter(
495 m
496 for m in self.dependency_processor.parent.self_and_descendants
497 if uow._mapper_for_dep[(m, self.dependency_processor)]
498 )
499 else:
500 return self.dependency_processor.mapper.self_and_descendants
501
502
503class Preprocess(IterateMappersMixin):
504 __slots__ = (
505 "dependency_processor",
506 "fromparent",
507 "processed",
508 "setup_flush_actions",
509 )
510
511 def __init__(self, dependency_processor, fromparent):
512 self.dependency_processor = dependency_processor
513 self.fromparent = fromparent
514 self.processed = set()
515 self.setup_flush_actions = False
516
517 def execute(self, uow):
518 delete_states = set()
519 save_states = set()
520
521 for mapper in self._mappers(uow):
522 for state in uow.mappers[mapper].difference(self.processed):
523 isdelete, listonly = uow.states[state]
524 if not listonly:
525 if isdelete:
526 delete_states.add(state)
527 else:
528 save_states.add(state)
529
530 if delete_states:
531 self.dependency_processor.presort_deletes(uow, delete_states)
532 self.processed.update(delete_states)
533 if save_states:
534 self.dependency_processor.presort_saves(uow, save_states)
535 self.processed.update(save_states)
536
537 if delete_states or save_states:
538 if not self.setup_flush_actions and (
539 self.dependency_processor.prop_has_changes(
540 uow, delete_states, True
541 )
542 or self.dependency_processor.prop_has_changes(
543 uow, save_states, False
544 )
545 ):
546 self.dependency_processor.per_property_flush_actions(uow)
547 self.setup_flush_actions = True
548 return True
549 else:
550 return False
551
552
553class PostSortRec:
554 __slots__ = ("disabled",)
555
556 def __new__(cls, uow, *args):
557 key = (cls,) + args
558 if key in uow.postsort_actions:
559 return uow.postsort_actions[key]
560 else:
561 uow.postsort_actions[key] = ret = object.__new__(cls)
562 ret.disabled = False
563 return ret
564
565 def execute_aggregate(self, uow, recs):
566 self.execute(uow)
567
568
569class ProcessAll(IterateMappersMixin, PostSortRec):
570 __slots__ = "dependency_processor", "isdelete", "fromparent", "sort_key"
571
572 def __init__(self, uow, dependency_processor, isdelete, fromparent):
573 self.dependency_processor = dependency_processor
574 self.sort_key = (
575 "ProcessAll",
576 self.dependency_processor.sort_key,
577 isdelete,
578 )
579 self.isdelete = isdelete
580 self.fromparent = fromparent
581 uow.deps[dependency_processor.parent.base_mapper].add(
582 dependency_processor
583 )
584
585 def execute(self, uow):
586 states = self._elements(uow)
587 if self.isdelete:
588 self.dependency_processor.process_deletes(uow, states)
589 else:
590 self.dependency_processor.process_saves(uow, states)
591
592 def per_state_flush_actions(self, uow):
593 # this is handled by SaveUpdateAll and DeleteAll,
594 # since a ProcessAll should unconditionally be pulled
595 # into per-state if either the parent/child mappers
596 # are part of a cycle
597 return iter([])
598
599 def __repr__(self):
600 return "%s(%s, isdelete=%s)" % (
601 self.__class__.__name__,
602 self.dependency_processor,
603 self.isdelete,
604 )
605
606 def _elements(self, uow):
607 for mapper in self._mappers(uow):
608 for state in uow.mappers[mapper]:
609 isdelete, listonly = uow.states[state]
610 if isdelete == self.isdelete and not listonly:
611 yield state
612
613
614class PostUpdateAll(PostSortRec):
615 __slots__ = "mapper", "isdelete", "sort_key"
616
617 def __init__(self, uow, mapper, isdelete):
618 self.mapper = mapper
619 self.isdelete = isdelete
620 self.sort_key = ("PostUpdateAll", mapper._sort_key, isdelete)
621
622 @util.preload_module("sqlalchemy.orm.persistence")
623 def execute(self, uow):
624 persistence = util.preloaded.orm_persistence
625 states, cols = uow.post_update_states[self.mapper]
626 states = [s for s in states if uow.states[s][0] == self.isdelete]
627
628 persistence.post_update(self.mapper, states, uow, cols)
629
630
631class SaveUpdateAll(PostSortRec):
632 __slots__ = ("mapper", "sort_key")
633
634 def __init__(self, uow, mapper):
635 self.mapper = mapper
636 self.sort_key = ("SaveUpdateAll", mapper._sort_key)
637 assert mapper is mapper.base_mapper
638
639 @util.preload_module("sqlalchemy.orm.persistence")
640 def execute(self, uow):
641 util.preloaded.orm_persistence.save_obj(
642 self.mapper,
643 uow.states_for_mapper_hierarchy(self.mapper, False, False),
644 uow,
645 )
646
647 def per_state_flush_actions(self, uow):
648 states = list(
649 uow.states_for_mapper_hierarchy(self.mapper, False, False)
650 )
651 base_mapper = self.mapper.base_mapper
652 delete_all = DeleteAll(uow, base_mapper)
653 for state in states:
654 # keep saves before deletes -
655 # this ensures 'row switch' operations work
656 action = SaveUpdateState(uow, state)
657 uow.dependencies.add((action, delete_all))
658 yield action
659
660 for dep in uow.deps[self.mapper]:
661 states_for_prop = uow.filter_states_for_dep(dep, states)
662 dep.per_state_flush_actions(uow, states_for_prop, False)
663
664 def __repr__(self):
665 return "%s(%s)" % (self.__class__.__name__, self.mapper)
666
667
668class DeleteAll(PostSortRec):
669 __slots__ = ("mapper", "sort_key")
670
671 def __init__(self, uow, mapper):
672 self.mapper = mapper
673 self.sort_key = ("DeleteAll", mapper._sort_key)
674 assert mapper is mapper.base_mapper
675
676 @util.preload_module("sqlalchemy.orm.persistence")
677 def execute(self, uow):
678 util.preloaded.orm_persistence.delete_obj(
679 self.mapper,
680 uow.states_for_mapper_hierarchy(self.mapper, True, False),
681 uow,
682 )
683
684 def per_state_flush_actions(self, uow):
685 states = list(
686 uow.states_for_mapper_hierarchy(self.mapper, True, False)
687 )
688 base_mapper = self.mapper.base_mapper
689 save_all = SaveUpdateAll(uow, base_mapper)
690 for state in states:
691 # keep saves before deletes -
692 # this ensures 'row switch' operations work
693 action = DeleteState(uow, state)
694 uow.dependencies.add((save_all, action))
695 yield action
696
697 for dep in uow.deps[self.mapper]:
698 states_for_prop = uow.filter_states_for_dep(dep, states)
699 dep.per_state_flush_actions(uow, states_for_prop, True)
700
701 def __repr__(self):
702 return "%s(%s)" % (self.__class__.__name__, self.mapper)
703
704
705class ProcessState(PostSortRec):
706 __slots__ = "dependency_processor", "isdelete", "state", "sort_key"
707
708 def __init__(self, uow, dependency_processor, isdelete, state):
709 self.dependency_processor = dependency_processor
710 self.sort_key = ("ProcessState", dependency_processor.sort_key)
711 self.isdelete = isdelete
712 self.state = state
713
714 def execute_aggregate(self, uow, recs):
715 cls_ = self.__class__
716 dependency_processor = self.dependency_processor
717 isdelete = self.isdelete
718 our_recs = [
719 r
720 for r in recs
721 if r.__class__ is cls_
722 and r.dependency_processor is dependency_processor
723 and r.isdelete is isdelete
724 ]
725 recs.difference_update(our_recs)
726 states = [self.state] + [r.state for r in our_recs]
727 if isdelete:
728 dependency_processor.process_deletes(uow, states)
729 else:
730 dependency_processor.process_saves(uow, states)
731
732 def __repr__(self):
733 return "%s(%s, %s, delete=%s)" % (
734 self.__class__.__name__,
735 self.dependency_processor,
736 orm_util.state_str(self.state),
737 self.isdelete,
738 )
739
740
741class SaveUpdateState(PostSortRec):
742 __slots__ = "state", "mapper", "sort_key"
743
744 def __init__(self, uow, state):
745 self.state = state
746 self.mapper = state.mapper.base_mapper
747 self.sort_key = ("ProcessState", self.mapper._sort_key)
748
749 @util.preload_module("sqlalchemy.orm.persistence")
750 def execute_aggregate(self, uow, recs):
751 persistence = util.preloaded.orm_persistence
752 cls_ = self.__class__
753 mapper = self.mapper
754 our_recs = [
755 r for r in recs if r.__class__ is cls_ and r.mapper is mapper
756 ]
757 recs.difference_update(our_recs)
758 persistence.save_obj(
759 mapper, [self.state] + [r.state for r in our_recs], uow
760 )
761
762 def __repr__(self):
763 return "%s(%s)" % (
764 self.__class__.__name__,
765 orm_util.state_str(self.state),
766 )
767
768
769class DeleteState(PostSortRec):
770 __slots__ = "state", "mapper", "sort_key"
771
772 def __init__(self, uow, state):
773 self.state = state
774 self.mapper = state.mapper.base_mapper
775 self.sort_key = ("DeleteState", self.mapper._sort_key)
776
777 @util.preload_module("sqlalchemy.orm.persistence")
778 def execute_aggregate(self, uow, recs):
779 persistence = util.preloaded.orm_persistence
780 cls_ = self.__class__
781 mapper = self.mapper
782 our_recs = [
783 r for r in recs if r.__class__ is cls_ and r.mapper is mapper
784 ]
785 recs.difference_update(our_recs)
786 states = [self.state] + [r.state for r in our_recs]
787 persistence.delete_obj(
788 mapper, [s for s in states if uow.states[s][0]], uow
789 )
790
791 def __repr__(self):
792 return "%s(%s)" % (
793 self.__class__.__name__,
794 orm_util.state_str(self.state),
795 )