1# orm/unitofwork.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: ignore-errors
8
9
10"""The internals for the unit of work system.
11
12The session's flush() process passes objects to a contextual object
13here, which assembles flush tasks based on mappers and their properties,
14organizes them in order of dependency, and executes.
15
16"""
17
18from __future__ import annotations
19
20from typing import Any
21from typing import Dict
22from typing import Optional
23from typing import Set
24from typing import TYPE_CHECKING
25
26from . import attributes
27from . import exc as orm_exc
28from . import util as orm_util
29from .. import event
30from .. import util
31from ..util import topological
32
33
34if TYPE_CHECKING:
35 from .dependency import _DependencyProcessor
36 from .interfaces import MapperProperty
37 from .mapper import Mapper
38 from .session import Session
39 from .session import SessionTransaction
40 from .state import InstanceState
41
42
43def _track_cascade_events(descriptor, prop):
44 """Establish event listeners on object attributes which handle
45 cascade-on-set/append.
46
47 """
48 key = prop.key
49
50 def append(state, item, initiator, **kw):
51 # process "save_update" cascade rules for when
52 # an instance is appended to the list of another instance
53
54 if item is None:
55 return
56
57 sess = state.session
58 if sess:
59 if sess._warn_on_events:
60 sess._flush_warning("collection append")
61
62 prop = state.manager.mapper._props[key]
63 item_state = attributes.instance_state(item)
64
65 if (
66 prop._cascade.save_update
67 and (key == initiator.key)
68 and not sess._contains_state(item_state)
69 ):
70 sess._save_or_update_state(item_state)
71 return item
72
73 def remove(state, item, initiator, **kw):
74 if item is None:
75 return
76
77 sess = state.session
78
79 prop = state.manager.mapper._props[key]
80
81 if sess and sess._warn_on_events:
82 sess._flush_warning(
83 "collection remove"
84 if prop.uselist
85 else "related attribute delete"
86 )
87
88 if (
89 item is not None
90 and item is not attributes.NEVER_SET
91 and item is not attributes.PASSIVE_NO_RESULT
92 and prop._cascade.delete_orphan
93 ):
94 # expunge pending orphans
95 item_state = attributes.instance_state(item)
96
97 if prop.mapper._is_orphan(item_state):
98 if sess and item_state in sess._new:
99 sess.expunge(item)
100 else:
101 # the related item may or may not itself be in a
102 # Session, however the parent for which we are catching
103 # the event is not in a session, so memoize this on the
104 # item
105 item_state._orphaned_outside_of_session = True
106
107 def set_(state, newvalue, oldvalue, initiator, **kw):
108 # process "save_update" cascade rules for when an instance
109 # is attached to another instance
110 if oldvalue is newvalue:
111 return newvalue
112
113 sess = state.session
114 if sess:
115 if sess._warn_on_events:
116 sess._flush_warning("related attribute set")
117
118 prop = state.manager.mapper._props[key]
119 if newvalue is not None:
120 newvalue_state = attributes.instance_state(newvalue)
121 if (
122 prop._cascade.save_update
123 and (key == initiator.key)
124 and not sess._contains_state(newvalue_state)
125 ):
126 sess._save_or_update_state(newvalue_state)
127
128 if (
129 oldvalue is not None
130 and oldvalue is not attributes.NEVER_SET
131 and oldvalue is not attributes.PASSIVE_NO_RESULT
132 and prop._cascade.delete_orphan
133 ):
134 # possible to reach here with attributes.NEVER_SET ?
135 oldvalue_state = attributes.instance_state(oldvalue)
136
137 if oldvalue_state in sess._new and prop.mapper._is_orphan(
138 oldvalue_state
139 ):
140 sess.expunge(oldvalue)
141 return newvalue
142
143 event.listen(
144 descriptor, "append_wo_mutation", append, raw=True, include_key=True
145 )
146 event.listen(
147 descriptor, "append", append, raw=True, retval=True, include_key=True
148 )
149 event.listen(
150 descriptor, "remove", remove, raw=True, retval=True, include_key=True
151 )
152 event.listen(
153 descriptor, "set", set_, raw=True, retval=True, include_key=True
154 )
155
156
157class UOWTransaction:
158 """Manages the internal state of a unit of work flush operation."""
159
160 session: Session
161 transaction: SessionTransaction
162 attributes: Dict[str, Any]
163 deps: util.defaultdict[Mapper[Any], Set[_DependencyProcessor]]
164 mappers: util.defaultdict[Mapper[Any], Set[InstanceState[Any]]]
165
166 def __init__(self, session: Session):
167 self.session = session
168
169 # dictionary used by external actors to
170 # store arbitrary state information.
171 self.attributes = {}
172
173 # dictionary of mappers to sets of
174 # DependencyProcessors, which are also
175 # set to be part of the sorted flush actions,
176 # which have that mapper as a parent.
177 self.deps = util.defaultdict(set)
178
179 # dictionary of mappers to sets of InstanceState
180 # items pending for flush which have that mapper
181 # as a parent.
182 self.mappers = util.defaultdict(set)
183
184 # a dictionary of Preprocess objects, which gather
185 # additional states impacted by the flush
186 # and determine if a flush action is needed
187 self.presort_actions = {}
188
189 # dictionary of PostSortRec objects, each
190 # one issues work during the flush within
191 # a certain ordering.
192 self.postsort_actions = {}
193
194 # a set of 2-tuples, each containing two
195 # PostSortRec objects where the second
196 # is dependent on the first being executed
197 # first
198 self.dependencies = set()
199
200 # dictionary of InstanceState-> (isdelete, listonly)
201 # tuples, indicating if this state is to be deleted
202 # or insert/updated, or just refreshed
203 self.states = {}
204
205 # tracks InstanceStates which will be receiving
206 # a "post update" call. Keys are mappers,
207 # values are a set of states and a set of the
208 # columns which should be included in the update.
209 self.post_update_states = util.defaultdict(lambda: (set(), set()))
210
211 @property
212 def has_work(self):
213 return bool(self.states)
214
215 def was_already_deleted(self, state):
216 """Return ``True`` if the given state is expired and was deleted
217 previously.
218 """
219 if state.expired:
220 try:
221 state._load_expired(state, attributes.PASSIVE_OFF)
222 except orm_exc.ObjectDeletedError:
223 self.session._remove_newly_deleted([state])
224 return True
225 return False
226
227 def is_deleted(self, state):
228 """Return ``True`` if the given state is marked as deleted
229 within this uowtransaction."""
230
231 return state in self.states and self.states[state][0]
232
233 def memo(self, key, callable_):
234 if key in self.attributes:
235 return self.attributes[key]
236 else:
237 self.attributes[key] = ret = callable_()
238 return ret
239
240 def remove_state_actions(self, state):
241 """Remove pending actions for a state from the uowtransaction."""
242
243 isdelete = self.states[state][0]
244
245 self.states[state] = (isdelete, True)
246
247 def get_attribute_history(
248 self, state, key, passive=attributes.PASSIVE_NO_INITIALIZE
249 ):
250 """Facade to attributes.get_state_history(), including
251 caching of results."""
252
253 hashkey = ("history", state, key)
254
255 # cache the objects, not the states; the strong reference here
256 # prevents newly loaded objects from being dereferenced during the
257 # flush process
258
259 if hashkey in self.attributes:
260 history, state_history, cached_passive = self.attributes[hashkey]
261 # if the cached lookup was "passive" and now
262 # we want non-passive, do a non-passive lookup and re-cache
263
264 if (
265 not cached_passive & attributes.SQL_OK
266 and passive & attributes.SQL_OK
267 ):
268 impl = state.manager[key].impl
269 history = impl.get_history(
270 state,
271 state.dict,
272 attributes.PASSIVE_OFF
273 | attributes.LOAD_AGAINST_COMMITTED
274 | attributes.NO_RAISE,
275 )
276 if history and impl.uses_objects:
277 state_history = history.as_state()
278 else:
279 state_history = history
280 self.attributes[hashkey] = (history, state_history, passive)
281 else:
282 impl = state.manager[key].impl
283 # TODO: store the history as (state, object) tuples
284 # so we don't have to keep converting here
285 history = impl.get_history(
286 state,
287 state.dict,
288 passive
289 | attributes.LOAD_AGAINST_COMMITTED
290 | attributes.NO_RAISE,
291 )
292 if history and impl.uses_objects:
293 state_history = history.as_state()
294 else:
295 state_history = history
296 self.attributes[hashkey] = (history, state_history, passive)
297
298 return state_history
299
300 def has_dep(self, processor):
301 return (processor, True) in self.presort_actions
302
303 def register_preprocessor(self, processor, fromparent):
304 key = (processor, fromparent)
305 if key not in self.presort_actions:
306 self.presort_actions[key] = _Preprocess(processor, fromparent)
307
308 def register_object(
309 self,
310 state: InstanceState[Any],
311 isdelete: bool = False,
312 listonly: bool = False,
313 cancel_delete: bool = False,
314 operation: Optional[str] = None,
315 prop: Optional[MapperProperty] = None,
316 ) -> bool:
317 if not self.session._contains_state(state):
318 # this condition is normal when objects are registered
319 # as part of a relationship cascade operation. it should
320 # not occur for the top-level register from Session.flush().
321 if not state.deleted and operation is not None:
322 util.warn(
323 "Object of type %s not in session, %s operation "
324 "along '%s' will not proceed"
325 % (orm_util.state_class_str(state), operation, prop)
326 )
327 return False
328
329 if state not in self.states:
330 mapper = state.manager.mapper
331
332 if mapper not in self.mappers:
333 self._per_mapper_flush_actions(mapper)
334
335 self.mappers[mapper].add(state)
336 self.states[state] = (isdelete, listonly)
337 else:
338 if not listonly and (isdelete or cancel_delete):
339 self.states[state] = (isdelete, False)
340 return True
341
342 def register_post_update(self, state, post_update_cols):
343 mapper = state.manager.mapper.base_mapper
344 states, cols = self.post_update_states[mapper]
345 states.add(state)
346 cols.update(post_update_cols)
347
348 def _per_mapper_flush_actions(self, mapper):
349 saves = _SaveUpdateAll(self, mapper.base_mapper)
350 deletes = _DeleteAll(self, mapper.base_mapper)
351 self.dependencies.add((saves, deletes))
352
353 for dep in mapper._dependency_processors:
354 dep.per_property_preprocessors(self)
355
356 for prop in mapper.relationships:
357 if prop.viewonly:
358 continue
359 dep = prop._dependency_processor
360 dep.per_property_preprocessors(self)
361
362 @util.memoized_property
363 def _mapper_for_dep(self):
364 """return a dynamic mapping of (Mapper, DependencyProcessor) to
365 True or False, indicating if the DependencyProcessor operates
366 on objects of that Mapper.
367
368 The result is stored in the dictionary persistently once
369 calculated.
370
371 """
372 return util.PopulateDict(
373 lambda tup: tup[0]._props.get(tup[1].key) is tup[1].prop
374 )
375
376 def filter_states_for_dep(self, dep, states):
377 """Filter the given list of InstanceStates to those relevant to the
378 given DependencyProcessor.
379
380 """
381 mapper_for_dep = self._mapper_for_dep
382 return [s for s in states if mapper_for_dep[(s.manager.mapper, dep)]]
383
384 def states_for_mapper_hierarchy(self, mapper, isdelete, listonly):
385 checktup = (isdelete, listonly)
386 for mapper in mapper.base_mapper.self_and_descendants:
387 for state in self.mappers[mapper]:
388 if self.states[state] == checktup:
389 yield state
390
391 def _generate_actions(self):
392 """Generate the full, unsorted collection of PostSortRecs as
393 well as dependency pairs for this UOWTransaction.
394
395 """
396 # execute presort_actions, until all states
397 # have been processed. a presort_action might
398 # add new states to the uow.
399 while True:
400 ret = False
401 for action in list(self.presort_actions.values()):
402 if action.execute(self):
403 ret = True
404 if not ret:
405 break
406
407 # see if the graph of mapper dependencies has cycles.
408 self.cycles = cycles = topological.find_cycles(
409 self.dependencies, list(self.postsort_actions.values())
410 )
411
412 if cycles:
413 # if yes, break the per-mapper actions into
414 # per-state actions
415 convert = {
416 rec: set(rec.per_state_flush_actions(self)) for rec in cycles
417 }
418
419 # rewrite the existing dependencies to point to
420 # the per-state actions for those per-mapper actions
421 # that were broken up.
422 for edge in list(self.dependencies):
423 if (
424 None in edge
425 or edge[0].disabled
426 or edge[1].disabled
427 or cycles.issuperset(edge)
428 ):
429 self.dependencies.remove(edge)
430 elif edge[0] in cycles:
431 self.dependencies.remove(edge)
432 for dep in convert[edge[0]]:
433 self.dependencies.add((dep, edge[1]))
434 elif edge[1] in cycles:
435 self.dependencies.remove(edge)
436 for dep in convert[edge[1]]:
437 self.dependencies.add((edge[0], dep))
438
439 return {
440 a for a in self.postsort_actions.values() if not a.disabled
441 }.difference(cycles)
442
443 def execute(self) -> None:
444 postsort_actions = self._generate_actions()
445
446 postsort_actions = sorted(
447 postsort_actions,
448 key=lambda item: item.sort_key,
449 )
450 # sort = topological.sort(self.dependencies, postsort_actions)
451 # print "--------------"
452 # print "\ndependencies:", self.dependencies
453 # print "\ncycles:", self.cycles
454 # print "\nsort:", list(sort)
455 # print "\nCOUNT OF POSTSORT ACTIONS", len(postsort_actions)
456
457 # execute
458 if self.cycles:
459 for subset in topological.sort_as_subsets(
460 self.dependencies, postsort_actions
461 ):
462 set_ = set(subset)
463 while set_:
464 n = set_.pop()
465 n.execute_aggregate(self, set_)
466 else:
467 for rec in topological.sort(self.dependencies, postsort_actions):
468 rec.execute(self)
469
470 def finalize_flush_changes(self) -> None:
471 """Mark processed objects as clean / deleted after a successful
472 flush().
473
474 This method is called within the flush() method after the
475 execute() method has succeeded and the transaction has been committed.
476
477 """
478 if not self.states:
479 return
480
481 states = set(self.states)
482 isdel = {
483 s for (s, (isdelete, listonly)) in self.states.items() if isdelete
484 }
485 other = states.difference(isdel)
486 if isdel:
487 self.session._remove_newly_deleted(isdel)
488 if other:
489 self.session._register_persistent(other)
490
491
492class _IterateMappersMixin:
493 __slots__ = ()
494
495 def _mappers(self, uow):
496 if self.fromparent:
497 return iter(
498 m
499 for m in self.dependency_processor.parent.self_and_descendants
500 if uow._mapper_for_dep[(m, self.dependency_processor)]
501 )
502 else:
503 return self.dependency_processor.mapper.self_and_descendants
504
505
506class _Preprocess(_IterateMappersMixin):
507 __slots__ = (
508 "dependency_processor",
509 "fromparent",
510 "processed",
511 "setup_flush_actions",
512 )
513
514 def __init__(self, dependency_processor, fromparent):
515 self.dependency_processor = dependency_processor
516 self.fromparent = fromparent
517 self.processed = set()
518 self.setup_flush_actions = False
519
520 def execute(self, uow):
521 delete_states = set()
522 save_states = set()
523
524 for mapper in self._mappers(uow):
525 for state in uow.mappers[mapper].difference(self.processed):
526 (isdelete, listonly) = uow.states[state]
527 if not listonly:
528 if isdelete:
529 delete_states.add(state)
530 else:
531 save_states.add(state)
532
533 if delete_states:
534 self.dependency_processor.presort_deletes(uow, delete_states)
535 self.processed.update(delete_states)
536 if save_states:
537 self.dependency_processor.presort_saves(uow, save_states)
538 self.processed.update(save_states)
539
540 if delete_states or save_states:
541 if not self.setup_flush_actions and (
542 self.dependency_processor.prop_has_changes(
543 uow, delete_states, True
544 )
545 or self.dependency_processor.prop_has_changes(
546 uow, save_states, False
547 )
548 ):
549 self.dependency_processor.per_property_flush_actions(uow)
550 self.setup_flush_actions = True
551 return True
552 else:
553 return False
554
555
556class _PostSortRec:
557 __slots__ = ("disabled",)
558
559 def __new__(cls, uow, *args):
560 key = (cls,) + args
561 if key in uow.postsort_actions:
562 return uow.postsort_actions[key]
563 else:
564 uow.postsort_actions[key] = ret = object.__new__(cls)
565 ret.disabled = False
566 return ret
567
568 def execute_aggregate(self, uow, recs):
569 self.execute(uow)
570
571
572class _ProcessAll(_IterateMappersMixin, _PostSortRec):
573 __slots__ = "dependency_processor", "isdelete", "fromparent", "sort_key"
574
575 def __init__(self, uow, dependency_processor, isdelete, fromparent):
576 self.dependency_processor = dependency_processor
577 self.sort_key = (
578 "ProcessAll",
579 self.dependency_processor.sort_key,
580 isdelete,
581 )
582 self.isdelete = isdelete
583 self.fromparent = fromparent
584 uow.deps[dependency_processor.parent.base_mapper].add(
585 dependency_processor
586 )
587
588 def execute(self, uow):
589 states = self._elements(uow)
590 if self.isdelete:
591 self.dependency_processor.process_deletes(uow, states)
592 else:
593 self.dependency_processor.process_saves(uow, states)
594
595 def per_state_flush_actions(self, uow):
596 # this is handled by SaveUpdateAll and DeleteAll,
597 # since a ProcessAll should unconditionally be pulled
598 # into per-state if either the parent/child mappers
599 # are part of a cycle
600 return iter([])
601
602 def __repr__(self):
603 return "%s(%s, isdelete=%s)" % (
604 self.__class__.__name__,
605 self.dependency_processor,
606 self.isdelete,
607 )
608
609 def _elements(self, uow):
610 for mapper in self._mappers(uow):
611 for state in uow.mappers[mapper]:
612 (isdelete, listonly) = uow.states[state]
613 if isdelete == self.isdelete and not listonly:
614 yield state
615
616
617class _PostUpdateAll(_PostSortRec):
618 __slots__ = "mapper", "isdelete", "sort_key"
619
620 def __init__(self, uow, mapper, isdelete):
621 self.mapper = mapper
622 self.isdelete = isdelete
623 self.sort_key = ("PostUpdateAll", mapper._sort_key, isdelete)
624
625 @util.preload_module("sqlalchemy.orm.persistence")
626 def execute(self, uow):
627 persistence = util.preloaded.orm_persistence
628 states, cols = uow.post_update_states[self.mapper]
629 states = [s for s in states if uow.states[s][0] == self.isdelete]
630
631 persistence._post_update(self.mapper, states, uow, cols)
632
633
634class _SaveUpdateAll(_PostSortRec):
635 __slots__ = ("mapper", "sort_key")
636
637 def __init__(self, uow, mapper):
638 self.mapper = mapper
639 self.sort_key = ("SaveUpdateAll", mapper._sort_key)
640 assert mapper is mapper.base_mapper
641
642 @util.preload_module("sqlalchemy.orm.persistence")
643 def execute(self, uow):
644 util.preloaded.orm_persistence._save_obj(
645 self.mapper,
646 uow.states_for_mapper_hierarchy(self.mapper, False, False),
647 uow,
648 )
649
650 def per_state_flush_actions(self, uow):
651 states = list(
652 uow.states_for_mapper_hierarchy(self.mapper, False, False)
653 )
654 base_mapper = self.mapper.base_mapper
655 delete_all = _DeleteAll(uow, base_mapper)
656 for state in states:
657 # keep saves before deletes -
658 # this ensures 'row switch' operations work
659 action = _SaveUpdateState(uow, state)
660 uow.dependencies.add((action, delete_all))
661 yield action
662
663 for dep in uow.deps[self.mapper]:
664 states_for_prop = uow.filter_states_for_dep(dep, states)
665 dep.per_state_flush_actions(uow, states_for_prop, False)
666
667 def __repr__(self):
668 return "%s(%s)" % (self.__class__.__name__, self.mapper)
669
670
671class _DeleteAll(_PostSortRec):
672 __slots__ = ("mapper", "sort_key")
673
674 def __init__(self, uow, mapper):
675 self.mapper = mapper
676 self.sort_key = ("DeleteAll", mapper._sort_key)
677 assert mapper is mapper.base_mapper
678
679 @util.preload_module("sqlalchemy.orm.persistence")
680 def execute(self, uow):
681 util.preloaded.orm_persistence._delete_obj(
682 self.mapper,
683 uow.states_for_mapper_hierarchy(self.mapper, True, False),
684 uow,
685 )
686
687 def per_state_flush_actions(self, uow):
688 states = list(
689 uow.states_for_mapper_hierarchy(self.mapper, True, False)
690 )
691 base_mapper = self.mapper.base_mapper
692 save_all = _SaveUpdateAll(uow, base_mapper)
693 for state in states:
694 # keep saves before deletes -
695 # this ensures 'row switch' operations work
696 action = _DeleteState(uow, state)
697 uow.dependencies.add((save_all, action))
698 yield action
699
700 for dep in uow.deps[self.mapper]:
701 states_for_prop = uow.filter_states_for_dep(dep, states)
702 dep.per_state_flush_actions(uow, states_for_prop, True)
703
704 def __repr__(self):
705 return "%s(%s)" % (self.__class__.__name__, self.mapper)
706
707
708class _ProcessState(_PostSortRec):
709 __slots__ = "dependency_processor", "isdelete", "state", "sort_key"
710
711 def __init__(self, uow, dependency_processor, isdelete, state):
712 self.dependency_processor = dependency_processor
713 self.sort_key = ("ProcessState", dependency_processor.sort_key)
714 self.isdelete = isdelete
715 self.state = state
716
717 def execute_aggregate(self, uow, recs):
718 cls_ = self.__class__
719 dependency_processor = self.dependency_processor
720 isdelete = self.isdelete
721 our_recs = [
722 r
723 for r in recs
724 if r.__class__ is cls_
725 and r.dependency_processor is dependency_processor
726 and r.isdelete is isdelete
727 ]
728 recs.difference_update(our_recs)
729 states = [self.state] + [r.state for r in our_recs]
730 if isdelete:
731 dependency_processor.process_deletes(uow, states)
732 else:
733 dependency_processor.process_saves(uow, states)
734
735 def __repr__(self):
736 return "%s(%s, %s, delete=%s)" % (
737 self.__class__.__name__,
738 self.dependency_processor,
739 orm_util.state_str(self.state),
740 self.isdelete,
741 )
742
743
744class _SaveUpdateState(_PostSortRec):
745 __slots__ = "state", "mapper", "sort_key"
746
747 def __init__(self, uow, state):
748 self.state = state
749 self.mapper = state.mapper.base_mapper
750 self.sort_key = ("ProcessState", self.mapper._sort_key)
751
752 @util.preload_module("sqlalchemy.orm.persistence")
753 def execute_aggregate(self, uow, recs):
754 persistence = util.preloaded.orm_persistence
755 cls_ = self.__class__
756 mapper = self.mapper
757 our_recs = [
758 r for r in recs if r.__class__ is cls_ and r.mapper is mapper
759 ]
760 recs.difference_update(our_recs)
761 persistence._save_obj(
762 mapper, [self.state] + [r.state for r in our_recs], uow
763 )
764
765 def __repr__(self):
766 return "%s(%s)" % (
767 self.__class__.__name__,
768 orm_util.state_str(self.state),
769 )
770
771
772class _DeleteState(_PostSortRec):
773 __slots__ = "state", "mapper", "sort_key"
774
775 def __init__(self, uow, state):
776 self.state = state
777 self.mapper = state.mapper.base_mapper
778 self.sort_key = ("DeleteState", self.mapper._sort_key)
779
780 @util.preload_module("sqlalchemy.orm.persistence")
781 def execute_aggregate(self, uow, recs):
782 persistence = util.preloaded.orm_persistence
783 cls_ = self.__class__
784 mapper = self.mapper
785 our_recs = [
786 r for r in recs if r.__class__ is cls_ and r.mapper is mapper
787 ]
788 recs.difference_update(our_recs)
789 states = [self.state] + [r.state for r in our_recs]
790 persistence._delete_obj(
791 mapper, [s for s in states if uow.states[s][0]], uow
792 )
793
794 def __repr__(self):
795 return "%s(%s)" % (
796 self.__class__.__name__,
797 orm_util.state_str(self.state),
798 )