1# orm/bulk_persistence.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: ignore-errors
8
9
10"""additional ORM persistence classes related to "bulk" operations,
11specifically outside of the flush() process.
12
13"""
14
15from __future__ import annotations
16
17from typing import Any
18from typing import cast
19from typing import Dict
20from typing import Iterable
21from typing import Optional
22from typing import overload
23from typing import TYPE_CHECKING
24from typing import TypeVar
25from typing import Union
26
27from . import attributes
28from . import context
29from . import evaluator
30from . import exc as orm_exc
31from . import loading
32from . import persistence
33from .base import NO_VALUE
34from .context import AbstractORMCompileState
35from .context import FromStatement
36from .context import ORMFromStatementCompileState
37from .context import QueryContext
38from .. import exc as sa_exc
39from .. import util
40from ..engine import Dialect
41from ..engine import result as _result
42from ..sql import coercions
43from ..sql import dml
44from ..sql import expression
45from ..sql import roles
46from ..sql import select
47from ..sql import sqltypes
48from ..sql.base import _entity_namespace_key
49from ..sql.base import CompileState
50from ..sql.base import Options
51from ..sql.dml import DeleteDMLState
52from ..sql.dml import InsertDMLState
53from ..sql.dml import UpdateDMLState
54from ..util import EMPTY_DICT
55from ..util.typing import Literal
56
57if TYPE_CHECKING:
58 from ._typing import DMLStrategyArgument
59 from ._typing import OrmExecuteOptionsParameter
60 from ._typing import SynchronizeSessionArgument
61 from .mapper import Mapper
62 from .session import _BindArguments
63 from .session import ORMExecuteState
64 from .session import Session
65 from .session import SessionTransaction
66 from .state import InstanceState
67 from ..engine import Connection
68 from ..engine import cursor
69 from ..engine.interfaces import _CoreAnyExecuteParams
70
71_O = TypeVar("_O", bound=object)
72
73
74@overload
75def _bulk_insert(
76 mapper: Mapper[_O],
77 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
78 session_transaction: SessionTransaction,
79 *,
80 isstates: bool,
81 return_defaults: bool,
82 render_nulls: bool,
83 use_orm_insert_stmt: Literal[None] = ...,
84 execution_options: Optional[OrmExecuteOptionsParameter] = ...,
85) -> None: ...
86
87
88@overload
89def _bulk_insert(
90 mapper: Mapper[_O],
91 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
92 session_transaction: SessionTransaction,
93 *,
94 isstates: bool,
95 return_defaults: bool,
96 render_nulls: bool,
97 use_orm_insert_stmt: Optional[dml.Insert] = ...,
98 execution_options: Optional[OrmExecuteOptionsParameter] = ...,
99) -> cursor.CursorResult[Any]: ...
100
101
102def _bulk_insert(
103 mapper: Mapper[_O],
104 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
105 session_transaction: SessionTransaction,
106 *,
107 isstates: bool,
108 return_defaults: bool,
109 render_nulls: bool,
110 use_orm_insert_stmt: Optional[dml.Insert] = None,
111 execution_options: Optional[OrmExecuteOptionsParameter] = None,
112) -> Optional[cursor.CursorResult[Any]]:
113 base_mapper = mapper.base_mapper
114
115 if session_transaction.session.connection_callable:
116 raise NotImplementedError(
117 "connection_callable / per-instance sharding "
118 "not supported in bulk_insert()"
119 )
120
121 if isstates:
122 if TYPE_CHECKING:
123 mappings = cast(Iterable[InstanceState[_O]], mappings)
124
125 if return_defaults:
126 # list of states allows us to attach .key for return_defaults case
127 states = [(state, state.dict) for state in mappings]
128 mappings = [dict_ for (state, dict_) in states]
129 else:
130 mappings = [state.dict for state in mappings]
131 else:
132 if TYPE_CHECKING:
133 mappings = cast(Iterable[Dict[str, Any]], mappings)
134
135 if return_defaults:
136 # use dictionaries given, so that newly populated defaults
137 # can be delivered back to the caller (see #11661). This is **not**
138 # compatible with other use cases such as a session-executed
139 # insert() construct, as this will confuse the case of
140 # insert-per-subclass for joined inheritance cases (see
141 # test_bulk_statements.py::BulkDMLReturningJoinedInhTest).
142 #
143 # So in this conditional, we have **only** called
144 # session.bulk_insert_mappings() which does not have this
145 # requirement
146 mappings = list(mappings)
147 else:
148 # for all other cases we need to establish a local dictionary
149 # so that the incoming dictionaries aren't mutated
150 mappings = [dict(m) for m in mappings]
151 _expand_composites(mapper, mappings)
152
153 connection = session_transaction.connection(base_mapper)
154
155 return_result: Optional[cursor.CursorResult[Any]] = None
156
157 mappers_to_run = [
158 (table, mp)
159 for table, mp in base_mapper._sorted_tables.items()
160 if table in mapper._pks_by_table
161 ]
162
163 if return_defaults:
164 # not used by new-style bulk inserts, only used for legacy
165 bookkeeping = True
166 elif len(mappers_to_run) > 1:
167 # if we have more than one table, mapper to run where we will be
168 # either horizontally splicing, or copying values between tables,
169 # we need the "bookkeeping" / deterministic returning order
170 bookkeeping = True
171 else:
172 bookkeeping = False
173
174 for table, super_mapper in mappers_to_run:
175 # find bindparams in the statement. For bulk, we don't really know if
176 # a key in the params applies to a different table since we are
177 # potentially inserting for multiple tables here; looking at the
178 # bindparam() is a lot more direct. in most cases this will
179 # use _generate_cache_key() which is memoized, although in practice
180 # the ultimate statement that's executed is probably not the same
181 # object so that memoization might not matter much.
182 extra_bp_names = (
183 [
184 b.key
185 for b in use_orm_insert_stmt._get_embedded_bindparams()
186 if b.key in mappings[0]
187 ]
188 if use_orm_insert_stmt is not None
189 else ()
190 )
191
192 records = (
193 (
194 None,
195 state_dict,
196 params,
197 mapper,
198 connection,
199 value_params,
200 has_all_pks,
201 has_all_defaults,
202 )
203 for (
204 state,
205 state_dict,
206 params,
207 mp,
208 conn,
209 value_params,
210 has_all_pks,
211 has_all_defaults,
212 ) in persistence._collect_insert_commands(
213 table,
214 ((None, mapping, mapper, connection) for mapping in mappings),
215 bulk=True,
216 return_defaults=bookkeeping,
217 render_nulls=render_nulls,
218 include_bulk_keys=extra_bp_names,
219 )
220 )
221
222 result = persistence._emit_insert_statements(
223 base_mapper,
224 None,
225 super_mapper,
226 table,
227 records,
228 bookkeeping=bookkeeping,
229 use_orm_insert_stmt=use_orm_insert_stmt,
230 execution_options=execution_options,
231 )
232 if use_orm_insert_stmt is not None:
233 if not use_orm_insert_stmt._returning or return_result is None:
234 return_result = result
235 elif result.returns_rows:
236 assert bookkeeping
237 return_result = return_result.splice_horizontally(result)
238
239 if return_defaults and isstates:
240 identity_cls = mapper._identity_class
241 identity_props = [p.key for p in mapper._identity_key_props]
242 for state, dict_ in states:
243 state.key = (
244 identity_cls,
245 tuple([dict_[key] for key in identity_props]),
246 None,
247 )
248
249 if use_orm_insert_stmt is not None:
250 assert return_result is not None
251 return return_result
252
253
254@overload
255def _bulk_update(
256 mapper: Mapper[Any],
257 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
258 session_transaction: SessionTransaction,
259 *,
260 isstates: bool,
261 update_changed_only: bool,
262 use_orm_update_stmt: Literal[None] = ...,
263 enable_check_rowcount: bool = True,
264) -> None: ...
265
266
267@overload
268def _bulk_update(
269 mapper: Mapper[Any],
270 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
271 session_transaction: SessionTransaction,
272 *,
273 isstates: bool,
274 update_changed_only: bool,
275 use_orm_update_stmt: Optional[dml.Update] = ...,
276 enable_check_rowcount: bool = True,
277) -> _result.Result[Any]: ...
278
279
280def _bulk_update(
281 mapper: Mapper[Any],
282 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
283 session_transaction: SessionTransaction,
284 *,
285 isstates: bool,
286 update_changed_only: bool,
287 use_orm_update_stmt: Optional[dml.Update] = None,
288 enable_check_rowcount: bool = True,
289) -> Optional[_result.Result[Any]]:
290 base_mapper = mapper.base_mapper
291
292 search_keys = mapper._primary_key_propkeys
293 if mapper._version_id_prop:
294 search_keys = {mapper._version_id_prop.key}.union(search_keys)
295
296 def _changed_dict(mapper, state):
297 return {
298 k: v
299 for k, v in state.dict.items()
300 if k in state.committed_state or k in search_keys
301 }
302
303 if isstates:
304 if update_changed_only:
305 mappings = [_changed_dict(mapper, state) for state in mappings]
306 else:
307 mappings = [state.dict for state in mappings]
308 else:
309 mappings = [dict(m) for m in mappings]
310 _expand_composites(mapper, mappings)
311
312 if session_transaction.session.connection_callable:
313 raise NotImplementedError(
314 "connection_callable / per-instance sharding "
315 "not supported in bulk_update()"
316 )
317
318 connection = session_transaction.connection(base_mapper)
319
320 # find bindparams in the statement. see _bulk_insert for similar
321 # notes for the insert case
322 extra_bp_names = (
323 [
324 b.key
325 for b in use_orm_update_stmt._get_embedded_bindparams()
326 if b.key in mappings[0]
327 ]
328 if use_orm_update_stmt is not None
329 else ()
330 )
331
332 for table, super_mapper in base_mapper._sorted_tables.items():
333 if not mapper.isa(super_mapper) or table not in mapper._pks_by_table:
334 continue
335
336 records = persistence._collect_update_commands(
337 None,
338 table,
339 (
340 (
341 None,
342 mapping,
343 mapper,
344 connection,
345 (
346 mapping[mapper._version_id_prop.key]
347 if mapper._version_id_prop
348 else None
349 ),
350 )
351 for mapping in mappings
352 ),
353 bulk=True,
354 use_orm_update_stmt=use_orm_update_stmt,
355 include_bulk_keys=extra_bp_names,
356 )
357 persistence._emit_update_statements(
358 base_mapper,
359 None,
360 super_mapper,
361 table,
362 records,
363 bookkeeping=False,
364 use_orm_update_stmt=use_orm_update_stmt,
365 enable_check_rowcount=enable_check_rowcount,
366 )
367
368 if use_orm_update_stmt is not None:
369 return _result.null_result()
370
371
372def _expand_composites(mapper, mappings):
373 composite_attrs = mapper.composites
374 if not composite_attrs:
375 return
376
377 composite_keys = set(composite_attrs.keys())
378 populators = {
379 key: composite_attrs[key]._populate_composite_bulk_save_mappings_fn()
380 for key in composite_keys
381 }
382 for mapping in mappings:
383 for key in composite_keys.intersection(mapping):
384 populators[key](mapping)
385
386
387class ORMDMLState(AbstractORMCompileState):
388 is_dml_returning = True
389 from_statement_ctx: Optional[ORMFromStatementCompileState] = None
390
391 @classmethod
392 def _get_orm_crud_kv_pairs(
393 cls, mapper, statement, kv_iterator, needs_to_be_cacheable
394 ):
395 core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs
396
397 for k, v in kv_iterator:
398 k = coercions.expect(roles.DMLColumnRole, k)
399
400 if isinstance(k, str):
401 desc = _entity_namespace_key(mapper, k, default=NO_VALUE)
402 if desc is NO_VALUE:
403 yield (
404 coercions.expect(roles.DMLColumnRole, k),
405 (
406 coercions.expect(
407 roles.ExpressionElementRole,
408 v,
409 type_=sqltypes.NullType(),
410 is_crud=True,
411 )
412 if needs_to_be_cacheable
413 else v
414 ),
415 )
416 else:
417 yield from core_get_crud_kv_pairs(
418 statement,
419 desc._bulk_update_tuples(v),
420 needs_to_be_cacheable,
421 )
422 elif "entity_namespace" in k._annotations:
423 k_anno = k._annotations
424 attr = _entity_namespace_key(
425 k_anno["entity_namespace"], k_anno["proxy_key"]
426 )
427 yield from core_get_crud_kv_pairs(
428 statement,
429 attr._bulk_update_tuples(v),
430 needs_to_be_cacheable,
431 )
432 else:
433 yield (
434 k,
435 (
436 v
437 if not needs_to_be_cacheable
438 else coercions.expect(
439 roles.ExpressionElementRole,
440 v,
441 type_=sqltypes.NullType(),
442 is_crud=True,
443 )
444 ),
445 )
446
447 @classmethod
448 def _get_multi_crud_kv_pairs(cls, statement, kv_iterator):
449 plugin_subject = statement._propagate_attrs["plugin_subject"]
450
451 if not plugin_subject or not plugin_subject.mapper:
452 return UpdateDMLState._get_multi_crud_kv_pairs(
453 statement, kv_iterator
454 )
455
456 return [
457 dict(
458 cls._get_orm_crud_kv_pairs(
459 plugin_subject.mapper, statement, value_dict.items(), False
460 )
461 )
462 for value_dict in kv_iterator
463 ]
464
465 @classmethod
466 def _get_crud_kv_pairs(cls, statement, kv_iterator, needs_to_be_cacheable):
467 assert (
468 needs_to_be_cacheable
469 ), "no test coverage for needs_to_be_cacheable=False"
470
471 plugin_subject = statement._propagate_attrs["plugin_subject"]
472
473 if not plugin_subject or not plugin_subject.mapper:
474 return UpdateDMLState._get_crud_kv_pairs(
475 statement, kv_iterator, needs_to_be_cacheable
476 )
477
478 return list(
479 cls._get_orm_crud_kv_pairs(
480 plugin_subject.mapper,
481 statement,
482 kv_iterator,
483 needs_to_be_cacheable,
484 )
485 )
486
487 @classmethod
488 def get_entity_description(cls, statement):
489 ext_info = statement.table._annotations["parententity"]
490 mapper = ext_info.mapper
491 if ext_info.is_aliased_class:
492 _label_name = ext_info.name
493 else:
494 _label_name = mapper.class_.__name__
495
496 return {
497 "name": _label_name,
498 "type": mapper.class_,
499 "expr": ext_info.entity,
500 "entity": ext_info.entity,
501 "table": mapper.local_table,
502 }
503
504 @classmethod
505 def get_returning_column_descriptions(cls, statement):
506 def _ent_for_col(c):
507 return c._annotations.get("parententity", None)
508
509 def _attr_for_col(c, ent):
510 if ent is None:
511 return c
512 proxy_key = c._annotations.get("proxy_key", None)
513 if not proxy_key:
514 return c
515 else:
516 return getattr(ent.entity, proxy_key, c)
517
518 return [
519 {
520 "name": c.key,
521 "type": c.type,
522 "expr": _attr_for_col(c, ent),
523 "aliased": ent.is_aliased_class,
524 "entity": ent.entity,
525 }
526 for c, ent in [
527 (c, _ent_for_col(c)) for c in statement._all_selected_columns
528 ]
529 ]
530
531 def _setup_orm_returning(
532 self,
533 compiler,
534 orm_level_statement,
535 dml_level_statement,
536 dml_mapper,
537 *,
538 use_supplemental_cols=True,
539 ):
540 """establish ORM column handlers for an INSERT, UPDATE, or DELETE
541 which uses explicit returning().
542
543 called within compilation level create_for_statement.
544
545 The _return_orm_returning() method then receives the Result
546 after the statement was executed, and applies ORM loading to the
547 state that we first established here.
548
549 """
550
551 if orm_level_statement._returning:
552 fs = FromStatement(
553 orm_level_statement._returning,
554 dml_level_statement,
555 _adapt_on_names=False,
556 )
557 fs = fs.execution_options(**orm_level_statement._execution_options)
558 fs = fs.options(*orm_level_statement._with_options)
559 self.select_statement = fs
560 self.from_statement_ctx = fsc = (
561 ORMFromStatementCompileState.create_for_statement(fs, compiler)
562 )
563 fsc.setup_dml_returning_compile_state(dml_mapper)
564
565 dml_level_statement = dml_level_statement._generate()
566 dml_level_statement._returning = ()
567
568 cols_to_return = [c for c in fsc.primary_columns if c is not None]
569
570 # since we are splicing result sets together, make sure there
571 # are columns of some kind returned in each result set
572 if not cols_to_return:
573 cols_to_return.extend(dml_mapper.primary_key)
574
575 if use_supplemental_cols:
576 dml_level_statement = dml_level_statement.return_defaults(
577 # this is a little weird looking, but by passing
578 # primary key as the main list of cols, this tells
579 # return_defaults to omit server-default cols (and
580 # actually all cols, due to some weird thing we should
581 # clean up in crud.py).
582 # Since we have cols_to_return, just return what we asked
583 # for (plus primary key, which ORM persistence needs since
584 # we likely set bookkeeping=True here, which is another
585 # whole thing...). We dont want to clutter the
586 # statement up with lots of other cols the user didn't
587 # ask for. see #9685
588 *dml_mapper.primary_key,
589 supplemental_cols=cols_to_return,
590 )
591 else:
592 dml_level_statement = dml_level_statement.returning(
593 *cols_to_return
594 )
595
596 return dml_level_statement
597
598 @classmethod
599 def _return_orm_returning(
600 cls,
601 session,
602 statement,
603 params,
604 execution_options,
605 bind_arguments,
606 result,
607 ):
608 execution_context = result.context
609 compile_state = execution_context.compiled.compile_state
610
611 if (
612 compile_state.from_statement_ctx
613 and not compile_state.from_statement_ctx.compile_options._is_star
614 ):
615 load_options = execution_options.get(
616 "_sa_orm_load_options", QueryContext.default_load_options
617 )
618
619 querycontext = QueryContext(
620 compile_state.from_statement_ctx,
621 compile_state.select_statement,
622 statement,
623 params,
624 session,
625 load_options,
626 execution_options,
627 bind_arguments,
628 )
629 return loading.instances(result, querycontext)
630 else:
631 return result
632
633
634class BulkUDCompileState(ORMDMLState):
635 class default_update_options(Options):
636 _dml_strategy: DMLStrategyArgument = "auto"
637 _synchronize_session: SynchronizeSessionArgument = "auto"
638 _can_use_returning: bool = False
639 _is_delete_using: bool = False
640 _is_update_from: bool = False
641 _autoflush: bool = True
642 _subject_mapper: Optional[Mapper[Any]] = None
643 _resolved_values = EMPTY_DICT
644 _eval_condition = None
645 _matched_rows = None
646 _identity_token = None
647 _populate_existing: bool = False
648
649 @classmethod
650 def can_use_returning(
651 cls,
652 dialect: Dialect,
653 mapper: Mapper[Any],
654 *,
655 is_multitable: bool = False,
656 is_update_from: bool = False,
657 is_delete_using: bool = False,
658 is_executemany: bool = False,
659 ) -> bool:
660 raise NotImplementedError()
661
662 @classmethod
663 def orm_pre_session_exec(
664 cls,
665 session,
666 statement,
667 params,
668 execution_options,
669 bind_arguments,
670 is_pre_event,
671 ):
672 (
673 update_options,
674 execution_options,
675 ) = BulkUDCompileState.default_update_options.from_execution_options(
676 "_sa_orm_update_options",
677 {
678 "synchronize_session",
679 "autoflush",
680 "populate_existing",
681 "identity_token",
682 "is_delete_using",
683 "is_update_from",
684 "dml_strategy",
685 },
686 execution_options,
687 statement._execution_options,
688 )
689 bind_arguments["clause"] = statement
690 try:
691 plugin_subject = statement._propagate_attrs["plugin_subject"]
692 except KeyError:
693 assert False, "statement had 'orm' plugin but no plugin_subject"
694 else:
695 if plugin_subject:
696 bind_arguments["mapper"] = plugin_subject.mapper
697 update_options += {"_subject_mapper": plugin_subject.mapper}
698
699 if "parententity" not in statement.table._annotations:
700 update_options += {"_dml_strategy": "core_only"}
701 elif not isinstance(params, list):
702 if update_options._dml_strategy == "auto":
703 update_options += {"_dml_strategy": "orm"}
704 elif update_options._dml_strategy == "bulk":
705 raise sa_exc.InvalidRequestError(
706 'Can\'t use "bulk" ORM insert strategy without '
707 "passing separate parameters"
708 )
709 else:
710 if update_options._dml_strategy == "auto":
711 update_options += {"_dml_strategy": "bulk"}
712
713 sync = update_options._synchronize_session
714 if sync is not None:
715 if sync not in ("auto", "evaluate", "fetch", False):
716 raise sa_exc.ArgumentError(
717 "Valid strategies for session synchronization "
718 "are 'auto', 'evaluate', 'fetch', False"
719 )
720 if update_options._dml_strategy == "bulk" and sync == "fetch":
721 raise sa_exc.InvalidRequestError(
722 "The 'fetch' synchronization strategy is not available "
723 "for 'bulk' ORM updates (i.e. multiple parameter sets)"
724 )
725
726 if not is_pre_event:
727 if update_options._autoflush:
728 session._autoflush()
729
730 if update_options._dml_strategy == "orm":
731 if update_options._synchronize_session == "auto":
732 update_options = cls._do_pre_synchronize_auto(
733 session,
734 statement,
735 params,
736 execution_options,
737 bind_arguments,
738 update_options,
739 )
740 elif update_options._synchronize_session == "evaluate":
741 update_options = cls._do_pre_synchronize_evaluate(
742 session,
743 statement,
744 params,
745 execution_options,
746 bind_arguments,
747 update_options,
748 )
749 elif update_options._synchronize_session == "fetch":
750 update_options = cls._do_pre_synchronize_fetch(
751 session,
752 statement,
753 params,
754 execution_options,
755 bind_arguments,
756 update_options,
757 )
758 elif update_options._dml_strategy == "bulk":
759 if update_options._synchronize_session == "auto":
760 update_options += {"_synchronize_session": "evaluate"}
761
762 # indicators from the "pre exec" step that are then
763 # added to the DML statement, which will also be part of the cache
764 # key. The compile level create_for_statement() method will then
765 # consume these at compiler time.
766 statement = statement._annotate(
767 {
768 "synchronize_session": update_options._synchronize_session,
769 "is_delete_using": update_options._is_delete_using,
770 "is_update_from": update_options._is_update_from,
771 "dml_strategy": update_options._dml_strategy,
772 "can_use_returning": update_options._can_use_returning,
773 }
774 )
775
776 return (
777 statement,
778 util.immutabledict(execution_options).union(
779 {"_sa_orm_update_options": update_options}
780 ),
781 )
782
783 @classmethod
784 def orm_setup_cursor_result(
785 cls,
786 session,
787 statement,
788 params,
789 execution_options,
790 bind_arguments,
791 result,
792 ):
793 # this stage of the execution is called after the
794 # do_orm_execute event hook. meaning for an extension like
795 # horizontal sharding, this step happens *within* the horizontal
796 # sharding event handler which calls session.execute() re-entrantly
797 # and will occur for each backend individually.
798 # the sharding extension then returns its own merged result from the
799 # individual ones we return here.
800
801 update_options = execution_options["_sa_orm_update_options"]
802 if update_options._dml_strategy == "orm":
803 if update_options._synchronize_session == "evaluate":
804 cls._do_post_synchronize_evaluate(
805 session, statement, result, update_options
806 )
807 elif update_options._synchronize_session == "fetch":
808 cls._do_post_synchronize_fetch(
809 session, statement, result, update_options
810 )
811 elif update_options._dml_strategy == "bulk":
812 if update_options._synchronize_session == "evaluate":
813 cls._do_post_synchronize_bulk_evaluate(
814 session, params, result, update_options
815 )
816 return result
817
818 return cls._return_orm_returning(
819 session,
820 statement,
821 params,
822 execution_options,
823 bind_arguments,
824 result,
825 )
826
827 @classmethod
828 def _adjust_for_extra_criteria(cls, global_attributes, ext_info):
829 """Apply extra criteria filtering.
830
831 For all distinct single-table-inheritance mappers represented in the
832 table being updated or deleted, produce additional WHERE criteria such
833 that only the appropriate subtypes are selected from the total results.
834
835 Additionally, add WHERE criteria originating from LoaderCriteriaOptions
836 collected from the statement.
837
838 """
839
840 return_crit = ()
841
842 adapter = ext_info._adapter if ext_info.is_aliased_class else None
843
844 if (
845 "additional_entity_criteria",
846 ext_info.mapper,
847 ) in global_attributes:
848 return_crit += tuple(
849 ae._resolve_where_criteria(ext_info)
850 for ae in global_attributes[
851 ("additional_entity_criteria", ext_info.mapper)
852 ]
853 if ae.include_aliases or ae.entity is ext_info
854 )
855
856 if ext_info.mapper._single_table_criterion is not None:
857 return_crit += (ext_info.mapper._single_table_criterion,)
858
859 if adapter:
860 return_crit = tuple(adapter.traverse(crit) for crit in return_crit)
861
862 return return_crit
863
864 @classmethod
865 def _interpret_returning_rows(cls, result, mapper, rows):
866 """return rows that indicate PK cols in mapper.primary_key position
867 for RETURNING rows.
868
869 Prior to 2.0.36, this method seemed to be written for some kind of
870 inheritance scenario but the scenario was unused for actual joined
871 inheritance, and the function instead seemed to perform some kind of
872 partial translation that would remove non-PK cols if the PK cols
873 happened to be first in the row, but not otherwise. The joined
874 inheritance walk feature here seems to have never been used as it was
875 always skipped by the "local_table" check.
876
877 As of 2.0.36 the function strips away non-PK cols and provides the
878 PK cols for the table in mapper PK order.
879
880 """
881
882 try:
883 if mapper.local_table is not mapper.base_mapper.local_table:
884 # TODO: dive more into how a local table PK is used for fetch
885 # sync, not clear if this is correct as it depends on the
886 # downstream routine to fetch rows using
887 # local_table.primary_key order
888 pk_keys = result._tuple_getter(mapper.local_table.primary_key)
889 else:
890 pk_keys = result._tuple_getter(mapper.primary_key)
891 except KeyError:
892 # can't use these rows, they don't have PK cols in them
893 # this is an unusual case where the user would have used
894 # .return_defaults()
895 return []
896
897 return [pk_keys(row) for row in rows]
898
899 @classmethod
900 def _get_matched_objects_on_criteria(cls, update_options, states):
901 mapper = update_options._subject_mapper
902 eval_condition = update_options._eval_condition
903
904 raw_data = [
905 (state.obj(), state, state.dict)
906 for state in states
907 if state.mapper.isa(mapper) and not state.expired
908 ]
909
910 identity_token = update_options._identity_token
911 if identity_token is not None:
912 raw_data = [
913 (obj, state, dict_)
914 for obj, state, dict_ in raw_data
915 if state.identity_token == identity_token
916 ]
917
918 result = []
919 for obj, state, dict_ in raw_data:
920 evaled_condition = eval_condition(obj)
921
922 # caution: don't use "in ()" or == here, _EXPIRE_OBJECT
923 # evaluates as True for all comparisons
924 if (
925 evaled_condition is True
926 or evaled_condition is evaluator._EXPIRED_OBJECT
927 ):
928 result.append(
929 (
930 obj,
931 state,
932 dict_,
933 evaled_condition is evaluator._EXPIRED_OBJECT,
934 )
935 )
936 return result
937
938 @classmethod
939 def _eval_condition_from_statement(cls, update_options, statement):
940 mapper = update_options._subject_mapper
941 target_cls = mapper.class_
942
943 evaluator_compiler = evaluator._EvaluatorCompiler(target_cls)
944 crit = ()
945 if statement._where_criteria:
946 crit += statement._where_criteria
947
948 global_attributes = {}
949 for opt in statement._with_options:
950 if opt._is_criteria_option:
951 opt.get_global_criteria(global_attributes)
952
953 if global_attributes:
954 crit += cls._adjust_for_extra_criteria(global_attributes, mapper)
955
956 if crit:
957 eval_condition = evaluator_compiler.process(*crit)
958 else:
959 # workaround for mypy https://github.com/python/mypy/issues/14027
960 def _eval_condition(obj):
961 return True
962
963 eval_condition = _eval_condition
964
965 return eval_condition
966
967 @classmethod
968 def _do_pre_synchronize_auto(
969 cls,
970 session,
971 statement,
972 params,
973 execution_options,
974 bind_arguments,
975 update_options,
976 ):
977 """setup auto sync strategy
978
979
980 "auto" checks if we can use "evaluate" first, then falls back
981 to "fetch"
982
983 evaluate is vastly more efficient for the common case
984 where session is empty, only has a few objects, and the UPDATE
985 statement can potentially match thousands/millions of rows.
986
987 OTOH more complex criteria that fails to work with "evaluate"
988 we would hope usually correlates with fewer net rows.
989
990 """
991
992 try:
993 eval_condition = cls._eval_condition_from_statement(
994 update_options, statement
995 )
996
997 except evaluator.UnevaluatableError:
998 pass
999 else:
1000 return update_options + {
1001 "_eval_condition": eval_condition,
1002 "_synchronize_session": "evaluate",
1003 }
1004
1005 update_options += {"_synchronize_session": "fetch"}
1006 return cls._do_pre_synchronize_fetch(
1007 session,
1008 statement,
1009 params,
1010 execution_options,
1011 bind_arguments,
1012 update_options,
1013 )
1014
1015 @classmethod
1016 def _do_pre_synchronize_evaluate(
1017 cls,
1018 session,
1019 statement,
1020 params,
1021 execution_options,
1022 bind_arguments,
1023 update_options,
1024 ):
1025 try:
1026 eval_condition = cls._eval_condition_from_statement(
1027 update_options, statement
1028 )
1029
1030 except evaluator.UnevaluatableError as err:
1031 raise sa_exc.InvalidRequestError(
1032 'Could not evaluate current criteria in Python: "%s". '
1033 "Specify 'fetch' or False for the "
1034 "synchronize_session execution option." % err
1035 ) from err
1036
1037 return update_options + {
1038 "_eval_condition": eval_condition,
1039 }
1040
1041 @classmethod
1042 def _get_resolved_values(cls, mapper, statement):
1043 if statement._multi_values:
1044 return []
1045 elif statement._ordered_values:
1046 return list(statement._ordered_values)
1047 elif statement._values:
1048 return list(statement._values.items())
1049 else:
1050 return []
1051
1052 @classmethod
1053 def _resolved_keys_as_propnames(cls, mapper, resolved_values):
1054 values = []
1055 for k, v in resolved_values:
1056 if mapper and isinstance(k, expression.ColumnElement):
1057 try:
1058 attr = mapper._columntoproperty[k]
1059 except orm_exc.UnmappedColumnError:
1060 pass
1061 else:
1062 values.append((attr.key, v))
1063 else:
1064 raise sa_exc.InvalidRequestError(
1065 "Attribute name not found, can't be "
1066 "synchronized back to objects: %r" % k
1067 )
1068 return values
1069
1070 @classmethod
1071 def _do_pre_synchronize_fetch(
1072 cls,
1073 session,
1074 statement,
1075 params,
1076 execution_options,
1077 bind_arguments,
1078 update_options,
1079 ):
1080 mapper = update_options._subject_mapper
1081
1082 select_stmt = (
1083 select(*(mapper.primary_key + (mapper.select_identity_token,)))
1084 .select_from(mapper)
1085 .options(*statement._with_options)
1086 )
1087 select_stmt._where_criteria = statement._where_criteria
1088
1089 # conditionally run the SELECT statement for pre-fetch, testing the
1090 # "bind" for if we can use RETURNING or not using the do_orm_execute
1091 # event. If RETURNING is available, the do_orm_execute event
1092 # will cancel the SELECT from being actually run.
1093 #
1094 # The way this is organized seems strange, why don't we just
1095 # call can_use_returning() before invoking the statement and get
1096 # answer?, why does this go through the whole execute phase using an
1097 # event? Answer: because we are integrating with extensions such
1098 # as the horizontal sharding extention that "multiplexes" an individual
1099 # statement run through multiple engines, and it uses
1100 # do_orm_execute() to do that.
1101
1102 can_use_returning = None
1103
1104 def skip_for_returning(orm_context: ORMExecuteState) -> Any:
1105 bind = orm_context.session.get_bind(**orm_context.bind_arguments)
1106 nonlocal can_use_returning
1107
1108 per_bind_result = cls.can_use_returning(
1109 bind.dialect,
1110 mapper,
1111 is_update_from=update_options._is_update_from,
1112 is_delete_using=update_options._is_delete_using,
1113 is_executemany=orm_context.is_executemany,
1114 )
1115
1116 if can_use_returning is not None:
1117 if can_use_returning != per_bind_result:
1118 raise sa_exc.InvalidRequestError(
1119 "For synchronize_session='fetch', can't mix multiple "
1120 "backends where some support RETURNING and others "
1121 "don't"
1122 )
1123 elif orm_context.is_executemany and not per_bind_result:
1124 raise sa_exc.InvalidRequestError(
1125 "For synchronize_session='fetch', can't use multiple "
1126 "parameter sets in ORM mode, which this backend does not "
1127 "support with RETURNING"
1128 )
1129 else:
1130 can_use_returning = per_bind_result
1131
1132 if per_bind_result:
1133 return _result.null_result()
1134 else:
1135 return None
1136
1137 result = session.execute(
1138 select_stmt,
1139 params,
1140 execution_options=execution_options,
1141 bind_arguments=bind_arguments,
1142 _add_event=skip_for_returning,
1143 )
1144 matched_rows = result.fetchall()
1145
1146 return update_options + {
1147 "_matched_rows": matched_rows,
1148 "_can_use_returning": can_use_returning,
1149 }
1150
1151
1152@CompileState.plugin_for("orm", "insert")
1153class BulkORMInsert(ORMDMLState, InsertDMLState):
1154 class default_insert_options(Options):
1155 _dml_strategy: DMLStrategyArgument = "auto"
1156 _render_nulls: bool = False
1157 _return_defaults: bool = False
1158 _subject_mapper: Optional[Mapper[Any]] = None
1159 _autoflush: bool = True
1160 _populate_existing: bool = False
1161
1162 select_statement: Optional[FromStatement] = None
1163
1164 @classmethod
1165 def orm_pre_session_exec(
1166 cls,
1167 session,
1168 statement,
1169 params,
1170 execution_options,
1171 bind_arguments,
1172 is_pre_event,
1173 ):
1174 (
1175 insert_options,
1176 execution_options,
1177 ) = BulkORMInsert.default_insert_options.from_execution_options(
1178 "_sa_orm_insert_options",
1179 {"dml_strategy", "autoflush", "populate_existing", "render_nulls"},
1180 execution_options,
1181 statement._execution_options,
1182 )
1183 bind_arguments["clause"] = statement
1184 try:
1185 plugin_subject = statement._propagate_attrs["plugin_subject"]
1186 except KeyError:
1187 assert False, "statement had 'orm' plugin but no plugin_subject"
1188 else:
1189 if plugin_subject:
1190 bind_arguments["mapper"] = plugin_subject.mapper
1191 insert_options += {"_subject_mapper": plugin_subject.mapper}
1192
1193 if not params:
1194 if insert_options._dml_strategy == "auto":
1195 insert_options += {"_dml_strategy": "orm"}
1196 elif insert_options._dml_strategy == "bulk":
1197 raise sa_exc.InvalidRequestError(
1198 'Can\'t use "bulk" ORM insert strategy without '
1199 "passing separate parameters"
1200 )
1201 else:
1202 if insert_options._dml_strategy == "auto":
1203 insert_options += {"_dml_strategy": "bulk"}
1204
1205 if insert_options._dml_strategy != "raw":
1206 # for ORM object loading, like ORMContext, we have to disable
1207 # result set adapt_to_context, because we will be generating a
1208 # new statement with specific columns that's cached inside of
1209 # an ORMFromStatementCompileState, which we will re-use for
1210 # each result.
1211 if not execution_options:
1212 execution_options = context._orm_load_exec_options
1213 else:
1214 execution_options = execution_options.union(
1215 context._orm_load_exec_options
1216 )
1217
1218 if not is_pre_event and insert_options._autoflush:
1219 session._autoflush()
1220
1221 statement = statement._annotate(
1222 {"dml_strategy": insert_options._dml_strategy}
1223 )
1224
1225 return (
1226 statement,
1227 util.immutabledict(execution_options).union(
1228 {"_sa_orm_insert_options": insert_options}
1229 ),
1230 )
1231
1232 @classmethod
1233 def orm_execute_statement(
1234 cls,
1235 session: Session,
1236 statement: dml.Insert,
1237 params: _CoreAnyExecuteParams,
1238 execution_options: OrmExecuteOptionsParameter,
1239 bind_arguments: _BindArguments,
1240 conn: Connection,
1241 ) -> _result.Result:
1242 insert_options = execution_options.get(
1243 "_sa_orm_insert_options", cls.default_insert_options
1244 )
1245
1246 if insert_options._dml_strategy not in (
1247 "raw",
1248 "bulk",
1249 "orm",
1250 "auto",
1251 ):
1252 raise sa_exc.ArgumentError(
1253 "Valid strategies for ORM insert strategy "
1254 "are 'raw', 'orm', 'bulk', 'auto"
1255 )
1256
1257 result: _result.Result[Any]
1258
1259 if insert_options._dml_strategy == "raw":
1260 result = conn.execute(
1261 statement, params or {}, execution_options=execution_options
1262 )
1263 return result
1264
1265 if insert_options._dml_strategy == "bulk":
1266 mapper = insert_options._subject_mapper
1267
1268 if (
1269 statement._post_values_clause is not None
1270 and mapper._multiple_persistence_tables
1271 ):
1272 raise sa_exc.InvalidRequestError(
1273 "bulk INSERT with a 'post values' clause "
1274 "(typically upsert) not supported for multi-table "
1275 f"mapper {mapper}"
1276 )
1277
1278 assert mapper is not None
1279 assert session._transaction is not None
1280 result = _bulk_insert(
1281 mapper,
1282 cast(
1283 "Iterable[Dict[str, Any]]",
1284 [params] if isinstance(params, dict) else params,
1285 ),
1286 session._transaction,
1287 isstates=False,
1288 return_defaults=insert_options._return_defaults,
1289 render_nulls=insert_options._render_nulls,
1290 use_orm_insert_stmt=statement,
1291 execution_options=execution_options,
1292 )
1293 elif insert_options._dml_strategy == "orm":
1294 result = conn.execute(
1295 statement, params or {}, execution_options=execution_options
1296 )
1297 else:
1298 raise AssertionError()
1299
1300 if not bool(statement._returning):
1301 return result
1302
1303 if insert_options._populate_existing:
1304 load_options = execution_options.get(
1305 "_sa_orm_load_options", QueryContext.default_load_options
1306 )
1307 load_options += {"_populate_existing": True}
1308 execution_options = execution_options.union(
1309 {"_sa_orm_load_options": load_options}
1310 )
1311
1312 return cls._return_orm_returning(
1313 session,
1314 statement,
1315 params,
1316 execution_options,
1317 bind_arguments,
1318 result,
1319 )
1320
1321 @classmethod
1322 def create_for_statement(cls, statement, compiler, **kw) -> BulkORMInsert:
1323 self = cast(
1324 BulkORMInsert,
1325 super().create_for_statement(statement, compiler, **kw),
1326 )
1327
1328 if compiler is not None:
1329 toplevel = not compiler.stack
1330 else:
1331 toplevel = True
1332 if not toplevel:
1333 return self
1334
1335 mapper = statement._propagate_attrs["plugin_subject"]
1336 dml_strategy = statement._annotations.get("dml_strategy", "raw")
1337 if dml_strategy == "bulk":
1338 self._setup_for_bulk_insert(compiler)
1339 elif dml_strategy == "orm":
1340 self._setup_for_orm_insert(compiler, mapper)
1341
1342 return self
1343
1344 @classmethod
1345 def _resolved_keys_as_col_keys(cls, mapper, resolved_value_dict):
1346 return {
1347 col.key if col is not None else k: v
1348 for col, k, v in (
1349 (mapper.c.get(k), k, v) for k, v in resolved_value_dict.items()
1350 )
1351 }
1352
1353 def _setup_for_orm_insert(self, compiler, mapper):
1354 statement = orm_level_statement = cast(dml.Insert, self.statement)
1355
1356 statement = self._setup_orm_returning(
1357 compiler,
1358 orm_level_statement,
1359 statement,
1360 dml_mapper=mapper,
1361 use_supplemental_cols=False,
1362 )
1363 self.statement = statement
1364
1365 def _setup_for_bulk_insert(self, compiler):
1366 """establish an INSERT statement within the context of
1367 bulk insert.
1368
1369 This method will be within the "conn.execute()" call that is invoked
1370 by persistence._emit_insert_statement().
1371
1372 """
1373 statement = orm_level_statement = cast(dml.Insert, self.statement)
1374 an = statement._annotations
1375
1376 emit_insert_table, emit_insert_mapper = (
1377 an["_emit_insert_table"],
1378 an["_emit_insert_mapper"],
1379 )
1380
1381 statement = statement._clone()
1382
1383 statement.table = emit_insert_table
1384 if self._dict_parameters:
1385 self._dict_parameters = {
1386 col: val
1387 for col, val in self._dict_parameters.items()
1388 if col.table is emit_insert_table
1389 }
1390
1391 statement = self._setup_orm_returning(
1392 compiler,
1393 orm_level_statement,
1394 statement,
1395 dml_mapper=emit_insert_mapper,
1396 use_supplemental_cols=True,
1397 )
1398
1399 if (
1400 self.from_statement_ctx is not None
1401 and self.from_statement_ctx.compile_options._is_star
1402 ):
1403 raise sa_exc.CompileError(
1404 "Can't use RETURNING * with bulk ORM INSERT. "
1405 "Please use a different INSERT form, such as INSERT..VALUES "
1406 "or INSERT with a Core Connection"
1407 )
1408
1409 self.statement = statement
1410
1411
1412@CompileState.plugin_for("orm", "update")
1413class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
1414 @classmethod
1415 def create_for_statement(cls, statement, compiler, **kw):
1416 self = cls.__new__(cls)
1417
1418 dml_strategy = statement._annotations.get(
1419 "dml_strategy", "unspecified"
1420 )
1421
1422 toplevel = not compiler.stack
1423
1424 if toplevel and dml_strategy == "bulk":
1425 self._setup_for_bulk_update(statement, compiler)
1426 elif (
1427 dml_strategy == "core_only"
1428 or dml_strategy == "unspecified"
1429 and "parententity" not in statement.table._annotations
1430 ):
1431 UpdateDMLState.__init__(self, statement, compiler, **kw)
1432 elif not toplevel or dml_strategy in ("orm", "unspecified"):
1433 self._setup_for_orm_update(statement, compiler)
1434
1435 return self
1436
1437 def _setup_for_orm_update(self, statement, compiler, **kw):
1438 orm_level_statement = statement
1439
1440 toplevel = not compiler.stack
1441
1442 ext_info = statement.table._annotations["parententity"]
1443
1444 self.mapper = mapper = ext_info.mapper
1445
1446 self._resolved_values = self._get_resolved_values(mapper, statement)
1447
1448 self._init_global_attributes(
1449 statement,
1450 compiler,
1451 toplevel=toplevel,
1452 process_criteria_for_toplevel=toplevel,
1453 )
1454
1455 if statement._values:
1456 self._resolved_values = dict(self._resolved_values)
1457
1458 new_stmt = statement._clone()
1459
1460 if new_stmt.table._annotations["parententity"] is mapper:
1461 new_stmt.table = mapper.local_table
1462
1463 # note if the statement has _multi_values, these
1464 # are passed through to the new statement, which will then raise
1465 # InvalidRequestError because UPDATE doesn't support multi_values
1466 # right now.
1467 if statement._ordered_values:
1468 new_stmt._ordered_values = self._resolved_values
1469 elif statement._values:
1470 new_stmt._values = self._resolved_values
1471
1472 new_crit = self._adjust_for_extra_criteria(
1473 self.global_attributes, mapper
1474 )
1475 if new_crit:
1476 new_stmt = new_stmt.where(*new_crit)
1477
1478 # if we are against a lambda statement we might not be the
1479 # topmost object that received per-execute annotations
1480
1481 # do this first as we need to determine if there is
1482 # UPDATE..FROM
1483
1484 UpdateDMLState.__init__(self, new_stmt, compiler, **kw)
1485
1486 use_supplemental_cols = False
1487
1488 if not toplevel:
1489 synchronize_session = None
1490 else:
1491 synchronize_session = compiler._annotations.get(
1492 "synchronize_session", None
1493 )
1494 can_use_returning = compiler._annotations.get(
1495 "can_use_returning", None
1496 )
1497 if can_use_returning is not False:
1498 # even though pre_exec has determined basic
1499 # can_use_returning for the dialect, if we are to use
1500 # RETURNING we need to run can_use_returning() at this level
1501 # unconditionally because is_delete_using was not known
1502 # at the pre_exec level
1503 can_use_returning = (
1504 synchronize_session == "fetch"
1505 and self.can_use_returning(
1506 compiler.dialect, mapper, is_multitable=self.is_multitable
1507 )
1508 )
1509
1510 if synchronize_session == "fetch" and can_use_returning:
1511 use_supplemental_cols = True
1512
1513 # NOTE: we might want to RETURNING the actual columns to be
1514 # synchronized also. however this is complicated and difficult
1515 # to align against the behavior of "evaluate". Additionally,
1516 # in a large number (if not the majority) of cases, we have the
1517 # "evaluate" answer, usually a fixed value, in memory already and
1518 # there's no need to re-fetch the same value
1519 # over and over again. so perhaps if it could be RETURNING just
1520 # the elements that were based on a SQL expression and not
1521 # a constant. For now it doesn't quite seem worth it
1522 new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key)
1523
1524 if toplevel:
1525 new_stmt = self._setup_orm_returning(
1526 compiler,
1527 orm_level_statement,
1528 new_stmt,
1529 dml_mapper=mapper,
1530 use_supplemental_cols=use_supplemental_cols,
1531 )
1532
1533 self.statement = new_stmt
1534
1535 def _setup_for_bulk_update(self, statement, compiler, **kw):
1536 """establish an UPDATE statement within the context of
1537 bulk insert.
1538
1539 This method will be within the "conn.execute()" call that is invoked
1540 by persistence._emit_update_statement().
1541
1542 """
1543 statement = cast(dml.Update, statement)
1544 an = statement._annotations
1545
1546 emit_update_table, _ = (
1547 an["_emit_update_table"],
1548 an["_emit_update_mapper"],
1549 )
1550
1551 statement = statement._clone()
1552 statement.table = emit_update_table
1553
1554 UpdateDMLState.__init__(self, statement, compiler, **kw)
1555
1556 if self._ordered_values:
1557 raise sa_exc.InvalidRequestError(
1558 "bulk ORM UPDATE does not support ordered_values() for "
1559 "custom UPDATE statements with bulk parameter sets. Use a "
1560 "non-bulk UPDATE statement or use values()."
1561 )
1562
1563 if self._dict_parameters:
1564 self._dict_parameters = {
1565 col: val
1566 for col, val in self._dict_parameters.items()
1567 if col.table is emit_update_table
1568 }
1569 self.statement = statement
1570
1571 @classmethod
1572 def orm_execute_statement(
1573 cls,
1574 session: Session,
1575 statement: dml.Update,
1576 params: _CoreAnyExecuteParams,
1577 execution_options: OrmExecuteOptionsParameter,
1578 bind_arguments: _BindArguments,
1579 conn: Connection,
1580 ) -> _result.Result:
1581
1582 update_options = execution_options.get(
1583 "_sa_orm_update_options", cls.default_update_options
1584 )
1585
1586 if update_options._populate_existing:
1587 load_options = execution_options.get(
1588 "_sa_orm_load_options", QueryContext.default_load_options
1589 )
1590 load_options += {"_populate_existing": True}
1591 execution_options = execution_options.union(
1592 {"_sa_orm_load_options": load_options}
1593 )
1594
1595 if update_options._dml_strategy not in (
1596 "orm",
1597 "auto",
1598 "bulk",
1599 "core_only",
1600 ):
1601 raise sa_exc.ArgumentError(
1602 "Valid strategies for ORM UPDATE strategy "
1603 "are 'orm', 'auto', 'bulk', 'core_only'"
1604 )
1605
1606 result: _result.Result[Any]
1607
1608 if update_options._dml_strategy == "bulk":
1609 enable_check_rowcount = not statement._where_criteria
1610
1611 assert update_options._synchronize_session != "fetch"
1612
1613 if (
1614 statement._where_criteria
1615 and update_options._synchronize_session == "evaluate"
1616 ):
1617 raise sa_exc.InvalidRequestError(
1618 "bulk synchronize of persistent objects not supported "
1619 "when using bulk update with additional WHERE "
1620 "criteria right now. add synchronize_session=None "
1621 "execution option to bypass synchronize of persistent "
1622 "objects."
1623 )
1624 mapper = update_options._subject_mapper
1625 assert mapper is not None
1626 assert session._transaction is not None
1627 result = _bulk_update(
1628 mapper,
1629 cast(
1630 "Iterable[Dict[str, Any]]",
1631 [params] if isinstance(params, dict) else params,
1632 ),
1633 session._transaction,
1634 isstates=False,
1635 update_changed_only=False,
1636 use_orm_update_stmt=statement,
1637 enable_check_rowcount=enable_check_rowcount,
1638 )
1639 return cls.orm_setup_cursor_result(
1640 session,
1641 statement,
1642 params,
1643 execution_options,
1644 bind_arguments,
1645 result,
1646 )
1647 else:
1648 return super().orm_execute_statement(
1649 session,
1650 statement,
1651 params,
1652 execution_options,
1653 bind_arguments,
1654 conn,
1655 )
1656
1657 @classmethod
1658 def can_use_returning(
1659 cls,
1660 dialect: Dialect,
1661 mapper: Mapper[Any],
1662 *,
1663 is_multitable: bool = False,
1664 is_update_from: bool = False,
1665 is_delete_using: bool = False,
1666 is_executemany: bool = False,
1667 ) -> bool:
1668 # normal answer for "should we use RETURNING" at all.
1669 normal_answer = (
1670 dialect.update_returning and mapper.local_table.implicit_returning
1671 )
1672 if not normal_answer:
1673 return False
1674
1675 if is_executemany:
1676 return dialect.update_executemany_returning
1677
1678 # these workarounds are currently hypothetical for UPDATE,
1679 # unlike DELETE where they impact MariaDB
1680 if is_update_from:
1681 return dialect.update_returning_multifrom
1682
1683 elif is_multitable and not dialect.update_returning_multifrom:
1684 raise sa_exc.CompileError(
1685 f'Dialect "{dialect.name}" does not support RETURNING '
1686 "with UPDATE..FROM; for synchronize_session='fetch', "
1687 "please add the additional execution option "
1688 "'is_update_from=True' to the statement to indicate that "
1689 "a separate SELECT should be used for this backend."
1690 )
1691
1692 return True
1693
1694 @classmethod
1695 def _do_post_synchronize_bulk_evaluate(
1696 cls, session, params, result, update_options
1697 ):
1698 if not params:
1699 return
1700
1701 mapper = update_options._subject_mapper
1702 pk_keys = [prop.key for prop in mapper._identity_key_props]
1703
1704 identity_map = session.identity_map
1705
1706 for param in params:
1707 identity_key = mapper.identity_key_from_primary_key(
1708 (param[key] for key in pk_keys),
1709 update_options._identity_token,
1710 )
1711 state = identity_map.fast_get_state(identity_key)
1712 if not state:
1713 continue
1714
1715 evaluated_keys = set(param).difference(pk_keys)
1716
1717 dict_ = state.dict
1718 # only evaluate unmodified attributes
1719 to_evaluate = state.unmodified.intersection(evaluated_keys)
1720 for key in to_evaluate:
1721 if key in dict_:
1722 dict_[key] = param[key]
1723
1724 state.manager.dispatch.refresh(state, None, to_evaluate)
1725
1726 state._commit(dict_, list(to_evaluate))
1727
1728 # attributes that were formerly modified instead get expired.
1729 # this only gets hit if the session had pending changes
1730 # and autoflush were set to False.
1731 to_expire = evaluated_keys.intersection(dict_).difference(
1732 to_evaluate
1733 )
1734 if to_expire:
1735 state._expire_attributes(dict_, to_expire)
1736
1737 @classmethod
1738 def _do_post_synchronize_evaluate(
1739 cls, session, statement, result, update_options
1740 ):
1741 matched_objects = cls._get_matched_objects_on_criteria(
1742 update_options,
1743 session.identity_map.all_states(),
1744 )
1745
1746 cls._apply_update_set_values_to_objects(
1747 session,
1748 update_options,
1749 statement,
1750 result.context.compiled_parameters[0],
1751 [(obj, state, dict_) for obj, state, dict_, _ in matched_objects],
1752 result.prefetch_cols(),
1753 result.postfetch_cols(),
1754 )
1755
1756 @classmethod
1757 def _do_post_synchronize_fetch(
1758 cls, session, statement, result, update_options
1759 ):
1760 target_mapper = update_options._subject_mapper
1761
1762 returned_defaults_rows = result.returned_defaults_rows
1763 if returned_defaults_rows:
1764 pk_rows = cls._interpret_returning_rows(
1765 result, target_mapper, returned_defaults_rows
1766 )
1767 matched_rows = [
1768 tuple(row) + (update_options._identity_token,)
1769 for row in pk_rows
1770 ]
1771 else:
1772 matched_rows = update_options._matched_rows
1773
1774 objs = [
1775 session.identity_map[identity_key]
1776 for identity_key in [
1777 target_mapper.identity_key_from_primary_key(
1778 list(primary_key),
1779 identity_token=identity_token,
1780 )
1781 for primary_key, identity_token in [
1782 (row[0:-1], row[-1]) for row in matched_rows
1783 ]
1784 if update_options._identity_token is None
1785 or identity_token == update_options._identity_token
1786 ]
1787 if identity_key in session.identity_map
1788 ]
1789
1790 if not objs:
1791 return
1792
1793 cls._apply_update_set_values_to_objects(
1794 session,
1795 update_options,
1796 statement,
1797 result.context.compiled_parameters[0],
1798 [
1799 (
1800 obj,
1801 attributes.instance_state(obj),
1802 attributes.instance_dict(obj),
1803 )
1804 for obj in objs
1805 ],
1806 result.prefetch_cols(),
1807 result.postfetch_cols(),
1808 )
1809
1810 @classmethod
1811 def _apply_update_set_values_to_objects(
1812 cls,
1813 session,
1814 update_options,
1815 statement,
1816 effective_params,
1817 matched_objects,
1818 prefetch_cols,
1819 postfetch_cols,
1820 ):
1821 """apply values to objects derived from an update statement, e.g.
1822 UPDATE..SET <values>
1823
1824 """
1825
1826 mapper = update_options._subject_mapper
1827 target_cls = mapper.class_
1828 evaluator_compiler = evaluator._EvaluatorCompiler(target_cls)
1829 resolved_values = cls._get_resolved_values(mapper, statement)
1830 resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
1831 mapper, resolved_values
1832 )
1833 value_evaluators = {}
1834 for key, value in resolved_keys_as_propnames:
1835 try:
1836 _evaluator = evaluator_compiler.process(
1837 coercions.expect(roles.ExpressionElementRole, value)
1838 )
1839 except evaluator.UnevaluatableError:
1840 pass
1841 else:
1842 value_evaluators[key] = _evaluator
1843
1844 evaluated_keys = list(value_evaluators.keys())
1845 attrib = {k for k, v in resolved_keys_as_propnames}
1846
1847 states = set()
1848
1849 to_prefetch = {
1850 c
1851 for c in prefetch_cols
1852 if c.key in effective_params
1853 and c in mapper._columntoproperty
1854 and c.key not in evaluated_keys
1855 }
1856 to_expire = {
1857 mapper._columntoproperty[c].key
1858 for c in postfetch_cols
1859 if c in mapper._columntoproperty
1860 }.difference(evaluated_keys)
1861
1862 prefetch_transfer = [
1863 (mapper._columntoproperty[c].key, c.key) for c in to_prefetch
1864 ]
1865
1866 for obj, state, dict_ in matched_objects:
1867
1868 dict_.update(
1869 {
1870 col_to_prop: effective_params[c_key]
1871 for col_to_prop, c_key in prefetch_transfer
1872 }
1873 )
1874
1875 state._expire_attributes(state.dict, to_expire)
1876
1877 to_evaluate = state.unmodified.intersection(evaluated_keys)
1878
1879 for key in to_evaluate:
1880 if key in dict_:
1881 # only run eval for attributes that are present.
1882 dict_[key] = value_evaluators[key](obj)
1883
1884 state.manager.dispatch.refresh(state, None, to_evaluate)
1885
1886 state._commit(dict_, list(to_evaluate))
1887
1888 # attributes that were formerly modified instead get expired.
1889 # this only gets hit if the session had pending changes
1890 # and autoflush were set to False.
1891 to_expire = attrib.intersection(dict_).difference(to_evaluate)
1892 if to_expire:
1893 state._expire_attributes(dict_, to_expire)
1894
1895 states.add(state)
1896 session._register_altered(states)
1897
1898
1899@CompileState.plugin_for("orm", "delete")
1900class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
1901 @classmethod
1902 def create_for_statement(cls, statement, compiler, **kw):
1903 self = cls.__new__(cls)
1904
1905 dml_strategy = statement._annotations.get(
1906 "dml_strategy", "unspecified"
1907 )
1908
1909 if (
1910 dml_strategy == "core_only"
1911 or dml_strategy == "unspecified"
1912 and "parententity" not in statement.table._annotations
1913 ):
1914 DeleteDMLState.__init__(self, statement, compiler, **kw)
1915 return self
1916
1917 toplevel = not compiler.stack
1918
1919 orm_level_statement = statement
1920
1921 ext_info = statement.table._annotations["parententity"]
1922 self.mapper = mapper = ext_info.mapper
1923
1924 self._init_global_attributes(
1925 statement,
1926 compiler,
1927 toplevel=toplevel,
1928 process_criteria_for_toplevel=toplevel,
1929 )
1930
1931 new_stmt = statement._clone()
1932
1933 if new_stmt.table._annotations["parententity"] is mapper:
1934 new_stmt.table = mapper.local_table
1935
1936 new_crit = cls._adjust_for_extra_criteria(
1937 self.global_attributes, mapper
1938 )
1939 if new_crit:
1940 new_stmt = new_stmt.where(*new_crit)
1941
1942 # do this first as we need to determine if there is
1943 # DELETE..FROM
1944 DeleteDMLState.__init__(self, new_stmt, compiler, **kw)
1945
1946 use_supplemental_cols = False
1947
1948 if not toplevel:
1949 synchronize_session = None
1950 else:
1951 synchronize_session = compiler._annotations.get(
1952 "synchronize_session", None
1953 )
1954 can_use_returning = compiler._annotations.get(
1955 "can_use_returning", None
1956 )
1957 if can_use_returning is not False:
1958 # even though pre_exec has determined basic
1959 # can_use_returning for the dialect, if we are to use
1960 # RETURNING we need to run can_use_returning() at this level
1961 # unconditionally because is_delete_using was not known
1962 # at the pre_exec level
1963 can_use_returning = (
1964 synchronize_session == "fetch"
1965 and self.can_use_returning(
1966 compiler.dialect,
1967 mapper,
1968 is_multitable=self.is_multitable,
1969 is_delete_using=compiler._annotations.get(
1970 "is_delete_using", False
1971 ),
1972 )
1973 )
1974
1975 if can_use_returning:
1976 use_supplemental_cols = True
1977
1978 new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key)
1979
1980 if toplevel:
1981 new_stmt = self._setup_orm_returning(
1982 compiler,
1983 orm_level_statement,
1984 new_stmt,
1985 dml_mapper=mapper,
1986 use_supplemental_cols=use_supplemental_cols,
1987 )
1988
1989 self.statement = new_stmt
1990
1991 return self
1992
1993 @classmethod
1994 def orm_execute_statement(
1995 cls,
1996 session: Session,
1997 statement: dml.Delete,
1998 params: _CoreAnyExecuteParams,
1999 execution_options: OrmExecuteOptionsParameter,
2000 bind_arguments: _BindArguments,
2001 conn: Connection,
2002 ) -> _result.Result:
2003 update_options = execution_options.get(
2004 "_sa_orm_update_options", cls.default_update_options
2005 )
2006
2007 if update_options._dml_strategy == "bulk":
2008 raise sa_exc.InvalidRequestError(
2009 "Bulk ORM DELETE not supported right now. "
2010 "Statement may be invoked at the "
2011 "Core level using "
2012 "session.connection().execute(stmt, parameters)"
2013 )
2014
2015 if update_options._dml_strategy not in ("orm", "auto", "core_only"):
2016 raise sa_exc.ArgumentError(
2017 "Valid strategies for ORM DELETE strategy are 'orm', 'auto', "
2018 "'core_only'"
2019 )
2020
2021 return super().orm_execute_statement(
2022 session, statement, params, execution_options, bind_arguments, conn
2023 )
2024
2025 @classmethod
2026 def can_use_returning(
2027 cls,
2028 dialect: Dialect,
2029 mapper: Mapper[Any],
2030 *,
2031 is_multitable: bool = False,
2032 is_update_from: bool = False,
2033 is_delete_using: bool = False,
2034 is_executemany: bool = False,
2035 ) -> bool:
2036 # normal answer for "should we use RETURNING" at all.
2037 normal_answer = (
2038 dialect.delete_returning and mapper.local_table.implicit_returning
2039 )
2040 if not normal_answer:
2041 return False
2042
2043 # now get into special workarounds because MariaDB supports
2044 # DELETE...RETURNING but not DELETE...USING...RETURNING.
2045 if is_delete_using:
2046 # is_delete_using hint was passed. use
2047 # additional dialect feature (True for PG, False for MariaDB)
2048 return dialect.delete_returning_multifrom
2049
2050 elif is_multitable and not dialect.delete_returning_multifrom:
2051 # is_delete_using hint was not passed, but we determined
2052 # at compile time that this is in fact a DELETE..USING.
2053 # it's too late to continue since we did not pre-SELECT.
2054 # raise that we need that hint up front.
2055
2056 raise sa_exc.CompileError(
2057 f'Dialect "{dialect.name}" does not support RETURNING '
2058 "with DELETE..USING; for synchronize_session='fetch', "
2059 "please add the additional execution option "
2060 "'is_delete_using=True' to the statement to indicate that "
2061 "a separate SELECT should be used for this backend."
2062 )
2063
2064 return True
2065
2066 @classmethod
2067 def _do_post_synchronize_evaluate(
2068 cls, session, statement, result, update_options
2069 ):
2070 matched_objects = cls._get_matched_objects_on_criteria(
2071 update_options,
2072 session.identity_map.all_states(),
2073 )
2074
2075 to_delete = []
2076
2077 for _, state, dict_, is_partially_expired in matched_objects:
2078 if is_partially_expired:
2079 state._expire(dict_, session.identity_map._modified)
2080 else:
2081 to_delete.append(state)
2082
2083 if to_delete:
2084 session._remove_newly_deleted(to_delete)
2085
2086 @classmethod
2087 def _do_post_synchronize_fetch(
2088 cls, session, statement, result, update_options
2089 ):
2090 target_mapper = update_options._subject_mapper
2091
2092 returned_defaults_rows = result.returned_defaults_rows
2093
2094 if returned_defaults_rows:
2095 pk_rows = cls._interpret_returning_rows(
2096 result, target_mapper, returned_defaults_rows
2097 )
2098
2099 matched_rows = [
2100 tuple(row) + (update_options._identity_token,)
2101 for row in pk_rows
2102 ]
2103 else:
2104 matched_rows = update_options._matched_rows
2105
2106 for row in matched_rows:
2107 primary_key = row[0:-1]
2108 identity_token = row[-1]
2109
2110 # TODO: inline this and call remove_newly_deleted
2111 # once
2112 identity_key = target_mapper.identity_key_from_primary_key(
2113 list(primary_key),
2114 identity_token=identity_token,
2115 )
2116 if identity_key in session.identity_map:
2117 session._remove_newly_deleted(
2118 [
2119 attributes.instance_state(
2120 session.identity_map[identity_key]
2121 )
2122 ]
2123 )