1# orm/bulk_persistence.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: ignore-errors
8
9
10"""additional ORM persistence classes related to "bulk" operations,
11specifically outside of the flush() process.
12
13"""
14
15from __future__ import annotations
16
17from typing import Any
18from typing import cast
19from typing import Dict
20from typing import Iterable
21from typing import Optional
22from typing import overload
23from typing import TYPE_CHECKING
24from typing import TypeVar
25from typing import Union
26
27from . import attributes
28from . import context
29from . import evaluator
30from . import exc as orm_exc
31from . import loading
32from . import persistence
33from .base import NO_VALUE
34from .context import _AbstractORMCompileState
35from .context import _ORMFromStatementCompileState
36from .context import FromStatement
37from .context import QueryContext
38from .interfaces import PropComparator
39from .. import exc as sa_exc
40from .. import util
41from ..engine import Dialect
42from ..engine import result as _result
43from ..sql import coercions
44from ..sql import dml
45from ..sql import expression
46from ..sql import roles
47from ..sql import select
48from ..sql import sqltypes
49from ..sql.base import _entity_namespace_key
50from ..sql.base import CompileState
51from ..sql.base import Options
52from ..sql.dml import DeleteDMLState
53from ..sql.dml import InsertDMLState
54from ..sql.dml import UpdateDMLState
55from ..util import EMPTY_DICT
56from ..util.typing import Literal
57from ..util.typing import TupleAny
58from ..util.typing import Unpack
59
60if TYPE_CHECKING:
61 from ._typing import DMLStrategyArgument
62 from ._typing import OrmExecuteOptionsParameter
63 from ._typing import SynchronizeSessionArgument
64 from .mapper import Mapper
65 from .session import _BindArguments
66 from .session import ORMExecuteState
67 from .session import Session
68 from .session import SessionTransaction
69 from .state import InstanceState
70 from ..engine import Connection
71 from ..engine import cursor
72 from ..engine.interfaces import _CoreAnyExecuteParams
73
74_O = TypeVar("_O", bound=object)
75
76
77@overload
78def _bulk_insert(
79 mapper: Mapper[_O],
80 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
81 session_transaction: SessionTransaction,
82 *,
83 isstates: bool,
84 return_defaults: bool,
85 render_nulls: bool,
86 use_orm_insert_stmt: Literal[None] = ...,
87 execution_options: Optional[OrmExecuteOptionsParameter] = ...,
88) -> None: ...
89
90
91@overload
92def _bulk_insert(
93 mapper: Mapper[_O],
94 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
95 session_transaction: SessionTransaction,
96 *,
97 isstates: bool,
98 return_defaults: bool,
99 render_nulls: bool,
100 use_orm_insert_stmt: Optional[dml.Insert] = ...,
101 execution_options: Optional[OrmExecuteOptionsParameter] = ...,
102) -> cursor.CursorResult[Any]: ...
103
104
105def _bulk_insert(
106 mapper: Mapper[_O],
107 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
108 session_transaction: SessionTransaction,
109 *,
110 isstates: bool,
111 return_defaults: bool,
112 render_nulls: bool,
113 use_orm_insert_stmt: Optional[dml.Insert] = None,
114 execution_options: Optional[OrmExecuteOptionsParameter] = None,
115) -> Optional[cursor.CursorResult[Any]]:
116 base_mapper = mapper.base_mapper
117
118 if session_transaction.session.connection_callable:
119 raise NotImplementedError(
120 "connection_callable / per-instance sharding "
121 "not supported in bulk_insert()"
122 )
123
124 if isstates:
125 if TYPE_CHECKING:
126 mappings = cast(Iterable[InstanceState[_O]], mappings)
127
128 if return_defaults:
129 # list of states allows us to attach .key for return_defaults case
130 states = [(state, state.dict) for state in mappings]
131 mappings = [dict_ for (state, dict_) in states]
132 else:
133 mappings = [state.dict for state in mappings]
134 else:
135 if TYPE_CHECKING:
136 mappings = cast(Iterable[Dict[str, Any]], mappings)
137
138 if return_defaults:
139 # use dictionaries given, so that newly populated defaults
140 # can be delivered back to the caller (see #11661). This is **not**
141 # compatible with other use cases such as a session-executed
142 # insert() construct, as this will confuse the case of
143 # insert-per-subclass for joined inheritance cases (see
144 # test_bulk_statements.py::BulkDMLReturningJoinedInhTest).
145 #
146 # So in this conditional, we have **only** called
147 # session.bulk_insert_mappings() which does not have this
148 # requirement
149 mappings = list(mappings)
150 else:
151 # for all other cases we need to establish a local dictionary
152 # so that the incoming dictionaries aren't mutated
153 mappings = [dict(m) for m in mappings]
154 _expand_other_attrs(mapper, mappings)
155
156 connection = session_transaction.connection(base_mapper)
157
158 return_result: Optional[cursor.CursorResult[Any]] = None
159
160 mappers_to_run = [
161 (table, mp)
162 for table, mp in base_mapper._sorted_tables.items()
163 if table in mapper._pks_by_table
164 ]
165
166 if return_defaults:
167 # not used by new-style bulk inserts, only used for legacy
168 bookkeeping = True
169 elif len(mappers_to_run) > 1:
170 # if we have more than one table, mapper to run where we will be
171 # either horizontally splicing, or copying values between tables,
172 # we need the "bookkeeping" / deterministic returning order
173 bookkeeping = True
174 else:
175 bookkeeping = False
176
177 for table, super_mapper in mappers_to_run:
178 # find bindparams in the statement. For bulk, we don't really know if
179 # a key in the params applies to a different table since we are
180 # potentially inserting for multiple tables here; looking at the
181 # bindparam() is a lot more direct. in most cases this will
182 # use _generate_cache_key() which is memoized, although in practice
183 # the ultimate statement that's executed is probably not the same
184 # object so that memoization might not matter much.
185 extra_bp_names = (
186 [
187 b.key
188 for b in use_orm_insert_stmt._get_embedded_bindparams()
189 if b.key in mappings[0]
190 ]
191 if use_orm_insert_stmt is not None
192 else ()
193 )
194
195 records = (
196 (
197 None,
198 state_dict,
199 params,
200 mapper,
201 connection,
202 value_params,
203 has_all_pks,
204 has_all_defaults,
205 )
206 for (
207 state,
208 state_dict,
209 params,
210 mp,
211 conn,
212 value_params,
213 has_all_pks,
214 has_all_defaults,
215 ) in persistence._collect_insert_commands(
216 table,
217 ((None, mapping, mapper, connection) for mapping in mappings),
218 bulk=True,
219 return_defaults=bookkeeping,
220 render_nulls=render_nulls,
221 include_bulk_keys=extra_bp_names,
222 )
223 )
224
225 result = persistence._emit_insert_statements(
226 base_mapper,
227 None,
228 super_mapper,
229 table,
230 records,
231 bookkeeping=bookkeeping,
232 use_orm_insert_stmt=use_orm_insert_stmt,
233 execution_options=execution_options,
234 )
235 if use_orm_insert_stmt is not None:
236 if not use_orm_insert_stmt._returning or return_result is None:
237 return_result = result
238 elif result.returns_rows:
239 assert bookkeeping
240 return_result = return_result.splice_horizontally(result)
241
242 if return_defaults and isstates:
243 identity_cls = mapper._identity_class
244 identity_props = [p.key for p in mapper._identity_key_props]
245 for state, dict_ in states:
246 state.key = (
247 identity_cls,
248 tuple([dict_[key] for key in identity_props]),
249 None,
250 )
251
252 if use_orm_insert_stmt is not None:
253 assert return_result is not None
254 return return_result
255
256
257@overload
258def _bulk_update(
259 mapper: Mapper[Any],
260 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
261 session_transaction: SessionTransaction,
262 *,
263 isstates: bool,
264 update_changed_only: bool,
265 use_orm_update_stmt: Literal[None] = ...,
266 enable_check_rowcount: bool = True,
267) -> None: ...
268
269
270@overload
271def _bulk_update(
272 mapper: Mapper[Any],
273 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
274 session_transaction: SessionTransaction,
275 *,
276 isstates: bool,
277 update_changed_only: bool,
278 use_orm_update_stmt: Optional[dml.Update] = ...,
279 enable_check_rowcount: bool = True,
280) -> _result.Result[Unpack[TupleAny]]: ...
281
282
283def _bulk_update(
284 mapper: Mapper[Any],
285 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
286 session_transaction: SessionTransaction,
287 *,
288 isstates: bool,
289 update_changed_only: bool,
290 use_orm_update_stmt: Optional[dml.Update] = None,
291 enable_check_rowcount: bool = True,
292) -> Optional[_result.Result[Unpack[TupleAny]]]:
293 base_mapper = mapper.base_mapper
294
295 search_keys = mapper._primary_key_propkeys
296 if mapper._version_id_prop:
297 search_keys = {mapper._version_id_prop.key}.union(search_keys)
298
299 def _changed_dict(mapper, state):
300 return {
301 k: v
302 for k, v in state.dict.items()
303 if k in state.committed_state or k in search_keys
304 }
305
306 if isstates:
307 if update_changed_only:
308 mappings = [_changed_dict(mapper, state) for state in mappings]
309 else:
310 mappings = [state.dict for state in mappings]
311 else:
312 mappings = [dict(m) for m in mappings]
313 _expand_other_attrs(mapper, mappings)
314
315 if session_transaction.session.connection_callable:
316 raise NotImplementedError(
317 "connection_callable / per-instance sharding "
318 "not supported in bulk_update()"
319 )
320
321 connection = session_transaction.connection(base_mapper)
322
323 # find bindparams in the statement. see _bulk_insert for similar
324 # notes for the insert case
325 extra_bp_names = (
326 [
327 b.key
328 for b in use_orm_update_stmt._get_embedded_bindparams()
329 if b.key in mappings[0]
330 ]
331 if use_orm_update_stmt is not None
332 else ()
333 )
334
335 for table, super_mapper in base_mapper._sorted_tables.items():
336 if not mapper.isa(super_mapper) or table not in mapper._pks_by_table:
337 continue
338
339 records = persistence._collect_update_commands(
340 None,
341 table,
342 (
343 (
344 None,
345 mapping,
346 mapper,
347 connection,
348 (
349 mapping[mapper._version_id_prop.key]
350 if mapper._version_id_prop
351 else None
352 ),
353 )
354 for mapping in mappings
355 ),
356 bulk=True,
357 use_orm_update_stmt=use_orm_update_stmt,
358 include_bulk_keys=extra_bp_names,
359 )
360 persistence._emit_update_statements(
361 base_mapper,
362 None,
363 super_mapper,
364 table,
365 records,
366 bookkeeping=False,
367 use_orm_update_stmt=use_orm_update_stmt,
368 enable_check_rowcount=enable_check_rowcount,
369 )
370
371 if use_orm_update_stmt is not None:
372 return _result.null_result()
373
374
375def _expand_other_attrs(
376 mapper: Mapper[Any], mappings: Iterable[Dict[str, Any]]
377) -> None:
378 all_attrs = mapper.all_orm_descriptors
379
380 attr_keys = set(all_attrs.keys())
381
382 bulk_dml_setters = {
383 key: setter
384 for key, setter in (
385 (key, attr._bulk_dml_setter(key))
386 for key, attr in (
387 (key, _entity_namespace_key(mapper, key, default=NO_VALUE))
388 for key in attr_keys
389 )
390 if attr is not NO_VALUE and isinstance(attr, PropComparator)
391 )
392 if setter is not None
393 }
394 setters_todo = set(bulk_dml_setters)
395 if not setters_todo:
396 return
397
398 for mapping in mappings:
399 for key in setters_todo.intersection(mapping):
400 bulk_dml_setters[key](mapping)
401
402
403class _ORMDMLState(_AbstractORMCompileState):
404 is_dml_returning = True
405 from_statement_ctx: Optional[_ORMFromStatementCompileState] = None
406
407 @classmethod
408 def _get_orm_crud_kv_pairs(
409 cls, mapper, statement, kv_iterator, needs_to_be_cacheable
410 ):
411 core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs
412
413 for k, v in kv_iterator:
414 k = coercions.expect(roles.DMLColumnRole, k)
415
416 if isinstance(k, str):
417 desc = _entity_namespace_key(mapper, k, default=NO_VALUE)
418 if not isinstance(desc, PropComparator):
419 yield (
420 coercions.expect(roles.DMLColumnRole, k),
421 (
422 coercions.expect(
423 roles.ExpressionElementRole,
424 v,
425 type_=sqltypes.NullType(),
426 is_crud=True,
427 )
428 if needs_to_be_cacheable
429 else v
430 ),
431 )
432 else:
433 yield from core_get_crud_kv_pairs(
434 statement,
435 desc._bulk_update_tuples(v),
436 needs_to_be_cacheable,
437 )
438 elif "entity_namespace" in k._annotations:
439 k_anno = k._annotations
440 attr = _entity_namespace_key(
441 k_anno["entity_namespace"], k_anno["proxy_key"]
442 )
443 assert isinstance(attr, PropComparator)
444 yield from core_get_crud_kv_pairs(
445 statement,
446 attr._bulk_update_tuples(v),
447 needs_to_be_cacheable,
448 )
449 else:
450 yield (
451 k,
452 (
453 v
454 if not needs_to_be_cacheable
455 else coercions.expect(
456 roles.ExpressionElementRole,
457 v,
458 type_=sqltypes.NullType(),
459 is_crud=True,
460 )
461 ),
462 )
463
464 @classmethod
465 def _get_dml_plugin_subject(cls, statement):
466 plugin_subject = statement.table._propagate_attrs.get("plugin_subject")
467
468 if (
469 not plugin_subject
470 or not plugin_subject.mapper
471 or plugin_subject
472 is not statement._propagate_attrs["plugin_subject"]
473 ):
474 return None
475 return plugin_subject
476
477 @classmethod
478 def _get_multi_crud_kv_pairs(cls, statement, kv_iterator):
479 plugin_subject = cls._get_dml_plugin_subject(statement)
480
481 if not plugin_subject:
482 return UpdateDMLState._get_multi_crud_kv_pairs(
483 statement, kv_iterator
484 )
485
486 return [
487 dict(
488 cls._get_orm_crud_kv_pairs(
489 plugin_subject.mapper, statement, value_dict.items(), False
490 )
491 )
492 for value_dict in kv_iterator
493 ]
494
495 @classmethod
496 def _get_crud_kv_pairs(cls, statement, kv_iterator, needs_to_be_cacheable):
497 assert (
498 needs_to_be_cacheable
499 ), "no test coverage for needs_to_be_cacheable=False"
500
501 plugin_subject = cls._get_dml_plugin_subject(statement)
502
503 if not plugin_subject:
504 return UpdateDMLState._get_crud_kv_pairs(
505 statement, kv_iterator, needs_to_be_cacheable
506 )
507 return list(
508 cls._get_orm_crud_kv_pairs(
509 plugin_subject.mapper,
510 statement,
511 kv_iterator,
512 needs_to_be_cacheable,
513 )
514 )
515
516 @classmethod
517 def get_entity_description(cls, statement):
518 ext_info = statement.table._annotations["parententity"]
519 mapper = ext_info.mapper
520 if ext_info.is_aliased_class:
521 _label_name = ext_info.name
522 else:
523 _label_name = mapper.class_.__name__
524
525 return {
526 "name": _label_name,
527 "type": mapper.class_,
528 "expr": ext_info.entity,
529 "entity": ext_info.entity,
530 "table": mapper.local_table,
531 }
532
533 @classmethod
534 def get_returning_column_descriptions(cls, statement):
535 def _ent_for_col(c):
536 return c._annotations.get("parententity", None)
537
538 def _attr_for_col(c, ent):
539 if ent is None:
540 return c
541 proxy_key = c._annotations.get("proxy_key", None)
542 if not proxy_key:
543 return c
544 else:
545 return getattr(ent.entity, proxy_key, c)
546
547 return [
548 {
549 "name": c.key,
550 "type": c.type,
551 "expr": _attr_for_col(c, ent),
552 "aliased": ent.is_aliased_class,
553 "entity": ent.entity,
554 }
555 for c, ent in [
556 (c, _ent_for_col(c)) for c in statement._all_selected_columns
557 ]
558 ]
559
560 def _setup_orm_returning(
561 self,
562 compiler,
563 orm_level_statement,
564 dml_level_statement,
565 dml_mapper,
566 *,
567 use_supplemental_cols=True,
568 ):
569 """establish ORM column handlers for an INSERT, UPDATE, or DELETE
570 which uses explicit returning().
571
572 called within compilation level create_for_statement.
573
574 The _return_orm_returning() method then receives the Result
575 after the statement was executed, and applies ORM loading to the
576 state that we first established here.
577
578 """
579
580 if orm_level_statement._returning:
581 fs = FromStatement(
582 orm_level_statement._returning,
583 dml_level_statement,
584 _adapt_on_names=False,
585 )
586 fs = fs.execution_options(**orm_level_statement._execution_options)
587 fs = fs.options(*orm_level_statement._with_options)
588 self.select_statement = fs
589 self.from_statement_ctx = fsc = (
590 _ORMFromStatementCompileState.create_for_statement(
591 fs, compiler
592 )
593 )
594 fsc.setup_dml_returning_compile_state(dml_mapper)
595
596 dml_level_statement = dml_level_statement._generate()
597 dml_level_statement._returning = ()
598
599 cols_to_return = [c for c in fsc.primary_columns if c is not None]
600
601 # since we are splicing result sets together, make sure there
602 # are columns of some kind returned in each result set
603 if not cols_to_return:
604 cols_to_return.extend(dml_mapper.primary_key)
605
606 if use_supplemental_cols:
607 dml_level_statement = dml_level_statement.return_defaults(
608 # this is a little weird looking, but by passing
609 # primary key as the main list of cols, this tells
610 # return_defaults to omit server-default cols (and
611 # actually all cols, due to some weird thing we should
612 # clean up in crud.py).
613 # Since we have cols_to_return, just return what we asked
614 # for (plus primary key, which ORM persistence needs since
615 # we likely set bookkeeping=True here, which is another
616 # whole thing...). We dont want to clutter the
617 # statement up with lots of other cols the user didn't
618 # ask for. see #9685
619 *dml_mapper.primary_key,
620 supplemental_cols=cols_to_return,
621 )
622 else:
623 dml_level_statement = dml_level_statement.returning(
624 *cols_to_return
625 )
626
627 return dml_level_statement
628
629 @classmethod
630 def _return_orm_returning(
631 cls,
632 session,
633 statement,
634 params,
635 execution_options,
636 bind_arguments,
637 result,
638 ):
639 execution_context = result.context
640 compile_state = execution_context.compiled.compile_state
641
642 if (
643 compile_state.from_statement_ctx
644 and not compile_state.from_statement_ctx.compile_options._is_star
645 ):
646 load_options = execution_options.get(
647 "_sa_orm_load_options", QueryContext.default_load_options
648 )
649
650 querycontext = QueryContext(
651 compile_state.from_statement_ctx,
652 compile_state.select_statement,
653 statement,
654 params,
655 session,
656 load_options,
657 execution_options,
658 bind_arguments,
659 )
660 return loading.instances(result, querycontext)
661 else:
662 return result
663
664
665class _BulkUDCompileState(_ORMDMLState):
666 class default_update_options(Options):
667 _dml_strategy: DMLStrategyArgument = "auto"
668 _synchronize_session: SynchronizeSessionArgument = "auto"
669 _can_use_returning: bool = False
670 _is_delete_using: bool = False
671 _is_update_from: bool = False
672 _autoflush: bool = True
673 _subject_mapper: Optional[Mapper[Any]] = None
674 _resolved_values = EMPTY_DICT
675 _eval_condition = None
676 _matched_rows = None
677 _identity_token = None
678 _populate_existing: bool = False
679
680 @classmethod
681 def can_use_returning(
682 cls,
683 dialect: Dialect,
684 mapper: Mapper[Any],
685 *,
686 is_multitable: bool = False,
687 is_update_from: bool = False,
688 is_delete_using: bool = False,
689 is_executemany: bool = False,
690 ) -> bool:
691 raise NotImplementedError()
692
693 @classmethod
694 def orm_pre_session_exec(
695 cls,
696 session,
697 statement,
698 params,
699 execution_options,
700 bind_arguments,
701 is_pre_event,
702 ):
703 (
704 update_options,
705 execution_options,
706 ) = _BulkUDCompileState.default_update_options.from_execution_options(
707 "_sa_orm_update_options",
708 {
709 "synchronize_session",
710 "autoflush",
711 "populate_existing",
712 "identity_token",
713 "is_delete_using",
714 "is_update_from",
715 "dml_strategy",
716 },
717 execution_options,
718 statement._execution_options,
719 )
720 bind_arguments["clause"] = statement
721 try:
722 plugin_subject = statement._propagate_attrs["plugin_subject"]
723 except KeyError:
724 assert False, "statement had 'orm' plugin but no plugin_subject"
725 else:
726 if plugin_subject:
727 bind_arguments["mapper"] = plugin_subject.mapper
728 update_options += {"_subject_mapper": plugin_subject.mapper}
729
730 if "parententity" not in statement.table._annotations:
731 update_options += {"_dml_strategy": "core_only"}
732 elif not isinstance(params, list):
733 if update_options._dml_strategy == "auto":
734 update_options += {"_dml_strategy": "orm"}
735 elif update_options._dml_strategy == "bulk":
736 raise sa_exc.InvalidRequestError(
737 'Can\'t use "bulk" ORM insert strategy without '
738 "passing separate parameters"
739 )
740 else:
741 if update_options._dml_strategy == "auto":
742 update_options += {"_dml_strategy": "bulk"}
743
744 sync = update_options._synchronize_session
745 if sync is not None:
746 if sync not in ("auto", "evaluate", "fetch", False):
747 raise sa_exc.ArgumentError(
748 "Valid strategies for session synchronization "
749 "are 'auto', 'evaluate', 'fetch', False"
750 )
751 if update_options._dml_strategy == "bulk" and sync == "fetch":
752 raise sa_exc.InvalidRequestError(
753 "The 'fetch' synchronization strategy is not available "
754 "for 'bulk' ORM updates (i.e. multiple parameter sets)"
755 )
756
757 if not is_pre_event:
758 if update_options._autoflush:
759 session._autoflush()
760
761 if update_options._dml_strategy == "orm":
762 if update_options._synchronize_session == "auto":
763 update_options = cls._do_pre_synchronize_auto(
764 session,
765 statement,
766 params,
767 execution_options,
768 bind_arguments,
769 update_options,
770 )
771 elif update_options._synchronize_session == "evaluate":
772 update_options = cls._do_pre_synchronize_evaluate(
773 session,
774 statement,
775 params,
776 execution_options,
777 bind_arguments,
778 update_options,
779 )
780 elif update_options._synchronize_session == "fetch":
781 update_options = cls._do_pre_synchronize_fetch(
782 session,
783 statement,
784 params,
785 execution_options,
786 bind_arguments,
787 update_options,
788 )
789 elif update_options._dml_strategy == "bulk":
790 if update_options._synchronize_session == "auto":
791 update_options += {"_synchronize_session": "evaluate"}
792
793 # indicators from the "pre exec" step that are then
794 # added to the DML statement, which will also be part of the cache
795 # key. The compile level create_for_statement() method will then
796 # consume these at compiler time.
797 statement = statement._annotate(
798 {
799 "synchronize_session": update_options._synchronize_session,
800 "is_delete_using": update_options._is_delete_using,
801 "is_update_from": update_options._is_update_from,
802 "dml_strategy": update_options._dml_strategy,
803 "can_use_returning": update_options._can_use_returning,
804 }
805 )
806
807 return (
808 statement,
809 util.immutabledict(execution_options).union(
810 {"_sa_orm_update_options": update_options}
811 ),
812 )
813
814 @classmethod
815 def orm_setup_cursor_result(
816 cls,
817 session,
818 statement,
819 params,
820 execution_options,
821 bind_arguments,
822 result,
823 ):
824 # this stage of the execution is called after the
825 # do_orm_execute event hook. meaning for an extension like
826 # horizontal sharding, this step happens *within* the horizontal
827 # sharding event handler which calls session.execute() re-entrantly
828 # and will occur for each backend individually.
829 # the sharding extension then returns its own merged result from the
830 # individual ones we return here.
831
832 update_options = execution_options["_sa_orm_update_options"]
833 if update_options._dml_strategy == "orm":
834 if update_options._synchronize_session == "evaluate":
835 cls._do_post_synchronize_evaluate(
836 session, statement, result, update_options
837 )
838 elif update_options._synchronize_session == "fetch":
839 cls._do_post_synchronize_fetch(
840 session, statement, result, update_options
841 )
842 elif update_options._dml_strategy == "bulk":
843 if update_options._synchronize_session == "evaluate":
844 cls._do_post_synchronize_bulk_evaluate(
845 session, params, result, update_options
846 )
847 return result
848
849 return cls._return_orm_returning(
850 session,
851 statement,
852 params,
853 execution_options,
854 bind_arguments,
855 result,
856 )
857
858 @classmethod
859 def _adjust_for_extra_criteria(cls, global_attributes, ext_info):
860 """Apply extra criteria filtering.
861
862 For all distinct single-table-inheritance mappers represented in the
863 table being updated or deleted, produce additional WHERE criteria such
864 that only the appropriate subtypes are selected from the total results.
865
866 Additionally, add WHERE criteria originating from LoaderCriteriaOptions
867 collected from the statement.
868
869 """
870
871 return_crit = ()
872
873 adapter = ext_info._adapter if ext_info.is_aliased_class else None
874
875 if (
876 "additional_entity_criteria",
877 ext_info.mapper,
878 ) in global_attributes:
879 return_crit += tuple(
880 ae._resolve_where_criteria(ext_info)
881 for ae in global_attributes[
882 ("additional_entity_criteria", ext_info.mapper)
883 ]
884 if ae.include_aliases or ae.entity is ext_info
885 )
886
887 if ext_info.mapper._single_table_criterion is not None:
888 return_crit += (ext_info.mapper._single_table_criterion,)
889
890 if adapter:
891 return_crit = tuple(adapter.traverse(crit) for crit in return_crit)
892
893 return return_crit
894
895 @classmethod
896 def _interpret_returning_rows(cls, result, mapper, rows):
897 """return rows that indicate PK cols in mapper.primary_key position
898 for RETURNING rows.
899
900 Prior to 2.0.36, this method seemed to be written for some kind of
901 inheritance scenario but the scenario was unused for actual joined
902 inheritance, and the function instead seemed to perform some kind of
903 partial translation that would remove non-PK cols if the PK cols
904 happened to be first in the row, but not otherwise. The joined
905 inheritance walk feature here seems to have never been used as it was
906 always skipped by the "local_table" check.
907
908 As of 2.0.36 the function strips away non-PK cols and provides the
909 PK cols for the table in mapper PK order.
910
911 """
912
913 try:
914 if mapper.local_table is not mapper.base_mapper.local_table:
915 # TODO: dive more into how a local table PK is used for fetch
916 # sync, not clear if this is correct as it depends on the
917 # downstream routine to fetch rows using
918 # local_table.primary_key order
919 pk_keys = result._tuple_getter(mapper.local_table.primary_key)
920 else:
921 pk_keys = result._tuple_getter(mapper.primary_key)
922 except KeyError:
923 # can't use these rows, they don't have PK cols in them
924 # this is an unusual case where the user would have used
925 # .return_defaults()
926 return []
927
928 return [pk_keys(row) for row in rows]
929
930 @classmethod
931 def _get_matched_objects_on_criteria(cls, update_options, states):
932 mapper = update_options._subject_mapper
933 eval_condition = update_options._eval_condition
934
935 raw_data = [
936 (state.obj(), state, state.dict)
937 for state in states
938 if state.mapper.isa(mapper) and not state.expired
939 ]
940
941 identity_token = update_options._identity_token
942 if identity_token is not None:
943 raw_data = [
944 (obj, state, dict_)
945 for obj, state, dict_ in raw_data
946 if state.identity_token == identity_token
947 ]
948
949 result = []
950 for obj, state, dict_ in raw_data:
951 evaled_condition = eval_condition(obj)
952
953 # caution: don't use "in ()" or == here, _EXPIRE_OBJECT
954 # evaluates as True for all comparisons
955 if (
956 evaled_condition is True
957 or evaled_condition is evaluator._EXPIRED_OBJECT
958 ):
959 result.append(
960 (
961 obj,
962 state,
963 dict_,
964 evaled_condition is evaluator._EXPIRED_OBJECT,
965 )
966 )
967 return result
968
969 @classmethod
970 def _eval_condition_from_statement(cls, update_options, statement):
971 mapper = update_options._subject_mapper
972 target_cls = mapper.class_
973
974 evaluator_compiler = evaluator._EvaluatorCompiler(target_cls)
975 crit = ()
976 if statement._where_criteria:
977 crit += statement._where_criteria
978
979 global_attributes = {}
980 for opt in statement._with_options:
981 if opt._is_criteria_option:
982 opt.get_global_criteria(global_attributes)
983
984 if global_attributes:
985 crit += cls._adjust_for_extra_criteria(global_attributes, mapper)
986
987 if crit:
988 eval_condition = evaluator_compiler.process(*crit)
989 else:
990 # workaround for mypy https://github.com/python/mypy/issues/14027
991 def _eval_condition(obj):
992 return True
993
994 eval_condition = _eval_condition
995
996 return eval_condition
997
998 @classmethod
999 def _do_pre_synchronize_auto(
1000 cls,
1001 session,
1002 statement,
1003 params,
1004 execution_options,
1005 bind_arguments,
1006 update_options,
1007 ):
1008 """setup auto sync strategy
1009
1010
1011 "auto" checks if we can use "evaluate" first, then falls back
1012 to "fetch"
1013
1014 evaluate is vastly more efficient for the common case
1015 where session is empty, only has a few objects, and the UPDATE
1016 statement can potentially match thousands/millions of rows.
1017
1018 OTOH more complex criteria that fails to work with "evaluate"
1019 we would hope usually correlates with fewer net rows.
1020
1021 """
1022
1023 try:
1024 eval_condition = cls._eval_condition_from_statement(
1025 update_options, statement
1026 )
1027
1028 except evaluator.UnevaluatableError:
1029 pass
1030 else:
1031 return update_options + {
1032 "_eval_condition": eval_condition,
1033 "_synchronize_session": "evaluate",
1034 }
1035
1036 update_options += {"_synchronize_session": "fetch"}
1037 return cls._do_pre_synchronize_fetch(
1038 session,
1039 statement,
1040 params,
1041 execution_options,
1042 bind_arguments,
1043 update_options,
1044 )
1045
1046 @classmethod
1047 def _do_pre_synchronize_evaluate(
1048 cls,
1049 session,
1050 statement,
1051 params,
1052 execution_options,
1053 bind_arguments,
1054 update_options,
1055 ):
1056 try:
1057 eval_condition = cls._eval_condition_from_statement(
1058 update_options, statement
1059 )
1060
1061 except evaluator.UnevaluatableError as err:
1062 raise sa_exc.InvalidRequestError(
1063 'Could not evaluate current criteria in Python: "%s". '
1064 "Specify 'fetch' or False for the "
1065 "synchronize_session execution option." % err
1066 ) from err
1067
1068 return update_options + {
1069 "_eval_condition": eval_condition,
1070 }
1071
1072 @classmethod
1073 def _get_resolved_values(cls, mapper, statement):
1074 if statement._multi_values:
1075 return []
1076 elif statement._values:
1077 return list(statement._values.items())
1078 else:
1079 return []
1080
1081 @classmethod
1082 def _resolved_keys_as_propnames(cls, mapper, resolved_values):
1083 values = []
1084 for k, v in resolved_values:
1085 if mapper and isinstance(k, expression.ColumnElement):
1086 try:
1087 attr = mapper._columntoproperty[k]
1088 except orm_exc.UnmappedColumnError:
1089 pass
1090 else:
1091 values.append((attr.key, v))
1092 else:
1093 raise sa_exc.InvalidRequestError(
1094 "Attribute name not found, can't be "
1095 "synchronized back to objects: %r" % k
1096 )
1097 return values
1098
1099 @classmethod
1100 def _do_pre_synchronize_fetch(
1101 cls,
1102 session,
1103 statement,
1104 params,
1105 execution_options,
1106 bind_arguments,
1107 update_options,
1108 ):
1109 mapper = update_options._subject_mapper
1110
1111 select_stmt = (
1112 select(*(mapper.primary_key + (mapper.select_identity_token,)))
1113 .select_from(mapper)
1114 .options(*statement._with_options)
1115 )
1116 select_stmt._where_criteria = statement._where_criteria
1117
1118 # conditionally run the SELECT statement for pre-fetch, testing the
1119 # "bind" for if we can use RETURNING or not using the do_orm_execute
1120 # event. If RETURNING is available, the do_orm_execute event
1121 # will cancel the SELECT from being actually run.
1122 #
1123 # The way this is organized seems strange, why don't we just
1124 # call can_use_returning() before invoking the statement and get
1125 # answer?, why does this go through the whole execute phase using an
1126 # event? Answer: because we are integrating with extensions such
1127 # as the horizontal sharding extention that "multiplexes" an individual
1128 # statement run through multiple engines, and it uses
1129 # do_orm_execute() to do that.
1130
1131 can_use_returning = None
1132
1133 def skip_for_returning(orm_context: ORMExecuteState) -> Any:
1134 bind = orm_context.session.get_bind(**orm_context.bind_arguments)
1135 nonlocal can_use_returning
1136
1137 per_bind_result = cls.can_use_returning(
1138 bind.dialect,
1139 mapper,
1140 is_update_from=update_options._is_update_from,
1141 is_delete_using=update_options._is_delete_using,
1142 is_executemany=orm_context.is_executemany,
1143 )
1144
1145 if can_use_returning is not None:
1146 if can_use_returning != per_bind_result:
1147 raise sa_exc.InvalidRequestError(
1148 "For synchronize_session='fetch', can't mix multiple "
1149 "backends where some support RETURNING and others "
1150 "don't"
1151 )
1152 elif orm_context.is_executemany and not per_bind_result:
1153 raise sa_exc.InvalidRequestError(
1154 "For synchronize_session='fetch', can't use multiple "
1155 "parameter sets in ORM mode, which this backend does not "
1156 "support with RETURNING"
1157 )
1158 else:
1159 can_use_returning = per_bind_result
1160
1161 if per_bind_result:
1162 return _result.null_result()
1163 else:
1164 return None
1165
1166 result = session.execute(
1167 select_stmt,
1168 params,
1169 execution_options=execution_options,
1170 bind_arguments=bind_arguments,
1171 _add_event=skip_for_returning,
1172 )
1173 matched_rows = result.fetchall()
1174
1175 return update_options + {
1176 "_matched_rows": matched_rows,
1177 "_can_use_returning": can_use_returning,
1178 }
1179
1180
1181@CompileState.plugin_for("orm", "insert")
1182class _BulkORMInsert(_ORMDMLState, InsertDMLState):
1183 class default_insert_options(Options):
1184 _dml_strategy: DMLStrategyArgument = "auto"
1185 _render_nulls: bool = False
1186 _return_defaults: bool = False
1187 _subject_mapper: Optional[Mapper[Any]] = None
1188 _autoflush: bool = True
1189 _populate_existing: bool = False
1190
1191 select_statement: Optional[FromStatement] = None
1192
1193 @classmethod
1194 def orm_pre_session_exec(
1195 cls,
1196 session,
1197 statement,
1198 params,
1199 execution_options,
1200 bind_arguments,
1201 is_pre_event,
1202 ):
1203 (
1204 insert_options,
1205 execution_options,
1206 ) = _BulkORMInsert.default_insert_options.from_execution_options(
1207 "_sa_orm_insert_options",
1208 {"dml_strategy", "autoflush", "populate_existing", "render_nulls"},
1209 execution_options,
1210 statement._execution_options,
1211 )
1212 bind_arguments["clause"] = statement
1213 try:
1214 plugin_subject = statement._propagate_attrs["plugin_subject"]
1215 except KeyError:
1216 assert False, "statement had 'orm' plugin but no plugin_subject"
1217 else:
1218 if plugin_subject:
1219 bind_arguments["mapper"] = plugin_subject.mapper
1220 insert_options += {"_subject_mapper": plugin_subject.mapper}
1221
1222 if not params:
1223 if insert_options._dml_strategy == "auto":
1224 insert_options += {"_dml_strategy": "orm"}
1225 elif insert_options._dml_strategy == "bulk":
1226 raise sa_exc.InvalidRequestError(
1227 'Can\'t use "bulk" ORM insert strategy without '
1228 "passing separate parameters"
1229 )
1230 else:
1231 if insert_options._dml_strategy == "auto":
1232 insert_options += {"_dml_strategy": "bulk"}
1233
1234 if insert_options._dml_strategy != "raw":
1235 # for ORM object loading, like ORMContext, we have to disable
1236 # result set adapt_to_context, because we will be generating a
1237 # new statement with specific columns that's cached inside of
1238 # an ORMFromStatementCompileState, which we will re-use for
1239 # each result.
1240 if not execution_options:
1241 execution_options = context._orm_load_exec_options
1242 else:
1243 execution_options = execution_options.union(
1244 context._orm_load_exec_options
1245 )
1246
1247 if not is_pre_event and insert_options._autoflush:
1248 session._autoflush()
1249
1250 statement = statement._annotate(
1251 {"dml_strategy": insert_options._dml_strategy}
1252 )
1253
1254 return (
1255 statement,
1256 util.immutabledict(execution_options).union(
1257 {"_sa_orm_insert_options": insert_options}
1258 ),
1259 )
1260
1261 @classmethod
1262 def orm_execute_statement(
1263 cls,
1264 session: Session,
1265 statement: dml.Insert,
1266 params: _CoreAnyExecuteParams,
1267 execution_options: OrmExecuteOptionsParameter,
1268 bind_arguments: _BindArguments,
1269 conn: Connection,
1270 ) -> _result.Result:
1271 insert_options = execution_options.get(
1272 "_sa_orm_insert_options", cls.default_insert_options
1273 )
1274
1275 if insert_options._dml_strategy not in (
1276 "raw",
1277 "bulk",
1278 "orm",
1279 "auto",
1280 ):
1281 raise sa_exc.ArgumentError(
1282 "Valid strategies for ORM insert strategy "
1283 "are 'raw', 'orm', 'bulk', 'auto"
1284 )
1285
1286 result: _result.Result[Unpack[TupleAny]]
1287
1288 if insert_options._dml_strategy == "raw":
1289 result = conn.execute(
1290 statement, params or {}, execution_options=execution_options
1291 )
1292 return result
1293
1294 if insert_options._dml_strategy == "bulk":
1295 mapper = insert_options._subject_mapper
1296
1297 if (
1298 statement._post_values_clause is not None
1299 and mapper._multiple_persistence_tables
1300 ):
1301 raise sa_exc.InvalidRequestError(
1302 "bulk INSERT with a 'post values' clause "
1303 "(typically upsert) not supported for multi-table "
1304 f"mapper {mapper}"
1305 )
1306
1307 assert mapper is not None
1308 assert session._transaction is not None
1309 result = _bulk_insert(
1310 mapper,
1311 cast(
1312 "Iterable[Dict[str, Any]]",
1313 [params] if isinstance(params, dict) else params,
1314 ),
1315 session._transaction,
1316 isstates=False,
1317 return_defaults=insert_options._return_defaults,
1318 render_nulls=insert_options._render_nulls,
1319 use_orm_insert_stmt=statement,
1320 execution_options=execution_options,
1321 )
1322 elif insert_options._dml_strategy == "orm":
1323 result = conn.execute(
1324 statement, params or {}, execution_options=execution_options
1325 )
1326 else:
1327 raise AssertionError()
1328
1329 if not bool(statement._returning):
1330 return result
1331
1332 if insert_options._populate_existing:
1333 load_options = execution_options.get(
1334 "_sa_orm_load_options", QueryContext.default_load_options
1335 )
1336 load_options += {"_populate_existing": True}
1337 execution_options = execution_options.union(
1338 {"_sa_orm_load_options": load_options}
1339 )
1340
1341 return cls._return_orm_returning(
1342 session,
1343 statement,
1344 params,
1345 execution_options,
1346 bind_arguments,
1347 result,
1348 )
1349
1350 @classmethod
1351 def create_for_statement(cls, statement, compiler, **kw) -> _BulkORMInsert:
1352 self = cast(
1353 _BulkORMInsert,
1354 super().create_for_statement(statement, compiler, **kw),
1355 )
1356
1357 if compiler is not None:
1358 toplevel = not compiler.stack
1359 else:
1360 toplevel = True
1361 if not toplevel:
1362 return self
1363
1364 mapper = statement._propagate_attrs["plugin_subject"]
1365 dml_strategy = statement._annotations.get("dml_strategy", "raw")
1366 if dml_strategy == "bulk":
1367 self._setup_for_bulk_insert(compiler)
1368 elif dml_strategy == "orm":
1369 self._setup_for_orm_insert(compiler, mapper)
1370
1371 return self
1372
1373 @classmethod
1374 def _resolved_keys_as_col_keys(cls, mapper, resolved_value_dict):
1375 return {
1376 col.key if col is not None else k: v
1377 for col, k, v in (
1378 (mapper.c.get(k), k, v) for k, v in resolved_value_dict.items()
1379 )
1380 }
1381
1382 def _setup_for_orm_insert(self, compiler, mapper):
1383 statement = orm_level_statement = cast(dml.Insert, self.statement)
1384
1385 statement = self._setup_orm_returning(
1386 compiler,
1387 orm_level_statement,
1388 statement,
1389 dml_mapper=mapper,
1390 use_supplemental_cols=False,
1391 )
1392 self.statement = statement
1393
1394 def _setup_for_bulk_insert(self, compiler):
1395 """establish an INSERT statement within the context of
1396 bulk insert.
1397
1398 This method will be within the "conn.execute()" call that is invoked
1399 by persistence._emit_insert_statement().
1400
1401 """
1402 statement = orm_level_statement = cast(dml.Insert, self.statement)
1403 an = statement._annotations
1404
1405 emit_insert_table, emit_insert_mapper = (
1406 an["_emit_insert_table"],
1407 an["_emit_insert_mapper"],
1408 )
1409
1410 statement = statement._clone()
1411
1412 statement.table = emit_insert_table
1413 if self._dict_parameters:
1414 self._dict_parameters = {
1415 col: val
1416 for col, val in self._dict_parameters.items()
1417 if col.table is emit_insert_table
1418 }
1419
1420 statement = self._setup_orm_returning(
1421 compiler,
1422 orm_level_statement,
1423 statement,
1424 dml_mapper=emit_insert_mapper,
1425 use_supplemental_cols=True,
1426 )
1427
1428 if (
1429 self.from_statement_ctx is not None
1430 and self.from_statement_ctx.compile_options._is_star
1431 ):
1432 raise sa_exc.CompileError(
1433 "Can't use RETURNING * with bulk ORM INSERT. "
1434 "Please use a different INSERT form, such as INSERT..VALUES "
1435 "or INSERT with a Core Connection"
1436 )
1437
1438 self.statement = statement
1439
1440
1441@CompileState.plugin_for("orm", "update")
1442class _BulkORMUpdate(_BulkUDCompileState, UpdateDMLState):
1443 @classmethod
1444 def create_for_statement(cls, statement, compiler, **kw):
1445 self = cls.__new__(cls)
1446
1447 dml_strategy = statement._annotations.get(
1448 "dml_strategy", "unspecified"
1449 )
1450
1451 toplevel = not compiler.stack
1452
1453 if toplevel and dml_strategy == "bulk":
1454 self._setup_for_bulk_update(statement, compiler)
1455 elif (
1456 dml_strategy == "core_only"
1457 or dml_strategy == "unspecified"
1458 and "parententity" not in statement.table._annotations
1459 ):
1460 UpdateDMLState.__init__(self, statement, compiler, **kw)
1461 elif not toplevel or dml_strategy in ("orm", "unspecified"):
1462 self._setup_for_orm_update(statement, compiler)
1463
1464 return self
1465
1466 def _setup_for_orm_update(self, statement, compiler, **kw):
1467 orm_level_statement = statement
1468
1469 toplevel = not compiler.stack
1470
1471 ext_info = statement.table._annotations["parententity"]
1472
1473 self.mapper = mapper = ext_info.mapper
1474
1475 self._resolved_values = self._get_resolved_values(mapper, statement)
1476
1477 self._init_global_attributes(
1478 statement,
1479 compiler,
1480 toplevel=toplevel,
1481 process_criteria_for_toplevel=toplevel,
1482 )
1483
1484 if statement._values:
1485 self._resolved_values = dict(self._resolved_values)
1486
1487 new_stmt = statement._clone()
1488
1489 if new_stmt.table._annotations["parententity"] is mapper:
1490 new_stmt.table = mapper.local_table
1491
1492 # note if the statement has _multi_values, these
1493 # are passed through to the new statement, which will then raise
1494 # InvalidRequestError because UPDATE doesn't support multi_values
1495 # right now.
1496 if statement._values:
1497 new_stmt._values = self._resolved_values
1498
1499 new_crit = self._adjust_for_extra_criteria(
1500 self.global_attributes, mapper
1501 )
1502 if new_crit:
1503 new_stmt = new_stmt.where(*new_crit)
1504
1505 # if we are against a lambda statement we might not be the
1506 # topmost object that received per-execute annotations
1507
1508 # do this first as we need to determine if there is
1509 # UPDATE..FROM
1510
1511 UpdateDMLState.__init__(self, new_stmt, compiler, **kw)
1512
1513 use_supplemental_cols = False
1514
1515 if not toplevel:
1516 synchronize_session = None
1517 else:
1518 synchronize_session = compiler._annotations.get(
1519 "synchronize_session", None
1520 )
1521 can_use_returning = compiler._annotations.get(
1522 "can_use_returning", None
1523 )
1524 if can_use_returning is not False:
1525 # even though pre_exec has determined basic
1526 # can_use_returning for the dialect, if we are to use
1527 # RETURNING we need to run can_use_returning() at this level
1528 # unconditionally because is_delete_using was not known
1529 # at the pre_exec level
1530 can_use_returning = (
1531 synchronize_session == "fetch"
1532 and self.can_use_returning(
1533 compiler.dialect, mapper, is_multitable=self.is_multitable
1534 )
1535 )
1536
1537 if synchronize_session == "fetch" and can_use_returning:
1538 use_supplemental_cols = True
1539
1540 # NOTE: we might want to RETURNING the actual columns to be
1541 # synchronized also. however this is complicated and difficult
1542 # to align against the behavior of "evaluate". Additionally,
1543 # in a large number (if not the majority) of cases, we have the
1544 # "evaluate" answer, usually a fixed value, in memory already and
1545 # there's no need to re-fetch the same value
1546 # over and over again. so perhaps if it could be RETURNING just
1547 # the elements that were based on a SQL expression and not
1548 # a constant. For now it doesn't quite seem worth it
1549 new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key)
1550
1551 if toplevel:
1552 new_stmt = self._setup_orm_returning(
1553 compiler,
1554 orm_level_statement,
1555 new_stmt,
1556 dml_mapper=mapper,
1557 use_supplemental_cols=use_supplemental_cols,
1558 )
1559
1560 self.statement = new_stmt
1561
1562 def _setup_for_bulk_update(self, statement, compiler, **kw):
1563 """establish an UPDATE statement within the context of
1564 bulk insert.
1565
1566 This method will be within the "conn.execute()" call that is invoked
1567 by persistence._emit_update_statement().
1568
1569 """
1570 statement = cast(dml.Update, statement)
1571 an = statement._annotations
1572
1573 emit_update_table, _ = (
1574 an["_emit_update_table"],
1575 an["_emit_update_mapper"],
1576 )
1577
1578 statement = statement._clone()
1579 statement.table = emit_update_table
1580
1581 UpdateDMLState.__init__(self, statement, compiler, **kw)
1582
1583 if self._maintain_values_ordering:
1584 raise sa_exc.InvalidRequestError(
1585 "bulk ORM UPDATE does not support ordered_values() for "
1586 "custom UPDATE statements with bulk parameter sets. Use a "
1587 "non-bulk UPDATE statement or use values()."
1588 )
1589
1590 if self._dict_parameters:
1591 self._dict_parameters = {
1592 col: val
1593 for col, val in self._dict_parameters.items()
1594 if col.table is emit_update_table
1595 }
1596 self.statement = statement
1597
1598 @classmethod
1599 def orm_execute_statement(
1600 cls,
1601 session: Session,
1602 statement: dml.Update,
1603 params: _CoreAnyExecuteParams,
1604 execution_options: OrmExecuteOptionsParameter,
1605 bind_arguments: _BindArguments,
1606 conn: Connection,
1607 ) -> _result.Result:
1608
1609 update_options = execution_options.get(
1610 "_sa_orm_update_options", cls.default_update_options
1611 )
1612
1613 if update_options._populate_existing:
1614 load_options = execution_options.get(
1615 "_sa_orm_load_options", QueryContext.default_load_options
1616 )
1617 load_options += {"_populate_existing": True}
1618 execution_options = execution_options.union(
1619 {"_sa_orm_load_options": load_options}
1620 )
1621
1622 if update_options._dml_strategy not in (
1623 "orm",
1624 "auto",
1625 "bulk",
1626 "core_only",
1627 ):
1628 raise sa_exc.ArgumentError(
1629 "Valid strategies for ORM UPDATE strategy "
1630 "are 'orm', 'auto', 'bulk', 'core_only'"
1631 )
1632
1633 result: _result.Result[Unpack[TupleAny]]
1634
1635 if update_options._dml_strategy == "bulk":
1636 enable_check_rowcount = not statement._where_criteria
1637
1638 assert update_options._synchronize_session != "fetch"
1639
1640 if (
1641 statement._where_criteria
1642 and update_options._synchronize_session == "evaluate"
1643 ):
1644 raise sa_exc.InvalidRequestError(
1645 "bulk synchronize of persistent objects not supported "
1646 "when using bulk update with additional WHERE "
1647 "criteria right now. add synchronize_session=None "
1648 "execution option to bypass synchronize of persistent "
1649 "objects."
1650 )
1651 mapper = update_options._subject_mapper
1652 assert mapper is not None
1653 assert session._transaction is not None
1654 result = _bulk_update(
1655 mapper,
1656 cast(
1657 "Iterable[Dict[str, Any]]",
1658 [params] if isinstance(params, dict) else params,
1659 ),
1660 session._transaction,
1661 isstates=False,
1662 update_changed_only=False,
1663 use_orm_update_stmt=statement,
1664 enable_check_rowcount=enable_check_rowcount,
1665 )
1666 return cls.orm_setup_cursor_result(
1667 session,
1668 statement,
1669 params,
1670 execution_options,
1671 bind_arguments,
1672 result,
1673 )
1674 else:
1675 return super().orm_execute_statement(
1676 session,
1677 statement,
1678 params,
1679 execution_options,
1680 bind_arguments,
1681 conn,
1682 )
1683
1684 @classmethod
1685 def can_use_returning(
1686 cls,
1687 dialect: Dialect,
1688 mapper: Mapper[Any],
1689 *,
1690 is_multitable: bool = False,
1691 is_update_from: bool = False,
1692 is_delete_using: bool = False,
1693 is_executemany: bool = False,
1694 ) -> bool:
1695 # normal answer for "should we use RETURNING" at all.
1696 normal_answer = (
1697 dialect.update_returning and mapper.local_table.implicit_returning
1698 )
1699 if not normal_answer:
1700 return False
1701
1702 if is_executemany:
1703 return dialect.update_executemany_returning
1704
1705 # these workarounds are currently hypothetical for UPDATE,
1706 # unlike DELETE where they impact MariaDB
1707 if is_update_from:
1708 return dialect.update_returning_multifrom
1709
1710 elif is_multitable and not dialect.update_returning_multifrom:
1711 raise sa_exc.CompileError(
1712 f'Dialect "{dialect.name}" does not support RETURNING '
1713 "with UPDATE..FROM; for synchronize_session='fetch', "
1714 "please add the additional execution option "
1715 "'is_update_from=True' to the statement to indicate that "
1716 "a separate SELECT should be used for this backend."
1717 )
1718
1719 return True
1720
1721 @classmethod
1722 def _do_post_synchronize_bulk_evaluate(
1723 cls, session, params, result, update_options
1724 ):
1725 if not params:
1726 return
1727
1728 mapper = update_options._subject_mapper
1729 pk_keys = [prop.key for prop in mapper._identity_key_props]
1730
1731 identity_map = session.identity_map
1732
1733 for param in params:
1734 identity_key = mapper.identity_key_from_primary_key(
1735 (param[key] for key in pk_keys),
1736 update_options._identity_token,
1737 )
1738 state = identity_map.fast_get_state(identity_key)
1739 if not state:
1740 continue
1741
1742 evaluated_keys = set(param).difference(pk_keys)
1743
1744 dict_ = state.dict
1745 # only evaluate unmodified attributes
1746 to_evaluate = state.unmodified.intersection(evaluated_keys)
1747 for key in to_evaluate:
1748 if key in dict_:
1749 dict_[key] = param[key]
1750
1751 state.manager.dispatch.refresh(state, None, to_evaluate)
1752
1753 state._commit(dict_, list(to_evaluate))
1754
1755 # attributes that were formerly modified instead get expired.
1756 # this only gets hit if the session had pending changes
1757 # and autoflush were set to False.
1758 to_expire = evaluated_keys.intersection(dict_).difference(
1759 to_evaluate
1760 )
1761 if to_expire:
1762 state._expire_attributes(dict_, to_expire)
1763
1764 @classmethod
1765 def _do_post_synchronize_evaluate(
1766 cls, session, statement, result, update_options
1767 ):
1768 matched_objects = cls._get_matched_objects_on_criteria(
1769 update_options,
1770 session.identity_map.all_states(),
1771 )
1772
1773 cls._apply_update_set_values_to_objects(
1774 session,
1775 update_options,
1776 statement,
1777 result.context.compiled_parameters[0],
1778 [(obj, state, dict_) for obj, state, dict_, _ in matched_objects],
1779 result.prefetch_cols(),
1780 result.postfetch_cols(),
1781 )
1782
1783 @classmethod
1784 def _do_post_synchronize_fetch(
1785 cls, session, statement, result, update_options
1786 ):
1787 target_mapper = update_options._subject_mapper
1788
1789 returned_defaults_rows = result.returned_defaults_rows
1790 if returned_defaults_rows:
1791 pk_rows = cls._interpret_returning_rows(
1792 result, target_mapper, returned_defaults_rows
1793 )
1794 matched_rows = [
1795 tuple(row) + (update_options._identity_token,)
1796 for row in pk_rows
1797 ]
1798 else:
1799 matched_rows = update_options._matched_rows
1800
1801 objs = [
1802 session.identity_map[identity_key]
1803 for identity_key in [
1804 target_mapper.identity_key_from_primary_key(
1805 list(primary_key),
1806 identity_token=identity_token,
1807 )
1808 for primary_key, identity_token in [
1809 (row[0:-1], row[-1]) for row in matched_rows
1810 ]
1811 if update_options._identity_token is None
1812 or identity_token == update_options._identity_token
1813 ]
1814 if identity_key in session.identity_map
1815 ]
1816
1817 if not objs:
1818 return
1819
1820 cls._apply_update_set_values_to_objects(
1821 session,
1822 update_options,
1823 statement,
1824 result.context.compiled_parameters[0],
1825 [
1826 (
1827 obj,
1828 attributes.instance_state(obj),
1829 attributes.instance_dict(obj),
1830 )
1831 for obj in objs
1832 ],
1833 result.prefetch_cols(),
1834 result.postfetch_cols(),
1835 )
1836
1837 @classmethod
1838 def _apply_update_set_values_to_objects(
1839 cls,
1840 session,
1841 update_options,
1842 statement,
1843 effective_params,
1844 matched_objects,
1845 prefetch_cols,
1846 postfetch_cols,
1847 ):
1848 """apply values to objects derived from an update statement, e.g.
1849 UPDATE..SET <values>
1850
1851 """
1852
1853 mapper = update_options._subject_mapper
1854 target_cls = mapper.class_
1855 evaluator_compiler = evaluator._EvaluatorCompiler(target_cls)
1856 resolved_values = cls._get_resolved_values(mapper, statement)
1857 resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
1858 mapper, resolved_values
1859 )
1860 value_evaluators = {}
1861 for key, value in resolved_keys_as_propnames:
1862 try:
1863 _evaluator = evaluator_compiler.process(
1864 coercions.expect(roles.ExpressionElementRole, value)
1865 )
1866 except evaluator.UnevaluatableError:
1867 pass
1868 else:
1869 value_evaluators[key] = _evaluator
1870
1871 evaluated_keys = list(value_evaluators.keys())
1872 attrib = {k for k, v in resolved_keys_as_propnames}
1873
1874 states = set()
1875
1876 to_prefetch = {
1877 c
1878 for c in prefetch_cols
1879 if c.key in effective_params
1880 and c in mapper._columntoproperty
1881 and c.key not in evaluated_keys
1882 }
1883 to_expire = {
1884 mapper._columntoproperty[c].key
1885 for c in postfetch_cols
1886 if c in mapper._columntoproperty
1887 }.difference(evaluated_keys)
1888
1889 prefetch_transfer = [
1890 (mapper._columntoproperty[c].key, c.key) for c in to_prefetch
1891 ]
1892
1893 for obj, state, dict_ in matched_objects:
1894
1895 dict_.update(
1896 {
1897 col_to_prop: effective_params[c_key]
1898 for col_to_prop, c_key in prefetch_transfer
1899 }
1900 )
1901
1902 state._expire_attributes(state.dict, to_expire)
1903
1904 to_evaluate = state.unmodified.intersection(evaluated_keys)
1905
1906 for key in to_evaluate:
1907 if key in dict_:
1908 # only run eval for attributes that are present.
1909 dict_[key] = value_evaluators[key](obj)
1910
1911 state.manager.dispatch.refresh(state, None, to_evaluate)
1912
1913 state._commit(dict_, list(to_evaluate))
1914
1915 # attributes that were formerly modified instead get expired.
1916 # this only gets hit if the session had pending changes
1917 # and autoflush were set to False.
1918 to_expire = attrib.intersection(dict_).difference(to_evaluate)
1919 if to_expire:
1920 state._expire_attributes(dict_, to_expire)
1921
1922 states.add(state)
1923 session._register_altered(states)
1924
1925
1926@CompileState.plugin_for("orm", "delete")
1927class _BulkORMDelete(_BulkUDCompileState, DeleteDMLState):
1928 @classmethod
1929 def create_for_statement(cls, statement, compiler, **kw):
1930 self = cls.__new__(cls)
1931
1932 dml_strategy = statement._annotations.get(
1933 "dml_strategy", "unspecified"
1934 )
1935
1936 if (
1937 dml_strategy == "core_only"
1938 or dml_strategy == "unspecified"
1939 and "parententity" not in statement.table._annotations
1940 ):
1941 DeleteDMLState.__init__(self, statement, compiler, **kw)
1942 return self
1943
1944 toplevel = not compiler.stack
1945
1946 orm_level_statement = statement
1947
1948 ext_info = statement.table._annotations["parententity"]
1949 self.mapper = mapper = ext_info.mapper
1950
1951 self._init_global_attributes(
1952 statement,
1953 compiler,
1954 toplevel=toplevel,
1955 process_criteria_for_toplevel=toplevel,
1956 )
1957
1958 new_stmt = statement._clone()
1959
1960 if new_stmt.table._annotations["parententity"] is mapper:
1961 new_stmt.table = mapper.local_table
1962
1963 new_crit = cls._adjust_for_extra_criteria(
1964 self.global_attributes, mapper
1965 )
1966 if new_crit:
1967 new_stmt = new_stmt.where(*new_crit)
1968
1969 # do this first as we need to determine if there is
1970 # DELETE..FROM
1971 DeleteDMLState.__init__(self, new_stmt, compiler, **kw)
1972
1973 use_supplemental_cols = False
1974
1975 if not toplevel:
1976 synchronize_session = None
1977 else:
1978 synchronize_session = compiler._annotations.get(
1979 "synchronize_session", None
1980 )
1981 can_use_returning = compiler._annotations.get(
1982 "can_use_returning", None
1983 )
1984 if can_use_returning is not False:
1985 # even though pre_exec has determined basic
1986 # can_use_returning for the dialect, if we are to use
1987 # RETURNING we need to run can_use_returning() at this level
1988 # unconditionally because is_delete_using was not known
1989 # at the pre_exec level
1990 can_use_returning = (
1991 synchronize_session == "fetch"
1992 and self.can_use_returning(
1993 compiler.dialect,
1994 mapper,
1995 is_multitable=self.is_multitable,
1996 is_delete_using=compiler._annotations.get(
1997 "is_delete_using", False
1998 ),
1999 )
2000 )
2001
2002 if can_use_returning:
2003 use_supplemental_cols = True
2004
2005 new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key)
2006
2007 if toplevel:
2008 new_stmt = self._setup_orm_returning(
2009 compiler,
2010 orm_level_statement,
2011 new_stmt,
2012 dml_mapper=mapper,
2013 use_supplemental_cols=use_supplemental_cols,
2014 )
2015
2016 self.statement = new_stmt
2017
2018 return self
2019
2020 @classmethod
2021 def orm_execute_statement(
2022 cls,
2023 session: Session,
2024 statement: dml.Delete,
2025 params: _CoreAnyExecuteParams,
2026 execution_options: OrmExecuteOptionsParameter,
2027 bind_arguments: _BindArguments,
2028 conn: Connection,
2029 ) -> _result.Result:
2030 update_options = execution_options.get(
2031 "_sa_orm_update_options", cls.default_update_options
2032 )
2033
2034 if update_options._dml_strategy == "bulk":
2035 raise sa_exc.InvalidRequestError(
2036 "Bulk ORM DELETE not supported right now. "
2037 "Statement may be invoked at the "
2038 "Core level using "
2039 "session.connection().execute(stmt, parameters)"
2040 )
2041
2042 if update_options._dml_strategy not in ("orm", "auto", "core_only"):
2043 raise sa_exc.ArgumentError(
2044 "Valid strategies for ORM DELETE strategy are 'orm', 'auto', "
2045 "'core_only'"
2046 )
2047
2048 return super().orm_execute_statement(
2049 session, statement, params, execution_options, bind_arguments, conn
2050 )
2051
2052 @classmethod
2053 def can_use_returning(
2054 cls,
2055 dialect: Dialect,
2056 mapper: Mapper[Any],
2057 *,
2058 is_multitable: bool = False,
2059 is_update_from: bool = False,
2060 is_delete_using: bool = False,
2061 is_executemany: bool = False,
2062 ) -> bool:
2063 # normal answer for "should we use RETURNING" at all.
2064 normal_answer = (
2065 dialect.delete_returning and mapper.local_table.implicit_returning
2066 )
2067 if not normal_answer:
2068 return False
2069
2070 # now get into special workarounds because MariaDB supports
2071 # DELETE...RETURNING but not DELETE...USING...RETURNING.
2072 if is_delete_using:
2073 # is_delete_using hint was passed. use
2074 # additional dialect feature (True for PG, False for MariaDB)
2075 return dialect.delete_returning_multifrom
2076
2077 elif is_multitable and not dialect.delete_returning_multifrom:
2078 # is_delete_using hint was not passed, but we determined
2079 # at compile time that this is in fact a DELETE..USING.
2080 # it's too late to continue since we did not pre-SELECT.
2081 # raise that we need that hint up front.
2082
2083 raise sa_exc.CompileError(
2084 f'Dialect "{dialect.name}" does not support RETURNING '
2085 "with DELETE..USING; for synchronize_session='fetch', "
2086 "please add the additional execution option "
2087 "'is_delete_using=True' to the statement to indicate that "
2088 "a separate SELECT should be used for this backend."
2089 )
2090
2091 return True
2092
2093 @classmethod
2094 def _do_post_synchronize_evaluate(
2095 cls, session, statement, result, update_options
2096 ):
2097 matched_objects = cls._get_matched_objects_on_criteria(
2098 update_options,
2099 session.identity_map.all_states(),
2100 )
2101
2102 to_delete = []
2103
2104 for _, state, dict_, is_partially_expired in matched_objects:
2105 if is_partially_expired:
2106 state._expire(dict_, session.identity_map._modified)
2107 else:
2108 to_delete.append(state)
2109
2110 if to_delete:
2111 session._remove_newly_deleted(to_delete)
2112
2113 @classmethod
2114 def _do_post_synchronize_fetch(
2115 cls, session, statement, result, update_options
2116 ):
2117 target_mapper = update_options._subject_mapper
2118
2119 returned_defaults_rows = result.returned_defaults_rows
2120
2121 if returned_defaults_rows:
2122 pk_rows = cls._interpret_returning_rows(
2123 result, target_mapper, returned_defaults_rows
2124 )
2125
2126 matched_rows = [
2127 tuple(row) + (update_options._identity_token,)
2128 for row in pk_rows
2129 ]
2130 else:
2131 matched_rows = update_options._matched_rows
2132
2133 for row in matched_rows:
2134 primary_key = row[0:-1]
2135 identity_token = row[-1]
2136
2137 # TODO: inline this and call remove_newly_deleted
2138 # once
2139 identity_key = target_mapper.identity_key_from_primary_key(
2140 list(primary_key),
2141 identity_token=identity_token,
2142 )
2143 if identity_key in session.identity_map:
2144 session._remove_newly_deleted(
2145 [
2146 attributes.instance_state(
2147 session.identity_map[identity_key]
2148 )
2149 ]
2150 )