1# orm/bulk_persistence.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: ignore-errors
8
9
10"""additional ORM persistence classes related to "bulk" operations,
11specifically outside of the flush() process.
12
13"""
14
15from __future__ import annotations
16
17from typing import Any
18from typing import cast
19from typing import Dict
20from typing import Iterable
21from typing import Optional
22from typing import overload
23from typing import TYPE_CHECKING
24from typing import TypeVar
25from typing import Union
26
27from . import attributes
28from . import context
29from . import evaluator
30from . import exc as orm_exc
31from . import loading
32from . import persistence
33from .base import NO_VALUE
34from .context import _AbstractORMCompileState
35from .context import _ORMFromStatementCompileState
36from .context import FromStatement
37from .context import QueryContext
38from .. import exc as sa_exc
39from .. import util
40from ..engine import Dialect
41from ..engine import result as _result
42from ..sql import coercions
43from ..sql import dml
44from ..sql import expression
45from ..sql import roles
46from ..sql import select
47from ..sql import sqltypes
48from ..sql.base import _entity_namespace_key
49from ..sql.base import CompileState
50from ..sql.base import Options
51from ..sql.dml import DeleteDMLState
52from ..sql.dml import InsertDMLState
53from ..sql.dml import UpdateDMLState
54from ..util import EMPTY_DICT
55from ..util.typing import Literal
56from ..util.typing import TupleAny
57from ..util.typing import Unpack
58
59if TYPE_CHECKING:
60 from ._typing import DMLStrategyArgument
61 from ._typing import OrmExecuteOptionsParameter
62 from ._typing import SynchronizeSessionArgument
63 from .mapper import Mapper
64 from .session import _BindArguments
65 from .session import ORMExecuteState
66 from .session import Session
67 from .session import SessionTransaction
68 from .state import InstanceState
69 from ..engine import Connection
70 from ..engine import cursor
71 from ..engine.interfaces import _CoreAnyExecuteParams
72
73_O = TypeVar("_O", bound=object)
74
75
76@overload
77def _bulk_insert(
78 mapper: Mapper[_O],
79 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
80 session_transaction: SessionTransaction,
81 *,
82 isstates: bool,
83 return_defaults: bool,
84 render_nulls: bool,
85 use_orm_insert_stmt: Literal[None] = ...,
86 execution_options: Optional[OrmExecuteOptionsParameter] = ...,
87) -> None: ...
88
89
90@overload
91def _bulk_insert(
92 mapper: Mapper[_O],
93 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
94 session_transaction: SessionTransaction,
95 *,
96 isstates: bool,
97 return_defaults: bool,
98 render_nulls: bool,
99 use_orm_insert_stmt: Optional[dml.Insert] = ...,
100 execution_options: Optional[OrmExecuteOptionsParameter] = ...,
101) -> cursor.CursorResult[Any]: ...
102
103
104def _bulk_insert(
105 mapper: Mapper[_O],
106 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
107 session_transaction: SessionTransaction,
108 *,
109 isstates: bool,
110 return_defaults: bool,
111 render_nulls: bool,
112 use_orm_insert_stmt: Optional[dml.Insert] = None,
113 execution_options: Optional[OrmExecuteOptionsParameter] = None,
114) -> Optional[cursor.CursorResult[Any]]:
115 base_mapper = mapper.base_mapper
116
117 if session_transaction.session.connection_callable:
118 raise NotImplementedError(
119 "connection_callable / per-instance sharding "
120 "not supported in bulk_insert()"
121 )
122
123 if isstates:
124 if TYPE_CHECKING:
125 mappings = cast(Iterable[InstanceState[_O]], mappings)
126
127 if return_defaults:
128 # list of states allows us to attach .key for return_defaults case
129 states = [(state, state.dict) for state in mappings]
130 mappings = [dict_ for (state, dict_) in states]
131 else:
132 mappings = [state.dict for state in mappings]
133 else:
134 if TYPE_CHECKING:
135 mappings = cast(Iterable[Dict[str, Any]], mappings)
136
137 if return_defaults:
138 # use dictionaries given, so that newly populated defaults
139 # can be delivered back to the caller (see #11661). This is **not**
140 # compatible with other use cases such as a session-executed
141 # insert() construct, as this will confuse the case of
142 # insert-per-subclass for joined inheritance cases (see
143 # test_bulk_statements.py::BulkDMLReturningJoinedInhTest).
144 #
145 # So in this conditional, we have **only** called
146 # session.bulk_insert_mappings() which does not have this
147 # requirement
148 mappings = list(mappings)
149 else:
150 # for all other cases we need to establish a local dictionary
151 # so that the incoming dictionaries aren't mutated
152 mappings = [dict(m) for m in mappings]
153 _expand_composites(mapper, mappings)
154
155 connection = session_transaction.connection(base_mapper)
156
157 return_result: Optional[cursor.CursorResult[Any]] = None
158
159 mappers_to_run = [
160 (table, mp)
161 for table, mp in base_mapper._sorted_tables.items()
162 if table in mapper._pks_by_table
163 ]
164
165 if return_defaults:
166 # not used by new-style bulk inserts, only used for legacy
167 bookkeeping = True
168 elif len(mappers_to_run) > 1:
169 # if we have more than one table, mapper to run where we will be
170 # either horizontally splicing, or copying values between tables,
171 # we need the "bookkeeping" / deterministic returning order
172 bookkeeping = True
173 else:
174 bookkeeping = False
175
176 for table, super_mapper in mappers_to_run:
177 # find bindparams in the statement. For bulk, we don't really know if
178 # a key in the params applies to a different table since we are
179 # potentially inserting for multiple tables here; looking at the
180 # bindparam() is a lot more direct. in most cases this will
181 # use _generate_cache_key() which is memoized, although in practice
182 # the ultimate statement that's executed is probably not the same
183 # object so that memoization might not matter much.
184 extra_bp_names = (
185 [
186 b.key
187 for b in use_orm_insert_stmt._get_embedded_bindparams()
188 if b.key in mappings[0]
189 ]
190 if use_orm_insert_stmt is not None
191 else ()
192 )
193
194 records = (
195 (
196 None,
197 state_dict,
198 params,
199 mapper,
200 connection,
201 value_params,
202 has_all_pks,
203 has_all_defaults,
204 )
205 for (
206 state,
207 state_dict,
208 params,
209 mp,
210 conn,
211 value_params,
212 has_all_pks,
213 has_all_defaults,
214 ) in persistence._collect_insert_commands(
215 table,
216 ((None, mapping, mapper, connection) for mapping in mappings),
217 bulk=True,
218 return_defaults=bookkeeping,
219 render_nulls=render_nulls,
220 include_bulk_keys=extra_bp_names,
221 )
222 )
223
224 result = persistence._emit_insert_statements(
225 base_mapper,
226 None,
227 super_mapper,
228 table,
229 records,
230 bookkeeping=bookkeeping,
231 use_orm_insert_stmt=use_orm_insert_stmt,
232 execution_options=execution_options,
233 )
234 if use_orm_insert_stmt is not None:
235 if not use_orm_insert_stmt._returning or return_result is None:
236 return_result = result
237 elif result.returns_rows:
238 assert bookkeeping
239 return_result = return_result.splice_horizontally(result)
240
241 if return_defaults and isstates:
242 identity_cls = mapper._identity_class
243 identity_props = [p.key for p in mapper._identity_key_props]
244 for state, dict_ in states:
245 state.key = (
246 identity_cls,
247 tuple([dict_[key] for key in identity_props]),
248 None,
249 )
250
251 if use_orm_insert_stmt is not None:
252 assert return_result is not None
253 return return_result
254
255
256@overload
257def _bulk_update(
258 mapper: Mapper[Any],
259 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
260 session_transaction: SessionTransaction,
261 *,
262 isstates: bool,
263 update_changed_only: bool,
264 use_orm_update_stmt: Literal[None] = ...,
265 enable_check_rowcount: bool = True,
266) -> None: ...
267
268
269@overload
270def _bulk_update(
271 mapper: Mapper[Any],
272 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
273 session_transaction: SessionTransaction,
274 *,
275 isstates: bool,
276 update_changed_only: bool,
277 use_orm_update_stmt: Optional[dml.Update] = ...,
278 enable_check_rowcount: bool = True,
279) -> _result.Result[Unpack[TupleAny]]: ...
280
281
282def _bulk_update(
283 mapper: Mapper[Any],
284 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
285 session_transaction: SessionTransaction,
286 *,
287 isstates: bool,
288 update_changed_only: bool,
289 use_orm_update_stmt: Optional[dml.Update] = None,
290 enable_check_rowcount: bool = True,
291) -> Optional[_result.Result[Unpack[TupleAny]]]:
292 base_mapper = mapper.base_mapper
293
294 search_keys = mapper._primary_key_propkeys
295 if mapper._version_id_prop:
296 search_keys = {mapper._version_id_prop.key}.union(search_keys)
297
298 def _changed_dict(mapper, state):
299 return {
300 k: v
301 for k, v in state.dict.items()
302 if k in state.committed_state or k in search_keys
303 }
304
305 if isstates:
306 if update_changed_only:
307 mappings = [_changed_dict(mapper, state) for state in mappings]
308 else:
309 mappings = [state.dict for state in mappings]
310 else:
311 mappings = [dict(m) for m in mappings]
312 _expand_composites(mapper, mappings)
313
314 if session_transaction.session.connection_callable:
315 raise NotImplementedError(
316 "connection_callable / per-instance sharding "
317 "not supported in bulk_update()"
318 )
319
320 connection = session_transaction.connection(base_mapper)
321
322 # find bindparams in the statement. see _bulk_insert for similar
323 # notes for the insert case
324 extra_bp_names = (
325 [
326 b.key
327 for b in use_orm_update_stmt._get_embedded_bindparams()
328 if b.key in mappings[0]
329 ]
330 if use_orm_update_stmt is not None
331 else ()
332 )
333
334 for table, super_mapper in base_mapper._sorted_tables.items():
335 if not mapper.isa(super_mapper) or table not in mapper._pks_by_table:
336 continue
337
338 records = persistence._collect_update_commands(
339 None,
340 table,
341 (
342 (
343 None,
344 mapping,
345 mapper,
346 connection,
347 (
348 mapping[mapper._version_id_prop.key]
349 if mapper._version_id_prop
350 else None
351 ),
352 )
353 for mapping in mappings
354 ),
355 bulk=True,
356 use_orm_update_stmt=use_orm_update_stmt,
357 include_bulk_keys=extra_bp_names,
358 )
359 persistence._emit_update_statements(
360 base_mapper,
361 None,
362 super_mapper,
363 table,
364 records,
365 bookkeeping=False,
366 use_orm_update_stmt=use_orm_update_stmt,
367 enable_check_rowcount=enable_check_rowcount,
368 )
369
370 if use_orm_update_stmt is not None:
371 return _result.null_result()
372
373
374def _expand_composites(mapper, mappings):
375 composite_attrs = mapper.composites
376 if not composite_attrs:
377 return
378
379 composite_keys = set(composite_attrs.keys())
380 populators = {
381 key: composite_attrs[key]._populate_composite_bulk_save_mappings_fn()
382 for key in composite_keys
383 }
384 for mapping in mappings:
385 for key in composite_keys.intersection(mapping):
386 populators[key](mapping)
387
388
389class _ORMDMLState(_AbstractORMCompileState):
390 is_dml_returning = True
391 from_statement_ctx: Optional[_ORMFromStatementCompileState] = None
392
393 @classmethod
394 def _get_orm_crud_kv_pairs(
395 cls, mapper, statement, kv_iterator, needs_to_be_cacheable
396 ):
397 core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs
398
399 for k, v in kv_iterator:
400 k = coercions.expect(roles.DMLColumnRole, k)
401
402 if isinstance(k, str):
403 desc = _entity_namespace_key(mapper, k, default=NO_VALUE)
404 if desc is NO_VALUE:
405 yield (
406 coercions.expect(roles.DMLColumnRole, k),
407 (
408 coercions.expect(
409 roles.ExpressionElementRole,
410 v,
411 type_=sqltypes.NullType(),
412 is_crud=True,
413 )
414 if needs_to_be_cacheable
415 else v
416 ),
417 )
418 else:
419 yield from core_get_crud_kv_pairs(
420 statement,
421 desc._bulk_update_tuples(v),
422 needs_to_be_cacheable,
423 )
424 elif "entity_namespace" in k._annotations:
425 k_anno = k._annotations
426 attr = _entity_namespace_key(
427 k_anno["entity_namespace"], k_anno["proxy_key"]
428 )
429 yield from core_get_crud_kv_pairs(
430 statement,
431 attr._bulk_update_tuples(v),
432 needs_to_be_cacheable,
433 )
434 else:
435 yield (
436 k,
437 (
438 v
439 if not needs_to_be_cacheable
440 else coercions.expect(
441 roles.ExpressionElementRole,
442 v,
443 type_=sqltypes.NullType(),
444 is_crud=True,
445 )
446 ),
447 )
448
449 @classmethod
450 def _get_dml_plugin_subject(cls, statement):
451 plugin_subject = statement.table._propagate_attrs.get("plugin_subject")
452
453 if (
454 not plugin_subject
455 or not plugin_subject.mapper
456 or plugin_subject
457 is not statement._propagate_attrs["plugin_subject"]
458 ):
459 return None
460 return plugin_subject
461
462 @classmethod
463 def _get_multi_crud_kv_pairs(cls, statement, kv_iterator):
464 plugin_subject = cls._get_dml_plugin_subject(statement)
465
466 if not plugin_subject:
467 return UpdateDMLState._get_multi_crud_kv_pairs(
468 statement, kv_iterator
469 )
470
471 return [
472 dict(
473 cls._get_orm_crud_kv_pairs(
474 plugin_subject.mapper, statement, value_dict.items(), False
475 )
476 )
477 for value_dict in kv_iterator
478 ]
479
480 @classmethod
481 def _get_crud_kv_pairs(cls, statement, kv_iterator, needs_to_be_cacheable):
482 assert (
483 needs_to_be_cacheable
484 ), "no test coverage for needs_to_be_cacheable=False"
485
486 plugin_subject = cls._get_dml_plugin_subject(statement)
487
488 if not plugin_subject:
489 return UpdateDMLState._get_crud_kv_pairs(
490 statement, kv_iterator, needs_to_be_cacheable
491 )
492 return list(
493 cls._get_orm_crud_kv_pairs(
494 plugin_subject.mapper,
495 statement,
496 kv_iterator,
497 needs_to_be_cacheable,
498 )
499 )
500
501 @classmethod
502 def get_entity_description(cls, statement):
503 ext_info = statement.table._annotations["parententity"]
504 mapper = ext_info.mapper
505 if ext_info.is_aliased_class:
506 _label_name = ext_info.name
507 else:
508 _label_name = mapper.class_.__name__
509
510 return {
511 "name": _label_name,
512 "type": mapper.class_,
513 "expr": ext_info.entity,
514 "entity": ext_info.entity,
515 "table": mapper.local_table,
516 }
517
518 @classmethod
519 def get_returning_column_descriptions(cls, statement):
520 def _ent_for_col(c):
521 return c._annotations.get("parententity", None)
522
523 def _attr_for_col(c, ent):
524 if ent is None:
525 return c
526 proxy_key = c._annotations.get("proxy_key", None)
527 if not proxy_key:
528 return c
529 else:
530 return getattr(ent.entity, proxy_key, c)
531
532 return [
533 {
534 "name": c.key,
535 "type": c.type,
536 "expr": _attr_for_col(c, ent),
537 "aliased": ent.is_aliased_class,
538 "entity": ent.entity,
539 }
540 for c, ent in [
541 (c, _ent_for_col(c)) for c in statement._all_selected_columns
542 ]
543 ]
544
545 def _setup_orm_returning(
546 self,
547 compiler,
548 orm_level_statement,
549 dml_level_statement,
550 dml_mapper,
551 *,
552 use_supplemental_cols=True,
553 ):
554 """establish ORM column handlers for an INSERT, UPDATE, or DELETE
555 which uses explicit returning().
556
557 called within compilation level create_for_statement.
558
559 The _return_orm_returning() method then receives the Result
560 after the statement was executed, and applies ORM loading to the
561 state that we first established here.
562
563 """
564
565 if orm_level_statement._returning:
566 fs = FromStatement(
567 orm_level_statement._returning,
568 dml_level_statement,
569 _adapt_on_names=False,
570 )
571 fs = fs.execution_options(**orm_level_statement._execution_options)
572 fs = fs.options(*orm_level_statement._with_options)
573 self.select_statement = fs
574 self.from_statement_ctx = fsc = (
575 _ORMFromStatementCompileState.create_for_statement(
576 fs, compiler
577 )
578 )
579 fsc.setup_dml_returning_compile_state(dml_mapper)
580
581 dml_level_statement = dml_level_statement._generate()
582 dml_level_statement._returning = ()
583
584 cols_to_return = [c for c in fsc.primary_columns if c is not None]
585
586 # since we are splicing result sets together, make sure there
587 # are columns of some kind returned in each result set
588 if not cols_to_return:
589 cols_to_return.extend(dml_mapper.primary_key)
590
591 if use_supplemental_cols:
592 dml_level_statement = dml_level_statement.return_defaults(
593 # this is a little weird looking, but by passing
594 # primary key as the main list of cols, this tells
595 # return_defaults to omit server-default cols (and
596 # actually all cols, due to some weird thing we should
597 # clean up in crud.py).
598 # Since we have cols_to_return, just return what we asked
599 # for (plus primary key, which ORM persistence needs since
600 # we likely set bookkeeping=True here, which is another
601 # whole thing...). We dont want to clutter the
602 # statement up with lots of other cols the user didn't
603 # ask for. see #9685
604 *dml_mapper.primary_key,
605 supplemental_cols=cols_to_return,
606 )
607 else:
608 dml_level_statement = dml_level_statement.returning(
609 *cols_to_return
610 )
611
612 return dml_level_statement
613
614 @classmethod
615 def _return_orm_returning(
616 cls,
617 session,
618 statement,
619 params,
620 execution_options,
621 bind_arguments,
622 result,
623 ):
624 execution_context = result.context
625 compile_state = execution_context.compiled.compile_state
626
627 if (
628 compile_state.from_statement_ctx
629 and not compile_state.from_statement_ctx.compile_options._is_star
630 ):
631 load_options = execution_options.get(
632 "_sa_orm_load_options", QueryContext.default_load_options
633 )
634
635 querycontext = QueryContext(
636 compile_state.from_statement_ctx,
637 compile_state.select_statement,
638 statement,
639 params,
640 session,
641 load_options,
642 execution_options,
643 bind_arguments,
644 )
645 return loading.instances(result, querycontext)
646 else:
647 return result
648
649
650class _BulkUDCompileState(_ORMDMLState):
651 class default_update_options(Options):
652 _dml_strategy: DMLStrategyArgument = "auto"
653 _synchronize_session: SynchronizeSessionArgument = "auto"
654 _can_use_returning: bool = False
655 _is_delete_using: bool = False
656 _is_update_from: bool = False
657 _autoflush: bool = True
658 _subject_mapper: Optional[Mapper[Any]] = None
659 _resolved_values = EMPTY_DICT
660 _eval_condition = None
661 _matched_rows = None
662 _identity_token = None
663 _populate_existing: bool = False
664
665 @classmethod
666 def can_use_returning(
667 cls,
668 dialect: Dialect,
669 mapper: Mapper[Any],
670 *,
671 is_multitable: bool = False,
672 is_update_from: bool = False,
673 is_delete_using: bool = False,
674 is_executemany: bool = False,
675 ) -> bool:
676 raise NotImplementedError()
677
678 @classmethod
679 def orm_pre_session_exec(
680 cls,
681 session,
682 statement,
683 params,
684 execution_options,
685 bind_arguments,
686 is_pre_event,
687 ):
688 (
689 update_options,
690 execution_options,
691 ) = _BulkUDCompileState.default_update_options.from_execution_options(
692 "_sa_orm_update_options",
693 {
694 "synchronize_session",
695 "autoflush",
696 "populate_existing",
697 "identity_token",
698 "is_delete_using",
699 "is_update_from",
700 "dml_strategy",
701 },
702 execution_options,
703 statement._execution_options,
704 )
705 bind_arguments["clause"] = statement
706 try:
707 plugin_subject = statement._propagate_attrs["plugin_subject"]
708 except KeyError:
709 assert False, "statement had 'orm' plugin but no plugin_subject"
710 else:
711 if plugin_subject:
712 bind_arguments["mapper"] = plugin_subject.mapper
713 update_options += {"_subject_mapper": plugin_subject.mapper}
714
715 if "parententity" not in statement.table._annotations:
716 update_options += {"_dml_strategy": "core_only"}
717 elif not isinstance(params, list):
718 if update_options._dml_strategy == "auto":
719 update_options += {"_dml_strategy": "orm"}
720 elif update_options._dml_strategy == "bulk":
721 raise sa_exc.InvalidRequestError(
722 'Can\'t use "bulk" ORM insert strategy without '
723 "passing separate parameters"
724 )
725 else:
726 if update_options._dml_strategy == "auto":
727 update_options += {"_dml_strategy": "bulk"}
728
729 sync = update_options._synchronize_session
730 if sync is not None:
731 if sync not in ("auto", "evaluate", "fetch", False):
732 raise sa_exc.ArgumentError(
733 "Valid strategies for session synchronization "
734 "are 'auto', 'evaluate', 'fetch', False"
735 )
736 if update_options._dml_strategy == "bulk" and sync == "fetch":
737 raise sa_exc.InvalidRequestError(
738 "The 'fetch' synchronization strategy is not available "
739 "for 'bulk' ORM updates (i.e. multiple parameter sets)"
740 )
741
742 if not is_pre_event:
743 if update_options._autoflush:
744 session._autoflush()
745
746 if update_options._dml_strategy == "orm":
747 if update_options._synchronize_session == "auto":
748 update_options = cls._do_pre_synchronize_auto(
749 session,
750 statement,
751 params,
752 execution_options,
753 bind_arguments,
754 update_options,
755 )
756 elif update_options._synchronize_session == "evaluate":
757 update_options = cls._do_pre_synchronize_evaluate(
758 session,
759 statement,
760 params,
761 execution_options,
762 bind_arguments,
763 update_options,
764 )
765 elif update_options._synchronize_session == "fetch":
766 update_options = cls._do_pre_synchronize_fetch(
767 session,
768 statement,
769 params,
770 execution_options,
771 bind_arguments,
772 update_options,
773 )
774 elif update_options._dml_strategy == "bulk":
775 if update_options._synchronize_session == "auto":
776 update_options += {"_synchronize_session": "evaluate"}
777
778 # indicators from the "pre exec" step that are then
779 # added to the DML statement, which will also be part of the cache
780 # key. The compile level create_for_statement() method will then
781 # consume these at compiler time.
782 statement = statement._annotate(
783 {
784 "synchronize_session": update_options._synchronize_session,
785 "is_delete_using": update_options._is_delete_using,
786 "is_update_from": update_options._is_update_from,
787 "dml_strategy": update_options._dml_strategy,
788 "can_use_returning": update_options._can_use_returning,
789 }
790 )
791
792 return (
793 statement,
794 util.immutabledict(execution_options).union(
795 {"_sa_orm_update_options": update_options}
796 ),
797 )
798
799 @classmethod
800 def orm_setup_cursor_result(
801 cls,
802 session,
803 statement,
804 params,
805 execution_options,
806 bind_arguments,
807 result,
808 ):
809 # this stage of the execution is called after the
810 # do_orm_execute event hook. meaning for an extension like
811 # horizontal sharding, this step happens *within* the horizontal
812 # sharding event handler which calls session.execute() re-entrantly
813 # and will occur for each backend individually.
814 # the sharding extension then returns its own merged result from the
815 # individual ones we return here.
816
817 update_options = execution_options["_sa_orm_update_options"]
818 if update_options._dml_strategy == "orm":
819 if update_options._synchronize_session == "evaluate":
820 cls._do_post_synchronize_evaluate(
821 session, statement, result, update_options
822 )
823 elif update_options._synchronize_session == "fetch":
824 cls._do_post_synchronize_fetch(
825 session, statement, result, update_options
826 )
827 elif update_options._dml_strategy == "bulk":
828 if update_options._synchronize_session == "evaluate":
829 cls._do_post_synchronize_bulk_evaluate(
830 session, params, result, update_options
831 )
832 return result
833
834 return cls._return_orm_returning(
835 session,
836 statement,
837 params,
838 execution_options,
839 bind_arguments,
840 result,
841 )
842
843 @classmethod
844 def _adjust_for_extra_criteria(cls, global_attributes, ext_info):
845 """Apply extra criteria filtering.
846
847 For all distinct single-table-inheritance mappers represented in the
848 table being updated or deleted, produce additional WHERE criteria such
849 that only the appropriate subtypes are selected from the total results.
850
851 Additionally, add WHERE criteria originating from LoaderCriteriaOptions
852 collected from the statement.
853
854 """
855
856 return_crit = ()
857
858 adapter = ext_info._adapter if ext_info.is_aliased_class else None
859
860 if (
861 "additional_entity_criteria",
862 ext_info.mapper,
863 ) in global_attributes:
864 return_crit += tuple(
865 ae._resolve_where_criteria(ext_info)
866 for ae in global_attributes[
867 ("additional_entity_criteria", ext_info.mapper)
868 ]
869 if ae.include_aliases or ae.entity is ext_info
870 )
871
872 if ext_info.mapper._single_table_criterion is not None:
873 return_crit += (ext_info.mapper._single_table_criterion,)
874
875 if adapter:
876 return_crit = tuple(adapter.traverse(crit) for crit in return_crit)
877
878 return return_crit
879
880 @classmethod
881 def _interpret_returning_rows(cls, result, mapper, rows):
882 """return rows that indicate PK cols in mapper.primary_key position
883 for RETURNING rows.
884
885 Prior to 2.0.36, this method seemed to be written for some kind of
886 inheritance scenario but the scenario was unused for actual joined
887 inheritance, and the function instead seemed to perform some kind of
888 partial translation that would remove non-PK cols if the PK cols
889 happened to be first in the row, but not otherwise. The joined
890 inheritance walk feature here seems to have never been used as it was
891 always skipped by the "local_table" check.
892
893 As of 2.0.36 the function strips away non-PK cols and provides the
894 PK cols for the table in mapper PK order.
895
896 """
897
898 try:
899 if mapper.local_table is not mapper.base_mapper.local_table:
900 # TODO: dive more into how a local table PK is used for fetch
901 # sync, not clear if this is correct as it depends on the
902 # downstream routine to fetch rows using
903 # local_table.primary_key order
904 pk_keys = result._tuple_getter(mapper.local_table.primary_key)
905 else:
906 pk_keys = result._tuple_getter(mapper.primary_key)
907 except KeyError:
908 # can't use these rows, they don't have PK cols in them
909 # this is an unusual case where the user would have used
910 # .return_defaults()
911 return []
912
913 return [pk_keys(row) for row in rows]
914
915 @classmethod
916 def _get_matched_objects_on_criteria(cls, update_options, states):
917 mapper = update_options._subject_mapper
918 eval_condition = update_options._eval_condition
919
920 raw_data = [
921 (state.obj(), state, state.dict)
922 for state in states
923 if state.mapper.isa(mapper) and not state.expired
924 ]
925
926 identity_token = update_options._identity_token
927 if identity_token is not None:
928 raw_data = [
929 (obj, state, dict_)
930 for obj, state, dict_ in raw_data
931 if state.identity_token == identity_token
932 ]
933
934 result = []
935 for obj, state, dict_ in raw_data:
936 evaled_condition = eval_condition(obj)
937
938 # caution: don't use "in ()" or == here, _EXPIRE_OBJECT
939 # evaluates as True for all comparisons
940 if (
941 evaled_condition is True
942 or evaled_condition is evaluator._EXPIRED_OBJECT
943 ):
944 result.append(
945 (
946 obj,
947 state,
948 dict_,
949 evaled_condition is evaluator._EXPIRED_OBJECT,
950 )
951 )
952 return result
953
954 @classmethod
955 def _eval_condition_from_statement(cls, update_options, statement):
956 mapper = update_options._subject_mapper
957 target_cls = mapper.class_
958
959 evaluator_compiler = evaluator._EvaluatorCompiler(target_cls)
960 crit = ()
961 if statement._where_criteria:
962 crit += statement._where_criteria
963
964 global_attributes = {}
965 for opt in statement._with_options:
966 if opt._is_criteria_option:
967 opt.get_global_criteria(global_attributes)
968
969 if global_attributes:
970 crit += cls._adjust_for_extra_criteria(global_attributes, mapper)
971
972 if crit:
973 eval_condition = evaluator_compiler.process(*crit)
974 else:
975 # workaround for mypy https://github.com/python/mypy/issues/14027
976 def _eval_condition(obj):
977 return True
978
979 eval_condition = _eval_condition
980
981 return eval_condition
982
983 @classmethod
984 def _do_pre_synchronize_auto(
985 cls,
986 session,
987 statement,
988 params,
989 execution_options,
990 bind_arguments,
991 update_options,
992 ):
993 """setup auto sync strategy
994
995
996 "auto" checks if we can use "evaluate" first, then falls back
997 to "fetch"
998
999 evaluate is vastly more efficient for the common case
1000 where session is empty, only has a few objects, and the UPDATE
1001 statement can potentially match thousands/millions of rows.
1002
1003 OTOH more complex criteria that fails to work with "evaluate"
1004 we would hope usually correlates with fewer net rows.
1005
1006 """
1007
1008 try:
1009 eval_condition = cls._eval_condition_from_statement(
1010 update_options, statement
1011 )
1012
1013 except evaluator.UnevaluatableError:
1014 pass
1015 else:
1016 return update_options + {
1017 "_eval_condition": eval_condition,
1018 "_synchronize_session": "evaluate",
1019 }
1020
1021 update_options += {"_synchronize_session": "fetch"}
1022 return cls._do_pre_synchronize_fetch(
1023 session,
1024 statement,
1025 params,
1026 execution_options,
1027 bind_arguments,
1028 update_options,
1029 )
1030
1031 @classmethod
1032 def _do_pre_synchronize_evaluate(
1033 cls,
1034 session,
1035 statement,
1036 params,
1037 execution_options,
1038 bind_arguments,
1039 update_options,
1040 ):
1041 try:
1042 eval_condition = cls._eval_condition_from_statement(
1043 update_options, statement
1044 )
1045
1046 except evaluator.UnevaluatableError as err:
1047 raise sa_exc.InvalidRequestError(
1048 'Could not evaluate current criteria in Python: "%s". '
1049 "Specify 'fetch' or False for the "
1050 "synchronize_session execution option." % err
1051 ) from err
1052
1053 return update_options + {
1054 "_eval_condition": eval_condition,
1055 }
1056
1057 @classmethod
1058 def _get_resolved_values(cls, mapper, statement):
1059 if statement._multi_values:
1060 return []
1061 elif statement._values:
1062 return list(statement._values.items())
1063 else:
1064 return []
1065
1066 @classmethod
1067 def _resolved_keys_as_propnames(cls, mapper, resolved_values):
1068 values = []
1069 for k, v in resolved_values:
1070 if mapper and isinstance(k, expression.ColumnElement):
1071 try:
1072 attr = mapper._columntoproperty[k]
1073 except orm_exc.UnmappedColumnError:
1074 pass
1075 else:
1076 values.append((attr.key, v))
1077 else:
1078 raise sa_exc.InvalidRequestError(
1079 "Attribute name not found, can't be "
1080 "synchronized back to objects: %r" % k
1081 )
1082 return values
1083
1084 @classmethod
1085 def _do_pre_synchronize_fetch(
1086 cls,
1087 session,
1088 statement,
1089 params,
1090 execution_options,
1091 bind_arguments,
1092 update_options,
1093 ):
1094 mapper = update_options._subject_mapper
1095
1096 select_stmt = (
1097 select(*(mapper.primary_key + (mapper.select_identity_token,)))
1098 .select_from(mapper)
1099 .options(*statement._with_options)
1100 )
1101 select_stmt._where_criteria = statement._where_criteria
1102
1103 # conditionally run the SELECT statement for pre-fetch, testing the
1104 # "bind" for if we can use RETURNING or not using the do_orm_execute
1105 # event. If RETURNING is available, the do_orm_execute event
1106 # will cancel the SELECT from being actually run.
1107 #
1108 # The way this is organized seems strange, why don't we just
1109 # call can_use_returning() before invoking the statement and get
1110 # answer?, why does this go through the whole execute phase using an
1111 # event? Answer: because we are integrating with extensions such
1112 # as the horizontal sharding extention that "multiplexes" an individual
1113 # statement run through multiple engines, and it uses
1114 # do_orm_execute() to do that.
1115
1116 can_use_returning = None
1117
1118 def skip_for_returning(orm_context: ORMExecuteState) -> Any:
1119 bind = orm_context.session.get_bind(**orm_context.bind_arguments)
1120 nonlocal can_use_returning
1121
1122 per_bind_result = cls.can_use_returning(
1123 bind.dialect,
1124 mapper,
1125 is_update_from=update_options._is_update_from,
1126 is_delete_using=update_options._is_delete_using,
1127 is_executemany=orm_context.is_executemany,
1128 )
1129
1130 if can_use_returning is not None:
1131 if can_use_returning != per_bind_result:
1132 raise sa_exc.InvalidRequestError(
1133 "For synchronize_session='fetch', can't mix multiple "
1134 "backends where some support RETURNING and others "
1135 "don't"
1136 )
1137 elif orm_context.is_executemany and not per_bind_result:
1138 raise sa_exc.InvalidRequestError(
1139 "For synchronize_session='fetch', can't use multiple "
1140 "parameter sets in ORM mode, which this backend does not "
1141 "support with RETURNING"
1142 )
1143 else:
1144 can_use_returning = per_bind_result
1145
1146 if per_bind_result:
1147 return _result.null_result()
1148 else:
1149 return None
1150
1151 result = session.execute(
1152 select_stmt,
1153 params,
1154 execution_options=execution_options,
1155 bind_arguments=bind_arguments,
1156 _add_event=skip_for_returning,
1157 )
1158 matched_rows = result.fetchall()
1159
1160 return update_options + {
1161 "_matched_rows": matched_rows,
1162 "_can_use_returning": can_use_returning,
1163 }
1164
1165
1166@CompileState.plugin_for("orm", "insert")
1167class _BulkORMInsert(_ORMDMLState, InsertDMLState):
1168 class default_insert_options(Options):
1169 _dml_strategy: DMLStrategyArgument = "auto"
1170 _render_nulls: bool = False
1171 _return_defaults: bool = False
1172 _subject_mapper: Optional[Mapper[Any]] = None
1173 _autoflush: bool = True
1174 _populate_existing: bool = False
1175
1176 select_statement: Optional[FromStatement] = None
1177
1178 @classmethod
1179 def orm_pre_session_exec(
1180 cls,
1181 session,
1182 statement,
1183 params,
1184 execution_options,
1185 bind_arguments,
1186 is_pre_event,
1187 ):
1188 (
1189 insert_options,
1190 execution_options,
1191 ) = _BulkORMInsert.default_insert_options.from_execution_options(
1192 "_sa_orm_insert_options",
1193 {"dml_strategy", "autoflush", "populate_existing", "render_nulls"},
1194 execution_options,
1195 statement._execution_options,
1196 )
1197 bind_arguments["clause"] = statement
1198 try:
1199 plugin_subject = statement._propagate_attrs["plugin_subject"]
1200 except KeyError:
1201 assert False, "statement had 'orm' plugin but no plugin_subject"
1202 else:
1203 if plugin_subject:
1204 bind_arguments["mapper"] = plugin_subject.mapper
1205 insert_options += {"_subject_mapper": plugin_subject.mapper}
1206
1207 if not params:
1208 if insert_options._dml_strategy == "auto":
1209 insert_options += {"_dml_strategy": "orm"}
1210 elif insert_options._dml_strategy == "bulk":
1211 raise sa_exc.InvalidRequestError(
1212 'Can\'t use "bulk" ORM insert strategy without '
1213 "passing separate parameters"
1214 )
1215 else:
1216 if insert_options._dml_strategy == "auto":
1217 insert_options += {"_dml_strategy": "bulk"}
1218
1219 if insert_options._dml_strategy != "raw":
1220 # for ORM object loading, like ORMContext, we have to disable
1221 # result set adapt_to_context, because we will be generating a
1222 # new statement with specific columns that's cached inside of
1223 # an ORMFromStatementCompileState, which we will re-use for
1224 # each result.
1225 if not execution_options:
1226 execution_options = context._orm_load_exec_options
1227 else:
1228 execution_options = execution_options.union(
1229 context._orm_load_exec_options
1230 )
1231
1232 if not is_pre_event and insert_options._autoflush:
1233 session._autoflush()
1234
1235 statement = statement._annotate(
1236 {"dml_strategy": insert_options._dml_strategy}
1237 )
1238
1239 return (
1240 statement,
1241 util.immutabledict(execution_options).union(
1242 {"_sa_orm_insert_options": insert_options}
1243 ),
1244 )
1245
1246 @classmethod
1247 def orm_execute_statement(
1248 cls,
1249 session: Session,
1250 statement: dml.Insert,
1251 params: _CoreAnyExecuteParams,
1252 execution_options: OrmExecuteOptionsParameter,
1253 bind_arguments: _BindArguments,
1254 conn: Connection,
1255 ) -> _result.Result:
1256 insert_options = execution_options.get(
1257 "_sa_orm_insert_options", cls.default_insert_options
1258 )
1259
1260 if insert_options._dml_strategy not in (
1261 "raw",
1262 "bulk",
1263 "orm",
1264 "auto",
1265 ):
1266 raise sa_exc.ArgumentError(
1267 "Valid strategies for ORM insert strategy "
1268 "are 'raw', 'orm', 'bulk', 'auto"
1269 )
1270
1271 result: _result.Result[Unpack[TupleAny]]
1272
1273 if insert_options._dml_strategy == "raw":
1274 result = conn.execute(
1275 statement, params or {}, execution_options=execution_options
1276 )
1277 return result
1278
1279 if insert_options._dml_strategy == "bulk":
1280 mapper = insert_options._subject_mapper
1281
1282 if (
1283 statement._post_values_clause is not None
1284 and mapper._multiple_persistence_tables
1285 ):
1286 raise sa_exc.InvalidRequestError(
1287 "bulk INSERT with a 'post values' clause "
1288 "(typically upsert) not supported for multi-table "
1289 f"mapper {mapper}"
1290 )
1291
1292 assert mapper is not None
1293 assert session._transaction is not None
1294 result = _bulk_insert(
1295 mapper,
1296 cast(
1297 "Iterable[Dict[str, Any]]",
1298 [params] if isinstance(params, dict) else params,
1299 ),
1300 session._transaction,
1301 isstates=False,
1302 return_defaults=insert_options._return_defaults,
1303 render_nulls=insert_options._render_nulls,
1304 use_orm_insert_stmt=statement,
1305 execution_options=execution_options,
1306 )
1307 elif insert_options._dml_strategy == "orm":
1308 result = conn.execute(
1309 statement, params or {}, execution_options=execution_options
1310 )
1311 else:
1312 raise AssertionError()
1313
1314 if not bool(statement._returning):
1315 return result
1316
1317 if insert_options._populate_existing:
1318 load_options = execution_options.get(
1319 "_sa_orm_load_options", QueryContext.default_load_options
1320 )
1321 load_options += {"_populate_existing": True}
1322 execution_options = execution_options.union(
1323 {"_sa_orm_load_options": load_options}
1324 )
1325
1326 return cls._return_orm_returning(
1327 session,
1328 statement,
1329 params,
1330 execution_options,
1331 bind_arguments,
1332 result,
1333 )
1334
1335 @classmethod
1336 def create_for_statement(cls, statement, compiler, **kw) -> _BulkORMInsert:
1337 self = cast(
1338 _BulkORMInsert,
1339 super().create_for_statement(statement, compiler, **kw),
1340 )
1341
1342 if compiler is not None:
1343 toplevel = not compiler.stack
1344 else:
1345 toplevel = True
1346 if not toplevel:
1347 return self
1348
1349 mapper = statement._propagate_attrs["plugin_subject"]
1350 dml_strategy = statement._annotations.get("dml_strategy", "raw")
1351 if dml_strategy == "bulk":
1352 self._setup_for_bulk_insert(compiler)
1353 elif dml_strategy == "orm":
1354 self._setup_for_orm_insert(compiler, mapper)
1355
1356 return self
1357
1358 @classmethod
1359 def _resolved_keys_as_col_keys(cls, mapper, resolved_value_dict):
1360 return {
1361 col.key if col is not None else k: v
1362 for col, k, v in (
1363 (mapper.c.get(k), k, v) for k, v in resolved_value_dict.items()
1364 )
1365 }
1366
1367 def _setup_for_orm_insert(self, compiler, mapper):
1368 statement = orm_level_statement = cast(dml.Insert, self.statement)
1369
1370 statement = self._setup_orm_returning(
1371 compiler,
1372 orm_level_statement,
1373 statement,
1374 dml_mapper=mapper,
1375 use_supplemental_cols=False,
1376 )
1377 self.statement = statement
1378
1379 def _setup_for_bulk_insert(self, compiler):
1380 """establish an INSERT statement within the context of
1381 bulk insert.
1382
1383 This method will be within the "conn.execute()" call that is invoked
1384 by persistence._emit_insert_statement().
1385
1386 """
1387 statement = orm_level_statement = cast(dml.Insert, self.statement)
1388 an = statement._annotations
1389
1390 emit_insert_table, emit_insert_mapper = (
1391 an["_emit_insert_table"],
1392 an["_emit_insert_mapper"],
1393 )
1394
1395 statement = statement._clone()
1396
1397 statement.table = emit_insert_table
1398 if self._dict_parameters:
1399 self._dict_parameters = {
1400 col: val
1401 for col, val in self._dict_parameters.items()
1402 if col.table is emit_insert_table
1403 }
1404
1405 statement = self._setup_orm_returning(
1406 compiler,
1407 orm_level_statement,
1408 statement,
1409 dml_mapper=emit_insert_mapper,
1410 use_supplemental_cols=True,
1411 )
1412
1413 if (
1414 self.from_statement_ctx is not None
1415 and self.from_statement_ctx.compile_options._is_star
1416 ):
1417 raise sa_exc.CompileError(
1418 "Can't use RETURNING * with bulk ORM INSERT. "
1419 "Please use a different INSERT form, such as INSERT..VALUES "
1420 "or INSERT with a Core Connection"
1421 )
1422
1423 self.statement = statement
1424
1425
1426@CompileState.plugin_for("orm", "update")
1427class _BulkORMUpdate(_BulkUDCompileState, UpdateDMLState):
1428 @classmethod
1429 def create_for_statement(cls, statement, compiler, **kw):
1430 self = cls.__new__(cls)
1431
1432 dml_strategy = statement._annotations.get(
1433 "dml_strategy", "unspecified"
1434 )
1435
1436 toplevel = not compiler.stack
1437
1438 if toplevel and dml_strategy == "bulk":
1439 self._setup_for_bulk_update(statement, compiler)
1440 elif (
1441 dml_strategy == "core_only"
1442 or dml_strategy == "unspecified"
1443 and "parententity" not in statement.table._annotations
1444 ):
1445 UpdateDMLState.__init__(self, statement, compiler, **kw)
1446 elif not toplevel or dml_strategy in ("orm", "unspecified"):
1447 self._setup_for_orm_update(statement, compiler)
1448
1449 return self
1450
1451 def _setup_for_orm_update(self, statement, compiler, **kw):
1452 orm_level_statement = statement
1453
1454 toplevel = not compiler.stack
1455
1456 ext_info = statement.table._annotations["parententity"]
1457
1458 self.mapper = mapper = ext_info.mapper
1459
1460 self._resolved_values = self._get_resolved_values(mapper, statement)
1461
1462 self._init_global_attributes(
1463 statement,
1464 compiler,
1465 toplevel=toplevel,
1466 process_criteria_for_toplevel=toplevel,
1467 )
1468
1469 if statement._values:
1470 self._resolved_values = dict(self._resolved_values)
1471
1472 new_stmt = statement._clone()
1473
1474 if new_stmt.table._annotations["parententity"] is mapper:
1475 new_stmt.table = mapper.local_table
1476
1477 # note if the statement has _multi_values, these
1478 # are passed through to the new statement, which will then raise
1479 # InvalidRequestError because UPDATE doesn't support multi_values
1480 # right now.
1481 if statement._values:
1482 new_stmt._values = self._resolved_values
1483
1484 new_crit = self._adjust_for_extra_criteria(
1485 self.global_attributes, mapper
1486 )
1487 if new_crit:
1488 new_stmt = new_stmt.where(*new_crit)
1489
1490 # if we are against a lambda statement we might not be the
1491 # topmost object that received per-execute annotations
1492
1493 # do this first as we need to determine if there is
1494 # UPDATE..FROM
1495
1496 UpdateDMLState.__init__(self, new_stmt, compiler, **kw)
1497
1498 use_supplemental_cols = False
1499
1500 if not toplevel:
1501 synchronize_session = None
1502 else:
1503 synchronize_session = compiler._annotations.get(
1504 "synchronize_session", None
1505 )
1506 can_use_returning = compiler._annotations.get(
1507 "can_use_returning", None
1508 )
1509 if can_use_returning is not False:
1510 # even though pre_exec has determined basic
1511 # can_use_returning for the dialect, if we are to use
1512 # RETURNING we need to run can_use_returning() at this level
1513 # unconditionally because is_delete_using was not known
1514 # at the pre_exec level
1515 can_use_returning = (
1516 synchronize_session == "fetch"
1517 and self.can_use_returning(
1518 compiler.dialect, mapper, is_multitable=self.is_multitable
1519 )
1520 )
1521
1522 if synchronize_session == "fetch" and can_use_returning:
1523 use_supplemental_cols = True
1524
1525 # NOTE: we might want to RETURNING the actual columns to be
1526 # synchronized also. however this is complicated and difficult
1527 # to align against the behavior of "evaluate". Additionally,
1528 # in a large number (if not the majority) of cases, we have the
1529 # "evaluate" answer, usually a fixed value, in memory already and
1530 # there's no need to re-fetch the same value
1531 # over and over again. so perhaps if it could be RETURNING just
1532 # the elements that were based on a SQL expression and not
1533 # a constant. For now it doesn't quite seem worth it
1534 new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key)
1535
1536 if toplevel:
1537 new_stmt = self._setup_orm_returning(
1538 compiler,
1539 orm_level_statement,
1540 new_stmt,
1541 dml_mapper=mapper,
1542 use_supplemental_cols=use_supplemental_cols,
1543 )
1544
1545 self.statement = new_stmt
1546
1547 def _setup_for_bulk_update(self, statement, compiler, **kw):
1548 """establish an UPDATE statement within the context of
1549 bulk insert.
1550
1551 This method will be within the "conn.execute()" call that is invoked
1552 by persistence._emit_update_statement().
1553
1554 """
1555 statement = cast(dml.Update, statement)
1556 an = statement._annotations
1557
1558 emit_update_table, _ = (
1559 an["_emit_update_table"],
1560 an["_emit_update_mapper"],
1561 )
1562
1563 statement = statement._clone()
1564 statement.table = emit_update_table
1565
1566 UpdateDMLState.__init__(self, statement, compiler, **kw)
1567
1568 if self._maintain_values_ordering:
1569 raise sa_exc.InvalidRequestError(
1570 "bulk ORM UPDATE does not support ordered_values() for "
1571 "custom UPDATE statements with bulk parameter sets. Use a "
1572 "non-bulk UPDATE statement or use values()."
1573 )
1574
1575 if self._dict_parameters:
1576 self._dict_parameters = {
1577 col: val
1578 for col, val in self._dict_parameters.items()
1579 if col.table is emit_update_table
1580 }
1581 self.statement = statement
1582
1583 @classmethod
1584 def orm_execute_statement(
1585 cls,
1586 session: Session,
1587 statement: dml.Update,
1588 params: _CoreAnyExecuteParams,
1589 execution_options: OrmExecuteOptionsParameter,
1590 bind_arguments: _BindArguments,
1591 conn: Connection,
1592 ) -> _result.Result:
1593
1594 update_options = execution_options.get(
1595 "_sa_orm_update_options", cls.default_update_options
1596 )
1597
1598 if update_options._populate_existing:
1599 load_options = execution_options.get(
1600 "_sa_orm_load_options", QueryContext.default_load_options
1601 )
1602 load_options += {"_populate_existing": True}
1603 execution_options = execution_options.union(
1604 {"_sa_orm_load_options": load_options}
1605 )
1606
1607 if update_options._dml_strategy not in (
1608 "orm",
1609 "auto",
1610 "bulk",
1611 "core_only",
1612 ):
1613 raise sa_exc.ArgumentError(
1614 "Valid strategies for ORM UPDATE strategy "
1615 "are 'orm', 'auto', 'bulk', 'core_only'"
1616 )
1617
1618 result: _result.Result[Unpack[TupleAny]]
1619
1620 if update_options._dml_strategy == "bulk":
1621 enable_check_rowcount = not statement._where_criteria
1622
1623 assert update_options._synchronize_session != "fetch"
1624
1625 if (
1626 statement._where_criteria
1627 and update_options._synchronize_session == "evaluate"
1628 ):
1629 raise sa_exc.InvalidRequestError(
1630 "bulk synchronize of persistent objects not supported "
1631 "when using bulk update with additional WHERE "
1632 "criteria right now. add synchronize_session=None "
1633 "execution option to bypass synchronize of persistent "
1634 "objects."
1635 )
1636 mapper = update_options._subject_mapper
1637 assert mapper is not None
1638 assert session._transaction is not None
1639 result = _bulk_update(
1640 mapper,
1641 cast(
1642 "Iterable[Dict[str, Any]]",
1643 [params] if isinstance(params, dict) else params,
1644 ),
1645 session._transaction,
1646 isstates=False,
1647 update_changed_only=False,
1648 use_orm_update_stmt=statement,
1649 enable_check_rowcount=enable_check_rowcount,
1650 )
1651 return cls.orm_setup_cursor_result(
1652 session,
1653 statement,
1654 params,
1655 execution_options,
1656 bind_arguments,
1657 result,
1658 )
1659 else:
1660 return super().orm_execute_statement(
1661 session,
1662 statement,
1663 params,
1664 execution_options,
1665 bind_arguments,
1666 conn,
1667 )
1668
1669 @classmethod
1670 def can_use_returning(
1671 cls,
1672 dialect: Dialect,
1673 mapper: Mapper[Any],
1674 *,
1675 is_multitable: bool = False,
1676 is_update_from: bool = False,
1677 is_delete_using: bool = False,
1678 is_executemany: bool = False,
1679 ) -> bool:
1680 # normal answer for "should we use RETURNING" at all.
1681 normal_answer = (
1682 dialect.update_returning and mapper.local_table.implicit_returning
1683 )
1684 if not normal_answer:
1685 return False
1686
1687 if is_executemany:
1688 return dialect.update_executemany_returning
1689
1690 # these workarounds are currently hypothetical for UPDATE,
1691 # unlike DELETE where they impact MariaDB
1692 if is_update_from:
1693 return dialect.update_returning_multifrom
1694
1695 elif is_multitable and not dialect.update_returning_multifrom:
1696 raise sa_exc.CompileError(
1697 f'Dialect "{dialect.name}" does not support RETURNING '
1698 "with UPDATE..FROM; for synchronize_session='fetch', "
1699 "please add the additional execution option "
1700 "'is_update_from=True' to the statement to indicate that "
1701 "a separate SELECT should be used for this backend."
1702 )
1703
1704 return True
1705
1706 @classmethod
1707 def _do_post_synchronize_bulk_evaluate(
1708 cls, session, params, result, update_options
1709 ):
1710 if not params:
1711 return
1712
1713 mapper = update_options._subject_mapper
1714 pk_keys = [prop.key for prop in mapper._identity_key_props]
1715
1716 identity_map = session.identity_map
1717
1718 for param in params:
1719 identity_key = mapper.identity_key_from_primary_key(
1720 (param[key] for key in pk_keys),
1721 update_options._identity_token,
1722 )
1723 state = identity_map.fast_get_state(identity_key)
1724 if not state:
1725 continue
1726
1727 evaluated_keys = set(param).difference(pk_keys)
1728
1729 dict_ = state.dict
1730 # only evaluate unmodified attributes
1731 to_evaluate = state.unmodified.intersection(evaluated_keys)
1732 for key in to_evaluate:
1733 if key in dict_:
1734 dict_[key] = param[key]
1735
1736 state.manager.dispatch.refresh(state, None, to_evaluate)
1737
1738 state._commit(dict_, list(to_evaluate))
1739
1740 # attributes that were formerly modified instead get expired.
1741 # this only gets hit if the session had pending changes
1742 # and autoflush were set to False.
1743 to_expire = evaluated_keys.intersection(dict_).difference(
1744 to_evaluate
1745 )
1746 if to_expire:
1747 state._expire_attributes(dict_, to_expire)
1748
1749 @classmethod
1750 def _do_post_synchronize_evaluate(
1751 cls, session, statement, result, update_options
1752 ):
1753 matched_objects = cls._get_matched_objects_on_criteria(
1754 update_options,
1755 session.identity_map.all_states(),
1756 )
1757
1758 cls._apply_update_set_values_to_objects(
1759 session,
1760 update_options,
1761 statement,
1762 result.context.compiled_parameters[0],
1763 [(obj, state, dict_) for obj, state, dict_, _ in matched_objects],
1764 result.prefetch_cols(),
1765 result.postfetch_cols(),
1766 )
1767
1768 @classmethod
1769 def _do_post_synchronize_fetch(
1770 cls, session, statement, result, update_options
1771 ):
1772 target_mapper = update_options._subject_mapper
1773
1774 returned_defaults_rows = result.returned_defaults_rows
1775 if returned_defaults_rows:
1776 pk_rows = cls._interpret_returning_rows(
1777 result, target_mapper, returned_defaults_rows
1778 )
1779 matched_rows = [
1780 tuple(row) + (update_options._identity_token,)
1781 for row in pk_rows
1782 ]
1783 else:
1784 matched_rows = update_options._matched_rows
1785
1786 objs = [
1787 session.identity_map[identity_key]
1788 for identity_key in [
1789 target_mapper.identity_key_from_primary_key(
1790 list(primary_key),
1791 identity_token=identity_token,
1792 )
1793 for primary_key, identity_token in [
1794 (row[0:-1], row[-1]) for row in matched_rows
1795 ]
1796 if update_options._identity_token is None
1797 or identity_token == update_options._identity_token
1798 ]
1799 if identity_key in session.identity_map
1800 ]
1801
1802 if not objs:
1803 return
1804
1805 cls._apply_update_set_values_to_objects(
1806 session,
1807 update_options,
1808 statement,
1809 result.context.compiled_parameters[0],
1810 [
1811 (
1812 obj,
1813 attributes.instance_state(obj),
1814 attributes.instance_dict(obj),
1815 )
1816 for obj in objs
1817 ],
1818 result.prefetch_cols(),
1819 result.postfetch_cols(),
1820 )
1821
1822 @classmethod
1823 def _apply_update_set_values_to_objects(
1824 cls,
1825 session,
1826 update_options,
1827 statement,
1828 effective_params,
1829 matched_objects,
1830 prefetch_cols,
1831 postfetch_cols,
1832 ):
1833 """apply values to objects derived from an update statement, e.g.
1834 UPDATE..SET <values>
1835
1836 """
1837
1838 mapper = update_options._subject_mapper
1839 target_cls = mapper.class_
1840 evaluator_compiler = evaluator._EvaluatorCompiler(target_cls)
1841 resolved_values = cls._get_resolved_values(mapper, statement)
1842 resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
1843 mapper, resolved_values
1844 )
1845 value_evaluators = {}
1846 for key, value in resolved_keys_as_propnames:
1847 try:
1848 _evaluator = evaluator_compiler.process(
1849 coercions.expect(roles.ExpressionElementRole, value)
1850 )
1851 except evaluator.UnevaluatableError:
1852 pass
1853 else:
1854 value_evaluators[key] = _evaluator
1855
1856 evaluated_keys = list(value_evaluators.keys())
1857 attrib = {k for k, v in resolved_keys_as_propnames}
1858
1859 states = set()
1860
1861 to_prefetch = {
1862 c
1863 for c in prefetch_cols
1864 if c.key in effective_params
1865 and c in mapper._columntoproperty
1866 and c.key not in evaluated_keys
1867 }
1868 to_expire = {
1869 mapper._columntoproperty[c].key
1870 for c in postfetch_cols
1871 if c in mapper._columntoproperty
1872 }.difference(evaluated_keys)
1873
1874 prefetch_transfer = [
1875 (mapper._columntoproperty[c].key, c.key) for c in to_prefetch
1876 ]
1877
1878 for obj, state, dict_ in matched_objects:
1879
1880 dict_.update(
1881 {
1882 col_to_prop: effective_params[c_key]
1883 for col_to_prop, c_key in prefetch_transfer
1884 }
1885 )
1886
1887 state._expire_attributes(state.dict, to_expire)
1888
1889 to_evaluate = state.unmodified.intersection(evaluated_keys)
1890
1891 for key in to_evaluate:
1892 if key in dict_:
1893 # only run eval for attributes that are present.
1894 dict_[key] = value_evaluators[key](obj)
1895
1896 state.manager.dispatch.refresh(state, None, to_evaluate)
1897
1898 state._commit(dict_, list(to_evaluate))
1899
1900 # attributes that were formerly modified instead get expired.
1901 # this only gets hit if the session had pending changes
1902 # and autoflush were set to False.
1903 to_expire = attrib.intersection(dict_).difference(to_evaluate)
1904 if to_expire:
1905 state._expire_attributes(dict_, to_expire)
1906
1907 states.add(state)
1908 session._register_altered(states)
1909
1910
1911@CompileState.plugin_for("orm", "delete")
1912class _BulkORMDelete(_BulkUDCompileState, DeleteDMLState):
1913 @classmethod
1914 def create_for_statement(cls, statement, compiler, **kw):
1915 self = cls.__new__(cls)
1916
1917 dml_strategy = statement._annotations.get(
1918 "dml_strategy", "unspecified"
1919 )
1920
1921 if (
1922 dml_strategy == "core_only"
1923 or dml_strategy == "unspecified"
1924 and "parententity" not in statement.table._annotations
1925 ):
1926 DeleteDMLState.__init__(self, statement, compiler, **kw)
1927 return self
1928
1929 toplevel = not compiler.stack
1930
1931 orm_level_statement = statement
1932
1933 ext_info = statement.table._annotations["parententity"]
1934 self.mapper = mapper = ext_info.mapper
1935
1936 self._init_global_attributes(
1937 statement,
1938 compiler,
1939 toplevel=toplevel,
1940 process_criteria_for_toplevel=toplevel,
1941 )
1942
1943 new_stmt = statement._clone()
1944
1945 if new_stmt.table._annotations["parententity"] is mapper:
1946 new_stmt.table = mapper.local_table
1947
1948 new_crit = cls._adjust_for_extra_criteria(
1949 self.global_attributes, mapper
1950 )
1951 if new_crit:
1952 new_stmt = new_stmt.where(*new_crit)
1953
1954 # do this first as we need to determine if there is
1955 # DELETE..FROM
1956 DeleteDMLState.__init__(self, new_stmt, compiler, **kw)
1957
1958 use_supplemental_cols = False
1959
1960 if not toplevel:
1961 synchronize_session = None
1962 else:
1963 synchronize_session = compiler._annotations.get(
1964 "synchronize_session", None
1965 )
1966 can_use_returning = compiler._annotations.get(
1967 "can_use_returning", None
1968 )
1969 if can_use_returning is not False:
1970 # even though pre_exec has determined basic
1971 # can_use_returning for the dialect, if we are to use
1972 # RETURNING we need to run can_use_returning() at this level
1973 # unconditionally because is_delete_using was not known
1974 # at the pre_exec level
1975 can_use_returning = (
1976 synchronize_session == "fetch"
1977 and self.can_use_returning(
1978 compiler.dialect,
1979 mapper,
1980 is_multitable=self.is_multitable,
1981 is_delete_using=compiler._annotations.get(
1982 "is_delete_using", False
1983 ),
1984 )
1985 )
1986
1987 if can_use_returning:
1988 use_supplemental_cols = True
1989
1990 new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key)
1991
1992 if toplevel:
1993 new_stmt = self._setup_orm_returning(
1994 compiler,
1995 orm_level_statement,
1996 new_stmt,
1997 dml_mapper=mapper,
1998 use_supplemental_cols=use_supplemental_cols,
1999 )
2000
2001 self.statement = new_stmt
2002
2003 return self
2004
2005 @classmethod
2006 def orm_execute_statement(
2007 cls,
2008 session: Session,
2009 statement: dml.Delete,
2010 params: _CoreAnyExecuteParams,
2011 execution_options: OrmExecuteOptionsParameter,
2012 bind_arguments: _BindArguments,
2013 conn: Connection,
2014 ) -> _result.Result:
2015 update_options = execution_options.get(
2016 "_sa_orm_update_options", cls.default_update_options
2017 )
2018
2019 if update_options._dml_strategy == "bulk":
2020 raise sa_exc.InvalidRequestError(
2021 "Bulk ORM DELETE not supported right now. "
2022 "Statement may be invoked at the "
2023 "Core level using "
2024 "session.connection().execute(stmt, parameters)"
2025 )
2026
2027 if update_options._dml_strategy not in ("orm", "auto", "core_only"):
2028 raise sa_exc.ArgumentError(
2029 "Valid strategies for ORM DELETE strategy are 'orm', 'auto', "
2030 "'core_only'"
2031 )
2032
2033 return super().orm_execute_statement(
2034 session, statement, params, execution_options, bind_arguments, conn
2035 )
2036
2037 @classmethod
2038 def can_use_returning(
2039 cls,
2040 dialect: Dialect,
2041 mapper: Mapper[Any],
2042 *,
2043 is_multitable: bool = False,
2044 is_update_from: bool = False,
2045 is_delete_using: bool = False,
2046 is_executemany: bool = False,
2047 ) -> bool:
2048 # normal answer for "should we use RETURNING" at all.
2049 normal_answer = (
2050 dialect.delete_returning and mapper.local_table.implicit_returning
2051 )
2052 if not normal_answer:
2053 return False
2054
2055 # now get into special workarounds because MariaDB supports
2056 # DELETE...RETURNING but not DELETE...USING...RETURNING.
2057 if is_delete_using:
2058 # is_delete_using hint was passed. use
2059 # additional dialect feature (True for PG, False for MariaDB)
2060 return dialect.delete_returning_multifrom
2061
2062 elif is_multitable and not dialect.delete_returning_multifrom:
2063 # is_delete_using hint was not passed, but we determined
2064 # at compile time that this is in fact a DELETE..USING.
2065 # it's too late to continue since we did not pre-SELECT.
2066 # raise that we need that hint up front.
2067
2068 raise sa_exc.CompileError(
2069 f'Dialect "{dialect.name}" does not support RETURNING '
2070 "with DELETE..USING; for synchronize_session='fetch', "
2071 "please add the additional execution option "
2072 "'is_delete_using=True' to the statement to indicate that "
2073 "a separate SELECT should be used for this backend."
2074 )
2075
2076 return True
2077
2078 @classmethod
2079 def _do_post_synchronize_evaluate(
2080 cls, session, statement, result, update_options
2081 ):
2082 matched_objects = cls._get_matched_objects_on_criteria(
2083 update_options,
2084 session.identity_map.all_states(),
2085 )
2086
2087 to_delete = []
2088
2089 for _, state, dict_, is_partially_expired in matched_objects:
2090 if is_partially_expired:
2091 state._expire(dict_, session.identity_map._modified)
2092 else:
2093 to_delete.append(state)
2094
2095 if to_delete:
2096 session._remove_newly_deleted(to_delete)
2097
2098 @classmethod
2099 def _do_post_synchronize_fetch(
2100 cls, session, statement, result, update_options
2101 ):
2102 target_mapper = update_options._subject_mapper
2103
2104 returned_defaults_rows = result.returned_defaults_rows
2105
2106 if returned_defaults_rows:
2107 pk_rows = cls._interpret_returning_rows(
2108 result, target_mapper, returned_defaults_rows
2109 )
2110
2111 matched_rows = [
2112 tuple(row) + (update_options._identity_token,)
2113 for row in pk_rows
2114 ]
2115 else:
2116 matched_rows = update_options._matched_rows
2117
2118 for row in matched_rows:
2119 primary_key = row[0:-1]
2120 identity_token = row[-1]
2121
2122 # TODO: inline this and call remove_newly_deleted
2123 # once
2124 identity_key = target_mapper.identity_key_from_primary_key(
2125 list(primary_key),
2126 identity_token=identity_token,
2127 )
2128 if identity_key in session.identity_map:
2129 session._remove_newly_deleted(
2130 [
2131 attributes.instance_state(
2132 session.identity_map[identity_key]
2133 )
2134 ]
2135 )