1# orm/bulk_persistence.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: ignore-errors
8
9
10"""additional ORM persistence classes related to "bulk" operations,
11specifically outside of the flush() process.
12
13"""
14
15from __future__ import annotations
16
17from typing import Any
18from typing import cast
19from typing import Dict
20from typing import Iterable
21from typing import Optional
22from typing import overload
23from typing import TYPE_CHECKING
24from typing import TypeVar
25from typing import Union
26
27from . import attributes
28from . import context
29from . import evaluator
30from . import exc as orm_exc
31from . import loading
32from . import persistence
33from .base import NO_VALUE
34from .context import AbstractORMCompileState
35from .context import FromStatement
36from .context import ORMFromStatementCompileState
37from .context import QueryContext
38from .. import exc as sa_exc
39from .. import util
40from ..engine import Dialect
41from ..engine import result as _result
42from ..sql import coercions
43from ..sql import dml
44from ..sql import expression
45from ..sql import roles
46from ..sql import select
47from ..sql import sqltypes
48from ..sql.base import _entity_namespace_key
49from ..sql.base import CompileState
50from ..sql.base import Options
51from ..sql.dml import DeleteDMLState
52from ..sql.dml import InsertDMLState
53from ..sql.dml import UpdateDMLState
54from ..util import EMPTY_DICT
55from ..util.typing import Literal
56
57if TYPE_CHECKING:
58 from ._typing import DMLStrategyArgument
59 from ._typing import OrmExecuteOptionsParameter
60 from ._typing import SynchronizeSessionArgument
61 from .mapper import Mapper
62 from .session import _BindArguments
63 from .session import ORMExecuteState
64 from .session import Session
65 from .session import SessionTransaction
66 from .state import InstanceState
67 from ..engine import Connection
68 from ..engine import cursor
69 from ..engine.interfaces import _CoreAnyExecuteParams
70
71_O = TypeVar("_O", bound=object)
72
73
74@overload
75def _bulk_insert(
76 mapper: Mapper[_O],
77 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
78 session_transaction: SessionTransaction,
79 *,
80 isstates: bool,
81 return_defaults: bool,
82 render_nulls: bool,
83 use_orm_insert_stmt: Literal[None] = ...,
84 execution_options: Optional[OrmExecuteOptionsParameter] = ...,
85) -> None: ...
86
87
88@overload
89def _bulk_insert(
90 mapper: Mapper[_O],
91 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
92 session_transaction: SessionTransaction,
93 *,
94 isstates: bool,
95 return_defaults: bool,
96 render_nulls: bool,
97 use_orm_insert_stmt: Optional[dml.Insert] = ...,
98 execution_options: Optional[OrmExecuteOptionsParameter] = ...,
99) -> cursor.CursorResult[Any]: ...
100
101
102def _bulk_insert(
103 mapper: Mapper[_O],
104 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
105 session_transaction: SessionTransaction,
106 *,
107 isstates: bool,
108 return_defaults: bool,
109 render_nulls: bool,
110 use_orm_insert_stmt: Optional[dml.Insert] = None,
111 execution_options: Optional[OrmExecuteOptionsParameter] = None,
112) -> Optional[cursor.CursorResult[Any]]:
113 base_mapper = mapper.base_mapper
114
115 if session_transaction.session.connection_callable:
116 raise NotImplementedError(
117 "connection_callable / per-instance sharding "
118 "not supported in bulk_insert()"
119 )
120
121 if isstates:
122 if TYPE_CHECKING:
123 mappings = cast(Iterable[InstanceState[_O]], mappings)
124
125 if return_defaults:
126 # list of states allows us to attach .key for return_defaults case
127 states = [(state, state.dict) for state in mappings]
128 mappings = [dict_ for (state, dict_) in states]
129 else:
130 mappings = [state.dict for state in mappings]
131 else:
132 if TYPE_CHECKING:
133 mappings = cast(Iterable[Dict[str, Any]], mappings)
134
135 if return_defaults:
136 # use dictionaries given, so that newly populated defaults
137 # can be delivered back to the caller (see #11661). This is **not**
138 # compatible with other use cases such as a session-executed
139 # insert() construct, as this will confuse the case of
140 # insert-per-subclass for joined inheritance cases (see
141 # test_bulk_statements.py::BulkDMLReturningJoinedInhTest).
142 #
143 # So in this conditional, we have **only** called
144 # session.bulk_insert_mappings() which does not have this
145 # requirement
146 mappings = list(mappings)
147 else:
148 # for all other cases we need to establish a local dictionary
149 # so that the incoming dictionaries aren't mutated
150 mappings = [dict(m) for m in mappings]
151 _expand_composites(mapper, mappings)
152
153 connection = session_transaction.connection(base_mapper)
154
155 return_result: Optional[cursor.CursorResult[Any]] = None
156
157 mappers_to_run = [
158 (table, mp)
159 for table, mp in base_mapper._sorted_tables.items()
160 if table in mapper._pks_by_table
161 ]
162
163 if return_defaults:
164 # not used by new-style bulk inserts, only used for legacy
165 bookkeeping = True
166 elif len(mappers_to_run) > 1:
167 # if we have more than one table, mapper to run where we will be
168 # either horizontally splicing, or copying values between tables,
169 # we need the "bookkeeping" / deterministic returning order
170 bookkeeping = True
171 else:
172 bookkeeping = False
173
174 for table, super_mapper in mappers_to_run:
175 # find bindparams in the statement. For bulk, we don't really know if
176 # a key in the params applies to a different table since we are
177 # potentially inserting for multiple tables here; looking at the
178 # bindparam() is a lot more direct. in most cases this will
179 # use _generate_cache_key() which is memoized, although in practice
180 # the ultimate statement that's executed is probably not the same
181 # object so that memoization might not matter much.
182 extra_bp_names = (
183 [
184 b.key
185 for b in use_orm_insert_stmt._get_embedded_bindparams()
186 if b.key in mappings[0]
187 ]
188 if use_orm_insert_stmt is not None
189 else ()
190 )
191
192 records = (
193 (
194 None,
195 state_dict,
196 params,
197 mapper,
198 connection,
199 value_params,
200 has_all_pks,
201 has_all_defaults,
202 )
203 for (
204 state,
205 state_dict,
206 params,
207 mp,
208 conn,
209 value_params,
210 has_all_pks,
211 has_all_defaults,
212 ) in persistence._collect_insert_commands(
213 table,
214 ((None, mapping, mapper, connection) for mapping in mappings),
215 bulk=True,
216 return_defaults=bookkeeping,
217 render_nulls=render_nulls,
218 include_bulk_keys=extra_bp_names,
219 )
220 )
221
222 result = persistence._emit_insert_statements(
223 base_mapper,
224 None,
225 super_mapper,
226 table,
227 records,
228 bookkeeping=bookkeeping,
229 use_orm_insert_stmt=use_orm_insert_stmt,
230 execution_options=execution_options,
231 )
232 if use_orm_insert_stmt is not None:
233 if not use_orm_insert_stmt._returning or return_result is None:
234 return_result = result
235 elif result.returns_rows:
236 assert bookkeeping
237 return_result = return_result.splice_horizontally(result)
238
239 if return_defaults and isstates:
240 identity_cls = mapper._identity_class
241 identity_props = [p.key for p in mapper._identity_key_props]
242 for state, dict_ in states:
243 state.key = (
244 identity_cls,
245 tuple([dict_[key] for key in identity_props]),
246 None,
247 )
248
249 if use_orm_insert_stmt is not None:
250 assert return_result is not None
251 return return_result
252
253
254@overload
255def _bulk_update(
256 mapper: Mapper[Any],
257 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
258 session_transaction: SessionTransaction,
259 *,
260 isstates: bool,
261 update_changed_only: bool,
262 use_orm_update_stmt: Literal[None] = ...,
263 enable_check_rowcount: bool = True,
264) -> None: ...
265
266
267@overload
268def _bulk_update(
269 mapper: Mapper[Any],
270 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
271 session_transaction: SessionTransaction,
272 *,
273 isstates: bool,
274 update_changed_only: bool,
275 use_orm_update_stmt: Optional[dml.Update] = ...,
276 enable_check_rowcount: bool = True,
277) -> _result.Result[Any]: ...
278
279
280def _bulk_update(
281 mapper: Mapper[Any],
282 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
283 session_transaction: SessionTransaction,
284 *,
285 isstates: bool,
286 update_changed_only: bool,
287 use_orm_update_stmt: Optional[dml.Update] = None,
288 enable_check_rowcount: bool = True,
289) -> Optional[_result.Result[Any]]:
290 base_mapper = mapper.base_mapper
291
292 search_keys = mapper._primary_key_propkeys
293 if mapper._version_id_prop:
294 search_keys = {mapper._version_id_prop.key}.union(search_keys)
295
296 def _changed_dict(mapper, state):
297 return {
298 k: v
299 for k, v in state.dict.items()
300 if k in state.committed_state or k in search_keys
301 }
302
303 if isstates:
304 if update_changed_only:
305 mappings = [_changed_dict(mapper, state) for state in mappings]
306 else:
307 mappings = [state.dict for state in mappings]
308 else:
309 mappings = [dict(m) for m in mappings]
310 _expand_composites(mapper, mappings)
311
312 if session_transaction.session.connection_callable:
313 raise NotImplementedError(
314 "connection_callable / per-instance sharding "
315 "not supported in bulk_update()"
316 )
317
318 connection = session_transaction.connection(base_mapper)
319
320 # find bindparams in the statement. see _bulk_insert for similar
321 # notes for the insert case
322 extra_bp_names = (
323 [
324 b.key
325 for b in use_orm_update_stmt._get_embedded_bindparams()
326 if b.key in mappings[0]
327 ]
328 if use_orm_update_stmt is not None
329 else ()
330 )
331
332 for table, super_mapper in base_mapper._sorted_tables.items():
333 if not mapper.isa(super_mapper) or table not in mapper._pks_by_table:
334 continue
335
336 records = persistence._collect_update_commands(
337 None,
338 table,
339 (
340 (
341 None,
342 mapping,
343 mapper,
344 connection,
345 (
346 mapping[mapper._version_id_prop.key]
347 if mapper._version_id_prop
348 else None
349 ),
350 )
351 for mapping in mappings
352 ),
353 bulk=True,
354 use_orm_update_stmt=use_orm_update_stmt,
355 include_bulk_keys=extra_bp_names,
356 )
357 persistence._emit_update_statements(
358 base_mapper,
359 None,
360 super_mapper,
361 table,
362 records,
363 bookkeeping=False,
364 use_orm_update_stmt=use_orm_update_stmt,
365 enable_check_rowcount=enable_check_rowcount,
366 )
367
368 if use_orm_update_stmt is not None:
369 return _result.null_result()
370
371
372def _expand_composites(mapper, mappings):
373 composite_attrs = mapper.composites
374 if not composite_attrs:
375 return
376
377 composite_keys = set(composite_attrs.keys())
378 populators = {
379 key: composite_attrs[key]._populate_composite_bulk_save_mappings_fn()
380 for key in composite_keys
381 }
382 for mapping in mappings:
383 for key in composite_keys.intersection(mapping):
384 populators[key](mapping)
385
386
387class ORMDMLState(AbstractORMCompileState):
388 is_dml_returning = True
389 from_statement_ctx: Optional[ORMFromStatementCompileState] = None
390
391 @classmethod
392 def _get_orm_crud_kv_pairs(
393 cls, mapper, statement, kv_iterator, needs_to_be_cacheable
394 ):
395 core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs
396
397 for k, v in kv_iterator:
398 k = coercions.expect(roles.DMLColumnRole, k)
399
400 if isinstance(k, str):
401 desc = _entity_namespace_key(mapper, k, default=NO_VALUE)
402 if desc is NO_VALUE:
403 yield (
404 coercions.expect(roles.DMLColumnRole, k),
405 (
406 coercions.expect(
407 roles.ExpressionElementRole,
408 v,
409 type_=sqltypes.NullType(),
410 is_crud=True,
411 )
412 if needs_to_be_cacheable
413 else v
414 ),
415 )
416 else:
417 yield from core_get_crud_kv_pairs(
418 statement,
419 desc._bulk_update_tuples(v),
420 needs_to_be_cacheable,
421 )
422 elif "entity_namespace" in k._annotations:
423 k_anno = k._annotations
424 attr = _entity_namespace_key(
425 k_anno["entity_namespace"], k_anno["proxy_key"]
426 )
427 yield from core_get_crud_kv_pairs(
428 statement,
429 attr._bulk_update_tuples(v),
430 needs_to_be_cacheable,
431 )
432 else:
433 yield (
434 k,
435 (
436 v
437 if not needs_to_be_cacheable
438 else coercions.expect(
439 roles.ExpressionElementRole,
440 v,
441 type_=sqltypes.NullType(),
442 is_crud=True,
443 )
444 ),
445 )
446
447 @classmethod
448 def _get_dml_plugin_subject(cls, statement):
449 plugin_subject = statement.table._propagate_attrs.get("plugin_subject")
450
451 if (
452 not plugin_subject
453 or not plugin_subject.mapper
454 or plugin_subject
455 is not statement._propagate_attrs["plugin_subject"]
456 ):
457 return None
458 return plugin_subject
459
460 @classmethod
461 def _get_multi_crud_kv_pairs(cls, statement, kv_iterator):
462 plugin_subject = cls._get_dml_plugin_subject(statement)
463
464 if not plugin_subject:
465 return UpdateDMLState._get_multi_crud_kv_pairs(
466 statement, kv_iterator
467 )
468
469 return [
470 dict(
471 cls._get_orm_crud_kv_pairs(
472 plugin_subject.mapper, statement, value_dict.items(), False
473 )
474 )
475 for value_dict in kv_iterator
476 ]
477
478 @classmethod
479 def _get_crud_kv_pairs(cls, statement, kv_iterator, needs_to_be_cacheable):
480 assert (
481 needs_to_be_cacheable
482 ), "no test coverage for needs_to_be_cacheable=False"
483
484 plugin_subject = cls._get_dml_plugin_subject(statement)
485
486 if not plugin_subject:
487 return UpdateDMLState._get_crud_kv_pairs(
488 statement, kv_iterator, needs_to_be_cacheable
489 )
490 return list(
491 cls._get_orm_crud_kv_pairs(
492 plugin_subject.mapper,
493 statement,
494 kv_iterator,
495 needs_to_be_cacheable,
496 )
497 )
498
499 @classmethod
500 def get_entity_description(cls, statement):
501 ext_info = statement.table._annotations["parententity"]
502 mapper = ext_info.mapper
503 if ext_info.is_aliased_class:
504 _label_name = ext_info.name
505 else:
506 _label_name = mapper.class_.__name__
507
508 return {
509 "name": _label_name,
510 "type": mapper.class_,
511 "expr": ext_info.entity,
512 "entity": ext_info.entity,
513 "table": mapper.local_table,
514 }
515
516 @classmethod
517 def get_returning_column_descriptions(cls, statement):
518 def _ent_for_col(c):
519 return c._annotations.get("parententity", None)
520
521 def _attr_for_col(c, ent):
522 if ent is None:
523 return c
524 proxy_key = c._annotations.get("proxy_key", None)
525 if not proxy_key:
526 return c
527 else:
528 return getattr(ent.entity, proxy_key, c)
529
530 return [
531 {
532 "name": c.key,
533 "type": c.type,
534 "expr": _attr_for_col(c, ent),
535 "aliased": ent.is_aliased_class,
536 "entity": ent.entity,
537 }
538 for c, ent in [
539 (c, _ent_for_col(c)) for c in statement._all_selected_columns
540 ]
541 ]
542
543 def _setup_orm_returning(
544 self,
545 compiler,
546 orm_level_statement,
547 dml_level_statement,
548 dml_mapper,
549 *,
550 use_supplemental_cols=True,
551 ):
552 """establish ORM column handlers for an INSERT, UPDATE, or DELETE
553 which uses explicit returning().
554
555 called within compilation level create_for_statement.
556
557 The _return_orm_returning() method then receives the Result
558 after the statement was executed, and applies ORM loading to the
559 state that we first established here.
560
561 """
562
563 if orm_level_statement._returning:
564 fs = FromStatement(
565 orm_level_statement._returning,
566 dml_level_statement,
567 _adapt_on_names=False,
568 )
569 fs = fs.execution_options(**orm_level_statement._execution_options)
570 fs = fs.options(*orm_level_statement._with_options)
571 self.select_statement = fs
572 self.from_statement_ctx = fsc = (
573 ORMFromStatementCompileState.create_for_statement(fs, compiler)
574 )
575 fsc.setup_dml_returning_compile_state(dml_mapper)
576
577 dml_level_statement = dml_level_statement._generate()
578 dml_level_statement._returning = ()
579
580 cols_to_return = [c for c in fsc.primary_columns if c is not None]
581
582 # since we are splicing result sets together, make sure there
583 # are columns of some kind returned in each result set
584 if not cols_to_return:
585 cols_to_return.extend(dml_mapper.primary_key)
586
587 if use_supplemental_cols:
588 dml_level_statement = dml_level_statement.return_defaults(
589 # this is a little weird looking, but by passing
590 # primary key as the main list of cols, this tells
591 # return_defaults to omit server-default cols (and
592 # actually all cols, due to some weird thing we should
593 # clean up in crud.py).
594 # Since we have cols_to_return, just return what we asked
595 # for (plus primary key, which ORM persistence needs since
596 # we likely set bookkeeping=True here, which is another
597 # whole thing...). We dont want to clutter the
598 # statement up with lots of other cols the user didn't
599 # ask for. see #9685
600 *dml_mapper.primary_key,
601 supplemental_cols=cols_to_return,
602 )
603 else:
604 dml_level_statement = dml_level_statement.returning(
605 *cols_to_return
606 )
607
608 return dml_level_statement
609
610 @classmethod
611 def _return_orm_returning(
612 cls,
613 session,
614 statement,
615 params,
616 execution_options,
617 bind_arguments,
618 result,
619 ):
620 execution_context = result.context
621 compile_state = execution_context.compiled.compile_state
622
623 if (
624 compile_state.from_statement_ctx
625 and not compile_state.from_statement_ctx.compile_options._is_star
626 ):
627 load_options = execution_options.get(
628 "_sa_orm_load_options", QueryContext.default_load_options
629 )
630
631 querycontext = QueryContext(
632 compile_state.from_statement_ctx,
633 compile_state.select_statement,
634 statement,
635 params,
636 session,
637 load_options,
638 execution_options,
639 bind_arguments,
640 )
641 return loading.instances(result, querycontext)
642 else:
643 return result
644
645
646class BulkUDCompileState(ORMDMLState):
647 class default_update_options(Options):
648 _dml_strategy: DMLStrategyArgument = "auto"
649 _synchronize_session: SynchronizeSessionArgument = "auto"
650 _can_use_returning: bool = False
651 _is_delete_using: bool = False
652 _is_update_from: bool = False
653 _autoflush: bool = True
654 _subject_mapper: Optional[Mapper[Any]] = None
655 _resolved_values = EMPTY_DICT
656 _eval_condition = None
657 _matched_rows = None
658 _identity_token = None
659 _populate_existing: bool = False
660
661 @classmethod
662 def can_use_returning(
663 cls,
664 dialect: Dialect,
665 mapper: Mapper[Any],
666 *,
667 is_multitable: bool = False,
668 is_update_from: bool = False,
669 is_delete_using: bool = False,
670 is_executemany: bool = False,
671 ) -> bool:
672 raise NotImplementedError()
673
674 @classmethod
675 def orm_pre_session_exec(
676 cls,
677 session,
678 statement,
679 params,
680 execution_options,
681 bind_arguments,
682 is_pre_event,
683 ):
684 (
685 update_options,
686 execution_options,
687 ) = BulkUDCompileState.default_update_options.from_execution_options(
688 "_sa_orm_update_options",
689 {
690 "synchronize_session",
691 "autoflush",
692 "populate_existing",
693 "identity_token",
694 "is_delete_using",
695 "is_update_from",
696 "dml_strategy",
697 },
698 execution_options,
699 statement._execution_options,
700 )
701 bind_arguments["clause"] = statement
702 try:
703 plugin_subject = statement._propagate_attrs["plugin_subject"]
704 except KeyError:
705 assert False, "statement had 'orm' plugin but no plugin_subject"
706 else:
707 if plugin_subject:
708 bind_arguments["mapper"] = plugin_subject.mapper
709 update_options += {"_subject_mapper": plugin_subject.mapper}
710
711 if "parententity" not in statement.table._annotations:
712 update_options += {"_dml_strategy": "core_only"}
713 elif not isinstance(params, list):
714 if update_options._dml_strategy == "auto":
715 update_options += {"_dml_strategy": "orm"}
716 elif update_options._dml_strategy == "bulk":
717 raise sa_exc.InvalidRequestError(
718 'Can\'t use "bulk" ORM insert strategy without '
719 "passing separate parameters"
720 )
721 else:
722 if update_options._dml_strategy == "auto":
723 update_options += {"_dml_strategy": "bulk"}
724
725 sync = update_options._synchronize_session
726 if sync is not None:
727 if sync not in ("auto", "evaluate", "fetch", False):
728 raise sa_exc.ArgumentError(
729 "Valid strategies for session synchronization "
730 "are 'auto', 'evaluate', 'fetch', False"
731 )
732 if update_options._dml_strategy == "bulk" and sync == "fetch":
733 raise sa_exc.InvalidRequestError(
734 "The 'fetch' synchronization strategy is not available "
735 "for 'bulk' ORM updates (i.e. multiple parameter sets)"
736 )
737
738 if not is_pre_event:
739 if update_options._autoflush:
740 session._autoflush()
741
742 if update_options._dml_strategy == "orm":
743 if update_options._synchronize_session == "auto":
744 update_options = cls._do_pre_synchronize_auto(
745 session,
746 statement,
747 params,
748 execution_options,
749 bind_arguments,
750 update_options,
751 )
752 elif update_options._synchronize_session == "evaluate":
753 update_options = cls._do_pre_synchronize_evaluate(
754 session,
755 statement,
756 params,
757 execution_options,
758 bind_arguments,
759 update_options,
760 )
761 elif update_options._synchronize_session == "fetch":
762 update_options = cls._do_pre_synchronize_fetch(
763 session,
764 statement,
765 params,
766 execution_options,
767 bind_arguments,
768 update_options,
769 )
770 elif update_options._dml_strategy == "bulk":
771 if update_options._synchronize_session == "auto":
772 update_options += {"_synchronize_session": "evaluate"}
773
774 # indicators from the "pre exec" step that are then
775 # added to the DML statement, which will also be part of the cache
776 # key. The compile level create_for_statement() method will then
777 # consume these at compiler time.
778 statement = statement._annotate(
779 {
780 "synchronize_session": update_options._synchronize_session,
781 "is_delete_using": update_options._is_delete_using,
782 "is_update_from": update_options._is_update_from,
783 "dml_strategy": update_options._dml_strategy,
784 "can_use_returning": update_options._can_use_returning,
785 }
786 )
787
788 return (
789 statement,
790 util.immutabledict(execution_options).union(
791 {"_sa_orm_update_options": update_options}
792 ),
793 )
794
795 @classmethod
796 def orm_setup_cursor_result(
797 cls,
798 session,
799 statement,
800 params,
801 execution_options,
802 bind_arguments,
803 result,
804 ):
805 # this stage of the execution is called after the
806 # do_orm_execute event hook. meaning for an extension like
807 # horizontal sharding, this step happens *within* the horizontal
808 # sharding event handler which calls session.execute() re-entrantly
809 # and will occur for each backend individually.
810 # the sharding extension then returns its own merged result from the
811 # individual ones we return here.
812
813 update_options = execution_options["_sa_orm_update_options"]
814 if update_options._dml_strategy == "orm":
815 if update_options._synchronize_session == "evaluate":
816 cls._do_post_synchronize_evaluate(
817 session, statement, result, update_options
818 )
819 elif update_options._synchronize_session == "fetch":
820 cls._do_post_synchronize_fetch(
821 session, statement, result, update_options
822 )
823 elif update_options._dml_strategy == "bulk":
824 if update_options._synchronize_session == "evaluate":
825 cls._do_post_synchronize_bulk_evaluate(
826 session, params, result, update_options
827 )
828 return result
829
830 return cls._return_orm_returning(
831 session,
832 statement,
833 params,
834 execution_options,
835 bind_arguments,
836 result,
837 )
838
839 @classmethod
840 def _adjust_for_extra_criteria(cls, global_attributes, ext_info):
841 """Apply extra criteria filtering.
842
843 For all distinct single-table-inheritance mappers represented in the
844 table being updated or deleted, produce additional WHERE criteria such
845 that only the appropriate subtypes are selected from the total results.
846
847 Additionally, add WHERE criteria originating from LoaderCriteriaOptions
848 collected from the statement.
849
850 """
851
852 return_crit = ()
853
854 adapter = ext_info._adapter if ext_info.is_aliased_class else None
855
856 if (
857 "additional_entity_criteria",
858 ext_info.mapper,
859 ) in global_attributes:
860 return_crit += tuple(
861 ae._resolve_where_criteria(ext_info)
862 for ae in global_attributes[
863 ("additional_entity_criteria", ext_info.mapper)
864 ]
865 if ae.include_aliases or ae.entity is ext_info
866 )
867
868 if ext_info.mapper._single_table_criterion is not None:
869 return_crit += (ext_info.mapper._single_table_criterion,)
870
871 if adapter:
872 return_crit = tuple(adapter.traverse(crit) for crit in return_crit)
873
874 return return_crit
875
876 @classmethod
877 def _interpret_returning_rows(cls, result, mapper, rows):
878 """return rows that indicate PK cols in mapper.primary_key position
879 for RETURNING rows.
880
881 Prior to 2.0.36, this method seemed to be written for some kind of
882 inheritance scenario but the scenario was unused for actual joined
883 inheritance, and the function instead seemed to perform some kind of
884 partial translation that would remove non-PK cols if the PK cols
885 happened to be first in the row, but not otherwise. The joined
886 inheritance walk feature here seems to have never been used as it was
887 always skipped by the "local_table" check.
888
889 As of 2.0.36 the function strips away non-PK cols and provides the
890 PK cols for the table in mapper PK order.
891
892 """
893
894 try:
895 if mapper.local_table is not mapper.base_mapper.local_table:
896 # TODO: dive more into how a local table PK is used for fetch
897 # sync, not clear if this is correct as it depends on the
898 # downstream routine to fetch rows using
899 # local_table.primary_key order
900 pk_keys = result._tuple_getter(mapper.local_table.primary_key)
901 else:
902 pk_keys = result._tuple_getter(mapper.primary_key)
903 except KeyError:
904 # can't use these rows, they don't have PK cols in them
905 # this is an unusual case where the user would have used
906 # .return_defaults()
907 return []
908
909 return [pk_keys(row) for row in rows]
910
911 @classmethod
912 def _get_matched_objects_on_criteria(cls, update_options, states):
913 mapper = update_options._subject_mapper
914 eval_condition = update_options._eval_condition
915
916 raw_data = [
917 (state.obj(), state, state.dict)
918 for state in states
919 if state.mapper.isa(mapper) and not state.expired
920 ]
921
922 identity_token = update_options._identity_token
923 if identity_token is not None:
924 raw_data = [
925 (obj, state, dict_)
926 for obj, state, dict_ in raw_data
927 if state.identity_token == identity_token
928 ]
929
930 result = []
931 for obj, state, dict_ in raw_data:
932 evaled_condition = eval_condition(obj)
933
934 # caution: don't use "in ()" or == here, _EXPIRE_OBJECT
935 # evaluates as True for all comparisons
936 if (
937 evaled_condition is True
938 or evaled_condition is evaluator._EXPIRED_OBJECT
939 ):
940 result.append(
941 (
942 obj,
943 state,
944 dict_,
945 evaled_condition is evaluator._EXPIRED_OBJECT,
946 )
947 )
948 return result
949
950 @classmethod
951 def _eval_condition_from_statement(cls, update_options, statement):
952 mapper = update_options._subject_mapper
953 target_cls = mapper.class_
954
955 evaluator_compiler = evaluator._EvaluatorCompiler(target_cls)
956 crit = ()
957 if statement._where_criteria:
958 crit += statement._where_criteria
959
960 global_attributes = {}
961 for opt in statement._with_options:
962 if opt._is_criteria_option:
963 opt.get_global_criteria(global_attributes)
964
965 if global_attributes:
966 crit += cls._adjust_for_extra_criteria(global_attributes, mapper)
967
968 if crit:
969 eval_condition = evaluator_compiler.process(*crit)
970 else:
971 # workaround for mypy https://github.com/python/mypy/issues/14027
972 def _eval_condition(obj):
973 return True
974
975 eval_condition = _eval_condition
976
977 return eval_condition
978
979 @classmethod
980 def _do_pre_synchronize_auto(
981 cls,
982 session,
983 statement,
984 params,
985 execution_options,
986 bind_arguments,
987 update_options,
988 ):
989 """setup auto sync strategy
990
991
992 "auto" checks if we can use "evaluate" first, then falls back
993 to "fetch"
994
995 evaluate is vastly more efficient for the common case
996 where session is empty, only has a few objects, and the UPDATE
997 statement can potentially match thousands/millions of rows.
998
999 OTOH more complex criteria that fails to work with "evaluate"
1000 we would hope usually correlates with fewer net rows.
1001
1002 """
1003
1004 try:
1005 eval_condition = cls._eval_condition_from_statement(
1006 update_options, statement
1007 )
1008
1009 except evaluator.UnevaluatableError:
1010 pass
1011 else:
1012 return update_options + {
1013 "_eval_condition": eval_condition,
1014 "_synchronize_session": "evaluate",
1015 }
1016
1017 update_options += {"_synchronize_session": "fetch"}
1018 return cls._do_pre_synchronize_fetch(
1019 session,
1020 statement,
1021 params,
1022 execution_options,
1023 bind_arguments,
1024 update_options,
1025 )
1026
1027 @classmethod
1028 def _do_pre_synchronize_evaluate(
1029 cls,
1030 session,
1031 statement,
1032 params,
1033 execution_options,
1034 bind_arguments,
1035 update_options,
1036 ):
1037 try:
1038 eval_condition = cls._eval_condition_from_statement(
1039 update_options, statement
1040 )
1041
1042 except evaluator.UnevaluatableError as err:
1043 raise sa_exc.InvalidRequestError(
1044 'Could not evaluate current criteria in Python: "%s". '
1045 "Specify 'fetch' or False for the "
1046 "synchronize_session execution option." % err
1047 ) from err
1048
1049 return update_options + {
1050 "_eval_condition": eval_condition,
1051 }
1052
1053 @classmethod
1054 def _get_resolved_values(cls, mapper, statement):
1055 if statement._multi_values:
1056 return []
1057 elif statement._ordered_values:
1058 return list(statement._ordered_values)
1059 elif statement._values:
1060 return list(statement._values.items())
1061 else:
1062 return []
1063
1064 @classmethod
1065 def _resolved_keys_as_propnames(cls, mapper, resolved_values):
1066 values = []
1067 for k, v in resolved_values:
1068 if mapper and isinstance(k, expression.ColumnElement):
1069 try:
1070 attr = mapper._columntoproperty[k]
1071 except orm_exc.UnmappedColumnError:
1072 pass
1073 else:
1074 values.append((attr.key, v))
1075 else:
1076 raise sa_exc.InvalidRequestError(
1077 "Attribute name not found, can't be "
1078 "synchronized back to objects: %r" % k
1079 )
1080 return values
1081
1082 @classmethod
1083 def _do_pre_synchronize_fetch(
1084 cls,
1085 session,
1086 statement,
1087 params,
1088 execution_options,
1089 bind_arguments,
1090 update_options,
1091 ):
1092 mapper = update_options._subject_mapper
1093
1094 select_stmt = (
1095 select(*(mapper.primary_key + (mapper.select_identity_token,)))
1096 .select_from(mapper)
1097 .options(*statement._with_options)
1098 )
1099 select_stmt._where_criteria = statement._where_criteria
1100
1101 # conditionally run the SELECT statement for pre-fetch, testing the
1102 # "bind" for if we can use RETURNING or not using the do_orm_execute
1103 # event. If RETURNING is available, the do_orm_execute event
1104 # will cancel the SELECT from being actually run.
1105 #
1106 # The way this is organized seems strange, why don't we just
1107 # call can_use_returning() before invoking the statement and get
1108 # answer?, why does this go through the whole execute phase using an
1109 # event? Answer: because we are integrating with extensions such
1110 # as the horizontal sharding extention that "multiplexes" an individual
1111 # statement run through multiple engines, and it uses
1112 # do_orm_execute() to do that.
1113
1114 can_use_returning = None
1115
1116 def skip_for_returning(orm_context: ORMExecuteState) -> Any:
1117 bind = orm_context.session.get_bind(**orm_context.bind_arguments)
1118 nonlocal can_use_returning
1119
1120 per_bind_result = cls.can_use_returning(
1121 bind.dialect,
1122 mapper,
1123 is_update_from=update_options._is_update_from,
1124 is_delete_using=update_options._is_delete_using,
1125 is_executemany=orm_context.is_executemany,
1126 )
1127
1128 if can_use_returning is not None:
1129 if can_use_returning != per_bind_result:
1130 raise sa_exc.InvalidRequestError(
1131 "For synchronize_session='fetch', can't mix multiple "
1132 "backends where some support RETURNING and others "
1133 "don't"
1134 )
1135 elif orm_context.is_executemany and not per_bind_result:
1136 raise sa_exc.InvalidRequestError(
1137 "For synchronize_session='fetch', can't use multiple "
1138 "parameter sets in ORM mode, which this backend does not "
1139 "support with RETURNING"
1140 )
1141 else:
1142 can_use_returning = per_bind_result
1143
1144 if per_bind_result:
1145 return _result.null_result()
1146 else:
1147 return None
1148
1149 result = session.execute(
1150 select_stmt,
1151 params,
1152 execution_options=execution_options,
1153 bind_arguments=bind_arguments,
1154 _add_event=skip_for_returning,
1155 )
1156 matched_rows = result.fetchall()
1157
1158 return update_options + {
1159 "_matched_rows": matched_rows,
1160 "_can_use_returning": can_use_returning,
1161 }
1162
1163
1164@CompileState.plugin_for("orm", "insert")
1165class BulkORMInsert(ORMDMLState, InsertDMLState):
1166 class default_insert_options(Options):
1167 _dml_strategy: DMLStrategyArgument = "auto"
1168 _render_nulls: bool = False
1169 _return_defaults: bool = False
1170 _subject_mapper: Optional[Mapper[Any]] = None
1171 _autoflush: bool = True
1172 _populate_existing: bool = False
1173
1174 select_statement: Optional[FromStatement] = None
1175
1176 @classmethod
1177 def orm_pre_session_exec(
1178 cls,
1179 session,
1180 statement,
1181 params,
1182 execution_options,
1183 bind_arguments,
1184 is_pre_event,
1185 ):
1186 (
1187 insert_options,
1188 execution_options,
1189 ) = BulkORMInsert.default_insert_options.from_execution_options(
1190 "_sa_orm_insert_options",
1191 {"dml_strategy", "autoflush", "populate_existing", "render_nulls"},
1192 execution_options,
1193 statement._execution_options,
1194 )
1195 bind_arguments["clause"] = statement
1196 try:
1197 plugin_subject = statement._propagate_attrs["plugin_subject"]
1198 except KeyError:
1199 assert False, "statement had 'orm' plugin but no plugin_subject"
1200 else:
1201 if plugin_subject:
1202 bind_arguments["mapper"] = plugin_subject.mapper
1203 insert_options += {"_subject_mapper": plugin_subject.mapper}
1204
1205 if not params:
1206 if insert_options._dml_strategy == "auto":
1207 insert_options += {"_dml_strategy": "orm"}
1208 elif insert_options._dml_strategy == "bulk":
1209 raise sa_exc.InvalidRequestError(
1210 'Can\'t use "bulk" ORM insert strategy without '
1211 "passing separate parameters"
1212 )
1213 else:
1214 if insert_options._dml_strategy == "auto":
1215 insert_options += {"_dml_strategy": "bulk"}
1216
1217 if insert_options._dml_strategy != "raw":
1218 # for ORM object loading, like ORMContext, we have to disable
1219 # result set adapt_to_context, because we will be generating a
1220 # new statement with specific columns that's cached inside of
1221 # an ORMFromStatementCompileState, which we will re-use for
1222 # each result.
1223 if not execution_options:
1224 execution_options = context._orm_load_exec_options
1225 else:
1226 execution_options = execution_options.union(
1227 context._orm_load_exec_options
1228 )
1229
1230 if not is_pre_event and insert_options._autoflush:
1231 session._autoflush()
1232
1233 statement = statement._annotate(
1234 {"dml_strategy": insert_options._dml_strategy}
1235 )
1236
1237 return (
1238 statement,
1239 util.immutabledict(execution_options).union(
1240 {"_sa_orm_insert_options": insert_options}
1241 ),
1242 )
1243
1244 @classmethod
1245 def orm_execute_statement(
1246 cls,
1247 session: Session,
1248 statement: dml.Insert,
1249 params: _CoreAnyExecuteParams,
1250 execution_options: OrmExecuteOptionsParameter,
1251 bind_arguments: _BindArguments,
1252 conn: Connection,
1253 ) -> _result.Result:
1254 insert_options = execution_options.get(
1255 "_sa_orm_insert_options", cls.default_insert_options
1256 )
1257
1258 if insert_options._dml_strategy not in (
1259 "raw",
1260 "bulk",
1261 "orm",
1262 "auto",
1263 ):
1264 raise sa_exc.ArgumentError(
1265 "Valid strategies for ORM insert strategy "
1266 "are 'raw', 'orm', 'bulk', 'auto"
1267 )
1268
1269 result: _result.Result[Any]
1270
1271 if insert_options._dml_strategy == "raw":
1272 result = conn.execute(
1273 statement, params or {}, execution_options=execution_options
1274 )
1275 return result
1276
1277 if insert_options._dml_strategy == "bulk":
1278 mapper = insert_options._subject_mapper
1279
1280 if (
1281 statement._post_values_clause is not None
1282 and mapper._multiple_persistence_tables
1283 ):
1284 raise sa_exc.InvalidRequestError(
1285 "bulk INSERT with a 'post values' clause "
1286 "(typically upsert) not supported for multi-table "
1287 f"mapper {mapper}"
1288 )
1289
1290 assert mapper is not None
1291 assert session._transaction is not None
1292 result = _bulk_insert(
1293 mapper,
1294 cast(
1295 "Iterable[Dict[str, Any]]",
1296 [params] if isinstance(params, dict) else params,
1297 ),
1298 session._transaction,
1299 isstates=False,
1300 return_defaults=insert_options._return_defaults,
1301 render_nulls=insert_options._render_nulls,
1302 use_orm_insert_stmt=statement,
1303 execution_options=execution_options,
1304 )
1305 elif insert_options._dml_strategy == "orm":
1306 result = conn.execute(
1307 statement, params or {}, execution_options=execution_options
1308 )
1309 else:
1310 raise AssertionError()
1311
1312 if not bool(statement._returning):
1313 return result
1314
1315 if insert_options._populate_existing:
1316 load_options = execution_options.get(
1317 "_sa_orm_load_options", QueryContext.default_load_options
1318 )
1319 load_options += {"_populate_existing": True}
1320 execution_options = execution_options.union(
1321 {"_sa_orm_load_options": load_options}
1322 )
1323
1324 return cls._return_orm_returning(
1325 session,
1326 statement,
1327 params,
1328 execution_options,
1329 bind_arguments,
1330 result,
1331 )
1332
1333 @classmethod
1334 def create_for_statement(cls, statement, compiler, **kw) -> BulkORMInsert:
1335 self = cast(
1336 BulkORMInsert,
1337 super().create_for_statement(statement, compiler, **kw),
1338 )
1339
1340 if compiler is not None:
1341 toplevel = not compiler.stack
1342 else:
1343 toplevel = True
1344 if not toplevel:
1345 return self
1346
1347 mapper = statement._propagate_attrs["plugin_subject"]
1348 dml_strategy = statement._annotations.get("dml_strategy", "raw")
1349 if dml_strategy == "bulk":
1350 self._setup_for_bulk_insert(compiler)
1351 elif dml_strategy == "orm":
1352 self._setup_for_orm_insert(compiler, mapper)
1353
1354 return self
1355
1356 @classmethod
1357 def _resolved_keys_as_col_keys(cls, mapper, resolved_value_dict):
1358 return {
1359 col.key if col is not None else k: v
1360 for col, k, v in (
1361 (mapper.c.get(k), k, v) for k, v in resolved_value_dict.items()
1362 )
1363 }
1364
1365 def _setup_for_orm_insert(self, compiler, mapper):
1366 statement = orm_level_statement = cast(dml.Insert, self.statement)
1367
1368 statement = self._setup_orm_returning(
1369 compiler,
1370 orm_level_statement,
1371 statement,
1372 dml_mapper=mapper,
1373 use_supplemental_cols=False,
1374 )
1375 self.statement = statement
1376
1377 def _setup_for_bulk_insert(self, compiler):
1378 """establish an INSERT statement within the context of
1379 bulk insert.
1380
1381 This method will be within the "conn.execute()" call that is invoked
1382 by persistence._emit_insert_statement().
1383
1384 """
1385 statement = orm_level_statement = cast(dml.Insert, self.statement)
1386 an = statement._annotations
1387
1388 emit_insert_table, emit_insert_mapper = (
1389 an["_emit_insert_table"],
1390 an["_emit_insert_mapper"],
1391 )
1392
1393 statement = statement._clone()
1394
1395 statement.table = emit_insert_table
1396 if self._dict_parameters:
1397 self._dict_parameters = {
1398 col: val
1399 for col, val in self._dict_parameters.items()
1400 if col.table is emit_insert_table
1401 }
1402
1403 statement = self._setup_orm_returning(
1404 compiler,
1405 orm_level_statement,
1406 statement,
1407 dml_mapper=emit_insert_mapper,
1408 use_supplemental_cols=True,
1409 )
1410
1411 if (
1412 self.from_statement_ctx is not None
1413 and self.from_statement_ctx.compile_options._is_star
1414 ):
1415 raise sa_exc.CompileError(
1416 "Can't use RETURNING * with bulk ORM INSERT. "
1417 "Please use a different INSERT form, such as INSERT..VALUES "
1418 "or INSERT with a Core Connection"
1419 )
1420
1421 self.statement = statement
1422
1423
1424@CompileState.plugin_for("orm", "update")
1425class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
1426 @classmethod
1427 def create_for_statement(cls, statement, compiler, **kw):
1428 self = cls.__new__(cls)
1429
1430 dml_strategy = statement._annotations.get(
1431 "dml_strategy", "unspecified"
1432 )
1433
1434 toplevel = not compiler.stack
1435
1436 if toplevel and dml_strategy == "bulk":
1437 self._setup_for_bulk_update(statement, compiler)
1438 elif (
1439 dml_strategy == "core_only"
1440 or dml_strategy == "unspecified"
1441 and "parententity" not in statement.table._annotations
1442 ):
1443 UpdateDMLState.__init__(self, statement, compiler, **kw)
1444 elif not toplevel or dml_strategy in ("orm", "unspecified"):
1445 self._setup_for_orm_update(statement, compiler)
1446
1447 return self
1448
1449 def _setup_for_orm_update(self, statement, compiler, **kw):
1450 orm_level_statement = statement
1451
1452 toplevel = not compiler.stack
1453
1454 ext_info = statement.table._annotations["parententity"]
1455
1456 self.mapper = mapper = ext_info.mapper
1457
1458 self._resolved_values = self._get_resolved_values(mapper, statement)
1459
1460 self._init_global_attributes(
1461 statement,
1462 compiler,
1463 toplevel=toplevel,
1464 process_criteria_for_toplevel=toplevel,
1465 )
1466
1467 if statement._values:
1468 self._resolved_values = dict(self._resolved_values)
1469
1470 new_stmt = statement._clone()
1471
1472 if new_stmt.table._annotations["parententity"] is mapper:
1473 new_stmt.table = mapper.local_table
1474
1475 # note if the statement has _multi_values, these
1476 # are passed through to the new statement, which will then raise
1477 # InvalidRequestError because UPDATE doesn't support multi_values
1478 # right now.
1479 if statement._ordered_values:
1480 new_stmt._ordered_values = self._resolved_values
1481 elif statement._values:
1482 new_stmt._values = self._resolved_values
1483
1484 new_crit = self._adjust_for_extra_criteria(
1485 self.global_attributes, mapper
1486 )
1487 if new_crit:
1488 new_stmt = new_stmt.where(*new_crit)
1489
1490 # if we are against a lambda statement we might not be the
1491 # topmost object that received per-execute annotations
1492
1493 # do this first as we need to determine if there is
1494 # UPDATE..FROM
1495
1496 UpdateDMLState.__init__(self, new_stmt, compiler, **kw)
1497
1498 use_supplemental_cols = False
1499
1500 if not toplevel:
1501 synchronize_session = None
1502 else:
1503 synchronize_session = compiler._annotations.get(
1504 "synchronize_session", None
1505 )
1506 can_use_returning = compiler._annotations.get(
1507 "can_use_returning", None
1508 )
1509 if can_use_returning is not False:
1510 # even though pre_exec has determined basic
1511 # can_use_returning for the dialect, if we are to use
1512 # RETURNING we need to run can_use_returning() at this level
1513 # unconditionally because is_delete_using was not known
1514 # at the pre_exec level
1515 can_use_returning = (
1516 synchronize_session == "fetch"
1517 and self.can_use_returning(
1518 compiler.dialect, mapper, is_multitable=self.is_multitable
1519 )
1520 )
1521
1522 if synchronize_session == "fetch" and can_use_returning:
1523 use_supplemental_cols = True
1524
1525 # NOTE: we might want to RETURNING the actual columns to be
1526 # synchronized also. however this is complicated and difficult
1527 # to align against the behavior of "evaluate". Additionally,
1528 # in a large number (if not the majority) of cases, we have the
1529 # "evaluate" answer, usually a fixed value, in memory already and
1530 # there's no need to re-fetch the same value
1531 # over and over again. so perhaps if it could be RETURNING just
1532 # the elements that were based on a SQL expression and not
1533 # a constant. For now it doesn't quite seem worth it
1534 new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key)
1535
1536 if toplevel:
1537 new_stmt = self._setup_orm_returning(
1538 compiler,
1539 orm_level_statement,
1540 new_stmt,
1541 dml_mapper=mapper,
1542 use_supplemental_cols=use_supplemental_cols,
1543 )
1544
1545 self.statement = new_stmt
1546
1547 def _setup_for_bulk_update(self, statement, compiler, **kw):
1548 """establish an UPDATE statement within the context of
1549 bulk insert.
1550
1551 This method will be within the "conn.execute()" call that is invoked
1552 by persistence._emit_update_statement().
1553
1554 """
1555 statement = cast(dml.Update, statement)
1556 an = statement._annotations
1557
1558 emit_update_table, _ = (
1559 an["_emit_update_table"],
1560 an["_emit_update_mapper"],
1561 )
1562
1563 statement = statement._clone()
1564 statement.table = emit_update_table
1565
1566 UpdateDMLState.__init__(self, statement, compiler, **kw)
1567
1568 if self._ordered_values:
1569 raise sa_exc.InvalidRequestError(
1570 "bulk ORM UPDATE does not support ordered_values() for "
1571 "custom UPDATE statements with bulk parameter sets. Use a "
1572 "non-bulk UPDATE statement or use values()."
1573 )
1574
1575 if self._dict_parameters:
1576 self._dict_parameters = {
1577 col: val
1578 for col, val in self._dict_parameters.items()
1579 if col.table is emit_update_table
1580 }
1581 self.statement = statement
1582
1583 @classmethod
1584 def orm_execute_statement(
1585 cls,
1586 session: Session,
1587 statement: dml.Update,
1588 params: _CoreAnyExecuteParams,
1589 execution_options: OrmExecuteOptionsParameter,
1590 bind_arguments: _BindArguments,
1591 conn: Connection,
1592 ) -> _result.Result:
1593
1594 update_options = execution_options.get(
1595 "_sa_orm_update_options", cls.default_update_options
1596 )
1597
1598 if update_options._populate_existing:
1599 load_options = execution_options.get(
1600 "_sa_orm_load_options", QueryContext.default_load_options
1601 )
1602 load_options += {"_populate_existing": True}
1603 execution_options = execution_options.union(
1604 {"_sa_orm_load_options": load_options}
1605 )
1606
1607 if update_options._dml_strategy not in (
1608 "orm",
1609 "auto",
1610 "bulk",
1611 "core_only",
1612 ):
1613 raise sa_exc.ArgumentError(
1614 "Valid strategies for ORM UPDATE strategy "
1615 "are 'orm', 'auto', 'bulk', 'core_only'"
1616 )
1617
1618 result: _result.Result[Any]
1619
1620 if update_options._dml_strategy == "bulk":
1621 enable_check_rowcount = not statement._where_criteria
1622
1623 assert update_options._synchronize_session != "fetch"
1624
1625 if (
1626 statement._where_criteria
1627 and update_options._synchronize_session == "evaluate"
1628 ):
1629 raise sa_exc.InvalidRequestError(
1630 "bulk synchronize of persistent objects not supported "
1631 "when using bulk update with additional WHERE "
1632 "criteria right now. add synchronize_session=None "
1633 "execution option to bypass synchronize of persistent "
1634 "objects."
1635 )
1636 mapper = update_options._subject_mapper
1637 assert mapper is not None
1638 assert session._transaction is not None
1639 result = _bulk_update(
1640 mapper,
1641 cast(
1642 "Iterable[Dict[str, Any]]",
1643 [params] if isinstance(params, dict) else params,
1644 ),
1645 session._transaction,
1646 isstates=False,
1647 update_changed_only=False,
1648 use_orm_update_stmt=statement,
1649 enable_check_rowcount=enable_check_rowcount,
1650 )
1651 return cls.orm_setup_cursor_result(
1652 session,
1653 statement,
1654 params,
1655 execution_options,
1656 bind_arguments,
1657 result,
1658 )
1659 else:
1660 return super().orm_execute_statement(
1661 session,
1662 statement,
1663 params,
1664 execution_options,
1665 bind_arguments,
1666 conn,
1667 )
1668
1669 @classmethod
1670 def can_use_returning(
1671 cls,
1672 dialect: Dialect,
1673 mapper: Mapper[Any],
1674 *,
1675 is_multitable: bool = False,
1676 is_update_from: bool = False,
1677 is_delete_using: bool = False,
1678 is_executemany: bool = False,
1679 ) -> bool:
1680 # normal answer for "should we use RETURNING" at all.
1681 normal_answer = (
1682 dialect.update_returning and mapper.local_table.implicit_returning
1683 )
1684 if not normal_answer:
1685 return False
1686
1687 if is_executemany:
1688 return dialect.update_executemany_returning
1689
1690 # these workarounds are currently hypothetical for UPDATE,
1691 # unlike DELETE where they impact MariaDB
1692 if is_update_from:
1693 return dialect.update_returning_multifrom
1694
1695 elif is_multitable and not dialect.update_returning_multifrom:
1696 raise sa_exc.CompileError(
1697 f'Dialect "{dialect.name}" does not support RETURNING '
1698 "with UPDATE..FROM; for synchronize_session='fetch', "
1699 "please add the additional execution option "
1700 "'is_update_from=True' to the statement to indicate that "
1701 "a separate SELECT should be used for this backend."
1702 )
1703
1704 return True
1705
1706 @classmethod
1707 def _do_post_synchronize_bulk_evaluate(
1708 cls, session, params, result, update_options
1709 ):
1710 if not params:
1711 return
1712
1713 mapper = update_options._subject_mapper
1714 pk_keys = [prop.key for prop in mapper._identity_key_props]
1715
1716 identity_map = session.identity_map
1717
1718 for param in params:
1719 identity_key = mapper.identity_key_from_primary_key(
1720 (param[key] for key in pk_keys),
1721 update_options._identity_token,
1722 )
1723 state = identity_map.fast_get_state(identity_key)
1724 if not state:
1725 continue
1726
1727 evaluated_keys = set(param).difference(pk_keys)
1728
1729 dict_ = state.dict
1730 # only evaluate unmodified attributes
1731 to_evaluate = state.unmodified.intersection(evaluated_keys)
1732 for key in to_evaluate:
1733 if key in dict_:
1734 dict_[key] = param[key]
1735
1736 state.manager.dispatch.refresh(state, None, to_evaluate)
1737
1738 state._commit(dict_, list(to_evaluate))
1739
1740 # attributes that were formerly modified instead get expired.
1741 # this only gets hit if the session had pending changes
1742 # and autoflush were set to False.
1743 to_expire = evaluated_keys.intersection(dict_).difference(
1744 to_evaluate
1745 )
1746 if to_expire:
1747 state._expire_attributes(dict_, to_expire)
1748
1749 @classmethod
1750 def _do_post_synchronize_evaluate(
1751 cls, session, statement, result, update_options
1752 ):
1753 matched_objects = cls._get_matched_objects_on_criteria(
1754 update_options,
1755 session.identity_map.all_states(),
1756 )
1757
1758 cls._apply_update_set_values_to_objects(
1759 session,
1760 update_options,
1761 statement,
1762 result.context.compiled_parameters[0],
1763 [(obj, state, dict_) for obj, state, dict_, _ in matched_objects],
1764 result.prefetch_cols(),
1765 result.postfetch_cols(),
1766 )
1767
1768 @classmethod
1769 def _do_post_synchronize_fetch(
1770 cls, session, statement, result, update_options
1771 ):
1772 target_mapper = update_options._subject_mapper
1773
1774 returned_defaults_rows = result.returned_defaults_rows
1775 if returned_defaults_rows:
1776 pk_rows = cls._interpret_returning_rows(
1777 result, target_mapper, returned_defaults_rows
1778 )
1779 matched_rows = [
1780 tuple(row) + (update_options._identity_token,)
1781 for row in pk_rows
1782 ]
1783 else:
1784 matched_rows = update_options._matched_rows
1785
1786 objs = [
1787 session.identity_map[identity_key]
1788 for identity_key in [
1789 target_mapper.identity_key_from_primary_key(
1790 list(primary_key),
1791 identity_token=identity_token,
1792 )
1793 for primary_key, identity_token in [
1794 (row[0:-1], row[-1]) for row in matched_rows
1795 ]
1796 if update_options._identity_token is None
1797 or identity_token == update_options._identity_token
1798 ]
1799 if identity_key in session.identity_map
1800 ]
1801
1802 if not objs:
1803 return
1804
1805 cls._apply_update_set_values_to_objects(
1806 session,
1807 update_options,
1808 statement,
1809 result.context.compiled_parameters[0],
1810 [
1811 (
1812 obj,
1813 attributes.instance_state(obj),
1814 attributes.instance_dict(obj),
1815 )
1816 for obj in objs
1817 ],
1818 result.prefetch_cols(),
1819 result.postfetch_cols(),
1820 )
1821
1822 @classmethod
1823 def _apply_update_set_values_to_objects(
1824 cls,
1825 session,
1826 update_options,
1827 statement,
1828 effective_params,
1829 matched_objects,
1830 prefetch_cols,
1831 postfetch_cols,
1832 ):
1833 """apply values to objects derived from an update statement, e.g.
1834 UPDATE..SET <values>
1835
1836 """
1837
1838 mapper = update_options._subject_mapper
1839 target_cls = mapper.class_
1840 evaluator_compiler = evaluator._EvaluatorCompiler(target_cls)
1841 resolved_values = cls._get_resolved_values(mapper, statement)
1842 resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
1843 mapper, resolved_values
1844 )
1845 value_evaluators = {}
1846 for key, value in resolved_keys_as_propnames:
1847 try:
1848 _evaluator = evaluator_compiler.process(
1849 coercions.expect(roles.ExpressionElementRole, value)
1850 )
1851 except evaluator.UnevaluatableError:
1852 pass
1853 else:
1854 value_evaluators[key] = _evaluator
1855
1856 evaluated_keys = list(value_evaluators.keys())
1857 attrib = {k for k, v in resolved_keys_as_propnames}
1858
1859 states = set()
1860
1861 to_prefetch = {
1862 c
1863 for c in prefetch_cols
1864 if c.key in effective_params
1865 and c in mapper._columntoproperty
1866 and c.key not in evaluated_keys
1867 }
1868 to_expire = {
1869 mapper._columntoproperty[c].key
1870 for c in postfetch_cols
1871 if c in mapper._columntoproperty
1872 }.difference(evaluated_keys)
1873
1874 prefetch_transfer = [
1875 (mapper._columntoproperty[c].key, c.key) for c in to_prefetch
1876 ]
1877
1878 for obj, state, dict_ in matched_objects:
1879
1880 dict_.update(
1881 {
1882 col_to_prop: effective_params[c_key]
1883 for col_to_prop, c_key in prefetch_transfer
1884 }
1885 )
1886
1887 state._expire_attributes(state.dict, to_expire)
1888
1889 to_evaluate = state.unmodified.intersection(evaluated_keys)
1890
1891 for key in to_evaluate:
1892 if key in dict_:
1893 # only run eval for attributes that are present.
1894 dict_[key] = value_evaluators[key](obj)
1895
1896 state.manager.dispatch.refresh(state, None, to_evaluate)
1897
1898 state._commit(dict_, list(to_evaluate))
1899
1900 # attributes that were formerly modified instead get expired.
1901 # this only gets hit if the session had pending changes
1902 # and autoflush were set to False.
1903 to_expire = attrib.intersection(dict_).difference(to_evaluate)
1904 if to_expire:
1905 state._expire_attributes(dict_, to_expire)
1906
1907 states.add(state)
1908 session._register_altered(states)
1909
1910
1911@CompileState.plugin_for("orm", "delete")
1912class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
1913 @classmethod
1914 def create_for_statement(cls, statement, compiler, **kw):
1915 self = cls.__new__(cls)
1916
1917 dml_strategy = statement._annotations.get(
1918 "dml_strategy", "unspecified"
1919 )
1920
1921 if (
1922 dml_strategy == "core_only"
1923 or dml_strategy == "unspecified"
1924 and "parententity" not in statement.table._annotations
1925 ):
1926 DeleteDMLState.__init__(self, statement, compiler, **kw)
1927 return self
1928
1929 toplevel = not compiler.stack
1930
1931 orm_level_statement = statement
1932
1933 ext_info = statement.table._annotations["parententity"]
1934 self.mapper = mapper = ext_info.mapper
1935
1936 self._init_global_attributes(
1937 statement,
1938 compiler,
1939 toplevel=toplevel,
1940 process_criteria_for_toplevel=toplevel,
1941 )
1942
1943 new_stmt = statement._clone()
1944
1945 if new_stmt.table._annotations["parententity"] is mapper:
1946 new_stmt.table = mapper.local_table
1947
1948 new_crit = cls._adjust_for_extra_criteria(
1949 self.global_attributes, mapper
1950 )
1951 if new_crit:
1952 new_stmt = new_stmt.where(*new_crit)
1953
1954 # do this first as we need to determine if there is
1955 # DELETE..FROM
1956 DeleteDMLState.__init__(self, new_stmt, compiler, **kw)
1957
1958 use_supplemental_cols = False
1959
1960 if not toplevel:
1961 synchronize_session = None
1962 else:
1963 synchronize_session = compiler._annotations.get(
1964 "synchronize_session", None
1965 )
1966 can_use_returning = compiler._annotations.get(
1967 "can_use_returning", None
1968 )
1969 if can_use_returning is not False:
1970 # even though pre_exec has determined basic
1971 # can_use_returning for the dialect, if we are to use
1972 # RETURNING we need to run can_use_returning() at this level
1973 # unconditionally because is_delete_using was not known
1974 # at the pre_exec level
1975 can_use_returning = (
1976 synchronize_session == "fetch"
1977 and self.can_use_returning(
1978 compiler.dialect,
1979 mapper,
1980 is_multitable=self.is_multitable,
1981 is_delete_using=compiler._annotations.get(
1982 "is_delete_using", False
1983 ),
1984 )
1985 )
1986
1987 if can_use_returning:
1988 use_supplemental_cols = True
1989
1990 new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key)
1991
1992 if toplevel:
1993 new_stmt = self._setup_orm_returning(
1994 compiler,
1995 orm_level_statement,
1996 new_stmt,
1997 dml_mapper=mapper,
1998 use_supplemental_cols=use_supplemental_cols,
1999 )
2000
2001 self.statement = new_stmt
2002
2003 return self
2004
2005 @classmethod
2006 def orm_execute_statement(
2007 cls,
2008 session: Session,
2009 statement: dml.Delete,
2010 params: _CoreAnyExecuteParams,
2011 execution_options: OrmExecuteOptionsParameter,
2012 bind_arguments: _BindArguments,
2013 conn: Connection,
2014 ) -> _result.Result:
2015 update_options = execution_options.get(
2016 "_sa_orm_update_options", cls.default_update_options
2017 )
2018
2019 if update_options._dml_strategy == "bulk":
2020 raise sa_exc.InvalidRequestError(
2021 "Bulk ORM DELETE not supported right now. "
2022 "Statement may be invoked at the "
2023 "Core level using "
2024 "session.connection().execute(stmt, parameters)"
2025 )
2026
2027 if update_options._dml_strategy not in ("orm", "auto", "core_only"):
2028 raise sa_exc.ArgumentError(
2029 "Valid strategies for ORM DELETE strategy are 'orm', 'auto', "
2030 "'core_only'"
2031 )
2032
2033 return super().orm_execute_statement(
2034 session, statement, params, execution_options, bind_arguments, conn
2035 )
2036
2037 @classmethod
2038 def can_use_returning(
2039 cls,
2040 dialect: Dialect,
2041 mapper: Mapper[Any],
2042 *,
2043 is_multitable: bool = False,
2044 is_update_from: bool = False,
2045 is_delete_using: bool = False,
2046 is_executemany: bool = False,
2047 ) -> bool:
2048 # normal answer for "should we use RETURNING" at all.
2049 normal_answer = (
2050 dialect.delete_returning and mapper.local_table.implicit_returning
2051 )
2052 if not normal_answer:
2053 return False
2054
2055 # now get into special workarounds because MariaDB supports
2056 # DELETE...RETURNING but not DELETE...USING...RETURNING.
2057 if is_delete_using:
2058 # is_delete_using hint was passed. use
2059 # additional dialect feature (True for PG, False for MariaDB)
2060 return dialect.delete_returning_multifrom
2061
2062 elif is_multitable and not dialect.delete_returning_multifrom:
2063 # is_delete_using hint was not passed, but we determined
2064 # at compile time that this is in fact a DELETE..USING.
2065 # it's too late to continue since we did not pre-SELECT.
2066 # raise that we need that hint up front.
2067
2068 raise sa_exc.CompileError(
2069 f'Dialect "{dialect.name}" does not support RETURNING '
2070 "with DELETE..USING; for synchronize_session='fetch', "
2071 "please add the additional execution option "
2072 "'is_delete_using=True' to the statement to indicate that "
2073 "a separate SELECT should be used for this backend."
2074 )
2075
2076 return True
2077
2078 @classmethod
2079 def _do_post_synchronize_evaluate(
2080 cls, session, statement, result, update_options
2081 ):
2082 matched_objects = cls._get_matched_objects_on_criteria(
2083 update_options,
2084 session.identity_map.all_states(),
2085 )
2086
2087 to_delete = []
2088
2089 for _, state, dict_, is_partially_expired in matched_objects:
2090 if is_partially_expired:
2091 state._expire(dict_, session.identity_map._modified)
2092 else:
2093 to_delete.append(state)
2094
2095 if to_delete:
2096 session._remove_newly_deleted(to_delete)
2097
2098 @classmethod
2099 def _do_post_synchronize_fetch(
2100 cls, session, statement, result, update_options
2101 ):
2102 target_mapper = update_options._subject_mapper
2103
2104 returned_defaults_rows = result.returned_defaults_rows
2105
2106 if returned_defaults_rows:
2107 pk_rows = cls._interpret_returning_rows(
2108 result, target_mapper, returned_defaults_rows
2109 )
2110
2111 matched_rows = [
2112 tuple(row) + (update_options._identity_token,)
2113 for row in pk_rows
2114 ]
2115 else:
2116 matched_rows = update_options._matched_rows
2117
2118 for row in matched_rows:
2119 primary_key = row[0:-1]
2120 identity_token = row[-1]
2121
2122 # TODO: inline this and call remove_newly_deleted
2123 # once
2124 identity_key = target_mapper.identity_key_from_primary_key(
2125 list(primary_key),
2126 identity_token=identity_token,
2127 )
2128 if identity_key in session.identity_map:
2129 session._remove_newly_deleted(
2130 [
2131 attributes.instance_state(
2132 session.identity_map[identity_key]
2133 )
2134 ]
2135 )