1# orm/bulk_persistence.py
2# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: ignore-errors
8
9
10"""additional ORM persistence classes related to "bulk" operations,
11specifically outside of the flush() process.
12
13"""
14
15from __future__ import annotations
16
17from typing import Any
18from typing import cast
19from typing import Dict
20from typing import Iterable
21from typing import Optional
22from typing import overload
23from typing import TYPE_CHECKING
24from typing import TypeVar
25from typing import Union
26
27from . import attributes
28from . import context
29from . import evaluator
30from . import exc as orm_exc
31from . import loading
32from . import persistence
33from .base import NO_VALUE
34from .context import AbstractORMCompileState
35from .context import FromStatement
36from .context import ORMFromStatementCompileState
37from .context import QueryContext
38from .. import exc as sa_exc
39from .. import util
40from ..engine import Dialect
41from ..engine import result as _result
42from ..sql import coercions
43from ..sql import dml
44from ..sql import expression
45from ..sql import roles
46from ..sql import select
47from ..sql import sqltypes
48from ..sql.base import _entity_namespace_key
49from ..sql.base import CompileState
50from ..sql.base import Options
51from ..sql.dml import DeleteDMLState
52from ..sql.dml import InsertDMLState
53from ..sql.dml import UpdateDMLState
54from ..util import EMPTY_DICT
55from ..util.typing import Literal
56from ..util.typing import TupleAny
57from ..util.typing import Unpack
58
59if TYPE_CHECKING:
60 from ._typing import DMLStrategyArgument
61 from ._typing import OrmExecuteOptionsParameter
62 from ._typing import SynchronizeSessionArgument
63 from .mapper import Mapper
64 from .session import _BindArguments
65 from .session import ORMExecuteState
66 from .session import Session
67 from .session import SessionTransaction
68 from .state import InstanceState
69 from ..engine import Connection
70 from ..engine import cursor
71 from ..engine.interfaces import _CoreAnyExecuteParams
72
73_O = TypeVar("_O", bound=object)
74
75
76@overload
77def _bulk_insert(
78 mapper: Mapper[_O],
79 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
80 session_transaction: SessionTransaction,
81 *,
82 isstates: bool,
83 return_defaults: bool,
84 render_nulls: bool,
85 use_orm_insert_stmt: Literal[None] = ...,
86 execution_options: Optional[OrmExecuteOptionsParameter] = ...,
87) -> None: ...
88
89
90@overload
91def _bulk_insert(
92 mapper: Mapper[_O],
93 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
94 session_transaction: SessionTransaction,
95 *,
96 isstates: bool,
97 return_defaults: bool,
98 render_nulls: bool,
99 use_orm_insert_stmt: Optional[dml.Insert] = ...,
100 execution_options: Optional[OrmExecuteOptionsParameter] = ...,
101) -> cursor.CursorResult[Any]: ...
102
103
104def _bulk_insert(
105 mapper: Mapper[_O],
106 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
107 session_transaction: SessionTransaction,
108 *,
109 isstates: bool,
110 return_defaults: bool,
111 render_nulls: bool,
112 use_orm_insert_stmt: Optional[dml.Insert] = None,
113 execution_options: Optional[OrmExecuteOptionsParameter] = None,
114) -> Optional[cursor.CursorResult[Any]]:
115 base_mapper = mapper.base_mapper
116
117 if session_transaction.session.connection_callable:
118 raise NotImplementedError(
119 "connection_callable / per-instance sharding "
120 "not supported in bulk_insert()"
121 )
122
123 if isstates:
124 if TYPE_CHECKING:
125 mappings = cast(Iterable[InstanceState[_O]], mappings)
126
127 if return_defaults:
128 # list of states allows us to attach .key for return_defaults case
129 states = [(state, state.dict) for state in mappings]
130 mappings = [dict_ for (state, dict_) in states]
131 else:
132 mappings = [state.dict for state in mappings]
133 else:
134 if TYPE_CHECKING:
135 mappings = cast(Iterable[Dict[str, Any]], mappings)
136
137 if return_defaults:
138 # use dictionaries given, so that newly populated defaults
139 # can be delivered back to the caller (see #11661). This is **not**
140 # compatible with other use cases such as a session-executed
141 # insert() construct, as this will confuse the case of
142 # insert-per-subclass for joined inheritance cases (see
143 # test_bulk_statements.py::BulkDMLReturningJoinedInhTest).
144 #
145 # So in this conditional, we have **only** called
146 # session.bulk_insert_mappings() which does not have this
147 # requirement
148 mappings = list(mappings)
149 else:
150 # for all other cases we need to establish a local dictionary
151 # so that the incoming dictionaries aren't mutated
152 mappings = [dict(m) for m in mappings]
153 _expand_composites(mapper, mappings)
154
155 connection = session_transaction.connection(base_mapper)
156
157 return_result: Optional[cursor.CursorResult[Any]] = None
158
159 mappers_to_run = [
160 (table, mp)
161 for table, mp in base_mapper._sorted_tables.items()
162 if table in mapper._pks_by_table
163 ]
164
165 if return_defaults:
166 # not used by new-style bulk inserts, only used for legacy
167 bookkeeping = True
168 elif len(mappers_to_run) > 1:
169 # if we have more than one table, mapper to run where we will be
170 # either horizontally splicing, or copying values between tables,
171 # we need the "bookkeeping" / deterministic returning order
172 bookkeeping = True
173 else:
174 bookkeeping = False
175
176 for table, super_mapper in mappers_to_run:
177 # find bindparams in the statement. For bulk, we don't really know if
178 # a key in the params applies to a different table since we are
179 # potentially inserting for multiple tables here; looking at the
180 # bindparam() is a lot more direct. in most cases this will
181 # use _generate_cache_key() which is memoized, although in practice
182 # the ultimate statement that's executed is probably not the same
183 # object so that memoization might not matter much.
184 extra_bp_names = (
185 [
186 b.key
187 for b in use_orm_insert_stmt._get_embedded_bindparams()
188 if b.key in mappings[0]
189 ]
190 if use_orm_insert_stmt is not None
191 else ()
192 )
193
194 records = (
195 (
196 None,
197 state_dict,
198 params,
199 mapper,
200 connection,
201 value_params,
202 has_all_pks,
203 has_all_defaults,
204 )
205 for (
206 state,
207 state_dict,
208 params,
209 mp,
210 conn,
211 value_params,
212 has_all_pks,
213 has_all_defaults,
214 ) in persistence._collect_insert_commands(
215 table,
216 ((None, mapping, mapper, connection) for mapping in mappings),
217 bulk=True,
218 return_defaults=bookkeeping,
219 render_nulls=render_nulls,
220 include_bulk_keys=extra_bp_names,
221 )
222 )
223
224 result = persistence._emit_insert_statements(
225 base_mapper,
226 None,
227 super_mapper,
228 table,
229 records,
230 bookkeeping=bookkeeping,
231 use_orm_insert_stmt=use_orm_insert_stmt,
232 execution_options=execution_options,
233 )
234 if use_orm_insert_stmt is not None:
235 if not use_orm_insert_stmt._returning or return_result is None:
236 return_result = result
237 elif result.returns_rows:
238 assert bookkeeping
239 return_result = return_result.splice_horizontally(result)
240
241 if return_defaults and isstates:
242 identity_cls = mapper._identity_class
243 identity_props = [p.key for p in mapper._identity_key_props]
244 for state, dict_ in states:
245 state.key = (
246 identity_cls,
247 tuple([dict_[key] for key in identity_props]),
248 None,
249 )
250
251 if use_orm_insert_stmt is not None:
252 assert return_result is not None
253 return return_result
254
255
256@overload
257def _bulk_update(
258 mapper: Mapper[Any],
259 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
260 session_transaction: SessionTransaction,
261 *,
262 isstates: bool,
263 update_changed_only: bool,
264 use_orm_update_stmt: Literal[None] = ...,
265 enable_check_rowcount: bool = True,
266) -> None: ...
267
268
269@overload
270def _bulk_update(
271 mapper: Mapper[Any],
272 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
273 session_transaction: SessionTransaction,
274 *,
275 isstates: bool,
276 update_changed_only: bool,
277 use_orm_update_stmt: Optional[dml.Update] = ...,
278 enable_check_rowcount: bool = True,
279) -> _result.Result[Unpack[TupleAny]]: ...
280
281
282def _bulk_update(
283 mapper: Mapper[Any],
284 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
285 session_transaction: SessionTransaction,
286 *,
287 isstates: bool,
288 update_changed_only: bool,
289 use_orm_update_stmt: Optional[dml.Update] = None,
290 enable_check_rowcount: bool = True,
291) -> Optional[_result.Result[Unpack[TupleAny]]]:
292 base_mapper = mapper.base_mapper
293
294 search_keys = mapper._primary_key_propkeys
295 if mapper._version_id_prop:
296 search_keys = {mapper._version_id_prop.key}.union(search_keys)
297
298 def _changed_dict(mapper, state):
299 return {
300 k: v
301 for k, v in state.dict.items()
302 if k in state.committed_state or k in search_keys
303 }
304
305 if isstates:
306 if update_changed_only:
307 mappings = [_changed_dict(mapper, state) for state in mappings]
308 else:
309 mappings = [state.dict for state in mappings]
310 else:
311 mappings = [dict(m) for m in mappings]
312 _expand_composites(mapper, mappings)
313
314 if session_transaction.session.connection_callable:
315 raise NotImplementedError(
316 "connection_callable / per-instance sharding "
317 "not supported in bulk_update()"
318 )
319
320 connection = session_transaction.connection(base_mapper)
321
322 # find bindparams in the statement. see _bulk_insert for similar
323 # notes for the insert case
324 extra_bp_names = (
325 [
326 b.key
327 for b in use_orm_update_stmt._get_embedded_bindparams()
328 if b.key in mappings[0]
329 ]
330 if use_orm_update_stmt is not None
331 else ()
332 )
333
334 for table, super_mapper in base_mapper._sorted_tables.items():
335 if not mapper.isa(super_mapper) or table not in mapper._pks_by_table:
336 continue
337
338 records = persistence._collect_update_commands(
339 None,
340 table,
341 (
342 (
343 None,
344 mapping,
345 mapper,
346 connection,
347 (
348 mapping[mapper._version_id_prop.key]
349 if mapper._version_id_prop
350 else None
351 ),
352 )
353 for mapping in mappings
354 ),
355 bulk=True,
356 use_orm_update_stmt=use_orm_update_stmt,
357 include_bulk_keys=extra_bp_names,
358 )
359 persistence._emit_update_statements(
360 base_mapper,
361 None,
362 super_mapper,
363 table,
364 records,
365 bookkeeping=False,
366 use_orm_update_stmt=use_orm_update_stmt,
367 enable_check_rowcount=enable_check_rowcount,
368 )
369
370 if use_orm_update_stmt is not None:
371 return _result.null_result()
372
373
374def _expand_composites(mapper, mappings):
375 composite_attrs = mapper.composites
376 if not composite_attrs:
377 return
378
379 composite_keys = set(composite_attrs.keys())
380 populators = {
381 key: composite_attrs[key]._populate_composite_bulk_save_mappings_fn()
382 for key in composite_keys
383 }
384 for mapping in mappings:
385 for key in composite_keys.intersection(mapping):
386 populators[key](mapping)
387
388
389class ORMDMLState(AbstractORMCompileState):
390 is_dml_returning = True
391 from_statement_ctx: Optional[ORMFromStatementCompileState] = None
392
393 @classmethod
394 def _get_orm_crud_kv_pairs(
395 cls, mapper, statement, kv_iterator, needs_to_be_cacheable
396 ):
397 core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs
398
399 for k, v in kv_iterator:
400 k = coercions.expect(roles.DMLColumnRole, k)
401
402 if isinstance(k, str):
403 desc = _entity_namespace_key(mapper, k, default=NO_VALUE)
404 if desc is NO_VALUE:
405 yield (
406 coercions.expect(roles.DMLColumnRole, k),
407 (
408 coercions.expect(
409 roles.ExpressionElementRole,
410 v,
411 type_=sqltypes.NullType(),
412 is_crud=True,
413 )
414 if needs_to_be_cacheable
415 else v
416 ),
417 )
418 else:
419 yield from core_get_crud_kv_pairs(
420 statement,
421 desc._bulk_update_tuples(v),
422 needs_to_be_cacheable,
423 )
424 elif "entity_namespace" in k._annotations:
425 k_anno = k._annotations
426 attr = _entity_namespace_key(
427 k_anno["entity_namespace"], k_anno["proxy_key"]
428 )
429 yield from core_get_crud_kv_pairs(
430 statement,
431 attr._bulk_update_tuples(v),
432 needs_to_be_cacheable,
433 )
434 else:
435 yield (
436 k,
437 (
438 v
439 if not needs_to_be_cacheable
440 else coercions.expect(
441 roles.ExpressionElementRole,
442 v,
443 type_=sqltypes.NullType(),
444 is_crud=True,
445 )
446 ),
447 )
448
449 @classmethod
450 def _get_multi_crud_kv_pairs(cls, statement, kv_iterator):
451 plugin_subject = statement._propagate_attrs["plugin_subject"]
452
453 if not plugin_subject or not plugin_subject.mapper:
454 return UpdateDMLState._get_multi_crud_kv_pairs(
455 statement, kv_iterator
456 )
457
458 return [
459 dict(
460 cls._get_orm_crud_kv_pairs(
461 plugin_subject.mapper, statement, value_dict.items(), False
462 )
463 )
464 for value_dict in kv_iterator
465 ]
466
467 @classmethod
468 def _get_crud_kv_pairs(cls, statement, kv_iterator, needs_to_be_cacheable):
469 assert (
470 needs_to_be_cacheable
471 ), "no test coverage for needs_to_be_cacheable=False"
472
473 plugin_subject = statement._propagate_attrs["plugin_subject"]
474
475 if not plugin_subject or not plugin_subject.mapper:
476 return UpdateDMLState._get_crud_kv_pairs(
477 statement, kv_iterator, needs_to_be_cacheable
478 )
479
480 return list(
481 cls._get_orm_crud_kv_pairs(
482 plugin_subject.mapper,
483 statement,
484 kv_iterator,
485 needs_to_be_cacheable,
486 )
487 )
488
489 @classmethod
490 def get_entity_description(cls, statement):
491 ext_info = statement.table._annotations["parententity"]
492 mapper = ext_info.mapper
493 if ext_info.is_aliased_class:
494 _label_name = ext_info.name
495 else:
496 _label_name = mapper.class_.__name__
497
498 return {
499 "name": _label_name,
500 "type": mapper.class_,
501 "expr": ext_info.entity,
502 "entity": ext_info.entity,
503 "table": mapper.local_table,
504 }
505
506 @classmethod
507 def get_returning_column_descriptions(cls, statement):
508 def _ent_for_col(c):
509 return c._annotations.get("parententity", None)
510
511 def _attr_for_col(c, ent):
512 if ent is None:
513 return c
514 proxy_key = c._annotations.get("proxy_key", None)
515 if not proxy_key:
516 return c
517 else:
518 return getattr(ent.entity, proxy_key, c)
519
520 return [
521 {
522 "name": c.key,
523 "type": c.type,
524 "expr": _attr_for_col(c, ent),
525 "aliased": ent.is_aliased_class,
526 "entity": ent.entity,
527 }
528 for c, ent in [
529 (c, _ent_for_col(c)) for c in statement._all_selected_columns
530 ]
531 ]
532
533 def _setup_orm_returning(
534 self,
535 compiler,
536 orm_level_statement,
537 dml_level_statement,
538 dml_mapper,
539 *,
540 use_supplemental_cols=True,
541 ):
542 """establish ORM column handlers for an INSERT, UPDATE, or DELETE
543 which uses explicit returning().
544
545 called within compilation level create_for_statement.
546
547 The _return_orm_returning() method then receives the Result
548 after the statement was executed, and applies ORM loading to the
549 state that we first established here.
550
551 """
552
553 if orm_level_statement._returning:
554 fs = FromStatement(
555 orm_level_statement._returning,
556 dml_level_statement,
557 _adapt_on_names=False,
558 )
559 fs = fs.execution_options(**orm_level_statement._execution_options)
560 fs = fs.options(*orm_level_statement._with_options)
561 self.select_statement = fs
562 self.from_statement_ctx = fsc = (
563 ORMFromStatementCompileState.create_for_statement(fs, compiler)
564 )
565 fsc.setup_dml_returning_compile_state(dml_mapper)
566
567 dml_level_statement = dml_level_statement._generate()
568 dml_level_statement._returning = ()
569
570 cols_to_return = [c for c in fsc.primary_columns if c is not None]
571
572 # since we are splicing result sets together, make sure there
573 # are columns of some kind returned in each result set
574 if not cols_to_return:
575 cols_to_return.extend(dml_mapper.primary_key)
576
577 if use_supplemental_cols:
578 dml_level_statement = dml_level_statement.return_defaults(
579 # this is a little weird looking, but by passing
580 # primary key as the main list of cols, this tells
581 # return_defaults to omit server-default cols (and
582 # actually all cols, due to some weird thing we should
583 # clean up in crud.py).
584 # Since we have cols_to_return, just return what we asked
585 # for (plus primary key, which ORM persistence needs since
586 # we likely set bookkeeping=True here, which is another
587 # whole thing...). We dont want to clutter the
588 # statement up with lots of other cols the user didn't
589 # ask for. see #9685
590 *dml_mapper.primary_key,
591 supplemental_cols=cols_to_return,
592 )
593 else:
594 dml_level_statement = dml_level_statement.returning(
595 *cols_to_return
596 )
597
598 return dml_level_statement
599
600 @classmethod
601 def _return_orm_returning(
602 cls,
603 session,
604 statement,
605 params,
606 execution_options,
607 bind_arguments,
608 result,
609 ):
610 execution_context = result.context
611 compile_state = execution_context.compiled.compile_state
612
613 if (
614 compile_state.from_statement_ctx
615 and not compile_state.from_statement_ctx.compile_options._is_star
616 ):
617 load_options = execution_options.get(
618 "_sa_orm_load_options", QueryContext.default_load_options
619 )
620
621 querycontext = QueryContext(
622 compile_state.from_statement_ctx,
623 compile_state.select_statement,
624 params,
625 session,
626 load_options,
627 execution_options,
628 bind_arguments,
629 )
630 return loading.instances(result, querycontext)
631 else:
632 return result
633
634
635class BulkUDCompileState(ORMDMLState):
636 class default_update_options(Options):
637 _dml_strategy: DMLStrategyArgument = "auto"
638 _synchronize_session: SynchronizeSessionArgument = "auto"
639 _can_use_returning: bool = False
640 _is_delete_using: bool = False
641 _is_update_from: bool = False
642 _autoflush: bool = True
643 _subject_mapper: Optional[Mapper[Any]] = None
644 _resolved_values = EMPTY_DICT
645 _eval_condition = None
646 _matched_rows = None
647 _identity_token = None
648
649 @classmethod
650 def can_use_returning(
651 cls,
652 dialect: Dialect,
653 mapper: Mapper[Any],
654 *,
655 is_multitable: bool = False,
656 is_update_from: bool = False,
657 is_delete_using: bool = False,
658 is_executemany: bool = False,
659 ) -> bool:
660 raise NotImplementedError()
661
662 @classmethod
663 def orm_pre_session_exec(
664 cls,
665 session,
666 statement,
667 params,
668 execution_options,
669 bind_arguments,
670 is_pre_event,
671 ):
672 (
673 update_options,
674 execution_options,
675 ) = BulkUDCompileState.default_update_options.from_execution_options(
676 "_sa_orm_update_options",
677 {
678 "synchronize_session",
679 "autoflush",
680 "identity_token",
681 "is_delete_using",
682 "is_update_from",
683 "dml_strategy",
684 },
685 execution_options,
686 statement._execution_options,
687 )
688 bind_arguments["clause"] = statement
689 try:
690 plugin_subject = statement._propagate_attrs["plugin_subject"]
691 except KeyError:
692 assert False, "statement had 'orm' plugin but no plugin_subject"
693 else:
694 if plugin_subject:
695 bind_arguments["mapper"] = plugin_subject.mapper
696 update_options += {"_subject_mapper": plugin_subject.mapper}
697
698 if "parententity" not in statement.table._annotations:
699 update_options += {"_dml_strategy": "core_only"}
700 elif not isinstance(params, list):
701 if update_options._dml_strategy == "auto":
702 update_options += {"_dml_strategy": "orm"}
703 elif update_options._dml_strategy == "bulk":
704 raise sa_exc.InvalidRequestError(
705 'Can\'t use "bulk" ORM insert strategy without '
706 "passing separate parameters"
707 )
708 else:
709 if update_options._dml_strategy == "auto":
710 update_options += {"_dml_strategy": "bulk"}
711
712 sync = update_options._synchronize_session
713 if sync is not None:
714 if sync not in ("auto", "evaluate", "fetch", False):
715 raise sa_exc.ArgumentError(
716 "Valid strategies for session synchronization "
717 "are 'auto', 'evaluate', 'fetch', False"
718 )
719 if update_options._dml_strategy == "bulk" and sync == "fetch":
720 raise sa_exc.InvalidRequestError(
721 "The 'fetch' synchronization strategy is not available "
722 "for 'bulk' ORM updates (i.e. multiple parameter sets)"
723 )
724
725 if not is_pre_event:
726 if update_options._autoflush:
727 session._autoflush()
728
729 if update_options._dml_strategy == "orm":
730 if update_options._synchronize_session == "auto":
731 update_options = cls._do_pre_synchronize_auto(
732 session,
733 statement,
734 params,
735 execution_options,
736 bind_arguments,
737 update_options,
738 )
739 elif update_options._synchronize_session == "evaluate":
740 update_options = cls._do_pre_synchronize_evaluate(
741 session,
742 statement,
743 params,
744 execution_options,
745 bind_arguments,
746 update_options,
747 )
748 elif update_options._synchronize_session == "fetch":
749 update_options = cls._do_pre_synchronize_fetch(
750 session,
751 statement,
752 params,
753 execution_options,
754 bind_arguments,
755 update_options,
756 )
757 elif update_options._dml_strategy == "bulk":
758 if update_options._synchronize_session == "auto":
759 update_options += {"_synchronize_session": "evaluate"}
760
761 # indicators from the "pre exec" step that are then
762 # added to the DML statement, which will also be part of the cache
763 # key. The compile level create_for_statement() method will then
764 # consume these at compiler time.
765 statement = statement._annotate(
766 {
767 "synchronize_session": update_options._synchronize_session,
768 "is_delete_using": update_options._is_delete_using,
769 "is_update_from": update_options._is_update_from,
770 "dml_strategy": update_options._dml_strategy,
771 "can_use_returning": update_options._can_use_returning,
772 }
773 )
774
775 return (
776 statement,
777 util.immutabledict(execution_options).union(
778 {"_sa_orm_update_options": update_options}
779 ),
780 )
781
782 @classmethod
783 def orm_setup_cursor_result(
784 cls,
785 session,
786 statement,
787 params,
788 execution_options,
789 bind_arguments,
790 result,
791 ):
792 # this stage of the execution is called after the
793 # do_orm_execute event hook. meaning for an extension like
794 # horizontal sharding, this step happens *within* the horizontal
795 # sharding event handler which calls session.execute() re-entrantly
796 # and will occur for each backend individually.
797 # the sharding extension then returns its own merged result from the
798 # individual ones we return here.
799
800 update_options = execution_options["_sa_orm_update_options"]
801 if update_options._dml_strategy == "orm":
802 if update_options._synchronize_session == "evaluate":
803 cls._do_post_synchronize_evaluate(
804 session, statement, result, update_options
805 )
806 elif update_options._synchronize_session == "fetch":
807 cls._do_post_synchronize_fetch(
808 session, statement, result, update_options
809 )
810 elif update_options._dml_strategy == "bulk":
811 if update_options._synchronize_session == "evaluate":
812 cls._do_post_synchronize_bulk_evaluate(
813 session, params, result, update_options
814 )
815 return result
816
817 return cls._return_orm_returning(
818 session,
819 statement,
820 params,
821 execution_options,
822 bind_arguments,
823 result,
824 )
825
826 @classmethod
827 def _adjust_for_extra_criteria(cls, global_attributes, ext_info):
828 """Apply extra criteria filtering.
829
830 For all distinct single-table-inheritance mappers represented in the
831 table being updated or deleted, produce additional WHERE criteria such
832 that only the appropriate subtypes are selected from the total results.
833
834 Additionally, add WHERE criteria originating from LoaderCriteriaOptions
835 collected from the statement.
836
837 """
838
839 return_crit = ()
840
841 adapter = ext_info._adapter if ext_info.is_aliased_class else None
842
843 if (
844 "additional_entity_criteria",
845 ext_info.mapper,
846 ) in global_attributes:
847 return_crit += tuple(
848 ae._resolve_where_criteria(ext_info)
849 for ae in global_attributes[
850 ("additional_entity_criteria", ext_info.mapper)
851 ]
852 if ae.include_aliases or ae.entity is ext_info
853 )
854
855 if ext_info.mapper._single_table_criterion is not None:
856 return_crit += (ext_info.mapper._single_table_criterion,)
857
858 if adapter:
859 return_crit = tuple(adapter.traverse(crit) for crit in return_crit)
860
861 return return_crit
862
863 @classmethod
864 def _interpret_returning_rows(cls, mapper, rows):
865 """translate from local inherited table columns to base mapper
866 primary key columns.
867
868 Joined inheritance mappers always establish the primary key in terms of
869 the base table. When we UPDATE a sub-table, we can only get
870 RETURNING for the sub-table's columns.
871
872 Here, we create a lookup from the local sub table's primary key
873 columns to the base table PK columns so that we can get identity
874 key values from RETURNING that's against the joined inheritance
875 sub-table.
876
877 the complexity here is to support more than one level deep of
878 inheritance, where we have to link columns to each other across
879 the inheritance hierarchy.
880
881 """
882
883 if mapper.local_table is not mapper.base_mapper.local_table:
884 return rows
885
886 # this starts as a mapping of
887 # local_pk_col: local_pk_col.
888 # we will then iteratively rewrite the "value" of the dict with
889 # each successive superclass column
890 local_pk_to_base_pk = {pk: pk for pk in mapper.local_table.primary_key}
891
892 for mp in mapper.iterate_to_root():
893 if mp.inherits is None:
894 break
895 elif mp.local_table is mp.inherits.local_table:
896 continue
897
898 t_to_e = dict(mp._table_to_equated[mp.inherits.local_table])
899 col_to_col = {sub_pk: super_pk for super_pk, sub_pk in t_to_e[mp]}
900 for pk, super_ in local_pk_to_base_pk.items():
901 local_pk_to_base_pk[pk] = col_to_col[super_]
902
903 lookup = {
904 local_pk_to_base_pk[lpk]: idx
905 for idx, lpk in enumerate(mapper.local_table.primary_key)
906 }
907 primary_key_convert = [
908 lookup[bpk] for bpk in mapper.base_mapper.primary_key
909 ]
910 return [tuple(row[idx] for idx in primary_key_convert) for row in rows]
911
912 @classmethod
913 def _get_matched_objects_on_criteria(cls, update_options, states):
914 mapper = update_options._subject_mapper
915 eval_condition = update_options._eval_condition
916
917 raw_data = [
918 (state.obj(), state, state.dict)
919 for state in states
920 if state.mapper.isa(mapper) and not state.expired
921 ]
922
923 identity_token = update_options._identity_token
924 if identity_token is not None:
925 raw_data = [
926 (obj, state, dict_)
927 for obj, state, dict_ in raw_data
928 if state.identity_token == identity_token
929 ]
930
931 result = []
932 for obj, state, dict_ in raw_data:
933 evaled_condition = eval_condition(obj)
934
935 # caution: don't use "in ()" or == here, _EXPIRE_OBJECT
936 # evaluates as True for all comparisons
937 if (
938 evaled_condition is True
939 or evaled_condition is evaluator._EXPIRED_OBJECT
940 ):
941 result.append(
942 (
943 obj,
944 state,
945 dict_,
946 evaled_condition is evaluator._EXPIRED_OBJECT,
947 )
948 )
949 return result
950
951 @classmethod
952 def _eval_condition_from_statement(cls, update_options, statement):
953 mapper = update_options._subject_mapper
954 target_cls = mapper.class_
955
956 evaluator_compiler = evaluator._EvaluatorCompiler(target_cls)
957 crit = ()
958 if statement._where_criteria:
959 crit += statement._where_criteria
960
961 global_attributes = {}
962 for opt in statement._with_options:
963 if opt._is_criteria_option:
964 opt.get_global_criteria(global_attributes)
965
966 if global_attributes:
967 crit += cls._adjust_for_extra_criteria(global_attributes, mapper)
968
969 if crit:
970 eval_condition = evaluator_compiler.process(*crit)
971 else:
972 # workaround for mypy https://github.com/python/mypy/issues/14027
973 def _eval_condition(obj):
974 return True
975
976 eval_condition = _eval_condition
977
978 return eval_condition
979
980 @classmethod
981 def _do_pre_synchronize_auto(
982 cls,
983 session,
984 statement,
985 params,
986 execution_options,
987 bind_arguments,
988 update_options,
989 ):
990 """setup auto sync strategy
991
992
993 "auto" checks if we can use "evaluate" first, then falls back
994 to "fetch"
995
996 evaluate is vastly more efficient for the common case
997 where session is empty, only has a few objects, and the UPDATE
998 statement can potentially match thousands/millions of rows.
999
1000 OTOH more complex criteria that fails to work with "evaluate"
1001 we would hope usually correlates with fewer net rows.
1002
1003 """
1004
1005 try:
1006 eval_condition = cls._eval_condition_from_statement(
1007 update_options, statement
1008 )
1009
1010 except evaluator.UnevaluatableError:
1011 pass
1012 else:
1013 return update_options + {
1014 "_eval_condition": eval_condition,
1015 "_synchronize_session": "evaluate",
1016 }
1017
1018 update_options += {"_synchronize_session": "fetch"}
1019 return cls._do_pre_synchronize_fetch(
1020 session,
1021 statement,
1022 params,
1023 execution_options,
1024 bind_arguments,
1025 update_options,
1026 )
1027
1028 @classmethod
1029 def _do_pre_synchronize_evaluate(
1030 cls,
1031 session,
1032 statement,
1033 params,
1034 execution_options,
1035 bind_arguments,
1036 update_options,
1037 ):
1038 try:
1039 eval_condition = cls._eval_condition_from_statement(
1040 update_options, statement
1041 )
1042
1043 except evaluator.UnevaluatableError as err:
1044 raise sa_exc.InvalidRequestError(
1045 'Could not evaluate current criteria in Python: "%s". '
1046 "Specify 'fetch' or False for the "
1047 "synchronize_session execution option." % err
1048 ) from err
1049
1050 return update_options + {
1051 "_eval_condition": eval_condition,
1052 }
1053
1054 @classmethod
1055 def _get_resolved_values(cls, mapper, statement):
1056 if statement._multi_values:
1057 return []
1058 elif statement._ordered_values:
1059 return list(statement._ordered_values)
1060 elif statement._values:
1061 return list(statement._values.items())
1062 else:
1063 return []
1064
1065 @classmethod
1066 def _resolved_keys_as_propnames(cls, mapper, resolved_values):
1067 values = []
1068 for k, v in resolved_values:
1069 if mapper and isinstance(k, expression.ColumnElement):
1070 try:
1071 attr = mapper._columntoproperty[k]
1072 except orm_exc.UnmappedColumnError:
1073 pass
1074 else:
1075 values.append((attr.key, v))
1076 else:
1077 raise sa_exc.InvalidRequestError(
1078 "Attribute name not found, can't be "
1079 "synchronized back to objects: %r" % k
1080 )
1081 return values
1082
1083 @classmethod
1084 def _do_pre_synchronize_fetch(
1085 cls,
1086 session,
1087 statement,
1088 params,
1089 execution_options,
1090 bind_arguments,
1091 update_options,
1092 ):
1093 mapper = update_options._subject_mapper
1094
1095 select_stmt = (
1096 select(*(mapper.primary_key + (mapper.select_identity_token,)))
1097 .select_from(mapper)
1098 .options(*statement._with_options)
1099 )
1100 select_stmt._where_criteria = statement._where_criteria
1101
1102 # conditionally run the SELECT statement for pre-fetch, testing the
1103 # "bind" for if we can use RETURNING or not using the do_orm_execute
1104 # event. If RETURNING is available, the do_orm_execute event
1105 # will cancel the SELECT from being actually run.
1106 #
1107 # The way this is organized seems strange, why don't we just
1108 # call can_use_returning() before invoking the statement and get
1109 # answer?, why does this go through the whole execute phase using an
1110 # event? Answer: because we are integrating with extensions such
1111 # as the horizontal sharding extention that "multiplexes" an individual
1112 # statement run through multiple engines, and it uses
1113 # do_orm_execute() to do that.
1114
1115 can_use_returning = None
1116
1117 def skip_for_returning(orm_context: ORMExecuteState) -> Any:
1118 bind = orm_context.session.get_bind(**orm_context.bind_arguments)
1119 nonlocal can_use_returning
1120
1121 per_bind_result = cls.can_use_returning(
1122 bind.dialect,
1123 mapper,
1124 is_update_from=update_options._is_update_from,
1125 is_delete_using=update_options._is_delete_using,
1126 is_executemany=orm_context.is_executemany,
1127 )
1128
1129 if can_use_returning is not None:
1130 if can_use_returning != per_bind_result:
1131 raise sa_exc.InvalidRequestError(
1132 "For synchronize_session='fetch', can't mix multiple "
1133 "backends where some support RETURNING and others "
1134 "don't"
1135 )
1136 elif orm_context.is_executemany and not per_bind_result:
1137 raise sa_exc.InvalidRequestError(
1138 "For synchronize_session='fetch', can't use multiple "
1139 "parameter sets in ORM mode, which this backend does not "
1140 "support with RETURNING"
1141 )
1142 else:
1143 can_use_returning = per_bind_result
1144
1145 if per_bind_result:
1146 return _result.null_result()
1147 else:
1148 return None
1149
1150 result = session.execute(
1151 select_stmt,
1152 params,
1153 execution_options=execution_options,
1154 bind_arguments=bind_arguments,
1155 _add_event=skip_for_returning,
1156 )
1157 matched_rows = result.fetchall()
1158
1159 return update_options + {
1160 "_matched_rows": matched_rows,
1161 "_can_use_returning": can_use_returning,
1162 }
1163
1164
1165@CompileState.plugin_for("orm", "insert")
1166class BulkORMInsert(ORMDMLState, InsertDMLState):
1167 class default_insert_options(Options):
1168 _dml_strategy: DMLStrategyArgument = "auto"
1169 _render_nulls: bool = False
1170 _return_defaults: bool = False
1171 _subject_mapper: Optional[Mapper[Any]] = None
1172 _autoflush: bool = True
1173 _populate_existing: bool = False
1174
1175 select_statement: Optional[FromStatement] = None
1176
1177 @classmethod
1178 def orm_pre_session_exec(
1179 cls,
1180 session,
1181 statement,
1182 params,
1183 execution_options,
1184 bind_arguments,
1185 is_pre_event,
1186 ):
1187 (
1188 insert_options,
1189 execution_options,
1190 ) = BulkORMInsert.default_insert_options.from_execution_options(
1191 "_sa_orm_insert_options",
1192 {"dml_strategy", "autoflush", "populate_existing", "render_nulls"},
1193 execution_options,
1194 statement._execution_options,
1195 )
1196 bind_arguments["clause"] = statement
1197 try:
1198 plugin_subject = statement._propagate_attrs["plugin_subject"]
1199 except KeyError:
1200 assert False, "statement had 'orm' plugin but no plugin_subject"
1201 else:
1202 if plugin_subject:
1203 bind_arguments["mapper"] = plugin_subject.mapper
1204 insert_options += {"_subject_mapper": plugin_subject.mapper}
1205
1206 if not params:
1207 if insert_options._dml_strategy == "auto":
1208 insert_options += {"_dml_strategy": "orm"}
1209 elif insert_options._dml_strategy == "bulk":
1210 raise sa_exc.InvalidRequestError(
1211 'Can\'t use "bulk" ORM insert strategy without '
1212 "passing separate parameters"
1213 )
1214 else:
1215 if insert_options._dml_strategy == "auto":
1216 insert_options += {"_dml_strategy": "bulk"}
1217
1218 if insert_options._dml_strategy != "raw":
1219 # for ORM object loading, like ORMContext, we have to disable
1220 # result set adapt_to_context, because we will be generating a
1221 # new statement with specific columns that's cached inside of
1222 # an ORMFromStatementCompileState, which we will re-use for
1223 # each result.
1224 if not execution_options:
1225 execution_options = context._orm_load_exec_options
1226 else:
1227 execution_options = execution_options.union(
1228 context._orm_load_exec_options
1229 )
1230
1231 if not is_pre_event and insert_options._autoflush:
1232 session._autoflush()
1233
1234 statement = statement._annotate(
1235 {"dml_strategy": insert_options._dml_strategy}
1236 )
1237
1238 return (
1239 statement,
1240 util.immutabledict(execution_options).union(
1241 {"_sa_orm_insert_options": insert_options}
1242 ),
1243 )
1244
1245 @classmethod
1246 def orm_execute_statement(
1247 cls,
1248 session: Session,
1249 statement: dml.Insert,
1250 params: _CoreAnyExecuteParams,
1251 execution_options: OrmExecuteOptionsParameter,
1252 bind_arguments: _BindArguments,
1253 conn: Connection,
1254 ) -> _result.Result:
1255 insert_options = execution_options.get(
1256 "_sa_orm_insert_options", cls.default_insert_options
1257 )
1258
1259 if insert_options._dml_strategy not in (
1260 "raw",
1261 "bulk",
1262 "orm",
1263 "auto",
1264 ):
1265 raise sa_exc.ArgumentError(
1266 "Valid strategies for ORM insert strategy "
1267 "are 'raw', 'orm', 'bulk', 'auto"
1268 )
1269
1270 result: _result.Result[Unpack[TupleAny]]
1271
1272 if insert_options._dml_strategy == "raw":
1273 result = conn.execute(
1274 statement, params or {}, execution_options=execution_options
1275 )
1276 return result
1277
1278 if insert_options._dml_strategy == "bulk":
1279 mapper = insert_options._subject_mapper
1280
1281 if (
1282 statement._post_values_clause is not None
1283 and mapper._multiple_persistence_tables
1284 ):
1285 raise sa_exc.InvalidRequestError(
1286 "bulk INSERT with a 'post values' clause "
1287 "(typically upsert) not supported for multi-table "
1288 f"mapper {mapper}"
1289 )
1290
1291 assert mapper is not None
1292 assert session._transaction is not None
1293 result = _bulk_insert(
1294 mapper,
1295 cast(
1296 "Iterable[Dict[str, Any]]",
1297 [params] if isinstance(params, dict) else params,
1298 ),
1299 session._transaction,
1300 isstates=False,
1301 return_defaults=insert_options._return_defaults,
1302 render_nulls=insert_options._render_nulls,
1303 use_orm_insert_stmt=statement,
1304 execution_options=execution_options,
1305 )
1306 elif insert_options._dml_strategy == "orm":
1307 result = conn.execute(
1308 statement, params or {}, execution_options=execution_options
1309 )
1310 else:
1311 raise AssertionError()
1312
1313 if not bool(statement._returning):
1314 return result
1315
1316 if insert_options._populate_existing:
1317 load_options = execution_options.get(
1318 "_sa_orm_load_options", QueryContext.default_load_options
1319 )
1320 load_options += {"_populate_existing": True}
1321 execution_options = execution_options.union(
1322 {"_sa_orm_load_options": load_options}
1323 )
1324
1325 return cls._return_orm_returning(
1326 session,
1327 statement,
1328 params,
1329 execution_options,
1330 bind_arguments,
1331 result,
1332 )
1333
1334 @classmethod
1335 def create_for_statement(cls, statement, compiler, **kw) -> BulkORMInsert:
1336 self = cast(
1337 BulkORMInsert,
1338 super().create_for_statement(statement, compiler, **kw),
1339 )
1340
1341 if compiler is not None:
1342 toplevel = not compiler.stack
1343 else:
1344 toplevel = True
1345 if not toplevel:
1346 return self
1347
1348 mapper = statement._propagate_attrs["plugin_subject"]
1349 dml_strategy = statement._annotations.get("dml_strategy", "raw")
1350 if dml_strategy == "bulk":
1351 self._setup_for_bulk_insert(compiler)
1352 elif dml_strategy == "orm":
1353 self._setup_for_orm_insert(compiler, mapper)
1354
1355 return self
1356
1357 @classmethod
1358 def _resolved_keys_as_col_keys(cls, mapper, resolved_value_dict):
1359 return {
1360 col.key if col is not None else k: v
1361 for col, k, v in (
1362 (mapper.c.get(k), k, v) for k, v in resolved_value_dict.items()
1363 )
1364 }
1365
1366 def _setup_for_orm_insert(self, compiler, mapper):
1367 statement = orm_level_statement = cast(dml.Insert, self.statement)
1368
1369 statement = self._setup_orm_returning(
1370 compiler,
1371 orm_level_statement,
1372 statement,
1373 dml_mapper=mapper,
1374 use_supplemental_cols=False,
1375 )
1376 self.statement = statement
1377
1378 def _setup_for_bulk_insert(self, compiler):
1379 """establish an INSERT statement within the context of
1380 bulk insert.
1381
1382 This method will be within the "conn.execute()" call that is invoked
1383 by persistence._emit_insert_statement().
1384
1385 """
1386 statement = orm_level_statement = cast(dml.Insert, self.statement)
1387 an = statement._annotations
1388
1389 emit_insert_table, emit_insert_mapper = (
1390 an["_emit_insert_table"],
1391 an["_emit_insert_mapper"],
1392 )
1393
1394 statement = statement._clone()
1395
1396 statement.table = emit_insert_table
1397 if self._dict_parameters:
1398 self._dict_parameters = {
1399 col: val
1400 for col, val in self._dict_parameters.items()
1401 if col.table is emit_insert_table
1402 }
1403
1404 statement = self._setup_orm_returning(
1405 compiler,
1406 orm_level_statement,
1407 statement,
1408 dml_mapper=emit_insert_mapper,
1409 use_supplemental_cols=True,
1410 )
1411
1412 if (
1413 self.from_statement_ctx is not None
1414 and self.from_statement_ctx.compile_options._is_star
1415 ):
1416 raise sa_exc.CompileError(
1417 "Can't use RETURNING * with bulk ORM INSERT. "
1418 "Please use a different INSERT form, such as INSERT..VALUES "
1419 "or INSERT with a Core Connection"
1420 )
1421
1422 self.statement = statement
1423
1424
1425@CompileState.plugin_for("orm", "update")
1426class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
1427 @classmethod
1428 def create_for_statement(cls, statement, compiler, **kw):
1429 self = cls.__new__(cls)
1430
1431 dml_strategy = statement._annotations.get(
1432 "dml_strategy", "unspecified"
1433 )
1434
1435 toplevel = not compiler.stack
1436
1437 if toplevel and dml_strategy == "bulk":
1438 self._setup_for_bulk_update(statement, compiler)
1439 elif (
1440 dml_strategy == "core_only"
1441 or dml_strategy == "unspecified"
1442 and "parententity" not in statement.table._annotations
1443 ):
1444 UpdateDMLState.__init__(self, statement, compiler, **kw)
1445 elif not toplevel or dml_strategy in ("orm", "unspecified"):
1446 self._setup_for_orm_update(statement, compiler)
1447
1448 return self
1449
1450 def _setup_for_orm_update(self, statement, compiler, **kw):
1451 orm_level_statement = statement
1452
1453 toplevel = not compiler.stack
1454
1455 ext_info = statement.table._annotations["parententity"]
1456
1457 self.mapper = mapper = ext_info.mapper
1458
1459 self._resolved_values = self._get_resolved_values(mapper, statement)
1460
1461 self._init_global_attributes(
1462 statement,
1463 compiler,
1464 toplevel=toplevel,
1465 process_criteria_for_toplevel=toplevel,
1466 )
1467
1468 if statement._values:
1469 self._resolved_values = dict(self._resolved_values)
1470
1471 new_stmt = statement._clone()
1472
1473 if new_stmt.table._annotations["parententity"] is mapper:
1474 new_stmt.table = mapper.local_table
1475
1476 # note if the statement has _multi_values, these
1477 # are passed through to the new statement, which will then raise
1478 # InvalidRequestError because UPDATE doesn't support multi_values
1479 # right now.
1480 if statement._ordered_values:
1481 new_stmt._ordered_values = self._resolved_values
1482 elif statement._values:
1483 new_stmt._values = self._resolved_values
1484
1485 new_crit = self._adjust_for_extra_criteria(
1486 self.global_attributes, mapper
1487 )
1488 if new_crit:
1489 new_stmt = new_stmt.where(*new_crit)
1490
1491 # if we are against a lambda statement we might not be the
1492 # topmost object that received per-execute annotations
1493
1494 # do this first as we need to determine if there is
1495 # UPDATE..FROM
1496
1497 UpdateDMLState.__init__(self, new_stmt, compiler, **kw)
1498
1499 use_supplemental_cols = False
1500
1501 if not toplevel:
1502 synchronize_session = None
1503 else:
1504 synchronize_session = compiler._annotations.get(
1505 "synchronize_session", None
1506 )
1507 can_use_returning = compiler._annotations.get(
1508 "can_use_returning", None
1509 )
1510 if can_use_returning is not False:
1511 # even though pre_exec has determined basic
1512 # can_use_returning for the dialect, if we are to use
1513 # RETURNING we need to run can_use_returning() at this level
1514 # unconditionally because is_delete_using was not known
1515 # at the pre_exec level
1516 can_use_returning = (
1517 synchronize_session == "fetch"
1518 and self.can_use_returning(
1519 compiler.dialect, mapper, is_multitable=self.is_multitable
1520 )
1521 )
1522
1523 if synchronize_session == "fetch" and can_use_returning:
1524 use_supplemental_cols = True
1525
1526 # NOTE: we might want to RETURNING the actual columns to be
1527 # synchronized also. however this is complicated and difficult
1528 # to align against the behavior of "evaluate". Additionally,
1529 # in a large number (if not the majority) of cases, we have the
1530 # "evaluate" answer, usually a fixed value, in memory already and
1531 # there's no need to re-fetch the same value
1532 # over and over again. so perhaps if it could be RETURNING just
1533 # the elements that were based on a SQL expression and not
1534 # a constant. For now it doesn't quite seem worth it
1535 new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key)
1536
1537 if toplevel:
1538 new_stmt = self._setup_orm_returning(
1539 compiler,
1540 orm_level_statement,
1541 new_stmt,
1542 dml_mapper=mapper,
1543 use_supplemental_cols=use_supplemental_cols,
1544 )
1545
1546 self.statement = new_stmt
1547
1548 def _setup_for_bulk_update(self, statement, compiler, **kw):
1549 """establish an UPDATE statement within the context of
1550 bulk insert.
1551
1552 This method will be within the "conn.execute()" call that is invoked
1553 by persistence._emit_update_statement().
1554
1555 """
1556 statement = cast(dml.Update, statement)
1557 an = statement._annotations
1558
1559 emit_update_table, _ = (
1560 an["_emit_update_table"],
1561 an["_emit_update_mapper"],
1562 )
1563
1564 statement = statement._clone()
1565 statement.table = emit_update_table
1566
1567 UpdateDMLState.__init__(self, statement, compiler, **kw)
1568
1569 if self._ordered_values:
1570 raise sa_exc.InvalidRequestError(
1571 "bulk ORM UPDATE does not support ordered_values() for "
1572 "custom UPDATE statements with bulk parameter sets. Use a "
1573 "non-bulk UPDATE statement or use values()."
1574 )
1575
1576 if self._dict_parameters:
1577 self._dict_parameters = {
1578 col: val
1579 for col, val in self._dict_parameters.items()
1580 if col.table is emit_update_table
1581 }
1582 self.statement = statement
1583
1584 @classmethod
1585 def orm_execute_statement(
1586 cls,
1587 session: Session,
1588 statement: dml.Update,
1589 params: _CoreAnyExecuteParams,
1590 execution_options: OrmExecuteOptionsParameter,
1591 bind_arguments: _BindArguments,
1592 conn: Connection,
1593 ) -> _result.Result:
1594 update_options = execution_options.get(
1595 "_sa_orm_update_options", cls.default_update_options
1596 )
1597
1598 if update_options._dml_strategy not in (
1599 "orm",
1600 "auto",
1601 "bulk",
1602 "core_only",
1603 ):
1604 raise sa_exc.ArgumentError(
1605 "Valid strategies for ORM UPDATE strategy "
1606 "are 'orm', 'auto', 'bulk', 'core_only'"
1607 )
1608
1609 result: _result.Result[Unpack[TupleAny]]
1610
1611 if update_options._dml_strategy == "bulk":
1612 enable_check_rowcount = not statement._where_criteria
1613
1614 assert update_options._synchronize_session != "fetch"
1615
1616 if (
1617 statement._where_criteria
1618 and update_options._synchronize_session == "evaluate"
1619 ):
1620 raise sa_exc.InvalidRequestError(
1621 "bulk synchronize of persistent objects not supported "
1622 "when using bulk update with additional WHERE "
1623 "criteria right now. add synchronize_session=None "
1624 "execution option to bypass synchronize of persistent "
1625 "objects."
1626 )
1627 mapper = update_options._subject_mapper
1628 assert mapper is not None
1629 assert session._transaction is not None
1630 result = _bulk_update(
1631 mapper,
1632 cast(
1633 "Iterable[Dict[str, Any]]",
1634 [params] if isinstance(params, dict) else params,
1635 ),
1636 session._transaction,
1637 isstates=False,
1638 update_changed_only=False,
1639 use_orm_update_stmt=statement,
1640 enable_check_rowcount=enable_check_rowcount,
1641 )
1642 return cls.orm_setup_cursor_result(
1643 session,
1644 statement,
1645 params,
1646 execution_options,
1647 bind_arguments,
1648 result,
1649 )
1650 else:
1651 return super().orm_execute_statement(
1652 session,
1653 statement,
1654 params,
1655 execution_options,
1656 bind_arguments,
1657 conn,
1658 )
1659
1660 @classmethod
1661 def can_use_returning(
1662 cls,
1663 dialect: Dialect,
1664 mapper: Mapper[Any],
1665 *,
1666 is_multitable: bool = False,
1667 is_update_from: bool = False,
1668 is_delete_using: bool = False,
1669 is_executemany: bool = False,
1670 ) -> bool:
1671 # normal answer for "should we use RETURNING" at all.
1672 normal_answer = (
1673 dialect.update_returning and mapper.local_table.implicit_returning
1674 )
1675 if not normal_answer:
1676 return False
1677
1678 if is_executemany:
1679 return dialect.update_executemany_returning
1680
1681 # these workarounds are currently hypothetical for UPDATE,
1682 # unlike DELETE where they impact MariaDB
1683 if is_update_from:
1684 return dialect.update_returning_multifrom
1685
1686 elif is_multitable and not dialect.update_returning_multifrom:
1687 raise sa_exc.CompileError(
1688 f'Dialect "{dialect.name}" does not support RETURNING '
1689 "with UPDATE..FROM; for synchronize_session='fetch', "
1690 "please add the additional execution option "
1691 "'is_update_from=True' to the statement to indicate that "
1692 "a separate SELECT should be used for this backend."
1693 )
1694
1695 return True
1696
1697 @classmethod
1698 def _do_post_synchronize_bulk_evaluate(
1699 cls, session, params, result, update_options
1700 ):
1701 if not params:
1702 return
1703
1704 mapper = update_options._subject_mapper
1705 pk_keys = [prop.key for prop in mapper._identity_key_props]
1706
1707 identity_map = session.identity_map
1708
1709 for param in params:
1710 identity_key = mapper.identity_key_from_primary_key(
1711 (param[key] for key in pk_keys),
1712 update_options._identity_token,
1713 )
1714 state = identity_map.fast_get_state(identity_key)
1715 if not state:
1716 continue
1717
1718 evaluated_keys = set(param).difference(pk_keys)
1719
1720 dict_ = state.dict
1721 # only evaluate unmodified attributes
1722 to_evaluate = state.unmodified.intersection(evaluated_keys)
1723 for key in to_evaluate:
1724 if key in dict_:
1725 dict_[key] = param[key]
1726
1727 state.manager.dispatch.refresh(state, None, to_evaluate)
1728
1729 state._commit(dict_, list(to_evaluate))
1730
1731 # attributes that were formerly modified instead get expired.
1732 # this only gets hit if the session had pending changes
1733 # and autoflush were set to False.
1734 to_expire = evaluated_keys.intersection(dict_).difference(
1735 to_evaluate
1736 )
1737 if to_expire:
1738 state._expire_attributes(dict_, to_expire)
1739
1740 @classmethod
1741 def _do_post_synchronize_evaluate(
1742 cls, session, statement, result, update_options
1743 ):
1744 matched_objects = cls._get_matched_objects_on_criteria(
1745 update_options,
1746 session.identity_map.all_states(),
1747 )
1748
1749 cls._apply_update_set_values_to_objects(
1750 session,
1751 update_options,
1752 statement,
1753 [(obj, state, dict_) for obj, state, dict_, _ in matched_objects],
1754 )
1755
1756 @classmethod
1757 def _do_post_synchronize_fetch(
1758 cls, session, statement, result, update_options
1759 ):
1760 target_mapper = update_options._subject_mapper
1761
1762 returned_defaults_rows = result.returned_defaults_rows
1763 if returned_defaults_rows:
1764 pk_rows = cls._interpret_returning_rows(
1765 target_mapper, returned_defaults_rows
1766 )
1767
1768 matched_rows = [
1769 tuple(row) + (update_options._identity_token,)
1770 for row in pk_rows
1771 ]
1772 else:
1773 matched_rows = update_options._matched_rows
1774
1775 objs = [
1776 session.identity_map[identity_key]
1777 for identity_key in [
1778 target_mapper.identity_key_from_primary_key(
1779 list(primary_key),
1780 identity_token=identity_token,
1781 )
1782 for primary_key, identity_token in [
1783 (row[0:-1], row[-1]) for row in matched_rows
1784 ]
1785 if update_options._identity_token is None
1786 or identity_token == update_options._identity_token
1787 ]
1788 if identity_key in session.identity_map
1789 ]
1790
1791 if not objs:
1792 return
1793
1794 cls._apply_update_set_values_to_objects(
1795 session,
1796 update_options,
1797 statement,
1798 [
1799 (
1800 obj,
1801 attributes.instance_state(obj),
1802 attributes.instance_dict(obj),
1803 )
1804 for obj in objs
1805 ],
1806 )
1807
1808 @classmethod
1809 def _apply_update_set_values_to_objects(
1810 cls, session, update_options, statement, matched_objects
1811 ):
1812 """apply values to objects derived from an update statement, e.g.
1813 UPDATE..SET <values>
1814
1815 """
1816 mapper = update_options._subject_mapper
1817 target_cls = mapper.class_
1818 evaluator_compiler = evaluator._EvaluatorCompiler(target_cls)
1819 resolved_values = cls._get_resolved_values(mapper, statement)
1820 resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
1821 mapper, resolved_values
1822 )
1823 value_evaluators = {}
1824 for key, value in resolved_keys_as_propnames:
1825 try:
1826 _evaluator = evaluator_compiler.process(
1827 coercions.expect(roles.ExpressionElementRole, value)
1828 )
1829 except evaluator.UnevaluatableError:
1830 pass
1831 else:
1832 value_evaluators[key] = _evaluator
1833
1834 evaluated_keys = list(value_evaluators.keys())
1835 attrib = {k for k, v in resolved_keys_as_propnames}
1836
1837 states = set()
1838 for obj, state, dict_ in matched_objects:
1839 to_evaluate = state.unmodified.intersection(evaluated_keys)
1840
1841 for key in to_evaluate:
1842 if key in dict_:
1843 # only run eval for attributes that are present.
1844 dict_[key] = value_evaluators[key](obj)
1845
1846 state.manager.dispatch.refresh(state, None, to_evaluate)
1847
1848 state._commit(dict_, list(to_evaluate))
1849
1850 # attributes that were formerly modified instead get expired.
1851 # this only gets hit if the session had pending changes
1852 # and autoflush were set to False.
1853 to_expire = attrib.intersection(dict_).difference(to_evaluate)
1854 if to_expire:
1855 state._expire_attributes(dict_, to_expire)
1856
1857 states.add(state)
1858 session._register_altered(states)
1859
1860
1861@CompileState.plugin_for("orm", "delete")
1862class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
1863 @classmethod
1864 def create_for_statement(cls, statement, compiler, **kw):
1865 self = cls.__new__(cls)
1866
1867 dml_strategy = statement._annotations.get(
1868 "dml_strategy", "unspecified"
1869 )
1870
1871 if (
1872 dml_strategy == "core_only"
1873 or dml_strategy == "unspecified"
1874 and "parententity" not in statement.table._annotations
1875 ):
1876 DeleteDMLState.__init__(self, statement, compiler, **kw)
1877 return self
1878
1879 toplevel = not compiler.stack
1880
1881 orm_level_statement = statement
1882
1883 ext_info = statement.table._annotations["parententity"]
1884 self.mapper = mapper = ext_info.mapper
1885
1886 self._init_global_attributes(
1887 statement,
1888 compiler,
1889 toplevel=toplevel,
1890 process_criteria_for_toplevel=toplevel,
1891 )
1892
1893 new_stmt = statement._clone()
1894
1895 if new_stmt.table._annotations["parententity"] is mapper:
1896 new_stmt.table = mapper.local_table
1897
1898 new_crit = cls._adjust_for_extra_criteria(
1899 self.global_attributes, mapper
1900 )
1901 if new_crit:
1902 new_stmt = new_stmt.where(*new_crit)
1903
1904 # do this first as we need to determine if there is
1905 # DELETE..FROM
1906 DeleteDMLState.__init__(self, new_stmt, compiler, **kw)
1907
1908 use_supplemental_cols = False
1909
1910 if not toplevel:
1911 synchronize_session = None
1912 else:
1913 synchronize_session = compiler._annotations.get(
1914 "synchronize_session", None
1915 )
1916 can_use_returning = compiler._annotations.get(
1917 "can_use_returning", None
1918 )
1919 if can_use_returning is not False:
1920 # even though pre_exec has determined basic
1921 # can_use_returning for the dialect, if we are to use
1922 # RETURNING we need to run can_use_returning() at this level
1923 # unconditionally because is_delete_using was not known
1924 # at the pre_exec level
1925 can_use_returning = (
1926 synchronize_session == "fetch"
1927 and self.can_use_returning(
1928 compiler.dialect,
1929 mapper,
1930 is_multitable=self.is_multitable,
1931 is_delete_using=compiler._annotations.get(
1932 "is_delete_using", False
1933 ),
1934 )
1935 )
1936
1937 if can_use_returning:
1938 use_supplemental_cols = True
1939
1940 new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key)
1941
1942 if toplevel:
1943 new_stmt = self._setup_orm_returning(
1944 compiler,
1945 orm_level_statement,
1946 new_stmt,
1947 dml_mapper=mapper,
1948 use_supplemental_cols=use_supplemental_cols,
1949 )
1950
1951 self.statement = new_stmt
1952
1953 return self
1954
1955 @classmethod
1956 def orm_execute_statement(
1957 cls,
1958 session: Session,
1959 statement: dml.Delete,
1960 params: _CoreAnyExecuteParams,
1961 execution_options: OrmExecuteOptionsParameter,
1962 bind_arguments: _BindArguments,
1963 conn: Connection,
1964 ) -> _result.Result:
1965 update_options = execution_options.get(
1966 "_sa_orm_update_options", cls.default_update_options
1967 )
1968
1969 if update_options._dml_strategy == "bulk":
1970 raise sa_exc.InvalidRequestError(
1971 "Bulk ORM DELETE not supported right now. "
1972 "Statement may be invoked at the "
1973 "Core level using "
1974 "session.connection().execute(stmt, parameters)"
1975 )
1976
1977 if update_options._dml_strategy not in ("orm", "auto", "core_only"):
1978 raise sa_exc.ArgumentError(
1979 "Valid strategies for ORM DELETE strategy are 'orm', 'auto', "
1980 "'core_only'"
1981 )
1982
1983 return super().orm_execute_statement(
1984 session, statement, params, execution_options, bind_arguments, conn
1985 )
1986
1987 @classmethod
1988 def can_use_returning(
1989 cls,
1990 dialect: Dialect,
1991 mapper: Mapper[Any],
1992 *,
1993 is_multitable: bool = False,
1994 is_update_from: bool = False,
1995 is_delete_using: bool = False,
1996 is_executemany: bool = False,
1997 ) -> bool:
1998 # normal answer for "should we use RETURNING" at all.
1999 normal_answer = (
2000 dialect.delete_returning and mapper.local_table.implicit_returning
2001 )
2002 if not normal_answer:
2003 return False
2004
2005 # now get into special workarounds because MariaDB supports
2006 # DELETE...RETURNING but not DELETE...USING...RETURNING.
2007 if is_delete_using:
2008 # is_delete_using hint was passed. use
2009 # additional dialect feature (True for PG, False for MariaDB)
2010 return dialect.delete_returning_multifrom
2011
2012 elif is_multitable and not dialect.delete_returning_multifrom:
2013 # is_delete_using hint was not passed, but we determined
2014 # at compile time that this is in fact a DELETE..USING.
2015 # it's too late to continue since we did not pre-SELECT.
2016 # raise that we need that hint up front.
2017
2018 raise sa_exc.CompileError(
2019 f'Dialect "{dialect.name}" does not support RETURNING '
2020 "with DELETE..USING; for synchronize_session='fetch', "
2021 "please add the additional execution option "
2022 "'is_delete_using=True' to the statement to indicate that "
2023 "a separate SELECT should be used for this backend."
2024 )
2025
2026 return True
2027
2028 @classmethod
2029 def _do_post_synchronize_evaluate(
2030 cls, session, statement, result, update_options
2031 ):
2032 matched_objects = cls._get_matched_objects_on_criteria(
2033 update_options,
2034 session.identity_map.all_states(),
2035 )
2036
2037 to_delete = []
2038
2039 for _, state, dict_, is_partially_expired in matched_objects:
2040 if is_partially_expired:
2041 state._expire(dict_, session.identity_map._modified)
2042 else:
2043 to_delete.append(state)
2044
2045 if to_delete:
2046 session._remove_newly_deleted(to_delete)
2047
2048 @classmethod
2049 def _do_post_synchronize_fetch(
2050 cls, session, statement, result, update_options
2051 ):
2052 target_mapper = update_options._subject_mapper
2053
2054 returned_defaults_rows = result.returned_defaults_rows
2055
2056 if returned_defaults_rows:
2057 pk_rows = cls._interpret_returning_rows(
2058 target_mapper, returned_defaults_rows
2059 )
2060
2061 matched_rows = [
2062 tuple(row) + (update_options._identity_token,)
2063 for row in pk_rows
2064 ]
2065 else:
2066 matched_rows = update_options._matched_rows
2067
2068 for row in matched_rows:
2069 primary_key = row[0:-1]
2070 identity_token = row[-1]
2071
2072 # TODO: inline this and call remove_newly_deleted
2073 # once
2074 identity_key = target_mapper.identity_key_from_primary_key(
2075 list(primary_key),
2076 identity_token=identity_token,
2077 )
2078 if identity_key in session.identity_map:
2079 session._remove_newly_deleted(
2080 [
2081 attributes.instance_state(
2082 session.identity_map[identity_key]
2083 )
2084 ]
2085 )