1# orm/bulk_persistence.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: ignore-errors
8
9
10"""additional ORM persistence classes related to "bulk" operations,
11specifically outside of the flush() process.
12
13"""
14
15from __future__ import annotations
16
17from typing import Any
18from typing import cast
19from typing import Dict
20from typing import Iterable
21from typing import Literal
22from typing import Optional
23from typing import overload
24from typing import TYPE_CHECKING
25from typing import TypeVar
26from typing import Union
27
28from . import attributes
29from . import context
30from . import evaluator
31from . import exc as orm_exc
32from . import loading
33from . import persistence
34from .base import NO_VALUE
35from .context import _AbstractORMCompileState
36from .context import _ORMFromStatementCompileState
37from .context import FromStatement
38from .context import QueryContext
39from .interfaces import PropComparator
40from .. import exc as sa_exc
41from .. import util
42from ..engine import Dialect
43from ..engine import result as _result
44from ..sql import coercions
45from ..sql import dml
46from ..sql import expression
47from ..sql import roles
48from ..sql import select
49from ..sql import sqltypes
50from ..sql.base import _entity_namespace_key
51from ..sql.base import CompileState
52from ..sql.base import Options
53from ..sql.dml import DeleteDMLState
54from ..sql.dml import InsertDMLState
55from ..sql.dml import UpdateDMLState
56from ..util import EMPTY_DICT
57from ..util.typing import TupleAny
58from ..util.typing import Unpack
59
60if TYPE_CHECKING:
61 from ._typing import DMLStrategyArgument
62 from ._typing import OrmExecuteOptionsParameter
63 from ._typing import SynchronizeSessionArgument
64 from .mapper import Mapper
65 from .session import _BindArguments
66 from .session import ORMExecuteState
67 from .session import Session
68 from .session import SessionTransaction
69 from .state import InstanceState
70 from ..engine import Connection
71 from ..engine import cursor
72 from ..engine.interfaces import _CoreAnyExecuteParams
73
74_O = TypeVar("_O", bound=object)
75
76
77@overload
78def _bulk_insert(
79 mapper: Mapper[_O],
80 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
81 session_transaction: SessionTransaction,
82 *,
83 isstates: bool,
84 return_defaults: bool,
85 render_nulls: bool,
86 use_orm_insert_stmt: Literal[None] = ...,
87 execution_options: Optional[OrmExecuteOptionsParameter] = ...,
88) -> None: ...
89
90
91@overload
92def _bulk_insert(
93 mapper: Mapper[_O],
94 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
95 session_transaction: SessionTransaction,
96 *,
97 isstates: bool,
98 return_defaults: bool,
99 render_nulls: bool,
100 use_orm_insert_stmt: Optional[dml.Insert] = ...,
101 execution_options: Optional[OrmExecuteOptionsParameter] = ...,
102) -> cursor.CursorResult[Any]: ...
103
104
105def _bulk_insert(
106 mapper: Mapper[_O],
107 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
108 session_transaction: SessionTransaction,
109 *,
110 isstates: bool,
111 return_defaults: bool,
112 render_nulls: bool,
113 use_orm_insert_stmt: Optional[dml.Insert] = None,
114 execution_options: Optional[OrmExecuteOptionsParameter] = None,
115) -> Optional[cursor.CursorResult[Any]]:
116 base_mapper = mapper.base_mapper
117
118 if session_transaction.session.connection_callable:
119 raise NotImplementedError(
120 "connection_callable / per-instance sharding "
121 "not supported in bulk_insert()"
122 )
123
124 if isstates:
125 if TYPE_CHECKING:
126 mappings = cast(Iterable[InstanceState[_O]], mappings)
127
128 if return_defaults:
129 # list of states allows us to attach .key for return_defaults case
130 states = [(state, state.dict) for state in mappings]
131 mappings = [dict_ for (state, dict_) in states]
132 else:
133 mappings = [state.dict for state in mappings]
134 else:
135 if TYPE_CHECKING:
136 mappings = cast(Iterable[Dict[str, Any]], mappings)
137
138 if return_defaults:
139 # use dictionaries given, so that newly populated defaults
140 # can be delivered back to the caller (see #11661). This is **not**
141 # compatible with other use cases such as a session-executed
142 # insert() construct, as this will confuse the case of
143 # insert-per-subclass for joined inheritance cases (see
144 # test_bulk_statements.py::BulkDMLReturningJoinedInhTest).
145 #
146 # So in this conditional, we have **only** called
147 # session.bulk_insert_mappings() which does not have this
148 # requirement
149 mappings = list(mappings)
150 else:
151 # for all other cases we need to establish a local dictionary
152 # so that the incoming dictionaries aren't mutated
153 mappings = [dict(m) for m in mappings]
154 _expand_other_attrs(mapper, mappings)
155
156 connection = session_transaction.connection(base_mapper)
157
158 return_result: Optional[cursor.CursorResult[Any]] = None
159
160 mappers_to_run = [
161 (table, mp)
162 for table, mp in base_mapper._sorted_tables.items()
163 if table in mapper._pks_by_table
164 ]
165
166 if return_defaults:
167 # not used by new-style bulk inserts, only used for legacy
168 bookkeeping = True
169 elif len(mappers_to_run) > 1:
170 # if we have more than one table, mapper to run where we will be
171 # either horizontally splicing, or copying values between tables,
172 # we need the "bookkeeping" / deterministic returning order
173 bookkeeping = True
174 else:
175 bookkeeping = False
176
177 for table, super_mapper in mappers_to_run:
178 # find bindparams in the statement. For bulk, we don't really know if
179 # a key in the params applies to a different table since we are
180 # potentially inserting for multiple tables here; looking at the
181 # bindparam() is a lot more direct. in most cases this will
182 # use _generate_cache_key() which is memoized, although in practice
183 # the ultimate statement that's executed is probably not the same
184 # object so that memoization might not matter much.
185 extra_bp_names = (
186 [
187 b.key
188 for b in use_orm_insert_stmt._get_embedded_bindparams()
189 if b.key in mappings[0]
190 ]
191 if use_orm_insert_stmt is not None
192 else ()
193 )
194
195 records = (
196 (
197 None,
198 state_dict,
199 params,
200 mapper,
201 connection,
202 value_params,
203 has_all_pks,
204 has_all_defaults,
205 )
206 for (
207 state,
208 state_dict,
209 params,
210 mp,
211 conn,
212 value_params,
213 has_all_pks,
214 has_all_defaults,
215 ) in persistence._collect_insert_commands(
216 table,
217 ((None, mapping, mapper, connection) for mapping in mappings),
218 bulk=True,
219 return_defaults=bookkeeping,
220 render_nulls=render_nulls,
221 include_bulk_keys=extra_bp_names,
222 )
223 )
224
225 result = persistence._emit_insert_statements(
226 base_mapper,
227 None,
228 super_mapper,
229 table,
230 records,
231 bookkeeping=bookkeeping,
232 use_orm_insert_stmt=use_orm_insert_stmt,
233 execution_options=execution_options,
234 )
235 if use_orm_insert_stmt is not None:
236 if not use_orm_insert_stmt._returning or return_result is None:
237 return_result = result
238 elif result.returns_rows:
239 assert bookkeeping
240 return_result = return_result.splice_horizontally(result)
241
242 if return_defaults and isstates:
243 identity_cls = mapper._identity_class
244 identity_props = [p.key for p in mapper._identity_key_props]
245 for state, dict_ in states:
246 state.key = (
247 identity_cls,
248 tuple([dict_[key] for key in identity_props]),
249 None,
250 )
251
252 if use_orm_insert_stmt is not None:
253 assert return_result is not None
254 return return_result
255
256
257@overload
258def _bulk_update(
259 mapper: Mapper[Any],
260 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
261 session_transaction: SessionTransaction,
262 *,
263 isstates: bool,
264 update_changed_only: bool,
265 use_orm_update_stmt: Literal[None] = ...,
266 enable_check_rowcount: bool = True,
267) -> None: ...
268
269
270@overload
271def _bulk_update(
272 mapper: Mapper[Any],
273 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
274 session_transaction: SessionTransaction,
275 *,
276 isstates: bool,
277 update_changed_only: bool,
278 use_orm_update_stmt: Optional[dml.Update] = ...,
279 enable_check_rowcount: bool = True,
280) -> _result.Result[Unpack[TupleAny]]: ...
281
282
283def _bulk_update(
284 mapper: Mapper[Any],
285 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
286 session_transaction: SessionTransaction,
287 *,
288 isstates: bool,
289 update_changed_only: bool,
290 use_orm_update_stmt: Optional[dml.Update] = None,
291 enable_check_rowcount: bool = True,
292) -> Optional[_result.Result[Unpack[TupleAny]]]:
293 base_mapper = mapper.base_mapper
294
295 search_keys = mapper._primary_key_propkeys
296 if mapper._version_id_prop:
297 search_keys = {mapper._version_id_prop.key}.union(search_keys)
298
299 def _changed_dict(mapper, state):
300 return {
301 k: v
302 for k, v in state.dict.items()
303 if k in state.committed_state or k in search_keys
304 }
305
306 if isstates:
307 if update_changed_only:
308 mappings = [_changed_dict(mapper, state) for state in mappings]
309 else:
310 mappings = [state.dict for state in mappings]
311 else:
312 mappings = [dict(m) for m in mappings]
313 _expand_other_attrs(mapper, mappings)
314
315 if session_transaction.session.connection_callable:
316 raise NotImplementedError(
317 "connection_callable / per-instance sharding "
318 "not supported in bulk_update()"
319 )
320
321 connection = session_transaction.connection(base_mapper)
322
323 # find bindparams in the statement. see _bulk_insert for similar
324 # notes for the insert case
325 extra_bp_names = (
326 [
327 b.key
328 for b in use_orm_update_stmt._get_embedded_bindparams()
329 if b.key in mappings[0]
330 ]
331 if use_orm_update_stmt is not None
332 else ()
333 )
334
335 for table, super_mapper in base_mapper._sorted_tables.items():
336 if not mapper.isa(super_mapper) or table not in mapper._pks_by_table:
337 continue
338
339 records = persistence._collect_update_commands(
340 None,
341 table,
342 (
343 (
344 None,
345 mapping,
346 mapper,
347 connection,
348 (
349 mapping[mapper._version_id_prop.key]
350 if mapper._version_id_prop
351 else None
352 ),
353 )
354 for mapping in mappings
355 ),
356 bulk=True,
357 use_orm_update_stmt=use_orm_update_stmt,
358 include_bulk_keys=extra_bp_names,
359 )
360 persistence._emit_update_statements(
361 base_mapper,
362 None,
363 super_mapper,
364 table,
365 records,
366 bookkeeping=False,
367 use_orm_update_stmt=use_orm_update_stmt,
368 enable_check_rowcount=enable_check_rowcount,
369 )
370
371 if use_orm_update_stmt is not None:
372 return _result.null_result()
373
374
375def _expand_other_attrs(
376 mapper: Mapper[Any], mappings: Iterable[Dict[str, Any]]
377) -> None:
378 all_attrs = mapper.all_orm_descriptors
379
380 attr_keys = set(all_attrs.keys())
381
382 bulk_dml_setters = {
383 key: setter
384 for key, setter in (
385 (key, attr._bulk_dml_setter(key))
386 for key, attr in (
387 (key, _entity_namespace_key(mapper, key, default=NO_VALUE))
388 for key in attr_keys
389 )
390 if attr is not NO_VALUE and isinstance(attr, PropComparator)
391 )
392 if setter is not None
393 }
394 setters_todo = set(bulk_dml_setters)
395 if not setters_todo:
396 return
397
398 for mapping in mappings:
399 for key in setters_todo.intersection(mapping):
400 bulk_dml_setters[key](mapping)
401
402
403class _ORMDMLState(_AbstractORMCompileState):
404 is_dml_returning = True
405 from_statement_ctx: Optional[_ORMFromStatementCompileState] = None
406
407 @classmethod
408 def _get_orm_crud_kv_pairs(
409 cls, mapper, statement, kv_iterator, needs_to_be_cacheable
410 ):
411 core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs
412
413 for k, v in kv_iterator:
414 k = coercions.expect(roles.DMLColumnRole, k)
415
416 if isinstance(k, str):
417 desc = _entity_namespace_key(mapper, k, default=NO_VALUE)
418 if not isinstance(desc, PropComparator):
419 yield (
420 coercions.expect(roles.DMLColumnRole, k),
421 (
422 coercions.expect(
423 roles.ExpressionElementRole,
424 v,
425 type_=sqltypes.NullType(),
426 is_crud=True,
427 )
428 if needs_to_be_cacheable
429 else v
430 ),
431 )
432 else:
433 yield from core_get_crud_kv_pairs(
434 statement,
435 desc._bulk_update_tuples(v),
436 needs_to_be_cacheable,
437 )
438 elif "entity_namespace" in k._annotations:
439 k_anno = k._annotations
440 attr = _entity_namespace_key(
441 k_anno["entity_namespace"], k_anno["proxy_key"]
442 )
443 assert isinstance(attr, PropComparator)
444 yield from core_get_crud_kv_pairs(
445 statement,
446 attr._bulk_update_tuples(v),
447 needs_to_be_cacheable,
448 )
449 else:
450 yield (
451 k,
452 (
453 v
454 if not needs_to_be_cacheable
455 else coercions.expect(
456 roles.ExpressionElementRole,
457 v,
458 type_=sqltypes.NullType(),
459 is_crud=True,
460 )
461 ),
462 )
463
464 @classmethod
465 def _get_dml_plugin_subject(cls, statement):
466 plugin_subject = statement.table._propagate_attrs.get("plugin_subject")
467
468 if (
469 not plugin_subject
470 or not plugin_subject.mapper
471 or plugin_subject
472 is not statement._propagate_attrs["plugin_subject"]
473 ):
474 return None
475 return plugin_subject
476
477 @classmethod
478 def _get_multi_crud_kv_pairs(cls, statement, kv_iterator):
479 plugin_subject = cls._get_dml_plugin_subject(statement)
480
481 if not plugin_subject:
482 return UpdateDMLState._get_multi_crud_kv_pairs(
483 statement, kv_iterator
484 )
485
486 return [
487 dict(
488 cls._get_orm_crud_kv_pairs(
489 plugin_subject.mapper, statement, value_dict.items(), False
490 )
491 )
492 for value_dict in kv_iterator
493 ]
494
495 @classmethod
496 def _get_crud_kv_pairs(cls, statement, kv_iterator, needs_to_be_cacheable):
497 assert (
498 needs_to_be_cacheable
499 ), "no test coverage for needs_to_be_cacheable=False"
500
501 plugin_subject = cls._get_dml_plugin_subject(statement)
502
503 if not plugin_subject:
504 return UpdateDMLState._get_crud_kv_pairs(
505 statement, kv_iterator, needs_to_be_cacheable
506 )
507 return list(
508 cls._get_orm_crud_kv_pairs(
509 plugin_subject.mapper,
510 statement,
511 kv_iterator,
512 needs_to_be_cacheable,
513 )
514 )
515
516 @classmethod
517 def get_entity_description(cls, statement):
518 ext_info = statement.table._annotations["parententity"]
519 mapper = ext_info.mapper
520 if ext_info.is_aliased_class:
521 _label_name = ext_info.name
522 else:
523 _label_name = mapper.class_.__name__
524
525 return {
526 "name": _label_name,
527 "type": mapper.class_,
528 "expr": ext_info.entity,
529 "entity": ext_info.entity,
530 "table": mapper.local_table,
531 }
532
533 @classmethod
534 def get_returning_column_descriptions(cls, statement):
535 def _ent_for_col(c):
536 return c._annotations.get("parententity", None)
537
538 def _attr_for_col(c, ent):
539 if ent is None:
540 return c
541 proxy_key = c._annotations.get("proxy_key", None)
542 if not proxy_key:
543 return c
544 else:
545 return getattr(ent.entity, proxy_key, c)
546
547 return [
548 {
549 "name": c.key,
550 "type": c.type,
551 "expr": _attr_for_col(c, ent),
552 "aliased": ent.is_aliased_class,
553 "entity": ent.entity,
554 }
555 for c, ent in [
556 (c, _ent_for_col(c)) for c in statement._all_selected_columns
557 ]
558 ]
559
560 def _setup_orm_returning(
561 self,
562 compiler,
563 orm_level_statement,
564 dml_level_statement,
565 dml_mapper,
566 *,
567 use_supplemental_cols=True,
568 ):
569 """establish ORM column handlers for an INSERT, UPDATE, or DELETE
570 which uses explicit returning().
571
572 called within compilation level create_for_statement.
573
574 The _return_orm_returning() method then receives the Result
575 after the statement was executed, and applies ORM loading to the
576 state that we first established here.
577
578 """
579
580 if orm_level_statement._returning:
581 fs = FromStatement(
582 orm_level_statement._returning,
583 dml_level_statement,
584 _adapt_on_names=False,
585 )
586 fs = fs.execution_options(**orm_level_statement._execution_options)
587 fs = fs.options(*orm_level_statement._with_options)
588 self.select_statement = fs
589 self.from_statement_ctx = fsc = (
590 _ORMFromStatementCompileState.create_for_statement(
591 fs, compiler
592 )
593 )
594 fsc.setup_dml_returning_compile_state(dml_mapper)
595
596 dml_level_statement = dml_level_statement._generate()
597 dml_level_statement._returning = ()
598
599 cols_to_return = [c for c in fsc.primary_columns if c is not None]
600
601 # since we are splicing result sets together, make sure there
602 # are columns of some kind returned in each result set
603 if not cols_to_return:
604 cols_to_return.extend(dml_mapper.primary_key)
605
606 if use_supplemental_cols:
607 dml_level_statement = dml_level_statement.return_defaults(
608 # this is a little weird looking, but by passing
609 # primary key as the main list of cols, this tells
610 # return_defaults to omit server-default cols (and
611 # actually all cols, due to some weird thing we should
612 # clean up in crud.py).
613 # Since we have cols_to_return, just return what we asked
614 # for (plus primary key, which ORM persistence needs since
615 # we likely set bookkeeping=True here, which is another
616 # whole thing...). We dont want to clutter the
617 # statement up with lots of other cols the user didn't
618 # ask for. see #9685
619 *dml_mapper.primary_key,
620 supplemental_cols=cols_to_return,
621 )
622 else:
623 dml_level_statement = dml_level_statement.returning(
624 *cols_to_return
625 )
626
627 return dml_level_statement
628
629 @classmethod
630 def _return_orm_returning(
631 cls,
632 session,
633 statement,
634 params,
635 execution_options,
636 bind_arguments,
637 result,
638 ):
639 execution_context = result.context
640 compile_state = execution_context.compiled.compile_state
641
642 if (
643 compile_state.from_statement_ctx
644 and not compile_state.from_statement_ctx.compile_options._is_star
645 ):
646 load_options = execution_options.get(
647 "_sa_orm_load_options", QueryContext.default_load_options
648 )
649
650 querycontext = QueryContext(
651 compile_state.from_statement_ctx,
652 compile_state.select_statement,
653 statement,
654 params,
655 session,
656 load_options,
657 execution_options,
658 bind_arguments,
659 )
660 return loading.instances(result, querycontext)
661 else:
662 return result
663
664
665class _BulkUDCompileState(_ORMDMLState):
666 class default_update_options(Options):
667 _dml_strategy: DMLStrategyArgument = "auto"
668 _synchronize_session: SynchronizeSessionArgument = "auto"
669 _can_use_returning: bool = False
670 _is_delete_using: bool = False
671 _is_update_from: bool = False
672 _autoflush: bool = True
673 _subject_mapper: Optional[Mapper[Any]] = None
674 _resolved_values = EMPTY_DICT
675 _eval_condition = None
676 _matched_rows = None
677 _identity_token = None
678 _populate_existing: bool = False
679
680 @classmethod
681 def can_use_returning(
682 cls,
683 dialect: Dialect,
684 mapper: Mapper[Any],
685 *,
686 is_multitable: bool = False,
687 is_update_from: bool = False,
688 is_delete_using: bool = False,
689 is_executemany: bool = False,
690 ) -> bool:
691 raise NotImplementedError()
692
693 @classmethod
694 def orm_pre_session_exec(
695 cls,
696 session,
697 statement,
698 params,
699 execution_options,
700 bind_arguments,
701 is_pre_event,
702 ):
703 (
704 update_options,
705 execution_options,
706 ) = _BulkUDCompileState.default_update_options.from_execution_options(
707 "_sa_orm_update_options",
708 {
709 "synchronize_session",
710 "autoflush",
711 "populate_existing",
712 "identity_token",
713 "is_delete_using",
714 "is_update_from",
715 "dml_strategy",
716 },
717 execution_options,
718 statement._execution_options,
719 )
720 bind_arguments["clause"] = statement
721 try:
722 plugin_subject = statement._propagate_attrs["plugin_subject"]
723 except KeyError:
724 assert False, "statement had 'orm' plugin but no plugin_subject"
725 else:
726 if plugin_subject:
727 bind_arguments["mapper"] = plugin_subject.mapper
728 update_options += {"_subject_mapper": plugin_subject.mapper}
729
730 if "parententity" not in statement.table._annotations:
731 update_options += {"_dml_strategy": "core_only"}
732 elif not isinstance(params, list):
733 if update_options._dml_strategy == "auto":
734 update_options += {"_dml_strategy": "orm"}
735 elif update_options._dml_strategy == "bulk":
736 raise sa_exc.InvalidRequestError(
737 'Can\'t use "bulk" ORM insert strategy without '
738 "passing separate parameters"
739 )
740 else:
741 if update_options._dml_strategy == "auto":
742 update_options += {"_dml_strategy": "bulk"}
743
744 sync = update_options._synchronize_session
745 if sync is not None:
746 if sync not in ("auto", "evaluate", "fetch", False):
747 raise sa_exc.ArgumentError(
748 "Valid strategies for session synchronization "
749 "are 'auto', 'evaluate', 'fetch', False"
750 )
751 if update_options._dml_strategy == "bulk" and sync == "fetch":
752 raise sa_exc.InvalidRequestError(
753 "The 'fetch' synchronization strategy is not available "
754 "for 'bulk' ORM updates (i.e. multiple parameter sets)"
755 )
756
757 if not is_pre_event:
758 if update_options._autoflush:
759 session._autoflush()
760
761 if update_options._dml_strategy == "orm":
762 if update_options._synchronize_session == "auto":
763 update_options = cls._do_pre_synchronize_auto(
764 session,
765 statement,
766 params,
767 execution_options,
768 bind_arguments,
769 update_options,
770 )
771 elif update_options._synchronize_session == "evaluate":
772 update_options = cls._do_pre_synchronize_evaluate(
773 session,
774 statement,
775 params,
776 execution_options,
777 bind_arguments,
778 update_options,
779 )
780 elif update_options._synchronize_session == "fetch":
781 update_options = cls._do_pre_synchronize_fetch(
782 session,
783 statement,
784 params,
785 execution_options,
786 bind_arguments,
787 update_options,
788 )
789 elif update_options._dml_strategy == "bulk":
790 if update_options._synchronize_session == "auto":
791 update_options += {"_synchronize_session": "evaluate"}
792
793 # indicators from the "pre exec" step that are then
794 # added to the DML statement, which will also be part of the cache
795 # key. The compile level create_for_statement() method will then
796 # consume these at compiler time.
797 statement = statement._annotate(
798 {
799 "synchronize_session": update_options._synchronize_session,
800 "is_delete_using": update_options._is_delete_using,
801 "is_update_from": update_options._is_update_from,
802 "dml_strategy": update_options._dml_strategy,
803 "can_use_returning": update_options._can_use_returning,
804 }
805 )
806
807 return (
808 statement,
809 util.immutabledict(execution_options).union(
810 {"_sa_orm_update_options": update_options}
811 ),
812 params,
813 )
814
815 @classmethod
816 def orm_setup_cursor_result(
817 cls,
818 session,
819 statement,
820 params,
821 execution_options,
822 bind_arguments,
823 result,
824 ):
825 # this stage of the execution is called after the
826 # do_orm_execute event hook. meaning for an extension like
827 # horizontal sharding, this step happens *within* the horizontal
828 # sharding event handler which calls session.execute() re-entrantly
829 # and will occur for each backend individually.
830 # the sharding extension then returns its own merged result from the
831 # individual ones we return here.
832
833 update_options = execution_options["_sa_orm_update_options"]
834 if update_options._dml_strategy == "orm":
835 if update_options._synchronize_session == "evaluate":
836 cls._do_post_synchronize_evaluate(
837 session, statement, result, update_options
838 )
839 elif update_options._synchronize_session == "fetch":
840 cls._do_post_synchronize_fetch(
841 session, statement, result, update_options
842 )
843 elif update_options._dml_strategy == "bulk":
844 if update_options._synchronize_session == "evaluate":
845 cls._do_post_synchronize_bulk_evaluate(
846 session, params, result, update_options
847 )
848 return result
849
850 return cls._return_orm_returning(
851 session,
852 statement,
853 params,
854 execution_options,
855 bind_arguments,
856 result,
857 )
858
859 @classmethod
860 def _adjust_for_extra_criteria(cls, global_attributes, ext_info):
861 """Apply extra criteria filtering.
862
863 For all distinct single-table-inheritance mappers represented in the
864 table being updated or deleted, produce additional WHERE criteria such
865 that only the appropriate subtypes are selected from the total results.
866
867 Additionally, add WHERE criteria originating from LoaderCriteriaOptions
868 collected from the statement.
869
870 """
871
872 return_crit = ()
873
874 adapter = ext_info._adapter if ext_info.is_aliased_class else None
875
876 if (
877 "additional_entity_criteria",
878 ext_info.mapper,
879 ) in global_attributes:
880 return_crit += tuple(
881 ae._resolve_where_criteria(ext_info)
882 for ae in global_attributes[
883 ("additional_entity_criteria", ext_info.mapper)
884 ]
885 if ae.include_aliases or ae.entity is ext_info
886 )
887
888 if ext_info.mapper._single_table_criterion is not None:
889 return_crit += (ext_info.mapper._single_table_criterion,)
890
891 if adapter:
892 return_crit = tuple(adapter.traverse(crit) for crit in return_crit)
893
894 return return_crit
895
896 @classmethod
897 def _interpret_returning_rows(cls, result, mapper, rows):
898 """return rows that indicate PK cols in mapper.primary_key position
899 for RETURNING rows.
900
901 Prior to 2.0.36, this method seemed to be written for some kind of
902 inheritance scenario but the scenario was unused for actual joined
903 inheritance, and the function instead seemed to perform some kind of
904 partial translation that would remove non-PK cols if the PK cols
905 happened to be first in the row, but not otherwise. The joined
906 inheritance walk feature here seems to have never been used as it was
907 always skipped by the "local_table" check.
908
909 As of 2.0.36 the function strips away non-PK cols and provides the
910 PK cols for the table in mapper PK order.
911
912 """
913
914 try:
915 if mapper.local_table is not mapper.base_mapper.local_table:
916 # TODO: dive more into how a local table PK is used for fetch
917 # sync, not clear if this is correct as it depends on the
918 # downstream routine to fetch rows using
919 # local_table.primary_key order
920 pk_keys = result._tuple_getter(mapper.local_table.primary_key)
921 else:
922 pk_keys = result._tuple_getter(mapper.primary_key)
923 except KeyError:
924 # can't use these rows, they don't have PK cols in them
925 # this is an unusual case where the user would have used
926 # .return_defaults()
927 return []
928
929 return [pk_keys(row) for row in rows]
930
931 @classmethod
932 def _get_matched_objects_on_criteria(cls, update_options, states):
933 mapper = update_options._subject_mapper
934 eval_condition = update_options._eval_condition
935
936 raw_data = [
937 (state.obj(), state, state.dict)
938 for state in states
939 if state.mapper.isa(mapper) and not state.expired
940 ]
941
942 identity_token = update_options._identity_token
943 if identity_token is not None:
944 raw_data = [
945 (obj, state, dict_)
946 for obj, state, dict_ in raw_data
947 if state.identity_token == identity_token
948 ]
949
950 result = []
951 for obj, state, dict_ in raw_data:
952 evaled_condition = eval_condition(obj)
953
954 # caution: don't use "in ()" or == here, _EXPIRE_OBJECT
955 # evaluates as True for all comparisons
956 if (
957 evaled_condition is True
958 or evaled_condition is evaluator._EXPIRED_OBJECT
959 ):
960 result.append(
961 (
962 obj,
963 state,
964 dict_,
965 evaled_condition is evaluator._EXPIRED_OBJECT,
966 )
967 )
968 return result
969
970 @classmethod
971 def _eval_condition_from_statement(cls, update_options, statement):
972 mapper = update_options._subject_mapper
973 target_cls = mapper.class_
974
975 evaluator_compiler = evaluator._EvaluatorCompiler(target_cls)
976 crit = ()
977 if statement._where_criteria:
978 crit += statement._where_criteria
979
980 global_attributes = {}
981 for opt in statement._with_options:
982 if opt._is_criteria_option:
983 opt.get_global_criteria(global_attributes)
984
985 if global_attributes:
986 crit += cls._adjust_for_extra_criteria(global_attributes, mapper)
987
988 if crit:
989 eval_condition = evaluator_compiler.process(*crit)
990 else:
991 # workaround for mypy https://github.com/python/mypy/issues/14027
992 def _eval_condition(obj):
993 return True
994
995 eval_condition = _eval_condition
996
997 return eval_condition
998
999 @classmethod
1000 def _do_pre_synchronize_auto(
1001 cls,
1002 session,
1003 statement,
1004 params,
1005 execution_options,
1006 bind_arguments,
1007 update_options,
1008 ):
1009 """setup auto sync strategy
1010
1011
1012 "auto" checks if we can use "evaluate" first, then falls back
1013 to "fetch"
1014
1015 evaluate is vastly more efficient for the common case
1016 where session is empty, only has a few objects, and the UPDATE
1017 statement can potentially match thousands/millions of rows.
1018
1019 OTOH more complex criteria that fails to work with "evaluate"
1020 we would hope usually correlates with fewer net rows.
1021
1022 """
1023
1024 try:
1025 eval_condition = cls._eval_condition_from_statement(
1026 update_options, statement
1027 )
1028
1029 except evaluator.UnevaluatableError:
1030 pass
1031 else:
1032 return update_options + {
1033 "_eval_condition": eval_condition,
1034 "_synchronize_session": "evaluate",
1035 }
1036
1037 update_options += {"_synchronize_session": "fetch"}
1038 return cls._do_pre_synchronize_fetch(
1039 session,
1040 statement,
1041 params,
1042 execution_options,
1043 bind_arguments,
1044 update_options,
1045 )
1046
1047 @classmethod
1048 def _do_pre_synchronize_evaluate(
1049 cls,
1050 session,
1051 statement,
1052 params,
1053 execution_options,
1054 bind_arguments,
1055 update_options,
1056 ):
1057 try:
1058 eval_condition = cls._eval_condition_from_statement(
1059 update_options, statement
1060 )
1061
1062 except evaluator.UnevaluatableError as err:
1063 raise sa_exc.InvalidRequestError(
1064 'Could not evaluate current criteria in Python: "%s". '
1065 "Specify 'fetch' or False for the "
1066 "synchronize_session execution option." % err
1067 ) from err
1068
1069 return update_options + {
1070 "_eval_condition": eval_condition,
1071 }
1072
1073 @classmethod
1074 def _get_resolved_values(cls, mapper, statement):
1075 if statement._multi_values:
1076 return []
1077 elif statement._values:
1078 return list(statement._values.items())
1079 else:
1080 return []
1081
1082 @classmethod
1083 def _resolved_keys_as_propnames(cls, mapper, resolved_values):
1084 values = []
1085 for k, v in resolved_values:
1086 if mapper and isinstance(k, expression.ColumnElement):
1087 try:
1088 attr = mapper._columntoproperty[k]
1089 except orm_exc.UnmappedColumnError:
1090 pass
1091 else:
1092 values.append((attr.key, v))
1093 else:
1094 raise sa_exc.InvalidRequestError(
1095 "Attribute name not found, can't be "
1096 "synchronized back to objects: %r" % k
1097 )
1098 return values
1099
1100 @classmethod
1101 def _do_pre_synchronize_fetch(
1102 cls,
1103 session,
1104 statement,
1105 params,
1106 execution_options,
1107 bind_arguments,
1108 update_options,
1109 ):
1110 mapper = update_options._subject_mapper
1111
1112 select_stmt = (
1113 select(*(mapper.primary_key + (mapper.select_identity_token,)))
1114 .select_from(mapper)
1115 .options(*statement._with_options)
1116 )
1117 select_stmt._where_criteria = statement._where_criteria
1118
1119 # conditionally run the SELECT statement for pre-fetch, testing the
1120 # "bind" for if we can use RETURNING or not using the do_orm_execute
1121 # event. If RETURNING is available, the do_orm_execute event
1122 # will cancel the SELECT from being actually run.
1123 #
1124 # The way this is organized seems strange, why don't we just
1125 # call can_use_returning() before invoking the statement and get
1126 # answer?, why does this go through the whole execute phase using an
1127 # event? Answer: because we are integrating with extensions such
1128 # as the horizontal sharding extention that "multiplexes" an individual
1129 # statement run through multiple engines, and it uses
1130 # do_orm_execute() to do that.
1131
1132 can_use_returning = None
1133
1134 def skip_for_returning(orm_context: ORMExecuteState) -> Any:
1135 bind = orm_context.session.get_bind(**orm_context.bind_arguments)
1136 nonlocal can_use_returning
1137
1138 per_bind_result = cls.can_use_returning(
1139 bind.dialect,
1140 mapper,
1141 is_update_from=update_options._is_update_from,
1142 is_delete_using=update_options._is_delete_using,
1143 is_executemany=orm_context.is_executemany,
1144 )
1145
1146 if can_use_returning is not None:
1147 if can_use_returning != per_bind_result:
1148 raise sa_exc.InvalidRequestError(
1149 "For synchronize_session='fetch', can't mix multiple "
1150 "backends where some support RETURNING and others "
1151 "don't"
1152 )
1153 elif orm_context.is_executemany and not per_bind_result:
1154 raise sa_exc.InvalidRequestError(
1155 "For synchronize_session='fetch', can't use multiple "
1156 "parameter sets in ORM mode, which this backend does not "
1157 "support with RETURNING"
1158 )
1159 else:
1160 can_use_returning = per_bind_result
1161
1162 if per_bind_result:
1163 return _result.null_result()
1164 else:
1165 return None
1166
1167 result = session.execute(
1168 select_stmt,
1169 params,
1170 execution_options=execution_options,
1171 bind_arguments=bind_arguments,
1172 _add_event=skip_for_returning,
1173 )
1174 matched_rows = result.fetchall()
1175
1176 return update_options + {
1177 "_matched_rows": matched_rows,
1178 "_can_use_returning": can_use_returning,
1179 }
1180
1181
1182@CompileState.plugin_for("orm", "insert")
1183class _BulkORMInsert(_ORMDMLState, InsertDMLState):
1184 class default_insert_options(Options):
1185 _dml_strategy: DMLStrategyArgument = "auto"
1186 _render_nulls: bool = False
1187 _return_defaults: bool = False
1188 _subject_mapper: Optional[Mapper[Any]] = None
1189 _autoflush: bool = True
1190 _populate_existing: bool = False
1191
1192 select_statement: Optional[FromStatement] = None
1193
1194 @classmethod
1195 def orm_pre_session_exec(
1196 cls,
1197 session,
1198 statement,
1199 params,
1200 execution_options,
1201 bind_arguments,
1202 is_pre_event,
1203 ):
1204 (
1205 insert_options,
1206 execution_options,
1207 ) = _BulkORMInsert.default_insert_options.from_execution_options(
1208 "_sa_orm_insert_options",
1209 {"dml_strategy", "autoflush", "populate_existing", "render_nulls"},
1210 execution_options,
1211 statement._execution_options,
1212 )
1213 bind_arguments["clause"] = statement
1214 try:
1215 plugin_subject = statement._propagate_attrs["plugin_subject"]
1216 except KeyError:
1217 assert False, "statement had 'orm' plugin but no plugin_subject"
1218 else:
1219 if plugin_subject:
1220 bind_arguments["mapper"] = plugin_subject.mapper
1221 insert_options += {"_subject_mapper": plugin_subject.mapper}
1222
1223 if not params:
1224 if insert_options._dml_strategy == "auto":
1225 insert_options += {"_dml_strategy": "orm"}
1226 elif insert_options._dml_strategy == "bulk":
1227 raise sa_exc.InvalidRequestError(
1228 'Can\'t use "bulk" ORM insert strategy without '
1229 "passing separate parameters"
1230 )
1231 else:
1232 if insert_options._dml_strategy == "auto":
1233 insert_options += {"_dml_strategy": "bulk"}
1234
1235 if insert_options._dml_strategy != "raw":
1236 # for ORM object loading, like ORMContext, we have to disable
1237 # result set adapt_to_context, because we will be generating a
1238 # new statement with specific columns that's cached inside of
1239 # an ORMFromStatementCompileState, which we will re-use for
1240 # each result.
1241 if not execution_options:
1242 execution_options = context._orm_load_exec_options
1243 else:
1244 execution_options = execution_options.union(
1245 context._orm_load_exec_options
1246 )
1247
1248 if not is_pre_event and insert_options._autoflush:
1249 session._autoflush()
1250
1251 statement = statement._annotate(
1252 {"dml_strategy": insert_options._dml_strategy}
1253 )
1254
1255 return (
1256 statement,
1257 util.immutabledict(execution_options).union(
1258 {"_sa_orm_insert_options": insert_options}
1259 ),
1260 params,
1261 )
1262
1263 @classmethod
1264 def orm_execute_statement(
1265 cls,
1266 session: Session,
1267 statement: dml.Insert,
1268 params: _CoreAnyExecuteParams,
1269 execution_options: OrmExecuteOptionsParameter,
1270 bind_arguments: _BindArguments,
1271 conn: Connection,
1272 ) -> _result.Result:
1273 insert_options = execution_options.get(
1274 "_sa_orm_insert_options", cls.default_insert_options
1275 )
1276
1277 if insert_options._dml_strategy not in (
1278 "raw",
1279 "bulk",
1280 "orm",
1281 "auto",
1282 ):
1283 raise sa_exc.ArgumentError(
1284 "Valid strategies for ORM insert strategy "
1285 "are 'raw', 'orm', 'bulk', 'auto"
1286 )
1287
1288 result: _result.Result[Unpack[TupleAny]]
1289
1290 if insert_options._dml_strategy == "raw":
1291 result = conn.execute(
1292 statement, params or {}, execution_options=execution_options
1293 )
1294 return result
1295
1296 if insert_options._dml_strategy == "bulk":
1297 mapper = insert_options._subject_mapper
1298
1299 if (
1300 statement._post_values_clause is not None
1301 and mapper._multiple_persistence_tables
1302 ):
1303 raise sa_exc.InvalidRequestError(
1304 "bulk INSERT with a 'post values' clause "
1305 "(typically upsert) not supported for multi-table "
1306 f"mapper {mapper}"
1307 )
1308
1309 assert mapper is not None
1310 assert session._transaction is not None
1311 result = _bulk_insert(
1312 mapper,
1313 cast(
1314 "Iterable[Dict[str, Any]]",
1315 [params] if isinstance(params, dict) else params,
1316 ),
1317 session._transaction,
1318 isstates=False,
1319 return_defaults=insert_options._return_defaults,
1320 render_nulls=insert_options._render_nulls,
1321 use_orm_insert_stmt=statement,
1322 execution_options=execution_options,
1323 )
1324 elif insert_options._dml_strategy == "orm":
1325 result = conn.execute(
1326 statement, params or {}, execution_options=execution_options
1327 )
1328 else:
1329 raise AssertionError()
1330
1331 if not bool(statement._returning):
1332 return result
1333
1334 if insert_options._populate_existing:
1335 load_options = execution_options.get(
1336 "_sa_orm_load_options", QueryContext.default_load_options
1337 )
1338 load_options += {"_populate_existing": True}
1339 execution_options = execution_options.union(
1340 {"_sa_orm_load_options": load_options}
1341 )
1342
1343 return cls._return_orm_returning(
1344 session,
1345 statement,
1346 params,
1347 execution_options,
1348 bind_arguments,
1349 result,
1350 )
1351
1352 @classmethod
1353 def create_for_statement(cls, statement, compiler, **kw) -> _BulkORMInsert:
1354 self = cast(
1355 _BulkORMInsert,
1356 super().create_for_statement(statement, compiler, **kw),
1357 )
1358
1359 if compiler is not None:
1360 toplevel = not compiler.stack
1361 else:
1362 toplevel = True
1363 if not toplevel:
1364 return self
1365
1366 mapper = statement._propagate_attrs["plugin_subject"]
1367 dml_strategy = statement._annotations.get("dml_strategy", "raw")
1368 if dml_strategy == "bulk":
1369 self._setup_for_bulk_insert(compiler)
1370 elif dml_strategy == "orm":
1371 self._setup_for_orm_insert(compiler, mapper)
1372
1373 return self
1374
1375 @classmethod
1376 def _resolved_keys_as_col_keys(cls, mapper, resolved_value_dict):
1377 return {
1378 col.key if col is not None else k: v
1379 for col, k, v in (
1380 (mapper.c.get(k), k, v) for k, v in resolved_value_dict.items()
1381 )
1382 }
1383
1384 def _setup_for_orm_insert(self, compiler, mapper):
1385 statement = orm_level_statement = cast(dml.Insert, self.statement)
1386
1387 statement = self._setup_orm_returning(
1388 compiler,
1389 orm_level_statement,
1390 statement,
1391 dml_mapper=mapper,
1392 use_supplemental_cols=False,
1393 )
1394 self.statement = statement
1395
1396 def _setup_for_bulk_insert(self, compiler):
1397 """establish an INSERT statement within the context of
1398 bulk insert.
1399
1400 This method will be within the "conn.execute()" call that is invoked
1401 by persistence._emit_insert_statement().
1402
1403 """
1404 statement = orm_level_statement = cast(dml.Insert, self.statement)
1405 an = statement._annotations
1406
1407 emit_insert_table, emit_insert_mapper = (
1408 an["_emit_insert_table"],
1409 an["_emit_insert_mapper"],
1410 )
1411
1412 statement = statement._clone()
1413
1414 statement.table = emit_insert_table
1415 if self._dict_parameters:
1416 self._dict_parameters = {
1417 col: val
1418 for col, val in self._dict_parameters.items()
1419 if col.table is emit_insert_table
1420 }
1421
1422 statement = self._setup_orm_returning(
1423 compiler,
1424 orm_level_statement,
1425 statement,
1426 dml_mapper=emit_insert_mapper,
1427 use_supplemental_cols=True,
1428 )
1429
1430 if (
1431 self.from_statement_ctx is not None
1432 and self.from_statement_ctx.compile_options._is_star
1433 ):
1434 raise sa_exc.CompileError(
1435 "Can't use RETURNING * with bulk ORM INSERT. "
1436 "Please use a different INSERT form, such as INSERT..VALUES "
1437 "or INSERT with a Core Connection"
1438 )
1439
1440 self.statement = statement
1441
1442
1443@CompileState.plugin_for("orm", "update")
1444class _BulkORMUpdate(_BulkUDCompileState, UpdateDMLState):
1445 @classmethod
1446 def create_for_statement(cls, statement, compiler, **kw):
1447 self = cls.__new__(cls)
1448
1449 dml_strategy = statement._annotations.get(
1450 "dml_strategy", "unspecified"
1451 )
1452
1453 toplevel = not compiler.stack
1454
1455 if toplevel and dml_strategy == "bulk":
1456 self._setup_for_bulk_update(statement, compiler)
1457 elif (
1458 dml_strategy == "core_only"
1459 or dml_strategy == "unspecified"
1460 and "parententity" not in statement.table._annotations
1461 ):
1462 UpdateDMLState.__init__(self, statement, compiler, **kw)
1463 elif not toplevel or dml_strategy in ("orm", "unspecified"):
1464 self._setup_for_orm_update(statement, compiler)
1465
1466 return self
1467
1468 def _setup_for_orm_update(self, statement, compiler, **kw):
1469 orm_level_statement = statement
1470
1471 toplevel = not compiler.stack
1472
1473 ext_info = statement.table._annotations["parententity"]
1474
1475 self.mapper = mapper = ext_info.mapper
1476
1477 self._resolved_values = self._get_resolved_values(mapper, statement)
1478
1479 self._init_global_attributes(
1480 statement,
1481 compiler,
1482 toplevel=toplevel,
1483 process_criteria_for_toplevel=toplevel,
1484 )
1485
1486 if statement._values:
1487 self._resolved_values = dict(self._resolved_values)
1488
1489 new_stmt = statement._clone()
1490
1491 if new_stmt.table._annotations["parententity"] is mapper:
1492 new_stmt.table = mapper.local_table
1493
1494 # note if the statement has _multi_values, these
1495 # are passed through to the new statement, which will then raise
1496 # InvalidRequestError because UPDATE doesn't support multi_values
1497 # right now.
1498 if statement._values:
1499 new_stmt._values = self._resolved_values
1500
1501 new_crit = self._adjust_for_extra_criteria(
1502 self.global_attributes, mapper
1503 )
1504 if new_crit:
1505 new_stmt = new_stmt.where(*new_crit)
1506
1507 # if we are against a lambda statement we might not be the
1508 # topmost object that received per-execute annotations
1509
1510 # do this first as we need to determine if there is
1511 # UPDATE..FROM
1512
1513 UpdateDMLState.__init__(self, new_stmt, compiler, **kw)
1514
1515 use_supplemental_cols = False
1516
1517 if not toplevel:
1518 synchronize_session = None
1519 else:
1520 synchronize_session = compiler._annotations.get(
1521 "synchronize_session", None
1522 )
1523 can_use_returning = compiler._annotations.get(
1524 "can_use_returning", None
1525 )
1526 if can_use_returning is not False:
1527 # even though pre_exec has determined basic
1528 # can_use_returning for the dialect, if we are to use
1529 # RETURNING we need to run can_use_returning() at this level
1530 # unconditionally because is_delete_using was not known
1531 # at the pre_exec level
1532 can_use_returning = (
1533 synchronize_session == "fetch"
1534 and self.can_use_returning(
1535 compiler.dialect, mapper, is_multitable=self.is_multitable
1536 )
1537 )
1538
1539 if synchronize_session == "fetch" and can_use_returning:
1540 use_supplemental_cols = True
1541
1542 # NOTE: we might want to RETURNING the actual columns to be
1543 # synchronized also. however this is complicated and difficult
1544 # to align against the behavior of "evaluate". Additionally,
1545 # in a large number (if not the majority) of cases, we have the
1546 # "evaluate" answer, usually a fixed value, in memory already and
1547 # there's no need to re-fetch the same value
1548 # over and over again. so perhaps if it could be RETURNING just
1549 # the elements that were based on a SQL expression and not
1550 # a constant. For now it doesn't quite seem worth it
1551 new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key)
1552
1553 if toplevel:
1554 new_stmt = self._setup_orm_returning(
1555 compiler,
1556 orm_level_statement,
1557 new_stmt,
1558 dml_mapper=mapper,
1559 use_supplemental_cols=use_supplemental_cols,
1560 )
1561
1562 self.statement = new_stmt
1563
1564 def _setup_for_bulk_update(self, statement, compiler, **kw):
1565 """establish an UPDATE statement within the context of
1566 bulk insert.
1567
1568 This method will be within the "conn.execute()" call that is invoked
1569 by persistence._emit_update_statement().
1570
1571 """
1572 statement = cast(dml.Update, statement)
1573 an = statement._annotations
1574
1575 emit_update_table, _ = (
1576 an["_emit_update_table"],
1577 an["_emit_update_mapper"],
1578 )
1579
1580 statement = statement._clone()
1581 statement.table = emit_update_table
1582
1583 UpdateDMLState.__init__(self, statement, compiler, **kw)
1584
1585 if self._maintain_values_ordering:
1586 raise sa_exc.InvalidRequestError(
1587 "bulk ORM UPDATE does not support ordered_values() for "
1588 "custom UPDATE statements with bulk parameter sets. Use a "
1589 "non-bulk UPDATE statement or use values()."
1590 )
1591
1592 if self._dict_parameters:
1593 self._dict_parameters = {
1594 col: val
1595 for col, val in self._dict_parameters.items()
1596 if col.table is emit_update_table
1597 }
1598 self.statement = statement
1599
1600 @classmethod
1601 def orm_execute_statement(
1602 cls,
1603 session: Session,
1604 statement: dml.Update,
1605 params: _CoreAnyExecuteParams,
1606 execution_options: OrmExecuteOptionsParameter,
1607 bind_arguments: _BindArguments,
1608 conn: Connection,
1609 ) -> _result.Result:
1610
1611 update_options = execution_options.get(
1612 "_sa_orm_update_options", cls.default_update_options
1613 )
1614
1615 if update_options._populate_existing:
1616 load_options = execution_options.get(
1617 "_sa_orm_load_options", QueryContext.default_load_options
1618 )
1619 load_options += {"_populate_existing": True}
1620 execution_options = execution_options.union(
1621 {"_sa_orm_load_options": load_options}
1622 )
1623
1624 if update_options._dml_strategy not in (
1625 "orm",
1626 "auto",
1627 "bulk",
1628 "core_only",
1629 ):
1630 raise sa_exc.ArgumentError(
1631 "Valid strategies for ORM UPDATE strategy "
1632 "are 'orm', 'auto', 'bulk', 'core_only'"
1633 )
1634
1635 result: _result.Result[Unpack[TupleAny]]
1636
1637 if update_options._dml_strategy == "bulk":
1638 enable_check_rowcount = not statement._where_criteria
1639
1640 assert update_options._synchronize_session != "fetch"
1641
1642 if (
1643 statement._where_criteria
1644 and update_options._synchronize_session == "evaluate"
1645 ):
1646 raise sa_exc.InvalidRequestError(
1647 "bulk synchronize of persistent objects not supported "
1648 "when using bulk update with additional WHERE "
1649 "criteria right now. add synchronize_session=None "
1650 "execution option to bypass synchronize of persistent "
1651 "objects."
1652 )
1653 mapper = update_options._subject_mapper
1654 assert mapper is not None
1655 assert session._transaction is not None
1656 result = _bulk_update(
1657 mapper,
1658 cast(
1659 "Iterable[Dict[str, Any]]",
1660 [params] if isinstance(params, dict) else params,
1661 ),
1662 session._transaction,
1663 isstates=False,
1664 update_changed_only=False,
1665 use_orm_update_stmt=statement,
1666 enable_check_rowcount=enable_check_rowcount,
1667 )
1668 return cls.orm_setup_cursor_result(
1669 session,
1670 statement,
1671 params,
1672 execution_options,
1673 bind_arguments,
1674 result,
1675 )
1676 else:
1677 return super().orm_execute_statement(
1678 session,
1679 statement,
1680 params,
1681 execution_options,
1682 bind_arguments,
1683 conn,
1684 )
1685
1686 @classmethod
1687 def can_use_returning(
1688 cls,
1689 dialect: Dialect,
1690 mapper: Mapper[Any],
1691 *,
1692 is_multitable: bool = False,
1693 is_update_from: bool = False,
1694 is_delete_using: bool = False,
1695 is_executemany: bool = False,
1696 ) -> bool:
1697 # normal answer for "should we use RETURNING" at all.
1698 normal_answer = (
1699 dialect.update_returning and mapper.local_table.implicit_returning
1700 )
1701 if not normal_answer:
1702 return False
1703
1704 if is_executemany:
1705 return dialect.update_executemany_returning
1706
1707 # these workarounds are currently hypothetical for UPDATE,
1708 # unlike DELETE where they impact MariaDB
1709 if is_update_from:
1710 return dialect.update_returning_multifrom
1711
1712 elif is_multitable and not dialect.update_returning_multifrom:
1713 raise sa_exc.CompileError(
1714 f'Dialect "{dialect.name}" does not support RETURNING '
1715 "with UPDATE..FROM; for synchronize_session='fetch', "
1716 "please add the additional execution option "
1717 "'is_update_from=True' to the statement to indicate that "
1718 "a separate SELECT should be used for this backend."
1719 )
1720
1721 return True
1722
1723 @classmethod
1724 def _do_post_synchronize_bulk_evaluate(
1725 cls, session, params, result, update_options
1726 ):
1727 if not params:
1728 return
1729
1730 mapper = update_options._subject_mapper
1731 pk_keys = [prop.key for prop in mapper._identity_key_props]
1732
1733 identity_map = session.identity_map
1734
1735 for param in params:
1736 identity_key = mapper.identity_key_from_primary_key(
1737 (param[key] for key in pk_keys),
1738 update_options._identity_token,
1739 )
1740 state = identity_map.fast_get_state(identity_key)
1741 if not state:
1742 continue
1743
1744 evaluated_keys = set(param).difference(pk_keys)
1745
1746 dict_ = state.dict
1747 # only evaluate unmodified attributes
1748 to_evaluate = state.unmodified.intersection(evaluated_keys)
1749 for key in to_evaluate:
1750 if key in dict_:
1751 dict_[key] = param[key]
1752
1753 state.manager.dispatch.refresh(state, None, to_evaluate)
1754
1755 state._commit(dict_, list(to_evaluate))
1756
1757 # attributes that were formerly modified instead get expired.
1758 # this only gets hit if the session had pending changes
1759 # and autoflush were set to False.
1760 to_expire = evaluated_keys.intersection(dict_).difference(
1761 to_evaluate
1762 )
1763 if to_expire:
1764 state._expire_attributes(dict_, to_expire)
1765
1766 @classmethod
1767 def _do_post_synchronize_evaluate(
1768 cls, session, statement, result, update_options
1769 ):
1770 matched_objects = cls._get_matched_objects_on_criteria(
1771 update_options,
1772 session.identity_map.all_states(),
1773 )
1774
1775 cls._apply_update_set_values_to_objects(
1776 session,
1777 update_options,
1778 statement,
1779 result.context.compiled_parameters[0],
1780 [(obj, state, dict_) for obj, state, dict_, _ in matched_objects],
1781 result.prefetch_cols(),
1782 result.postfetch_cols(),
1783 )
1784
1785 @classmethod
1786 def _do_post_synchronize_fetch(
1787 cls, session, statement, result, update_options
1788 ):
1789 target_mapper = update_options._subject_mapper
1790
1791 returned_defaults_rows = result.returned_defaults_rows
1792 if returned_defaults_rows:
1793 pk_rows = cls._interpret_returning_rows(
1794 result, target_mapper, returned_defaults_rows
1795 )
1796 matched_rows = [
1797 tuple(row) + (update_options._identity_token,)
1798 for row in pk_rows
1799 ]
1800 else:
1801 matched_rows = update_options._matched_rows
1802
1803 objs = [
1804 session.identity_map[identity_key]
1805 for identity_key in [
1806 target_mapper.identity_key_from_primary_key(
1807 list(primary_key),
1808 identity_token=identity_token,
1809 )
1810 for primary_key, identity_token in [
1811 (row[0:-1], row[-1]) for row in matched_rows
1812 ]
1813 if update_options._identity_token is None
1814 or identity_token == update_options._identity_token
1815 ]
1816 if identity_key in session.identity_map
1817 ]
1818
1819 if not objs:
1820 return
1821
1822 cls._apply_update_set_values_to_objects(
1823 session,
1824 update_options,
1825 statement,
1826 result.context.compiled_parameters[0],
1827 [
1828 (
1829 obj,
1830 attributes.instance_state(obj),
1831 attributes.instance_dict(obj),
1832 )
1833 for obj in objs
1834 ],
1835 result.prefetch_cols(),
1836 result.postfetch_cols(),
1837 )
1838
1839 @classmethod
1840 def _apply_update_set_values_to_objects(
1841 cls,
1842 session,
1843 update_options,
1844 statement,
1845 effective_params,
1846 matched_objects,
1847 prefetch_cols,
1848 postfetch_cols,
1849 ):
1850 """apply values to objects derived from an update statement, e.g.
1851 UPDATE..SET <values>
1852
1853 """
1854
1855 mapper = update_options._subject_mapper
1856 target_cls = mapper.class_
1857 evaluator_compiler = evaluator._EvaluatorCompiler(target_cls)
1858 resolved_values = cls._get_resolved_values(mapper, statement)
1859 resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
1860 mapper, resolved_values
1861 )
1862 value_evaluators = {}
1863 for key, value in resolved_keys_as_propnames:
1864 try:
1865 _evaluator = evaluator_compiler.process(
1866 coercions.expect(roles.ExpressionElementRole, value)
1867 )
1868 except evaluator.UnevaluatableError:
1869 pass
1870 else:
1871 value_evaluators[key] = _evaluator
1872
1873 evaluated_keys = list(value_evaluators.keys())
1874 attrib = {k for k, v in resolved_keys_as_propnames}
1875
1876 states = set()
1877
1878 to_prefetch = {
1879 c
1880 for c in prefetch_cols
1881 if c.key in effective_params
1882 and c in mapper._columntoproperty
1883 and c.key not in evaluated_keys
1884 }
1885 to_expire = {
1886 mapper._columntoproperty[c].key
1887 for c in postfetch_cols
1888 if c in mapper._columntoproperty
1889 }.difference(evaluated_keys)
1890
1891 prefetch_transfer = [
1892 (mapper._columntoproperty[c].key, c.key) for c in to_prefetch
1893 ]
1894
1895 for obj, state, dict_ in matched_objects:
1896
1897 dict_.update(
1898 {
1899 col_to_prop: effective_params[c_key]
1900 for col_to_prop, c_key in prefetch_transfer
1901 }
1902 )
1903
1904 state._expire_attributes(state.dict, to_expire)
1905
1906 to_evaluate = state.unmodified.intersection(evaluated_keys)
1907
1908 for key in to_evaluate:
1909 if key in dict_:
1910 # only run eval for attributes that are present.
1911 dict_[key] = value_evaluators[key](obj)
1912
1913 state.manager.dispatch.refresh(state, None, to_evaluate)
1914
1915 state._commit(dict_, list(to_evaluate))
1916
1917 # attributes that were formerly modified instead get expired.
1918 # this only gets hit if the session had pending changes
1919 # and autoflush were set to False.
1920 to_expire = attrib.intersection(dict_).difference(to_evaluate)
1921 if to_expire:
1922 state._expire_attributes(dict_, to_expire)
1923
1924 states.add(state)
1925 session._register_altered(states)
1926
1927
1928@CompileState.plugin_for("orm", "delete")
1929class _BulkORMDelete(_BulkUDCompileState, DeleteDMLState):
1930 @classmethod
1931 def create_for_statement(cls, statement, compiler, **kw):
1932 self = cls.__new__(cls)
1933
1934 dml_strategy = statement._annotations.get(
1935 "dml_strategy", "unspecified"
1936 )
1937
1938 if (
1939 dml_strategy == "core_only"
1940 or dml_strategy == "unspecified"
1941 and "parententity" not in statement.table._annotations
1942 ):
1943 DeleteDMLState.__init__(self, statement, compiler, **kw)
1944 return self
1945
1946 toplevel = not compiler.stack
1947
1948 orm_level_statement = statement
1949
1950 ext_info = statement.table._annotations["parententity"]
1951 self.mapper = mapper = ext_info.mapper
1952
1953 self._init_global_attributes(
1954 statement,
1955 compiler,
1956 toplevel=toplevel,
1957 process_criteria_for_toplevel=toplevel,
1958 )
1959
1960 new_stmt = statement._clone()
1961
1962 if new_stmt.table._annotations["parententity"] is mapper:
1963 new_stmt.table = mapper.local_table
1964
1965 new_crit = cls._adjust_for_extra_criteria(
1966 self.global_attributes, mapper
1967 )
1968 if new_crit:
1969 new_stmt = new_stmt.where(*new_crit)
1970
1971 # do this first as we need to determine if there is
1972 # DELETE..FROM
1973 DeleteDMLState.__init__(self, new_stmt, compiler, **kw)
1974
1975 use_supplemental_cols = False
1976
1977 if not toplevel:
1978 synchronize_session = None
1979 else:
1980 synchronize_session = compiler._annotations.get(
1981 "synchronize_session", None
1982 )
1983 can_use_returning = compiler._annotations.get(
1984 "can_use_returning", None
1985 )
1986 if can_use_returning is not False:
1987 # even though pre_exec has determined basic
1988 # can_use_returning for the dialect, if we are to use
1989 # RETURNING we need to run can_use_returning() at this level
1990 # unconditionally because is_delete_using was not known
1991 # at the pre_exec level
1992 can_use_returning = (
1993 synchronize_session == "fetch"
1994 and self.can_use_returning(
1995 compiler.dialect,
1996 mapper,
1997 is_multitable=self.is_multitable,
1998 is_delete_using=compiler._annotations.get(
1999 "is_delete_using", False
2000 ),
2001 )
2002 )
2003
2004 if can_use_returning:
2005 use_supplemental_cols = True
2006
2007 new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key)
2008
2009 if toplevel:
2010 new_stmt = self._setup_orm_returning(
2011 compiler,
2012 orm_level_statement,
2013 new_stmt,
2014 dml_mapper=mapper,
2015 use_supplemental_cols=use_supplemental_cols,
2016 )
2017
2018 self.statement = new_stmt
2019
2020 return self
2021
2022 @classmethod
2023 def orm_execute_statement(
2024 cls,
2025 session: Session,
2026 statement: dml.Delete,
2027 params: _CoreAnyExecuteParams,
2028 execution_options: OrmExecuteOptionsParameter,
2029 bind_arguments: _BindArguments,
2030 conn: Connection,
2031 ) -> _result.Result:
2032 update_options = execution_options.get(
2033 "_sa_orm_update_options", cls.default_update_options
2034 )
2035
2036 if update_options._dml_strategy == "bulk":
2037 raise sa_exc.InvalidRequestError(
2038 "Bulk ORM DELETE not supported right now. "
2039 "Statement may be invoked at the "
2040 "Core level using "
2041 "session.connection().execute(stmt, parameters)"
2042 )
2043
2044 if update_options._dml_strategy not in ("orm", "auto", "core_only"):
2045 raise sa_exc.ArgumentError(
2046 "Valid strategies for ORM DELETE strategy are 'orm', 'auto', "
2047 "'core_only'"
2048 )
2049
2050 return super().orm_execute_statement(
2051 session, statement, params, execution_options, bind_arguments, conn
2052 )
2053
2054 @classmethod
2055 def can_use_returning(
2056 cls,
2057 dialect: Dialect,
2058 mapper: Mapper[Any],
2059 *,
2060 is_multitable: bool = False,
2061 is_update_from: bool = False,
2062 is_delete_using: bool = False,
2063 is_executemany: bool = False,
2064 ) -> bool:
2065 # normal answer for "should we use RETURNING" at all.
2066 normal_answer = (
2067 dialect.delete_returning and mapper.local_table.implicit_returning
2068 )
2069 if not normal_answer:
2070 return False
2071
2072 # now get into special workarounds because MariaDB supports
2073 # DELETE...RETURNING but not DELETE...USING...RETURNING.
2074 if is_delete_using:
2075 # is_delete_using hint was passed. use
2076 # additional dialect feature (True for PG, False for MariaDB)
2077 return dialect.delete_returning_multifrom
2078
2079 elif is_multitable and not dialect.delete_returning_multifrom:
2080 # is_delete_using hint was not passed, but we determined
2081 # at compile time that this is in fact a DELETE..USING.
2082 # it's too late to continue since we did not pre-SELECT.
2083 # raise that we need that hint up front.
2084
2085 raise sa_exc.CompileError(
2086 f'Dialect "{dialect.name}" does not support RETURNING '
2087 "with DELETE..USING; for synchronize_session='fetch', "
2088 "please add the additional execution option "
2089 "'is_delete_using=True' to the statement to indicate that "
2090 "a separate SELECT should be used for this backend."
2091 )
2092
2093 return True
2094
2095 @classmethod
2096 def _do_post_synchronize_evaluate(
2097 cls, session, statement, result, update_options
2098 ):
2099 matched_objects = cls._get_matched_objects_on_criteria(
2100 update_options,
2101 session.identity_map.all_states(),
2102 )
2103
2104 to_delete = []
2105
2106 for _, state, dict_, is_partially_expired in matched_objects:
2107 if is_partially_expired:
2108 state._expire(dict_, session.identity_map._modified)
2109 else:
2110 to_delete.append(state)
2111
2112 if to_delete:
2113 session._remove_newly_deleted(to_delete)
2114
2115 @classmethod
2116 def _do_post_synchronize_fetch(
2117 cls, session, statement, result, update_options
2118 ):
2119 target_mapper = update_options._subject_mapper
2120
2121 returned_defaults_rows = result.returned_defaults_rows
2122
2123 if returned_defaults_rows:
2124 pk_rows = cls._interpret_returning_rows(
2125 result, target_mapper, returned_defaults_rows
2126 )
2127
2128 matched_rows = [
2129 tuple(row) + (update_options._identity_token,)
2130 for row in pk_rows
2131 ]
2132 else:
2133 matched_rows = update_options._matched_rows
2134
2135 for row in matched_rows:
2136 primary_key = row[0:-1]
2137 identity_token = row[-1]
2138
2139 # TODO: inline this and call remove_newly_deleted
2140 # once
2141 identity_key = target_mapper.identity_key_from_primary_key(
2142 list(primary_key),
2143 identity_token=identity_token,
2144 )
2145 if identity_key in session.identity_map:
2146 session._remove_newly_deleted(
2147 [
2148 attributes.instance_state(
2149 session.identity_map[identity_key]
2150 )
2151 ]
2152 )