1# orm/unitofwork.py
2# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7
8"""The internals for the unit of work system.
9
10The session's flush() process passes objects to a contextual object
11here, which assembles flush tasks based on mappers and their properties,
12organizes them in order of dependency, and executes.
13
14"""
15
16from . import attributes
17from . import exc as orm_exc
18from . import util as orm_util
19from .. import event
20from .. import util
21from ..util import topological
22
23
24def _warn_for_cascade_backrefs(state, prop):
25 util.warn_deprecated_20(
26 '"%s" object is being merged into a Session along the backref '
27 'cascade path for relationship "%s"; in SQLAlchemy 2.0, this '
28 "reverse cascade will not take place. Set cascade_backrefs to "
29 "False in either the relationship() or backref() function for "
30 "the 2.0 behavior; or to set globally for the whole "
31 "Session, set the future=True flag" % (state.class_.__name__, prop),
32 code="s9r1",
33 )
34
35
36def track_cascade_events(descriptor, prop):
37 """Establish event listeners on object attributes which handle
38 cascade-on-set/append.
39
40 """
41 key = prop.key
42
43 def append(state, item, initiator):
44 # process "save_update" cascade rules for when
45 # an instance is appended to the list of another instance
46
47 if item is None:
48 return
49
50 sess = state.session
51 if sess:
52 if sess._warn_on_events:
53 sess._flush_warning("collection append")
54
55 prop = state.manager.mapper._props[key]
56 item_state = attributes.instance_state(item)
57
58 if (
59 prop._cascade.save_update
60 and (
61 (prop.cascade_backrefs and not sess.future)
62 or key == initiator.key
63 )
64 and not sess._contains_state(item_state)
65 ):
66 if key != initiator.key:
67 _warn_for_cascade_backrefs(item_state, prop)
68 sess._save_or_update_state(item_state)
69 return item
70
71 def remove(state, item, initiator):
72 if item is None:
73 return
74
75 sess = state.session
76
77 prop = state.manager.mapper._props[key]
78
79 if sess and sess._warn_on_events:
80 sess._flush_warning(
81 "collection remove"
82 if prop.uselist
83 else "related attribute delete"
84 )
85
86 if (
87 item is not None
88 and item is not attributes.NEVER_SET
89 and item is not attributes.PASSIVE_NO_RESULT
90 and prop._cascade.delete_orphan
91 ):
92 # expunge pending orphans
93 item_state = attributes.instance_state(item)
94
95 if prop.mapper._is_orphan(item_state):
96 if sess and item_state in sess._new:
97 sess.expunge(item)
98 else:
99 # the related item may or may not itself be in a
100 # Session, however the parent for which we are catching
101 # the event is not in a session, so memoize this on the
102 # item
103 item_state._orphaned_outside_of_session = True
104
105 def set_(state, newvalue, oldvalue, initiator):
106 # process "save_update" cascade rules for when an instance
107 # is attached to another instance
108 if oldvalue is newvalue:
109 return newvalue
110
111 sess = state.session
112 if sess:
113
114 if sess._warn_on_events:
115 sess._flush_warning("related attribute set")
116
117 prop = state.manager.mapper._props[key]
118 if newvalue is not None:
119 newvalue_state = attributes.instance_state(newvalue)
120 if (
121 prop._cascade.save_update
122 and (
123 (prop.cascade_backrefs and not sess.future)
124 or key == initiator.key
125 )
126 and not sess._contains_state(newvalue_state)
127 ):
128 if key != initiator.key:
129 _warn_for_cascade_backrefs(newvalue_state, prop)
130 sess._save_or_update_state(newvalue_state)
131
132 if (
133 oldvalue is not None
134 and oldvalue is not attributes.NEVER_SET
135 and oldvalue is not attributes.PASSIVE_NO_RESULT
136 and prop._cascade.delete_orphan
137 ):
138 # possible to reach here with attributes.NEVER_SET ?
139 oldvalue_state = attributes.instance_state(oldvalue)
140
141 if oldvalue_state in sess._new and prop.mapper._is_orphan(
142 oldvalue_state
143 ):
144 sess.expunge(oldvalue)
145 return newvalue
146
147 event.listen(descriptor, "append_wo_mutation", append, raw=True)
148 event.listen(descriptor, "append", append, raw=True, retval=True)
149 event.listen(descriptor, "remove", remove, raw=True, retval=True)
150 event.listen(descriptor, "set", set_, raw=True, retval=True)
151
152
153class UOWTransaction(object):
154 def __init__(self, session):
155 self.session = session
156
157 # dictionary used by external actors to
158 # store arbitrary state information.
159 self.attributes = {}
160
161 # dictionary of mappers to sets of
162 # DependencyProcessors, which are also
163 # set to be part of the sorted flush actions,
164 # which have that mapper as a parent.
165 self.deps = util.defaultdict(set)
166
167 # dictionary of mappers to sets of InstanceState
168 # items pending for flush which have that mapper
169 # as a parent.
170 self.mappers = util.defaultdict(set)
171
172 # a dictionary of Preprocess objects, which gather
173 # additional states impacted by the flush
174 # and determine if a flush action is needed
175 self.presort_actions = {}
176
177 # dictionary of PostSortRec objects, each
178 # one issues work during the flush within
179 # a certain ordering.
180 self.postsort_actions = {}
181
182 # a set of 2-tuples, each containing two
183 # PostSortRec objects where the second
184 # is dependent on the first being executed
185 # first
186 self.dependencies = set()
187
188 # dictionary of InstanceState-> (isdelete, listonly)
189 # tuples, indicating if this state is to be deleted
190 # or insert/updated, or just refreshed
191 self.states = {}
192
193 # tracks InstanceStates which will be receiving
194 # a "post update" call. Keys are mappers,
195 # values are a set of states and a set of the
196 # columns which should be included in the update.
197 self.post_update_states = util.defaultdict(lambda: (set(), set()))
198
199 @property
200 def has_work(self):
201 return bool(self.states)
202
203 def was_already_deleted(self, state):
204 """Return ``True`` if the given state is expired and was deleted
205 previously.
206 """
207 if state.expired:
208 try:
209 state._load_expired(state, attributes.PASSIVE_OFF)
210 except orm_exc.ObjectDeletedError:
211 self.session._remove_newly_deleted([state])
212 return True
213 return False
214
215 def is_deleted(self, state):
216 """Return ``True`` if the given state is marked as deleted
217 within this uowtransaction."""
218
219 return state in self.states and self.states[state][0]
220
221 def memo(self, key, callable_):
222 if key in self.attributes:
223 return self.attributes[key]
224 else:
225 self.attributes[key] = ret = callable_()
226 return ret
227
228 def remove_state_actions(self, state):
229 """Remove pending actions for a state from the uowtransaction."""
230
231 isdelete = self.states[state][0]
232
233 self.states[state] = (isdelete, True)
234
235 def get_attribute_history(
236 self, state, key, passive=attributes.PASSIVE_NO_INITIALIZE
237 ):
238 """Facade to attributes.get_state_history(), including
239 caching of results."""
240
241 hashkey = ("history", state, key)
242
243 # cache the objects, not the states; the strong reference here
244 # prevents newly loaded objects from being dereferenced during the
245 # flush process
246
247 if hashkey in self.attributes:
248 history, state_history, cached_passive = self.attributes[hashkey]
249 # if the cached lookup was "passive" and now
250 # we want non-passive, do a non-passive lookup and re-cache
251
252 if (
253 not cached_passive & attributes.SQL_OK
254 and passive & attributes.SQL_OK
255 ):
256 impl = state.manager[key].impl
257 history = impl.get_history(
258 state,
259 state.dict,
260 attributes.PASSIVE_OFF
261 | attributes.LOAD_AGAINST_COMMITTED
262 | attributes.NO_RAISE,
263 )
264 if history and impl.uses_objects:
265 state_history = history.as_state()
266 else:
267 state_history = history
268 self.attributes[hashkey] = (history, state_history, passive)
269 else:
270 impl = state.manager[key].impl
271 # TODO: store the history as (state, object) tuples
272 # so we don't have to keep converting here
273 history = impl.get_history(
274 state,
275 state.dict,
276 passive
277 | attributes.LOAD_AGAINST_COMMITTED
278 | attributes.NO_RAISE,
279 )
280 if history and impl.uses_objects:
281 state_history = history.as_state()
282 else:
283 state_history = history
284 self.attributes[hashkey] = (history, state_history, passive)
285
286 return state_history
287
288 def has_dep(self, processor):
289 return (processor, True) in self.presort_actions
290
291 def register_preprocessor(self, processor, fromparent):
292 key = (processor, fromparent)
293 if key not in self.presort_actions:
294 self.presort_actions[key] = Preprocess(processor, fromparent)
295
296 def register_object(
297 self,
298 state,
299 isdelete=False,
300 listonly=False,
301 cancel_delete=False,
302 operation=None,
303 prop=None,
304 ):
305 if not self.session._contains_state(state):
306 # this condition is normal when objects are registered
307 # as part of a relationship cascade operation. it should
308 # not occur for the top-level register from Session.flush().
309 if not state.deleted and operation is not None:
310 util.warn(
311 "Object of type %s not in session, %s operation "
312 "along '%s' will not proceed"
313 % (orm_util.state_class_str(state), operation, prop)
314 )
315 return False
316
317 if state not in self.states:
318 mapper = state.manager.mapper
319
320 if mapper not in self.mappers:
321 self._per_mapper_flush_actions(mapper)
322
323 self.mappers[mapper].add(state)
324 self.states[state] = (isdelete, listonly)
325 else:
326 if not listonly and (isdelete or cancel_delete):
327 self.states[state] = (isdelete, False)
328 return True
329
330 def register_post_update(self, state, post_update_cols):
331 mapper = state.manager.mapper.base_mapper
332 states, cols = self.post_update_states[mapper]
333 states.add(state)
334 cols.update(post_update_cols)
335
336 def _per_mapper_flush_actions(self, mapper):
337 saves = SaveUpdateAll(self, mapper.base_mapper)
338 deletes = DeleteAll(self, mapper.base_mapper)
339 self.dependencies.add((saves, deletes))
340
341 for dep in mapper._dependency_processors:
342 dep.per_property_preprocessors(self)
343
344 for prop in mapper.relationships:
345 if prop.viewonly:
346 continue
347 dep = prop._dependency_processor
348 dep.per_property_preprocessors(self)
349
350 @util.memoized_property
351 def _mapper_for_dep(self):
352 """return a dynamic mapping of (Mapper, DependencyProcessor) to
353 True or False, indicating if the DependencyProcessor operates
354 on objects of that Mapper.
355
356 The result is stored in the dictionary persistently once
357 calculated.
358
359 """
360 return util.PopulateDict(
361 lambda tup: tup[0]._props.get(tup[1].key) is tup[1].prop
362 )
363
364 def filter_states_for_dep(self, dep, states):
365 """Filter the given list of InstanceStates to those relevant to the
366 given DependencyProcessor.
367
368 """
369 mapper_for_dep = self._mapper_for_dep
370 return [s for s in states if mapper_for_dep[(s.manager.mapper, dep)]]
371
372 def states_for_mapper_hierarchy(self, mapper, isdelete, listonly):
373 checktup = (isdelete, listonly)
374 for mapper in mapper.base_mapper.self_and_descendants:
375 for state in self.mappers[mapper]:
376 if self.states[state] == checktup:
377 yield state
378
379 def _generate_actions(self):
380 """Generate the full, unsorted collection of PostSortRecs as
381 well as dependency pairs for this UOWTransaction.
382
383 """
384 # execute presort_actions, until all states
385 # have been processed. a presort_action might
386 # add new states to the uow.
387 while True:
388 ret = False
389 for action in list(self.presort_actions.values()):
390 if action.execute(self):
391 ret = True
392 if not ret:
393 break
394
395 # see if the graph of mapper dependencies has cycles.
396 self.cycles = cycles = topological.find_cycles(
397 self.dependencies, list(self.postsort_actions.values())
398 )
399
400 if cycles:
401 # if yes, break the per-mapper actions into
402 # per-state actions
403 convert = dict(
404 (rec, set(rec.per_state_flush_actions(self))) for rec in cycles
405 )
406
407 # rewrite the existing dependencies to point to
408 # the per-state actions for those per-mapper actions
409 # that were broken up.
410 for edge in list(self.dependencies):
411 if (
412 None in edge
413 or edge[0].disabled
414 or edge[1].disabled
415 or cycles.issuperset(edge)
416 ):
417 self.dependencies.remove(edge)
418 elif edge[0] in cycles:
419 self.dependencies.remove(edge)
420 for dep in convert[edge[0]]:
421 self.dependencies.add((dep, edge[1]))
422 elif edge[1] in cycles:
423 self.dependencies.remove(edge)
424 for dep in convert[edge[1]]:
425 self.dependencies.add((edge[0], dep))
426
427 return set(
428 [a for a in self.postsort_actions.values() if not a.disabled]
429 ).difference(cycles)
430
431 def execute(self):
432 postsort_actions = self._generate_actions()
433
434 postsort_actions = sorted(
435 postsort_actions,
436 key=lambda item: item.sort_key,
437 )
438 # sort = topological.sort(self.dependencies, postsort_actions)
439 # print "--------------"
440 # print "\ndependencies:", self.dependencies
441 # print "\ncycles:", self.cycles
442 # print "\nsort:", list(sort)
443 # print "\nCOUNT OF POSTSORT ACTIONS", len(postsort_actions)
444
445 # execute
446 if self.cycles:
447 for subset in topological.sort_as_subsets(
448 self.dependencies, postsort_actions
449 ):
450 set_ = set(subset)
451 while set_:
452 n = set_.pop()
453 n.execute_aggregate(self, set_)
454 else:
455 for rec in topological.sort(self.dependencies, postsort_actions):
456 rec.execute(self)
457
458 def finalize_flush_changes(self):
459 """Mark processed objects as clean / deleted after a successful
460 flush().
461
462 This method is called within the flush() method after the
463 execute() method has succeeded and the transaction has been committed.
464
465 """
466 if not self.states:
467 return
468
469 states = set(self.states)
470 isdel = set(
471 s for (s, (isdelete, listonly)) in self.states.items() if isdelete
472 )
473 other = states.difference(isdel)
474 if isdel:
475 self.session._remove_newly_deleted(isdel)
476 if other:
477 self.session._register_persistent(other)
478
479
480class IterateMappersMixin(object):
481 def _mappers(self, uow):
482 if self.fromparent:
483 return iter(
484 m
485 for m in self.dependency_processor.parent.self_and_descendants
486 if uow._mapper_for_dep[(m, self.dependency_processor)]
487 )
488 else:
489 return self.dependency_processor.mapper.self_and_descendants
490
491
492class Preprocess(IterateMappersMixin):
493 __slots__ = (
494 "dependency_processor",
495 "fromparent",
496 "processed",
497 "setup_flush_actions",
498 )
499
500 def __init__(self, dependency_processor, fromparent):
501 self.dependency_processor = dependency_processor
502 self.fromparent = fromparent
503 self.processed = set()
504 self.setup_flush_actions = False
505
506 def execute(self, uow):
507 delete_states = set()
508 save_states = set()
509
510 for mapper in self._mappers(uow):
511 for state in uow.mappers[mapper].difference(self.processed):
512 (isdelete, listonly) = uow.states[state]
513 if not listonly:
514 if isdelete:
515 delete_states.add(state)
516 else:
517 save_states.add(state)
518
519 if delete_states:
520 self.dependency_processor.presort_deletes(uow, delete_states)
521 self.processed.update(delete_states)
522 if save_states:
523 self.dependency_processor.presort_saves(uow, save_states)
524 self.processed.update(save_states)
525
526 if delete_states or save_states:
527 if not self.setup_flush_actions and (
528 self.dependency_processor.prop_has_changes(
529 uow, delete_states, True
530 )
531 or self.dependency_processor.prop_has_changes(
532 uow, save_states, False
533 )
534 ):
535 self.dependency_processor.per_property_flush_actions(uow)
536 self.setup_flush_actions = True
537 return True
538 else:
539 return False
540
541
542class PostSortRec(object):
543 __slots__ = ("disabled",)
544
545 def __new__(cls, uow, *args):
546 key = (cls,) + args
547 if key in uow.postsort_actions:
548 return uow.postsort_actions[key]
549 else:
550 uow.postsort_actions[key] = ret = object.__new__(cls)
551 ret.disabled = False
552 return ret
553
554 def execute_aggregate(self, uow, recs):
555 self.execute(uow)
556
557
558class ProcessAll(IterateMappersMixin, PostSortRec):
559 __slots__ = "dependency_processor", "isdelete", "fromparent", "sort_key"
560
561 def __init__(self, uow, dependency_processor, isdelete, fromparent):
562 self.dependency_processor = dependency_processor
563 self.sort_key = (
564 "ProcessAll",
565 self.dependency_processor.sort_key,
566 isdelete,
567 )
568 self.isdelete = isdelete
569 self.fromparent = fromparent
570 uow.deps[dependency_processor.parent.base_mapper].add(
571 dependency_processor
572 )
573
574 def execute(self, uow):
575 states = self._elements(uow)
576 if self.isdelete:
577 self.dependency_processor.process_deletes(uow, states)
578 else:
579 self.dependency_processor.process_saves(uow, states)
580
581 def per_state_flush_actions(self, uow):
582 # this is handled by SaveUpdateAll and DeleteAll,
583 # since a ProcessAll should unconditionally be pulled
584 # into per-state if either the parent/child mappers
585 # are part of a cycle
586 return iter([])
587
588 def __repr__(self):
589 return "%s(%s, isdelete=%s)" % (
590 self.__class__.__name__,
591 self.dependency_processor,
592 self.isdelete,
593 )
594
595 def _elements(self, uow):
596 for mapper in self._mappers(uow):
597 for state in uow.mappers[mapper]:
598 (isdelete, listonly) = uow.states[state]
599 if isdelete == self.isdelete and not listonly:
600 yield state
601
602
603class PostUpdateAll(PostSortRec):
604 __slots__ = "mapper", "isdelete", "sort_key"
605
606 def __init__(self, uow, mapper, isdelete):
607 self.mapper = mapper
608 self.isdelete = isdelete
609 self.sort_key = ("PostUpdateAll", mapper._sort_key, isdelete)
610
611 @util.preload_module("sqlalchemy.orm.persistence")
612 def execute(self, uow):
613 persistence = util.preloaded.orm_persistence
614 states, cols = uow.post_update_states[self.mapper]
615 states = [s for s in states if uow.states[s][0] == self.isdelete]
616
617 persistence.post_update(self.mapper, states, uow, cols)
618
619
620class SaveUpdateAll(PostSortRec):
621 __slots__ = ("mapper", "sort_key")
622
623 def __init__(self, uow, mapper):
624 self.mapper = mapper
625 self.sort_key = ("SaveUpdateAll", mapper._sort_key)
626 assert mapper is mapper.base_mapper
627
628 @util.preload_module("sqlalchemy.orm.persistence")
629 def execute(self, uow):
630 util.preloaded.orm_persistence.save_obj(
631 self.mapper,
632 uow.states_for_mapper_hierarchy(self.mapper, False, False),
633 uow,
634 )
635
636 def per_state_flush_actions(self, uow):
637 states = list(
638 uow.states_for_mapper_hierarchy(self.mapper, False, False)
639 )
640 base_mapper = self.mapper.base_mapper
641 delete_all = DeleteAll(uow, base_mapper)
642 for state in states:
643 # keep saves before deletes -
644 # this ensures 'row switch' operations work
645 action = SaveUpdateState(uow, state)
646 uow.dependencies.add((action, delete_all))
647 yield action
648
649 for dep in uow.deps[self.mapper]:
650 states_for_prop = uow.filter_states_for_dep(dep, states)
651 dep.per_state_flush_actions(uow, states_for_prop, False)
652
653 def __repr__(self):
654 return "%s(%s)" % (self.__class__.__name__, self.mapper)
655
656
657class DeleteAll(PostSortRec):
658 __slots__ = ("mapper", "sort_key")
659
660 def __init__(self, uow, mapper):
661 self.mapper = mapper
662 self.sort_key = ("DeleteAll", mapper._sort_key)
663 assert mapper is mapper.base_mapper
664
665 @util.preload_module("sqlalchemy.orm.persistence")
666 def execute(self, uow):
667 util.preloaded.orm_persistence.delete_obj(
668 self.mapper,
669 uow.states_for_mapper_hierarchy(self.mapper, True, False),
670 uow,
671 )
672
673 def per_state_flush_actions(self, uow):
674 states = list(
675 uow.states_for_mapper_hierarchy(self.mapper, True, False)
676 )
677 base_mapper = self.mapper.base_mapper
678 save_all = SaveUpdateAll(uow, base_mapper)
679 for state in states:
680 # keep saves before deletes -
681 # this ensures 'row switch' operations work
682 action = DeleteState(uow, state)
683 uow.dependencies.add((save_all, action))
684 yield action
685
686 for dep in uow.deps[self.mapper]:
687 states_for_prop = uow.filter_states_for_dep(dep, states)
688 dep.per_state_flush_actions(uow, states_for_prop, True)
689
690 def __repr__(self):
691 return "%s(%s)" % (self.__class__.__name__, self.mapper)
692
693
694class ProcessState(PostSortRec):
695 __slots__ = "dependency_processor", "isdelete", "state", "sort_key"
696
697 def __init__(self, uow, dependency_processor, isdelete, state):
698 self.dependency_processor = dependency_processor
699 self.sort_key = ("ProcessState", dependency_processor.sort_key)
700 self.isdelete = isdelete
701 self.state = state
702
703 def execute_aggregate(self, uow, recs):
704 cls_ = self.__class__
705 dependency_processor = self.dependency_processor
706 isdelete = self.isdelete
707 our_recs = [
708 r
709 for r in recs
710 if r.__class__ is cls_
711 and r.dependency_processor is dependency_processor
712 and r.isdelete is isdelete
713 ]
714 recs.difference_update(our_recs)
715 states = [self.state] + [r.state for r in our_recs]
716 if isdelete:
717 dependency_processor.process_deletes(uow, states)
718 else:
719 dependency_processor.process_saves(uow, states)
720
721 def __repr__(self):
722 return "%s(%s, %s, delete=%s)" % (
723 self.__class__.__name__,
724 self.dependency_processor,
725 orm_util.state_str(self.state),
726 self.isdelete,
727 )
728
729
730class SaveUpdateState(PostSortRec):
731 __slots__ = "state", "mapper", "sort_key"
732
733 def __init__(self, uow, state):
734 self.state = state
735 self.mapper = state.mapper.base_mapper
736 self.sort_key = ("ProcessState", self.mapper._sort_key)
737
738 @util.preload_module("sqlalchemy.orm.persistence")
739 def execute_aggregate(self, uow, recs):
740 persistence = util.preloaded.orm_persistence
741 cls_ = self.__class__
742 mapper = self.mapper
743 our_recs = [
744 r for r in recs if r.__class__ is cls_ and r.mapper is mapper
745 ]
746 recs.difference_update(our_recs)
747 persistence.save_obj(
748 mapper, [self.state] + [r.state for r in our_recs], uow
749 )
750
751 def __repr__(self):
752 return "%s(%s)" % (
753 self.__class__.__name__,
754 orm_util.state_str(self.state),
755 )
756
757
758class DeleteState(PostSortRec):
759 __slots__ = "state", "mapper", "sort_key"
760
761 def __init__(self, uow, state):
762 self.state = state
763 self.mapper = state.mapper.base_mapper
764 self.sort_key = ("DeleteState", self.mapper._sort_key)
765
766 @util.preload_module("sqlalchemy.orm.persistence")
767 def execute_aggregate(self, uow, recs):
768 persistence = util.preloaded.orm_persistence
769 cls_ = self.__class__
770 mapper = self.mapper
771 our_recs = [
772 r for r in recs if r.__class__ is cls_ and r.mapper is mapper
773 ]
774 recs.difference_update(our_recs)
775 states = [self.state] + [r.state for r in our_recs]
776 persistence.delete_obj(
777 mapper, [s for s in states if uow.states[s][0]], uow
778 )
779
780 def __repr__(self):
781 return "%s(%s)" % (
782 self.__class__.__name__,
783 orm_util.state_str(self.state),
784 )