1# orm/unitofwork.py
2# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: ignore-errors
8
9
10"""The internals for the unit of work system.
11
12The session's flush() process passes objects to a contextual object
13here, which assembles flush tasks based on mappers and their properties,
14organizes them in order of dependency, and executes.
15
16"""
17
18from __future__ import annotations
19
20from typing import Any
21from typing import Dict
22from typing import Optional
23from typing import Set
24from typing import TYPE_CHECKING
25
26from . import attributes
27from . import exc as orm_exc
28from . import util as orm_util
29from .. import event
30from .. import util
31from ..util import topological
32
33
34if TYPE_CHECKING:
35 from .dependency import DependencyProcessor
36 from .interfaces import MapperProperty
37 from .mapper import Mapper
38 from .session import Session
39 from .session import SessionTransaction
40 from .state import InstanceState
41
42
43def track_cascade_events(descriptor, prop):
44 """Establish event listeners on object attributes which handle
45 cascade-on-set/append.
46
47 """
48 key = prop.key
49
50 def append(state, item, initiator, **kw):
51 # process "save_update" cascade rules for when
52 # an instance is appended to the list of another instance
53
54 if item is None:
55 return
56
57 sess = state.session
58 if sess:
59 if sess._warn_on_events:
60 sess._flush_warning("collection append")
61
62 prop = state.manager.mapper._props[key]
63 item_state = attributes.instance_state(item)
64
65 if (
66 prop._cascade.save_update
67 and (key == initiator.key)
68 and not sess._contains_state(item_state)
69 ):
70 sess._save_or_update_state(item_state)
71 return item
72
73 def remove(state, item, initiator, **kw):
74 if item is None:
75 return
76
77 sess = state.session
78
79 prop = state.manager.mapper._props[key]
80
81 if sess and sess._warn_on_events:
82 sess._flush_warning(
83 "collection remove"
84 if prop.uselist
85 else "related attribute delete"
86 )
87
88 if (
89 item is not None
90 and item is not attributes.NEVER_SET
91 and item is not attributes.PASSIVE_NO_RESULT
92 and prop._cascade.delete_orphan
93 ):
94 # expunge pending orphans
95 item_state = attributes.instance_state(item)
96
97 if prop.mapper._is_orphan(item_state):
98 if sess and item_state in sess._new:
99 sess.expunge(item)
100 else:
101 # the related item may or may not itself be in a
102 # Session, however the parent for which we are catching
103 # the event is not in a session, so memoize this on the
104 # item
105 item_state._orphaned_outside_of_session = True
106
107 def set_(state, newvalue, oldvalue, initiator, **kw):
108 # process "save_update" cascade rules for when an instance
109 # is attached to another instance
110 if oldvalue is newvalue:
111 return newvalue
112
113 sess = state.session
114 if sess:
115 if sess._warn_on_events:
116 sess._flush_warning("related attribute set")
117
118 prop = state.manager.mapper._props[key]
119 if newvalue is not None:
120 newvalue_state = attributes.instance_state(newvalue)
121 if (
122 prop._cascade.save_update
123 and (key == initiator.key)
124 and not sess._contains_state(newvalue_state)
125 ):
126 sess._save_or_update_state(newvalue_state)
127
128 if (
129 oldvalue is not None
130 and oldvalue is not attributes.NEVER_SET
131 and oldvalue is not attributes.PASSIVE_NO_RESULT
132 and prop._cascade.delete_orphan
133 ):
134 # possible to reach here with attributes.NEVER_SET ?
135 oldvalue_state = attributes.instance_state(oldvalue)
136
137 if oldvalue_state in sess._new and prop.mapper._is_orphan(
138 oldvalue_state
139 ):
140 sess.expunge(oldvalue)
141 return newvalue
142
143 event.listen(
144 descriptor, "append_wo_mutation", append, raw=True, include_key=True
145 )
146 event.listen(
147 descriptor, "append", append, raw=True, retval=True, include_key=True
148 )
149 event.listen(
150 descriptor, "remove", remove, raw=True, retval=True, include_key=True
151 )
152 event.listen(
153 descriptor, "set", set_, raw=True, retval=True, include_key=True
154 )
155
156
157class UOWTransaction:
158 session: Session
159 transaction: SessionTransaction
160 attributes: Dict[str, Any]
161 deps: util.defaultdict[Mapper[Any], Set[DependencyProcessor]]
162 mappers: util.defaultdict[Mapper[Any], Set[InstanceState[Any]]]
163
164 def __init__(self, session: Session):
165 self.session = session
166
167 # dictionary used by external actors to
168 # store arbitrary state information.
169 self.attributes = {}
170
171 # dictionary of mappers to sets of
172 # DependencyProcessors, which are also
173 # set to be part of the sorted flush actions,
174 # which have that mapper as a parent.
175 self.deps = util.defaultdict(set)
176
177 # dictionary of mappers to sets of InstanceState
178 # items pending for flush which have that mapper
179 # as a parent.
180 self.mappers = util.defaultdict(set)
181
182 # a dictionary of Preprocess objects, which gather
183 # additional states impacted by the flush
184 # and determine if a flush action is needed
185 self.presort_actions = {}
186
187 # dictionary of PostSortRec objects, each
188 # one issues work during the flush within
189 # a certain ordering.
190 self.postsort_actions = {}
191
192 # a set of 2-tuples, each containing two
193 # PostSortRec objects where the second
194 # is dependent on the first being executed
195 # first
196 self.dependencies = set()
197
198 # dictionary of InstanceState-> (isdelete, listonly)
199 # tuples, indicating if this state is to be deleted
200 # or insert/updated, or just refreshed
201 self.states = {}
202
203 # tracks InstanceStates which will be receiving
204 # a "post update" call. Keys are mappers,
205 # values are a set of states and a set of the
206 # columns which should be included in the update.
207 self.post_update_states = util.defaultdict(lambda: (set(), set()))
208
209 @property
210 def has_work(self):
211 return bool(self.states)
212
213 def was_already_deleted(self, state):
214 """Return ``True`` if the given state is expired and was deleted
215 previously.
216 """
217 if state.expired:
218 try:
219 state._load_expired(state, attributes.PASSIVE_OFF)
220 except orm_exc.ObjectDeletedError:
221 self.session._remove_newly_deleted([state])
222 return True
223 return False
224
225 def is_deleted(self, state):
226 """Return ``True`` if the given state is marked as deleted
227 within this uowtransaction."""
228
229 return state in self.states and self.states[state][0]
230
231 def memo(self, key, callable_):
232 if key in self.attributes:
233 return self.attributes[key]
234 else:
235 self.attributes[key] = ret = callable_()
236 return ret
237
238 def remove_state_actions(self, state):
239 """Remove pending actions for a state from the uowtransaction."""
240
241 isdelete = self.states[state][0]
242
243 self.states[state] = (isdelete, True)
244
245 def get_attribute_history(
246 self, state, key, passive=attributes.PASSIVE_NO_INITIALIZE
247 ):
248 """Facade to attributes.get_state_history(), including
249 caching of results."""
250
251 hashkey = ("history", state, key)
252
253 # cache the objects, not the states; the strong reference here
254 # prevents newly loaded objects from being dereferenced during the
255 # flush process
256
257 if hashkey in self.attributes:
258 history, state_history, cached_passive = self.attributes[hashkey]
259 # if the cached lookup was "passive" and now
260 # we want non-passive, do a non-passive lookup and re-cache
261
262 if (
263 not cached_passive & attributes.SQL_OK
264 and passive & attributes.SQL_OK
265 ):
266 impl = state.manager[key].impl
267 history = impl.get_history(
268 state,
269 state.dict,
270 attributes.PASSIVE_OFF
271 | attributes.LOAD_AGAINST_COMMITTED
272 | attributes.NO_RAISE,
273 )
274 if history and impl.uses_objects:
275 state_history = history.as_state()
276 else:
277 state_history = history
278 self.attributes[hashkey] = (history, state_history, passive)
279 else:
280 impl = state.manager[key].impl
281 # TODO: store the history as (state, object) tuples
282 # so we don't have to keep converting here
283 history = impl.get_history(
284 state,
285 state.dict,
286 passive
287 | attributes.LOAD_AGAINST_COMMITTED
288 | attributes.NO_RAISE,
289 )
290 if history and impl.uses_objects:
291 state_history = history.as_state()
292 else:
293 state_history = history
294 self.attributes[hashkey] = (history, state_history, passive)
295
296 return state_history
297
298 def has_dep(self, processor):
299 return (processor, True) in self.presort_actions
300
301 def register_preprocessor(self, processor, fromparent):
302 key = (processor, fromparent)
303 if key not in self.presort_actions:
304 self.presort_actions[key] = Preprocess(processor, fromparent)
305
306 def register_object(
307 self,
308 state: InstanceState[Any],
309 isdelete: bool = False,
310 listonly: bool = False,
311 cancel_delete: bool = False,
312 operation: Optional[str] = None,
313 prop: Optional[MapperProperty] = None,
314 ) -> bool:
315 if not self.session._contains_state(state):
316 # this condition is normal when objects are registered
317 # as part of a relationship cascade operation. it should
318 # not occur for the top-level register from Session.flush().
319 if not state.deleted and operation is not None:
320 util.warn(
321 "Object of type %s not in session, %s operation "
322 "along '%s' will not proceed"
323 % (orm_util.state_class_str(state), operation, prop)
324 )
325 return False
326
327 if state not in self.states:
328 mapper = state.manager.mapper
329
330 if mapper not in self.mappers:
331 self._per_mapper_flush_actions(mapper)
332
333 self.mappers[mapper].add(state)
334 self.states[state] = (isdelete, listonly)
335 else:
336 if not listonly and (isdelete or cancel_delete):
337 self.states[state] = (isdelete, False)
338 return True
339
340 def register_post_update(self, state, post_update_cols):
341 mapper = state.manager.mapper.base_mapper
342 states, cols = self.post_update_states[mapper]
343 states.add(state)
344 cols.update(post_update_cols)
345
346 def _per_mapper_flush_actions(self, mapper):
347 saves = SaveUpdateAll(self, mapper.base_mapper)
348 deletes = DeleteAll(self, mapper.base_mapper)
349 self.dependencies.add((saves, deletes))
350
351 for dep in mapper._dependency_processors:
352 dep.per_property_preprocessors(self)
353
354 for prop in mapper.relationships:
355 if prop.viewonly:
356 continue
357 dep = prop._dependency_processor
358 dep.per_property_preprocessors(self)
359
360 @util.memoized_property
361 def _mapper_for_dep(self):
362 """return a dynamic mapping of (Mapper, DependencyProcessor) to
363 True or False, indicating if the DependencyProcessor operates
364 on objects of that Mapper.
365
366 The result is stored in the dictionary persistently once
367 calculated.
368
369 """
370 return util.PopulateDict(
371 lambda tup: tup[0]._props.get(tup[1].key) is tup[1].prop
372 )
373
374 def filter_states_for_dep(self, dep, states):
375 """Filter the given list of InstanceStates to those relevant to the
376 given DependencyProcessor.
377
378 """
379 mapper_for_dep = self._mapper_for_dep
380 return [s for s in states if mapper_for_dep[(s.manager.mapper, dep)]]
381
382 def states_for_mapper_hierarchy(self, mapper, isdelete, listonly):
383 checktup = (isdelete, listonly)
384 for mapper in mapper.base_mapper.self_and_descendants:
385 for state in self.mappers[mapper]:
386 if self.states[state] == checktup:
387 yield state
388
389 def _generate_actions(self):
390 """Generate the full, unsorted collection of PostSortRecs as
391 well as dependency pairs for this UOWTransaction.
392
393 """
394 # execute presort_actions, until all states
395 # have been processed. a presort_action might
396 # add new states to the uow.
397 while True:
398 ret = False
399 for action in list(self.presort_actions.values()):
400 if action.execute(self):
401 ret = True
402 if not ret:
403 break
404
405 # see if the graph of mapper dependencies has cycles.
406 self.cycles = cycles = topological.find_cycles(
407 self.dependencies, list(self.postsort_actions.values())
408 )
409
410 if cycles:
411 # if yes, break the per-mapper actions into
412 # per-state actions
413 convert = {
414 rec: set(rec.per_state_flush_actions(self)) for rec in cycles
415 }
416
417 # rewrite the existing dependencies to point to
418 # the per-state actions for those per-mapper actions
419 # that were broken up.
420 for edge in list(self.dependencies):
421 if (
422 None in edge
423 or edge[0].disabled
424 or edge[1].disabled
425 or cycles.issuperset(edge)
426 ):
427 self.dependencies.remove(edge)
428 elif edge[0] in cycles:
429 self.dependencies.remove(edge)
430 for dep in convert[edge[0]]:
431 self.dependencies.add((dep, edge[1]))
432 elif edge[1] in cycles:
433 self.dependencies.remove(edge)
434 for dep in convert[edge[1]]:
435 self.dependencies.add((edge[0], dep))
436
437 return {
438 a for a in self.postsort_actions.values() if not a.disabled
439 }.difference(cycles)
440
441 def execute(self) -> None:
442 postsort_actions = self._generate_actions()
443
444 postsort_actions = sorted(
445 postsort_actions,
446 key=lambda item: item.sort_key,
447 )
448 # sort = topological.sort(self.dependencies, postsort_actions)
449 # print "--------------"
450 # print "\ndependencies:", self.dependencies
451 # print "\ncycles:", self.cycles
452 # print "\nsort:", list(sort)
453 # print "\nCOUNT OF POSTSORT ACTIONS", len(postsort_actions)
454
455 # execute
456 if self.cycles:
457 for subset in topological.sort_as_subsets(
458 self.dependencies, postsort_actions
459 ):
460 set_ = set(subset)
461 while set_:
462 n = set_.pop()
463 n.execute_aggregate(self, set_)
464 else:
465 for rec in topological.sort(self.dependencies, postsort_actions):
466 rec.execute(self)
467
468 def finalize_flush_changes(self) -> None:
469 """Mark processed objects as clean / deleted after a successful
470 flush().
471
472 This method is called within the flush() method after the
473 execute() method has succeeded and the transaction has been committed.
474
475 """
476 if not self.states:
477 return
478
479 states = set(self.states)
480 isdel = {
481 s for (s, (isdelete, listonly)) in self.states.items() if isdelete
482 }
483 other = states.difference(isdel)
484 if isdel:
485 self.session._remove_newly_deleted(isdel)
486 if other:
487 self.session._register_persistent(other)
488
489
490class IterateMappersMixin:
491 __slots__ = ()
492
493 def _mappers(self, uow):
494 if self.fromparent:
495 return iter(
496 m
497 for m in self.dependency_processor.parent.self_and_descendants
498 if uow._mapper_for_dep[(m, self.dependency_processor)]
499 )
500 else:
501 return self.dependency_processor.mapper.self_and_descendants
502
503
504class Preprocess(IterateMappersMixin):
505 __slots__ = (
506 "dependency_processor",
507 "fromparent",
508 "processed",
509 "setup_flush_actions",
510 )
511
512 def __init__(self, dependency_processor, fromparent):
513 self.dependency_processor = dependency_processor
514 self.fromparent = fromparent
515 self.processed = set()
516 self.setup_flush_actions = False
517
518 def execute(self, uow):
519 delete_states = set()
520 save_states = set()
521
522 for mapper in self._mappers(uow):
523 for state in uow.mappers[mapper].difference(self.processed):
524 (isdelete, listonly) = uow.states[state]
525 if not listonly:
526 if isdelete:
527 delete_states.add(state)
528 else:
529 save_states.add(state)
530
531 if delete_states:
532 self.dependency_processor.presort_deletes(uow, delete_states)
533 self.processed.update(delete_states)
534 if save_states:
535 self.dependency_processor.presort_saves(uow, save_states)
536 self.processed.update(save_states)
537
538 if delete_states or save_states:
539 if not self.setup_flush_actions and (
540 self.dependency_processor.prop_has_changes(
541 uow, delete_states, True
542 )
543 or self.dependency_processor.prop_has_changes(
544 uow, save_states, False
545 )
546 ):
547 self.dependency_processor.per_property_flush_actions(uow)
548 self.setup_flush_actions = True
549 return True
550 else:
551 return False
552
553
554class PostSortRec:
555 __slots__ = ("disabled",)
556
557 def __new__(cls, uow, *args):
558 key = (cls,) + args
559 if key in uow.postsort_actions:
560 return uow.postsort_actions[key]
561 else:
562 uow.postsort_actions[key] = ret = object.__new__(cls)
563 ret.disabled = False
564 return ret
565
566 def execute_aggregate(self, uow, recs):
567 self.execute(uow)
568
569
570class ProcessAll(IterateMappersMixin, PostSortRec):
571 __slots__ = "dependency_processor", "isdelete", "fromparent", "sort_key"
572
573 def __init__(self, uow, dependency_processor, isdelete, fromparent):
574 self.dependency_processor = dependency_processor
575 self.sort_key = (
576 "ProcessAll",
577 self.dependency_processor.sort_key,
578 isdelete,
579 )
580 self.isdelete = isdelete
581 self.fromparent = fromparent
582 uow.deps[dependency_processor.parent.base_mapper].add(
583 dependency_processor
584 )
585
586 def execute(self, uow):
587 states = self._elements(uow)
588 if self.isdelete:
589 self.dependency_processor.process_deletes(uow, states)
590 else:
591 self.dependency_processor.process_saves(uow, states)
592
593 def per_state_flush_actions(self, uow):
594 # this is handled by SaveUpdateAll and DeleteAll,
595 # since a ProcessAll should unconditionally be pulled
596 # into per-state if either the parent/child mappers
597 # are part of a cycle
598 return iter([])
599
600 def __repr__(self):
601 return "%s(%s, isdelete=%s)" % (
602 self.__class__.__name__,
603 self.dependency_processor,
604 self.isdelete,
605 )
606
607 def _elements(self, uow):
608 for mapper in self._mappers(uow):
609 for state in uow.mappers[mapper]:
610 (isdelete, listonly) = uow.states[state]
611 if isdelete == self.isdelete and not listonly:
612 yield state
613
614
615class PostUpdateAll(PostSortRec):
616 __slots__ = "mapper", "isdelete", "sort_key"
617
618 def __init__(self, uow, mapper, isdelete):
619 self.mapper = mapper
620 self.isdelete = isdelete
621 self.sort_key = ("PostUpdateAll", mapper._sort_key, isdelete)
622
623 @util.preload_module("sqlalchemy.orm.persistence")
624 def execute(self, uow):
625 persistence = util.preloaded.orm_persistence
626 states, cols = uow.post_update_states[self.mapper]
627 states = [s for s in states if uow.states[s][0] == self.isdelete]
628
629 persistence.post_update(self.mapper, states, uow, cols)
630
631
632class SaveUpdateAll(PostSortRec):
633 __slots__ = ("mapper", "sort_key")
634
635 def __init__(self, uow, mapper):
636 self.mapper = mapper
637 self.sort_key = ("SaveUpdateAll", mapper._sort_key)
638 assert mapper is mapper.base_mapper
639
640 @util.preload_module("sqlalchemy.orm.persistence")
641 def execute(self, uow):
642 util.preloaded.orm_persistence.save_obj(
643 self.mapper,
644 uow.states_for_mapper_hierarchy(self.mapper, False, False),
645 uow,
646 )
647
648 def per_state_flush_actions(self, uow):
649 states = list(
650 uow.states_for_mapper_hierarchy(self.mapper, False, False)
651 )
652 base_mapper = self.mapper.base_mapper
653 delete_all = DeleteAll(uow, base_mapper)
654 for state in states:
655 # keep saves before deletes -
656 # this ensures 'row switch' operations work
657 action = SaveUpdateState(uow, state)
658 uow.dependencies.add((action, delete_all))
659 yield action
660
661 for dep in uow.deps[self.mapper]:
662 states_for_prop = uow.filter_states_for_dep(dep, states)
663 dep.per_state_flush_actions(uow, states_for_prop, False)
664
665 def __repr__(self):
666 return "%s(%s)" % (self.__class__.__name__, self.mapper)
667
668
669class DeleteAll(PostSortRec):
670 __slots__ = ("mapper", "sort_key")
671
672 def __init__(self, uow, mapper):
673 self.mapper = mapper
674 self.sort_key = ("DeleteAll", mapper._sort_key)
675 assert mapper is mapper.base_mapper
676
677 @util.preload_module("sqlalchemy.orm.persistence")
678 def execute(self, uow):
679 util.preloaded.orm_persistence.delete_obj(
680 self.mapper,
681 uow.states_for_mapper_hierarchy(self.mapper, True, False),
682 uow,
683 )
684
685 def per_state_flush_actions(self, uow):
686 states = list(
687 uow.states_for_mapper_hierarchy(self.mapper, True, False)
688 )
689 base_mapper = self.mapper.base_mapper
690 save_all = SaveUpdateAll(uow, base_mapper)
691 for state in states:
692 # keep saves before deletes -
693 # this ensures 'row switch' operations work
694 action = DeleteState(uow, state)
695 uow.dependencies.add((save_all, action))
696 yield action
697
698 for dep in uow.deps[self.mapper]:
699 states_for_prop = uow.filter_states_for_dep(dep, states)
700 dep.per_state_flush_actions(uow, states_for_prop, True)
701
702 def __repr__(self):
703 return "%s(%s)" % (self.__class__.__name__, self.mapper)
704
705
706class ProcessState(PostSortRec):
707 __slots__ = "dependency_processor", "isdelete", "state", "sort_key"
708
709 def __init__(self, uow, dependency_processor, isdelete, state):
710 self.dependency_processor = dependency_processor
711 self.sort_key = ("ProcessState", dependency_processor.sort_key)
712 self.isdelete = isdelete
713 self.state = state
714
715 def execute_aggregate(self, uow, recs):
716 cls_ = self.__class__
717 dependency_processor = self.dependency_processor
718 isdelete = self.isdelete
719 our_recs = [
720 r
721 for r in recs
722 if r.__class__ is cls_
723 and r.dependency_processor is dependency_processor
724 and r.isdelete is isdelete
725 ]
726 recs.difference_update(our_recs)
727 states = [self.state] + [r.state for r in our_recs]
728 if isdelete:
729 dependency_processor.process_deletes(uow, states)
730 else:
731 dependency_processor.process_saves(uow, states)
732
733 def __repr__(self):
734 return "%s(%s, %s, delete=%s)" % (
735 self.__class__.__name__,
736 self.dependency_processor,
737 orm_util.state_str(self.state),
738 self.isdelete,
739 )
740
741
742class SaveUpdateState(PostSortRec):
743 __slots__ = "state", "mapper", "sort_key"
744
745 def __init__(self, uow, state):
746 self.state = state
747 self.mapper = state.mapper.base_mapper
748 self.sort_key = ("ProcessState", self.mapper._sort_key)
749
750 @util.preload_module("sqlalchemy.orm.persistence")
751 def execute_aggregate(self, uow, recs):
752 persistence = util.preloaded.orm_persistence
753 cls_ = self.__class__
754 mapper = self.mapper
755 our_recs = [
756 r for r in recs if r.__class__ is cls_ and r.mapper is mapper
757 ]
758 recs.difference_update(our_recs)
759 persistence.save_obj(
760 mapper, [self.state] + [r.state for r in our_recs], uow
761 )
762
763 def __repr__(self):
764 return "%s(%s)" % (
765 self.__class__.__name__,
766 orm_util.state_str(self.state),
767 )
768
769
770class DeleteState(PostSortRec):
771 __slots__ = "state", "mapper", "sort_key"
772
773 def __init__(self, uow, state):
774 self.state = state
775 self.mapper = state.mapper.base_mapper
776 self.sort_key = ("DeleteState", self.mapper._sort_key)
777
778 @util.preload_module("sqlalchemy.orm.persistence")
779 def execute_aggregate(self, uow, recs):
780 persistence = util.preloaded.orm_persistence
781 cls_ = self.__class__
782 mapper = self.mapper
783 our_recs = [
784 r for r in recs if r.__class__ is cls_ and r.mapper is mapper
785 ]
786 recs.difference_update(our_recs)
787 states = [self.state] + [r.state for r in our_recs]
788 persistence.delete_obj(
789 mapper, [s for s in states if uow.states[s][0]], uow
790 )
791
792 def __repr__(self):
793 return "%s(%s)" % (
794 self.__class__.__name__,
795 orm_util.state_str(self.state),
796 )