1# orm/bulk_persistence.py
2# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: ignore-errors
8
9
10"""additional ORM persistence classes related to "bulk" operations,
11specifically outside of the flush() process.
12
13"""
14
15from __future__ import annotations
16
17from typing import Any
18from typing import cast
19from typing import Dict
20from typing import Iterable
21from typing import Optional
22from typing import overload
23from typing import TYPE_CHECKING
24from typing import TypeVar
25from typing import Union
26
27from . import attributes
28from . import context
29from . import evaluator
30from . import exc as orm_exc
31from . import loading
32from . import persistence
33from .base import NO_VALUE
34from .context import AbstractORMCompileState
35from .context import FromStatement
36from .context import ORMFromStatementCompileState
37from .context import QueryContext
38from .. import exc as sa_exc
39from .. import util
40from ..engine import Dialect
41from ..engine import result as _result
42from ..sql import coercions
43from ..sql import dml
44from ..sql import expression
45from ..sql import roles
46from ..sql import select
47from ..sql import sqltypes
48from ..sql.base import _entity_namespace_key
49from ..sql.base import CompileState
50from ..sql.base import Options
51from ..sql.dml import DeleteDMLState
52from ..sql.dml import InsertDMLState
53from ..sql.dml import UpdateDMLState
54from ..util import EMPTY_DICT
55from ..util.typing import Literal
56from ..util.typing import TupleAny
57from ..util.typing import Unpack
58
59if TYPE_CHECKING:
60 from ._typing import DMLStrategyArgument
61 from ._typing import OrmExecuteOptionsParameter
62 from ._typing import SynchronizeSessionArgument
63 from .mapper import Mapper
64 from .session import _BindArguments
65 from .session import ORMExecuteState
66 from .session import Session
67 from .session import SessionTransaction
68 from .state import InstanceState
69 from ..engine import Connection
70 from ..engine import cursor
71 from ..engine.interfaces import _CoreAnyExecuteParams
72
73_O = TypeVar("_O", bound=object)
74
75
76@overload
77def _bulk_insert(
78 mapper: Mapper[_O],
79 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
80 session_transaction: SessionTransaction,
81 *,
82 isstates: bool,
83 return_defaults: bool,
84 render_nulls: bool,
85 use_orm_insert_stmt: Literal[None] = ...,
86 execution_options: Optional[OrmExecuteOptionsParameter] = ...,
87) -> None: ...
88
89
90@overload
91def _bulk_insert(
92 mapper: Mapper[_O],
93 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
94 session_transaction: SessionTransaction,
95 *,
96 isstates: bool,
97 return_defaults: bool,
98 render_nulls: bool,
99 use_orm_insert_stmt: Optional[dml.Insert] = ...,
100 execution_options: Optional[OrmExecuteOptionsParameter] = ...,
101) -> cursor.CursorResult[Any]: ...
102
103
104def _bulk_insert(
105 mapper: Mapper[_O],
106 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
107 session_transaction: SessionTransaction,
108 *,
109 isstates: bool,
110 return_defaults: bool,
111 render_nulls: bool,
112 use_orm_insert_stmt: Optional[dml.Insert] = None,
113 execution_options: Optional[OrmExecuteOptionsParameter] = None,
114) -> Optional[cursor.CursorResult[Any]]:
115 base_mapper = mapper.base_mapper
116
117 if session_transaction.session.connection_callable:
118 raise NotImplementedError(
119 "connection_callable / per-instance sharding "
120 "not supported in bulk_insert()"
121 )
122
123 if isstates:
124 if return_defaults:
125 states = [(state, state.dict) for state in mappings]
126 mappings = [dict_ for (state, dict_) in states]
127 else:
128 mappings = [state.dict for state in mappings]
129 else:
130 mappings = [dict(m) for m in mappings]
131 _expand_composites(mapper, mappings)
132
133 connection = session_transaction.connection(base_mapper)
134
135 return_result: Optional[cursor.CursorResult[Any]] = None
136
137 mappers_to_run = [
138 (table, mp)
139 for table, mp in base_mapper._sorted_tables.items()
140 if table in mapper._pks_by_table
141 ]
142
143 if return_defaults:
144 # not used by new-style bulk inserts, only used for legacy
145 bookkeeping = True
146 elif len(mappers_to_run) > 1:
147 # if we have more than one table, mapper to run where we will be
148 # either horizontally splicing, or copying values between tables,
149 # we need the "bookkeeping" / deterministic returning order
150 bookkeeping = True
151 else:
152 bookkeeping = False
153
154 for table, super_mapper in mappers_to_run:
155 # find bindparams in the statement. For bulk, we don't really know if
156 # a key in the params applies to a different table since we are
157 # potentially inserting for multiple tables here; looking at the
158 # bindparam() is a lot more direct. in most cases this will
159 # use _generate_cache_key() which is memoized, although in practice
160 # the ultimate statement that's executed is probably not the same
161 # object so that memoization might not matter much.
162 extra_bp_names = (
163 [
164 b.key
165 for b in use_orm_insert_stmt._get_embedded_bindparams()
166 if b.key in mappings[0]
167 ]
168 if use_orm_insert_stmt is not None
169 else ()
170 )
171
172 records = (
173 (
174 None,
175 state_dict,
176 params,
177 mapper,
178 connection,
179 value_params,
180 has_all_pks,
181 has_all_defaults,
182 )
183 for (
184 state,
185 state_dict,
186 params,
187 mp,
188 conn,
189 value_params,
190 has_all_pks,
191 has_all_defaults,
192 ) in persistence._collect_insert_commands(
193 table,
194 ((None, mapping, mapper, connection) for mapping in mappings),
195 bulk=True,
196 return_defaults=bookkeeping,
197 render_nulls=render_nulls,
198 include_bulk_keys=extra_bp_names,
199 )
200 )
201
202 result = persistence._emit_insert_statements(
203 base_mapper,
204 None,
205 super_mapper,
206 table,
207 records,
208 bookkeeping=bookkeeping,
209 use_orm_insert_stmt=use_orm_insert_stmt,
210 execution_options=execution_options,
211 )
212 if use_orm_insert_stmt is not None:
213 if not use_orm_insert_stmt._returning or return_result is None:
214 return_result = result
215 elif result.returns_rows:
216 assert bookkeeping
217 return_result = return_result.splice_horizontally(result)
218
219 if return_defaults and isstates:
220 identity_cls = mapper._identity_class
221 identity_props = [p.key for p in mapper._identity_key_props]
222 for state, dict_ in states:
223 state.key = (
224 identity_cls,
225 tuple([dict_[key] for key in identity_props]),
226 None,
227 )
228
229 if use_orm_insert_stmt is not None:
230 assert return_result is not None
231 return return_result
232
233
234@overload
235def _bulk_update(
236 mapper: Mapper[Any],
237 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
238 session_transaction: SessionTransaction,
239 *,
240 isstates: bool,
241 update_changed_only: bool,
242 use_orm_update_stmt: Literal[None] = ...,
243 enable_check_rowcount: bool = True,
244) -> None: ...
245
246
247@overload
248def _bulk_update(
249 mapper: Mapper[Any],
250 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
251 session_transaction: SessionTransaction,
252 *,
253 isstates: bool,
254 update_changed_only: bool,
255 use_orm_update_stmt: Optional[dml.Update] = ...,
256 enable_check_rowcount: bool = True,
257) -> _result.Result[Unpack[TupleAny]]: ...
258
259
260def _bulk_update(
261 mapper: Mapper[Any],
262 mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
263 session_transaction: SessionTransaction,
264 *,
265 isstates: bool,
266 update_changed_only: bool,
267 use_orm_update_stmt: Optional[dml.Update] = None,
268 enable_check_rowcount: bool = True,
269) -> Optional[_result.Result[Unpack[TupleAny]]]:
270 base_mapper = mapper.base_mapper
271
272 search_keys = mapper._primary_key_propkeys
273 if mapper._version_id_prop:
274 search_keys = {mapper._version_id_prop.key}.union(search_keys)
275
276 def _changed_dict(mapper, state):
277 return {
278 k: v
279 for k, v in state.dict.items()
280 if k in state.committed_state or k in search_keys
281 }
282
283 if isstates:
284 if update_changed_only:
285 mappings = [_changed_dict(mapper, state) for state in mappings]
286 else:
287 mappings = [state.dict for state in mappings]
288 else:
289 mappings = [dict(m) for m in mappings]
290 _expand_composites(mapper, mappings)
291
292 if session_transaction.session.connection_callable:
293 raise NotImplementedError(
294 "connection_callable / per-instance sharding "
295 "not supported in bulk_update()"
296 )
297
298 connection = session_transaction.connection(base_mapper)
299
300 # find bindparams in the statement. see _bulk_insert for similar
301 # notes for the insert case
302 extra_bp_names = (
303 [
304 b.key
305 for b in use_orm_update_stmt._get_embedded_bindparams()
306 if b.key in mappings[0]
307 ]
308 if use_orm_update_stmt is not None
309 else ()
310 )
311
312 for table, super_mapper in base_mapper._sorted_tables.items():
313 if not mapper.isa(super_mapper) or table not in mapper._pks_by_table:
314 continue
315
316 records = persistence._collect_update_commands(
317 None,
318 table,
319 (
320 (
321 None,
322 mapping,
323 mapper,
324 connection,
325 (
326 mapping[mapper._version_id_prop.key]
327 if mapper._version_id_prop
328 else None
329 ),
330 )
331 for mapping in mappings
332 ),
333 bulk=True,
334 use_orm_update_stmt=use_orm_update_stmt,
335 include_bulk_keys=extra_bp_names,
336 )
337 persistence._emit_update_statements(
338 base_mapper,
339 None,
340 super_mapper,
341 table,
342 records,
343 bookkeeping=False,
344 use_orm_update_stmt=use_orm_update_stmt,
345 enable_check_rowcount=enable_check_rowcount,
346 )
347
348 if use_orm_update_stmt is not None:
349 return _result.null_result()
350
351
352def _expand_composites(mapper, mappings):
353 composite_attrs = mapper.composites
354 if not composite_attrs:
355 return
356
357 composite_keys = set(composite_attrs.keys())
358 populators = {
359 key: composite_attrs[key]._populate_composite_bulk_save_mappings_fn()
360 for key in composite_keys
361 }
362 for mapping in mappings:
363 for key in composite_keys.intersection(mapping):
364 populators[key](mapping)
365
366
367class ORMDMLState(AbstractORMCompileState):
368 is_dml_returning = True
369 from_statement_ctx: Optional[ORMFromStatementCompileState] = None
370
371 @classmethod
372 def _get_orm_crud_kv_pairs(
373 cls, mapper, statement, kv_iterator, needs_to_be_cacheable
374 ):
375 core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs
376
377 for k, v in kv_iterator:
378 k = coercions.expect(roles.DMLColumnRole, k)
379
380 if isinstance(k, str):
381 desc = _entity_namespace_key(mapper, k, default=NO_VALUE)
382 if desc is NO_VALUE:
383 yield (
384 coercions.expect(roles.DMLColumnRole, k),
385 (
386 coercions.expect(
387 roles.ExpressionElementRole,
388 v,
389 type_=sqltypes.NullType(),
390 is_crud=True,
391 )
392 if needs_to_be_cacheable
393 else v
394 ),
395 )
396 else:
397 yield from core_get_crud_kv_pairs(
398 statement,
399 desc._bulk_update_tuples(v),
400 needs_to_be_cacheable,
401 )
402 elif "entity_namespace" in k._annotations:
403 k_anno = k._annotations
404 attr = _entity_namespace_key(
405 k_anno["entity_namespace"], k_anno["proxy_key"]
406 )
407 yield from core_get_crud_kv_pairs(
408 statement,
409 attr._bulk_update_tuples(v),
410 needs_to_be_cacheable,
411 )
412 else:
413 yield (
414 k,
415 (
416 v
417 if not needs_to_be_cacheable
418 else coercions.expect(
419 roles.ExpressionElementRole,
420 v,
421 type_=sqltypes.NullType(),
422 is_crud=True,
423 )
424 ),
425 )
426
427 @classmethod
428 def _get_multi_crud_kv_pairs(cls, statement, kv_iterator):
429 plugin_subject = statement._propagate_attrs["plugin_subject"]
430
431 if not plugin_subject or not plugin_subject.mapper:
432 return UpdateDMLState._get_multi_crud_kv_pairs(
433 statement, kv_iterator
434 )
435
436 return [
437 dict(
438 cls._get_orm_crud_kv_pairs(
439 plugin_subject.mapper, statement, value_dict.items(), False
440 )
441 )
442 for value_dict in kv_iterator
443 ]
444
445 @classmethod
446 def _get_crud_kv_pairs(cls, statement, kv_iterator, needs_to_be_cacheable):
447 assert (
448 needs_to_be_cacheable
449 ), "no test coverage for needs_to_be_cacheable=False"
450
451 plugin_subject = statement._propagate_attrs["plugin_subject"]
452
453 if not plugin_subject or not plugin_subject.mapper:
454 return UpdateDMLState._get_crud_kv_pairs(
455 statement, kv_iterator, needs_to_be_cacheable
456 )
457
458 return list(
459 cls._get_orm_crud_kv_pairs(
460 plugin_subject.mapper,
461 statement,
462 kv_iterator,
463 needs_to_be_cacheable,
464 )
465 )
466
467 @classmethod
468 def get_entity_description(cls, statement):
469 ext_info = statement.table._annotations["parententity"]
470 mapper = ext_info.mapper
471 if ext_info.is_aliased_class:
472 _label_name = ext_info.name
473 else:
474 _label_name = mapper.class_.__name__
475
476 return {
477 "name": _label_name,
478 "type": mapper.class_,
479 "expr": ext_info.entity,
480 "entity": ext_info.entity,
481 "table": mapper.local_table,
482 }
483
484 @classmethod
485 def get_returning_column_descriptions(cls, statement):
486 def _ent_for_col(c):
487 return c._annotations.get("parententity", None)
488
489 def _attr_for_col(c, ent):
490 if ent is None:
491 return c
492 proxy_key = c._annotations.get("proxy_key", None)
493 if not proxy_key:
494 return c
495 else:
496 return getattr(ent.entity, proxy_key, c)
497
498 return [
499 {
500 "name": c.key,
501 "type": c.type,
502 "expr": _attr_for_col(c, ent),
503 "aliased": ent.is_aliased_class,
504 "entity": ent.entity,
505 }
506 for c, ent in [
507 (c, _ent_for_col(c)) for c in statement._all_selected_columns
508 ]
509 ]
510
511 def _setup_orm_returning(
512 self,
513 compiler,
514 orm_level_statement,
515 dml_level_statement,
516 dml_mapper,
517 *,
518 use_supplemental_cols=True,
519 ):
520 """establish ORM column handlers for an INSERT, UPDATE, or DELETE
521 which uses explicit returning().
522
523 called within compilation level create_for_statement.
524
525 The _return_orm_returning() method then receives the Result
526 after the statement was executed, and applies ORM loading to the
527 state that we first established here.
528
529 """
530
531 if orm_level_statement._returning:
532 fs = FromStatement(
533 orm_level_statement._returning,
534 dml_level_statement,
535 _adapt_on_names=False,
536 )
537 fs = fs.execution_options(**orm_level_statement._execution_options)
538 fs = fs.options(*orm_level_statement._with_options)
539 self.select_statement = fs
540 self.from_statement_ctx = fsc = (
541 ORMFromStatementCompileState.create_for_statement(fs, compiler)
542 )
543 fsc.setup_dml_returning_compile_state(dml_mapper)
544
545 dml_level_statement = dml_level_statement._generate()
546 dml_level_statement._returning = ()
547
548 cols_to_return = [c for c in fsc.primary_columns if c is not None]
549
550 # since we are splicing result sets together, make sure there
551 # are columns of some kind returned in each result set
552 if not cols_to_return:
553 cols_to_return.extend(dml_mapper.primary_key)
554
555 if use_supplemental_cols:
556 dml_level_statement = dml_level_statement.return_defaults(
557 # this is a little weird looking, but by passing
558 # primary key as the main list of cols, this tells
559 # return_defaults to omit server-default cols (and
560 # actually all cols, due to some weird thing we should
561 # clean up in crud.py).
562 # Since we have cols_to_return, just return what we asked
563 # for (plus primary key, which ORM persistence needs since
564 # we likely set bookkeeping=True here, which is another
565 # whole thing...). We dont want to clutter the
566 # statement up with lots of other cols the user didn't
567 # ask for. see #9685
568 *dml_mapper.primary_key,
569 supplemental_cols=cols_to_return,
570 )
571 else:
572 dml_level_statement = dml_level_statement.returning(
573 *cols_to_return
574 )
575
576 return dml_level_statement
577
578 @classmethod
579 def _return_orm_returning(
580 cls,
581 session,
582 statement,
583 params,
584 execution_options,
585 bind_arguments,
586 result,
587 ):
588 execution_context = result.context
589 compile_state = execution_context.compiled.compile_state
590
591 if (
592 compile_state.from_statement_ctx
593 and not compile_state.from_statement_ctx.compile_options._is_star
594 ):
595 load_options = execution_options.get(
596 "_sa_orm_load_options", QueryContext.default_load_options
597 )
598
599 querycontext = QueryContext(
600 compile_state.from_statement_ctx,
601 compile_state.select_statement,
602 params,
603 session,
604 load_options,
605 execution_options,
606 bind_arguments,
607 )
608 return loading.instances(result, querycontext)
609 else:
610 return result
611
612
613class BulkUDCompileState(ORMDMLState):
614 class default_update_options(Options):
615 _dml_strategy: DMLStrategyArgument = "auto"
616 _synchronize_session: SynchronizeSessionArgument = "auto"
617 _can_use_returning: bool = False
618 _is_delete_using: bool = False
619 _is_update_from: bool = False
620 _autoflush: bool = True
621 _subject_mapper: Optional[Mapper[Any]] = None
622 _resolved_values = EMPTY_DICT
623 _eval_condition = None
624 _matched_rows = None
625 _identity_token = None
626
627 @classmethod
628 def can_use_returning(
629 cls,
630 dialect: Dialect,
631 mapper: Mapper[Any],
632 *,
633 is_multitable: bool = False,
634 is_update_from: bool = False,
635 is_delete_using: bool = False,
636 is_executemany: bool = False,
637 ) -> bool:
638 raise NotImplementedError()
639
640 @classmethod
641 def orm_pre_session_exec(
642 cls,
643 session,
644 statement,
645 params,
646 execution_options,
647 bind_arguments,
648 is_pre_event,
649 ):
650 (
651 update_options,
652 execution_options,
653 ) = BulkUDCompileState.default_update_options.from_execution_options(
654 "_sa_orm_update_options",
655 {
656 "synchronize_session",
657 "autoflush",
658 "identity_token",
659 "is_delete_using",
660 "is_update_from",
661 "dml_strategy",
662 },
663 execution_options,
664 statement._execution_options,
665 )
666 bind_arguments["clause"] = statement
667 try:
668 plugin_subject = statement._propagate_attrs["plugin_subject"]
669 except KeyError:
670 assert False, "statement had 'orm' plugin but no plugin_subject"
671 else:
672 if plugin_subject:
673 bind_arguments["mapper"] = plugin_subject.mapper
674 update_options += {"_subject_mapper": plugin_subject.mapper}
675
676 if "parententity" not in statement.table._annotations:
677 update_options += {"_dml_strategy": "core_only"}
678 elif not isinstance(params, list):
679 if update_options._dml_strategy == "auto":
680 update_options += {"_dml_strategy": "orm"}
681 elif update_options._dml_strategy == "bulk":
682 raise sa_exc.InvalidRequestError(
683 'Can\'t use "bulk" ORM insert strategy without '
684 "passing separate parameters"
685 )
686 else:
687 if update_options._dml_strategy == "auto":
688 update_options += {"_dml_strategy": "bulk"}
689
690 sync = update_options._synchronize_session
691 if sync is not None:
692 if sync not in ("auto", "evaluate", "fetch", False):
693 raise sa_exc.ArgumentError(
694 "Valid strategies for session synchronization "
695 "are 'auto', 'evaluate', 'fetch', False"
696 )
697 if update_options._dml_strategy == "bulk" and sync == "fetch":
698 raise sa_exc.InvalidRequestError(
699 "The 'fetch' synchronization strategy is not available "
700 "for 'bulk' ORM updates (i.e. multiple parameter sets)"
701 )
702
703 if not is_pre_event:
704 if update_options._autoflush:
705 session._autoflush()
706
707 if update_options._dml_strategy == "orm":
708 if update_options._synchronize_session == "auto":
709 update_options = cls._do_pre_synchronize_auto(
710 session,
711 statement,
712 params,
713 execution_options,
714 bind_arguments,
715 update_options,
716 )
717 elif update_options._synchronize_session == "evaluate":
718 update_options = cls._do_pre_synchronize_evaluate(
719 session,
720 statement,
721 params,
722 execution_options,
723 bind_arguments,
724 update_options,
725 )
726 elif update_options._synchronize_session == "fetch":
727 update_options = cls._do_pre_synchronize_fetch(
728 session,
729 statement,
730 params,
731 execution_options,
732 bind_arguments,
733 update_options,
734 )
735 elif update_options._dml_strategy == "bulk":
736 if update_options._synchronize_session == "auto":
737 update_options += {"_synchronize_session": "evaluate"}
738
739 # indicators from the "pre exec" step that are then
740 # added to the DML statement, which will also be part of the cache
741 # key. The compile level create_for_statement() method will then
742 # consume these at compiler time.
743 statement = statement._annotate(
744 {
745 "synchronize_session": update_options._synchronize_session,
746 "is_delete_using": update_options._is_delete_using,
747 "is_update_from": update_options._is_update_from,
748 "dml_strategy": update_options._dml_strategy,
749 "can_use_returning": update_options._can_use_returning,
750 }
751 )
752
753 return (
754 statement,
755 util.immutabledict(execution_options).union(
756 {"_sa_orm_update_options": update_options}
757 ),
758 )
759
760 @classmethod
761 def orm_setup_cursor_result(
762 cls,
763 session,
764 statement,
765 params,
766 execution_options,
767 bind_arguments,
768 result,
769 ):
770 # this stage of the execution is called after the
771 # do_orm_execute event hook. meaning for an extension like
772 # horizontal sharding, this step happens *within* the horizontal
773 # sharding event handler which calls session.execute() re-entrantly
774 # and will occur for each backend individually.
775 # the sharding extension then returns its own merged result from the
776 # individual ones we return here.
777
778 update_options = execution_options["_sa_orm_update_options"]
779 if update_options._dml_strategy == "orm":
780 if update_options._synchronize_session == "evaluate":
781 cls._do_post_synchronize_evaluate(
782 session, statement, result, update_options
783 )
784 elif update_options._synchronize_session == "fetch":
785 cls._do_post_synchronize_fetch(
786 session, statement, result, update_options
787 )
788 elif update_options._dml_strategy == "bulk":
789 if update_options._synchronize_session == "evaluate":
790 cls._do_post_synchronize_bulk_evaluate(
791 session, params, result, update_options
792 )
793 return result
794
795 return cls._return_orm_returning(
796 session,
797 statement,
798 params,
799 execution_options,
800 bind_arguments,
801 result,
802 )
803
804 @classmethod
805 def _adjust_for_extra_criteria(cls, global_attributes, ext_info):
806 """Apply extra criteria filtering.
807
808 For all distinct single-table-inheritance mappers represented in the
809 table being updated or deleted, produce additional WHERE criteria such
810 that only the appropriate subtypes are selected from the total results.
811
812 Additionally, add WHERE criteria originating from LoaderCriteriaOptions
813 collected from the statement.
814
815 """
816
817 return_crit = ()
818
819 adapter = ext_info._adapter if ext_info.is_aliased_class else None
820
821 if (
822 "additional_entity_criteria",
823 ext_info.mapper,
824 ) in global_attributes:
825 return_crit += tuple(
826 ae._resolve_where_criteria(ext_info)
827 for ae in global_attributes[
828 ("additional_entity_criteria", ext_info.mapper)
829 ]
830 if ae.include_aliases or ae.entity is ext_info
831 )
832
833 if ext_info.mapper._single_table_criterion is not None:
834 return_crit += (ext_info.mapper._single_table_criterion,)
835
836 if adapter:
837 return_crit = tuple(adapter.traverse(crit) for crit in return_crit)
838
839 return return_crit
840
841 @classmethod
842 def _interpret_returning_rows(cls, mapper, rows):
843 """translate from local inherited table columns to base mapper
844 primary key columns.
845
846 Joined inheritance mappers always establish the primary key in terms of
847 the base table. When we UPDATE a sub-table, we can only get
848 RETURNING for the sub-table's columns.
849
850 Here, we create a lookup from the local sub table's primary key
851 columns to the base table PK columns so that we can get identity
852 key values from RETURNING that's against the joined inheritance
853 sub-table.
854
855 the complexity here is to support more than one level deep of
856 inheritance, where we have to link columns to each other across
857 the inheritance hierarchy.
858
859 """
860
861 if mapper.local_table is not mapper.base_mapper.local_table:
862 return rows
863
864 # this starts as a mapping of
865 # local_pk_col: local_pk_col.
866 # we will then iteratively rewrite the "value" of the dict with
867 # each successive superclass column
868 local_pk_to_base_pk = {pk: pk for pk in mapper.local_table.primary_key}
869
870 for mp in mapper.iterate_to_root():
871 if mp.inherits is None:
872 break
873 elif mp.local_table is mp.inherits.local_table:
874 continue
875
876 t_to_e = dict(mp._table_to_equated[mp.inherits.local_table])
877 col_to_col = {sub_pk: super_pk for super_pk, sub_pk in t_to_e[mp]}
878 for pk, super_ in local_pk_to_base_pk.items():
879 local_pk_to_base_pk[pk] = col_to_col[super_]
880
881 lookup = {
882 local_pk_to_base_pk[lpk]: idx
883 for idx, lpk in enumerate(mapper.local_table.primary_key)
884 }
885 primary_key_convert = [
886 lookup[bpk] for bpk in mapper.base_mapper.primary_key
887 ]
888 return [tuple(row[idx] for idx in primary_key_convert) for row in rows]
889
890 @classmethod
891 def _get_matched_objects_on_criteria(cls, update_options, states):
892 mapper = update_options._subject_mapper
893 eval_condition = update_options._eval_condition
894
895 raw_data = [
896 (state.obj(), state, state.dict)
897 for state in states
898 if state.mapper.isa(mapper) and not state.expired
899 ]
900
901 identity_token = update_options._identity_token
902 if identity_token is not None:
903 raw_data = [
904 (obj, state, dict_)
905 for obj, state, dict_ in raw_data
906 if state.identity_token == identity_token
907 ]
908
909 result = []
910 for obj, state, dict_ in raw_data:
911 evaled_condition = eval_condition(obj)
912
913 # caution: don't use "in ()" or == here, _EXPIRE_OBJECT
914 # evaluates as True for all comparisons
915 if (
916 evaled_condition is True
917 or evaled_condition is evaluator._EXPIRED_OBJECT
918 ):
919 result.append(
920 (
921 obj,
922 state,
923 dict_,
924 evaled_condition is evaluator._EXPIRED_OBJECT,
925 )
926 )
927 return result
928
929 @classmethod
930 def _eval_condition_from_statement(cls, update_options, statement):
931 mapper = update_options._subject_mapper
932 target_cls = mapper.class_
933
934 evaluator_compiler = evaluator._EvaluatorCompiler(target_cls)
935 crit = ()
936 if statement._where_criteria:
937 crit += statement._where_criteria
938
939 global_attributes = {}
940 for opt in statement._with_options:
941 if opt._is_criteria_option:
942 opt.get_global_criteria(global_attributes)
943
944 if global_attributes:
945 crit += cls._adjust_for_extra_criteria(global_attributes, mapper)
946
947 if crit:
948 eval_condition = evaluator_compiler.process(*crit)
949 else:
950 # workaround for mypy https://github.com/python/mypy/issues/14027
951 def _eval_condition(obj):
952 return True
953
954 eval_condition = _eval_condition
955
956 return eval_condition
957
958 @classmethod
959 def _do_pre_synchronize_auto(
960 cls,
961 session,
962 statement,
963 params,
964 execution_options,
965 bind_arguments,
966 update_options,
967 ):
968 """setup auto sync strategy
969
970
971 "auto" checks if we can use "evaluate" first, then falls back
972 to "fetch"
973
974 evaluate is vastly more efficient for the common case
975 where session is empty, only has a few objects, and the UPDATE
976 statement can potentially match thousands/millions of rows.
977
978 OTOH more complex criteria that fails to work with "evaluate"
979 we would hope usually correlates with fewer net rows.
980
981 """
982
983 try:
984 eval_condition = cls._eval_condition_from_statement(
985 update_options, statement
986 )
987
988 except evaluator.UnevaluatableError:
989 pass
990 else:
991 return update_options + {
992 "_eval_condition": eval_condition,
993 "_synchronize_session": "evaluate",
994 }
995
996 update_options += {"_synchronize_session": "fetch"}
997 return cls._do_pre_synchronize_fetch(
998 session,
999 statement,
1000 params,
1001 execution_options,
1002 bind_arguments,
1003 update_options,
1004 )
1005
1006 @classmethod
1007 def _do_pre_synchronize_evaluate(
1008 cls,
1009 session,
1010 statement,
1011 params,
1012 execution_options,
1013 bind_arguments,
1014 update_options,
1015 ):
1016 try:
1017 eval_condition = cls._eval_condition_from_statement(
1018 update_options, statement
1019 )
1020
1021 except evaluator.UnevaluatableError as err:
1022 raise sa_exc.InvalidRequestError(
1023 'Could not evaluate current criteria in Python: "%s". '
1024 "Specify 'fetch' or False for the "
1025 "synchronize_session execution option." % err
1026 ) from err
1027
1028 return update_options + {
1029 "_eval_condition": eval_condition,
1030 }
1031
1032 @classmethod
1033 def _get_resolved_values(cls, mapper, statement):
1034 if statement._multi_values:
1035 return []
1036 elif statement._ordered_values:
1037 return list(statement._ordered_values)
1038 elif statement._values:
1039 return list(statement._values.items())
1040 else:
1041 return []
1042
1043 @classmethod
1044 def _resolved_keys_as_propnames(cls, mapper, resolved_values):
1045 values = []
1046 for k, v in resolved_values:
1047 if mapper and isinstance(k, expression.ColumnElement):
1048 try:
1049 attr = mapper._columntoproperty[k]
1050 except orm_exc.UnmappedColumnError:
1051 pass
1052 else:
1053 values.append((attr.key, v))
1054 else:
1055 raise sa_exc.InvalidRequestError(
1056 "Attribute name not found, can't be "
1057 "synchronized back to objects: %r" % k
1058 )
1059 return values
1060
1061 @classmethod
1062 def _do_pre_synchronize_fetch(
1063 cls,
1064 session,
1065 statement,
1066 params,
1067 execution_options,
1068 bind_arguments,
1069 update_options,
1070 ):
1071 mapper = update_options._subject_mapper
1072
1073 select_stmt = (
1074 select(*(mapper.primary_key + (mapper.select_identity_token,)))
1075 .select_from(mapper)
1076 .options(*statement._with_options)
1077 )
1078 select_stmt._where_criteria = statement._where_criteria
1079
1080 # conditionally run the SELECT statement for pre-fetch, testing the
1081 # "bind" for if we can use RETURNING or not using the do_orm_execute
1082 # event. If RETURNING is available, the do_orm_execute event
1083 # will cancel the SELECT from being actually run.
1084 #
1085 # The way this is organized seems strange, why don't we just
1086 # call can_use_returning() before invoking the statement and get
1087 # answer?, why does this go through the whole execute phase using an
1088 # event? Answer: because we are integrating with extensions such
1089 # as the horizontal sharding extention that "multiplexes" an individual
1090 # statement run through multiple engines, and it uses
1091 # do_orm_execute() to do that.
1092
1093 can_use_returning = None
1094
1095 def skip_for_returning(orm_context: ORMExecuteState) -> Any:
1096 bind = orm_context.session.get_bind(**orm_context.bind_arguments)
1097 nonlocal can_use_returning
1098
1099 per_bind_result = cls.can_use_returning(
1100 bind.dialect,
1101 mapper,
1102 is_update_from=update_options._is_update_from,
1103 is_delete_using=update_options._is_delete_using,
1104 is_executemany=orm_context.is_executemany,
1105 )
1106
1107 if can_use_returning is not None:
1108 if can_use_returning != per_bind_result:
1109 raise sa_exc.InvalidRequestError(
1110 "For synchronize_session='fetch', can't mix multiple "
1111 "backends where some support RETURNING and others "
1112 "don't"
1113 )
1114 elif orm_context.is_executemany and not per_bind_result:
1115 raise sa_exc.InvalidRequestError(
1116 "For synchronize_session='fetch', can't use multiple "
1117 "parameter sets in ORM mode, which this backend does not "
1118 "support with RETURNING"
1119 )
1120 else:
1121 can_use_returning = per_bind_result
1122
1123 if per_bind_result:
1124 return _result.null_result()
1125 else:
1126 return None
1127
1128 result = session.execute(
1129 select_stmt,
1130 params,
1131 execution_options=execution_options,
1132 bind_arguments=bind_arguments,
1133 _add_event=skip_for_returning,
1134 )
1135 matched_rows = result.fetchall()
1136
1137 return update_options + {
1138 "_matched_rows": matched_rows,
1139 "_can_use_returning": can_use_returning,
1140 }
1141
1142
1143@CompileState.plugin_for("orm", "insert")
1144class BulkORMInsert(ORMDMLState, InsertDMLState):
1145 class default_insert_options(Options):
1146 _dml_strategy: DMLStrategyArgument = "auto"
1147 _render_nulls: bool = False
1148 _return_defaults: bool = False
1149 _subject_mapper: Optional[Mapper[Any]] = None
1150 _autoflush: bool = True
1151 _populate_existing: bool = False
1152
1153 select_statement: Optional[FromStatement] = None
1154
1155 @classmethod
1156 def orm_pre_session_exec(
1157 cls,
1158 session,
1159 statement,
1160 params,
1161 execution_options,
1162 bind_arguments,
1163 is_pre_event,
1164 ):
1165 (
1166 insert_options,
1167 execution_options,
1168 ) = BulkORMInsert.default_insert_options.from_execution_options(
1169 "_sa_orm_insert_options",
1170 {"dml_strategy", "autoflush", "populate_existing", "render_nulls"},
1171 execution_options,
1172 statement._execution_options,
1173 )
1174 bind_arguments["clause"] = statement
1175 try:
1176 plugin_subject = statement._propagate_attrs["plugin_subject"]
1177 except KeyError:
1178 assert False, "statement had 'orm' plugin but no plugin_subject"
1179 else:
1180 if plugin_subject:
1181 bind_arguments["mapper"] = plugin_subject.mapper
1182 insert_options += {"_subject_mapper": plugin_subject.mapper}
1183
1184 if not params:
1185 if insert_options._dml_strategy == "auto":
1186 insert_options += {"_dml_strategy": "orm"}
1187 elif insert_options._dml_strategy == "bulk":
1188 raise sa_exc.InvalidRequestError(
1189 'Can\'t use "bulk" ORM insert strategy without '
1190 "passing separate parameters"
1191 )
1192 else:
1193 if insert_options._dml_strategy == "auto":
1194 insert_options += {"_dml_strategy": "bulk"}
1195
1196 if insert_options._dml_strategy != "raw":
1197 # for ORM object loading, like ORMContext, we have to disable
1198 # result set adapt_to_context, because we will be generating a
1199 # new statement with specific columns that's cached inside of
1200 # an ORMFromStatementCompileState, which we will re-use for
1201 # each result.
1202 if not execution_options:
1203 execution_options = context._orm_load_exec_options
1204 else:
1205 execution_options = execution_options.union(
1206 context._orm_load_exec_options
1207 )
1208
1209 if not is_pre_event and insert_options._autoflush:
1210 session._autoflush()
1211
1212 statement = statement._annotate(
1213 {"dml_strategy": insert_options._dml_strategy}
1214 )
1215
1216 return (
1217 statement,
1218 util.immutabledict(execution_options).union(
1219 {"_sa_orm_insert_options": insert_options}
1220 ),
1221 )
1222
1223 @classmethod
1224 def orm_execute_statement(
1225 cls,
1226 session: Session,
1227 statement: dml.Insert,
1228 params: _CoreAnyExecuteParams,
1229 execution_options: OrmExecuteOptionsParameter,
1230 bind_arguments: _BindArguments,
1231 conn: Connection,
1232 ) -> _result.Result:
1233 insert_options = execution_options.get(
1234 "_sa_orm_insert_options", cls.default_insert_options
1235 )
1236
1237 if insert_options._dml_strategy not in (
1238 "raw",
1239 "bulk",
1240 "orm",
1241 "auto",
1242 ):
1243 raise sa_exc.ArgumentError(
1244 "Valid strategies for ORM insert strategy "
1245 "are 'raw', 'orm', 'bulk', 'auto"
1246 )
1247
1248 result: _result.Result[Unpack[TupleAny]]
1249
1250 if insert_options._dml_strategy == "raw":
1251 result = conn.execute(
1252 statement, params or {}, execution_options=execution_options
1253 )
1254 return result
1255
1256 if insert_options._dml_strategy == "bulk":
1257 mapper = insert_options._subject_mapper
1258
1259 if (
1260 statement._post_values_clause is not None
1261 and mapper._multiple_persistence_tables
1262 ):
1263 raise sa_exc.InvalidRequestError(
1264 "bulk INSERT with a 'post values' clause "
1265 "(typically upsert) not supported for multi-table "
1266 f"mapper {mapper}"
1267 )
1268
1269 assert mapper is not None
1270 assert session._transaction is not None
1271 result = _bulk_insert(
1272 mapper,
1273 cast(
1274 "Iterable[Dict[str, Any]]",
1275 [params] if isinstance(params, dict) else params,
1276 ),
1277 session._transaction,
1278 isstates=False,
1279 return_defaults=insert_options._return_defaults,
1280 render_nulls=insert_options._render_nulls,
1281 use_orm_insert_stmt=statement,
1282 execution_options=execution_options,
1283 )
1284 elif insert_options._dml_strategy == "orm":
1285 result = conn.execute(
1286 statement, params or {}, execution_options=execution_options
1287 )
1288 else:
1289 raise AssertionError()
1290
1291 if not bool(statement._returning):
1292 return result
1293
1294 if insert_options._populate_existing:
1295 load_options = execution_options.get(
1296 "_sa_orm_load_options", QueryContext.default_load_options
1297 )
1298 load_options += {"_populate_existing": True}
1299 execution_options = execution_options.union(
1300 {"_sa_orm_load_options": load_options}
1301 )
1302
1303 return cls._return_orm_returning(
1304 session,
1305 statement,
1306 params,
1307 execution_options,
1308 bind_arguments,
1309 result,
1310 )
1311
1312 @classmethod
1313 def create_for_statement(cls, statement, compiler, **kw) -> BulkORMInsert:
1314 self = cast(
1315 BulkORMInsert,
1316 super().create_for_statement(statement, compiler, **kw),
1317 )
1318
1319 if compiler is not None:
1320 toplevel = not compiler.stack
1321 else:
1322 toplevel = True
1323 if not toplevel:
1324 return self
1325
1326 mapper = statement._propagate_attrs["plugin_subject"]
1327 dml_strategy = statement._annotations.get("dml_strategy", "raw")
1328 if dml_strategy == "bulk":
1329 self._setup_for_bulk_insert(compiler)
1330 elif dml_strategy == "orm":
1331 self._setup_for_orm_insert(compiler, mapper)
1332
1333 return self
1334
1335 @classmethod
1336 def _resolved_keys_as_col_keys(cls, mapper, resolved_value_dict):
1337 return {
1338 col.key if col is not None else k: v
1339 for col, k, v in (
1340 (mapper.c.get(k), k, v) for k, v in resolved_value_dict.items()
1341 )
1342 }
1343
1344 def _setup_for_orm_insert(self, compiler, mapper):
1345 statement = orm_level_statement = cast(dml.Insert, self.statement)
1346
1347 statement = self._setup_orm_returning(
1348 compiler,
1349 orm_level_statement,
1350 statement,
1351 dml_mapper=mapper,
1352 use_supplemental_cols=False,
1353 )
1354 self.statement = statement
1355
1356 def _setup_for_bulk_insert(self, compiler):
1357 """establish an INSERT statement within the context of
1358 bulk insert.
1359
1360 This method will be within the "conn.execute()" call that is invoked
1361 by persistence._emit_insert_statement().
1362
1363 """
1364 statement = orm_level_statement = cast(dml.Insert, self.statement)
1365 an = statement._annotations
1366
1367 emit_insert_table, emit_insert_mapper = (
1368 an["_emit_insert_table"],
1369 an["_emit_insert_mapper"],
1370 )
1371
1372 statement = statement._clone()
1373
1374 statement.table = emit_insert_table
1375 if self._dict_parameters:
1376 self._dict_parameters = {
1377 col: val
1378 for col, val in self._dict_parameters.items()
1379 if col.table is emit_insert_table
1380 }
1381
1382 statement = self._setup_orm_returning(
1383 compiler,
1384 orm_level_statement,
1385 statement,
1386 dml_mapper=emit_insert_mapper,
1387 use_supplemental_cols=True,
1388 )
1389
1390 if (
1391 self.from_statement_ctx is not None
1392 and self.from_statement_ctx.compile_options._is_star
1393 ):
1394 raise sa_exc.CompileError(
1395 "Can't use RETURNING * with bulk ORM INSERT. "
1396 "Please use a different INSERT form, such as INSERT..VALUES "
1397 "or INSERT with a Core Connection"
1398 )
1399
1400 self.statement = statement
1401
1402
1403@CompileState.plugin_for("orm", "update")
1404class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
1405 @classmethod
1406 def create_for_statement(cls, statement, compiler, **kw):
1407 self = cls.__new__(cls)
1408
1409 dml_strategy = statement._annotations.get(
1410 "dml_strategy", "unspecified"
1411 )
1412
1413 toplevel = not compiler.stack
1414
1415 if toplevel and dml_strategy == "bulk":
1416 self._setup_for_bulk_update(statement, compiler)
1417 elif (
1418 dml_strategy == "core_only"
1419 or dml_strategy == "unspecified"
1420 and "parententity" not in statement.table._annotations
1421 ):
1422 UpdateDMLState.__init__(self, statement, compiler, **kw)
1423 elif not toplevel or dml_strategy in ("orm", "unspecified"):
1424 self._setup_for_orm_update(statement, compiler)
1425
1426 return self
1427
1428 def _setup_for_orm_update(self, statement, compiler, **kw):
1429 orm_level_statement = statement
1430
1431 toplevel = not compiler.stack
1432
1433 ext_info = statement.table._annotations["parententity"]
1434
1435 self.mapper = mapper = ext_info.mapper
1436
1437 self._resolved_values = self._get_resolved_values(mapper, statement)
1438
1439 self._init_global_attributes(
1440 statement,
1441 compiler,
1442 toplevel=toplevel,
1443 process_criteria_for_toplevel=toplevel,
1444 )
1445
1446 if statement._values:
1447 self._resolved_values = dict(self._resolved_values)
1448
1449 new_stmt = statement._clone()
1450
1451 # note if the statement has _multi_values, these
1452 # are passed through to the new statement, which will then raise
1453 # InvalidRequestError because UPDATE doesn't support multi_values
1454 # right now.
1455 if statement._ordered_values:
1456 new_stmt._ordered_values = self._resolved_values
1457 elif statement._values:
1458 new_stmt._values = self._resolved_values
1459
1460 new_crit = self._adjust_for_extra_criteria(
1461 self.global_attributes, mapper
1462 )
1463 if new_crit:
1464 new_stmt = new_stmt.where(*new_crit)
1465
1466 # if we are against a lambda statement we might not be the
1467 # topmost object that received per-execute annotations
1468
1469 # do this first as we need to determine if there is
1470 # UPDATE..FROM
1471
1472 UpdateDMLState.__init__(self, new_stmt, compiler, **kw)
1473
1474 use_supplemental_cols = False
1475
1476 if not toplevel:
1477 synchronize_session = None
1478 else:
1479 synchronize_session = compiler._annotations.get(
1480 "synchronize_session", None
1481 )
1482 can_use_returning = compiler._annotations.get(
1483 "can_use_returning", None
1484 )
1485 if can_use_returning is not False:
1486 # even though pre_exec has determined basic
1487 # can_use_returning for the dialect, if we are to use
1488 # RETURNING we need to run can_use_returning() at this level
1489 # unconditionally because is_delete_using was not known
1490 # at the pre_exec level
1491 can_use_returning = (
1492 synchronize_session == "fetch"
1493 and self.can_use_returning(
1494 compiler.dialect, mapper, is_multitable=self.is_multitable
1495 )
1496 )
1497
1498 if synchronize_session == "fetch" and can_use_returning:
1499 use_supplemental_cols = True
1500
1501 # NOTE: we might want to RETURNING the actual columns to be
1502 # synchronized also. however this is complicated and difficult
1503 # to align against the behavior of "evaluate". Additionally,
1504 # in a large number (if not the majority) of cases, we have the
1505 # "evaluate" answer, usually a fixed value, in memory already and
1506 # there's no need to re-fetch the same value
1507 # over and over again. so perhaps if it could be RETURNING just
1508 # the elements that were based on a SQL expression and not
1509 # a constant. For now it doesn't quite seem worth it
1510 new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key)
1511
1512 if toplevel:
1513 new_stmt = self._setup_orm_returning(
1514 compiler,
1515 orm_level_statement,
1516 new_stmt,
1517 dml_mapper=mapper,
1518 use_supplemental_cols=use_supplemental_cols,
1519 )
1520
1521 self.statement = new_stmt
1522
1523 def _setup_for_bulk_update(self, statement, compiler, **kw):
1524 """establish an UPDATE statement within the context of
1525 bulk insert.
1526
1527 This method will be within the "conn.execute()" call that is invoked
1528 by persistence._emit_update_statement().
1529
1530 """
1531 statement = cast(dml.Update, statement)
1532 an = statement._annotations
1533
1534 emit_update_table, _ = (
1535 an["_emit_update_table"],
1536 an["_emit_update_mapper"],
1537 )
1538
1539 statement = statement._clone()
1540 statement.table = emit_update_table
1541
1542 UpdateDMLState.__init__(self, statement, compiler, **kw)
1543
1544 if self._ordered_values:
1545 raise sa_exc.InvalidRequestError(
1546 "bulk ORM UPDATE does not support ordered_values() for "
1547 "custom UPDATE statements with bulk parameter sets. Use a "
1548 "non-bulk UPDATE statement or use values()."
1549 )
1550
1551 if self._dict_parameters:
1552 self._dict_parameters = {
1553 col: val
1554 for col, val in self._dict_parameters.items()
1555 if col.table is emit_update_table
1556 }
1557 self.statement = statement
1558
1559 @classmethod
1560 def orm_execute_statement(
1561 cls,
1562 session: Session,
1563 statement: dml.Update,
1564 params: _CoreAnyExecuteParams,
1565 execution_options: OrmExecuteOptionsParameter,
1566 bind_arguments: _BindArguments,
1567 conn: Connection,
1568 ) -> _result.Result:
1569 update_options = execution_options.get(
1570 "_sa_orm_update_options", cls.default_update_options
1571 )
1572
1573 if update_options._dml_strategy not in (
1574 "orm",
1575 "auto",
1576 "bulk",
1577 "core_only",
1578 ):
1579 raise sa_exc.ArgumentError(
1580 "Valid strategies for ORM UPDATE strategy "
1581 "are 'orm', 'auto', 'bulk', 'core_only'"
1582 )
1583
1584 result: _result.Result[Unpack[TupleAny]]
1585
1586 if update_options._dml_strategy == "bulk":
1587 enable_check_rowcount = not statement._where_criteria
1588
1589 assert update_options._synchronize_session != "fetch"
1590
1591 if (
1592 statement._where_criteria
1593 and update_options._synchronize_session == "evaluate"
1594 ):
1595 raise sa_exc.InvalidRequestError(
1596 "bulk synchronize of persistent objects not supported "
1597 "when using bulk update with additional WHERE "
1598 "criteria right now. add synchronize_session=None "
1599 "execution option to bypass synchronize of persistent "
1600 "objects."
1601 )
1602 mapper = update_options._subject_mapper
1603 assert mapper is not None
1604 assert session._transaction is not None
1605 result = _bulk_update(
1606 mapper,
1607 cast(
1608 "Iterable[Dict[str, Any]]",
1609 [params] if isinstance(params, dict) else params,
1610 ),
1611 session._transaction,
1612 isstates=False,
1613 update_changed_only=False,
1614 use_orm_update_stmt=statement,
1615 enable_check_rowcount=enable_check_rowcount,
1616 )
1617 return cls.orm_setup_cursor_result(
1618 session,
1619 statement,
1620 params,
1621 execution_options,
1622 bind_arguments,
1623 result,
1624 )
1625 else:
1626 return super().orm_execute_statement(
1627 session,
1628 statement,
1629 params,
1630 execution_options,
1631 bind_arguments,
1632 conn,
1633 )
1634
1635 @classmethod
1636 def can_use_returning(
1637 cls,
1638 dialect: Dialect,
1639 mapper: Mapper[Any],
1640 *,
1641 is_multitable: bool = False,
1642 is_update_from: bool = False,
1643 is_delete_using: bool = False,
1644 is_executemany: bool = False,
1645 ) -> bool:
1646 # normal answer for "should we use RETURNING" at all.
1647 normal_answer = (
1648 dialect.update_returning and mapper.local_table.implicit_returning
1649 )
1650 if not normal_answer:
1651 return False
1652
1653 if is_executemany:
1654 return dialect.update_executemany_returning
1655
1656 # these workarounds are currently hypothetical for UPDATE,
1657 # unlike DELETE where they impact MariaDB
1658 if is_update_from:
1659 return dialect.update_returning_multifrom
1660
1661 elif is_multitable and not dialect.update_returning_multifrom:
1662 raise sa_exc.CompileError(
1663 f'Dialect "{dialect.name}" does not support RETURNING '
1664 "with UPDATE..FROM; for synchronize_session='fetch', "
1665 "please add the additional execution option "
1666 "'is_update_from=True' to the statement to indicate that "
1667 "a separate SELECT should be used for this backend."
1668 )
1669
1670 return True
1671
1672 @classmethod
1673 def _do_post_synchronize_bulk_evaluate(
1674 cls, session, params, result, update_options
1675 ):
1676 if not params:
1677 return
1678
1679 mapper = update_options._subject_mapper
1680 pk_keys = [prop.key for prop in mapper._identity_key_props]
1681
1682 identity_map = session.identity_map
1683
1684 for param in params:
1685 identity_key = mapper.identity_key_from_primary_key(
1686 (param[key] for key in pk_keys),
1687 update_options._identity_token,
1688 )
1689 state = identity_map.fast_get_state(identity_key)
1690 if not state:
1691 continue
1692
1693 evaluated_keys = set(param).difference(pk_keys)
1694
1695 dict_ = state.dict
1696 # only evaluate unmodified attributes
1697 to_evaluate = state.unmodified.intersection(evaluated_keys)
1698 for key in to_evaluate:
1699 if key in dict_:
1700 dict_[key] = param[key]
1701
1702 state.manager.dispatch.refresh(state, None, to_evaluate)
1703
1704 state._commit(dict_, list(to_evaluate))
1705
1706 # attributes that were formerly modified instead get expired.
1707 # this only gets hit if the session had pending changes
1708 # and autoflush were set to False.
1709 to_expire = evaluated_keys.intersection(dict_).difference(
1710 to_evaluate
1711 )
1712 if to_expire:
1713 state._expire_attributes(dict_, to_expire)
1714
1715 @classmethod
1716 def _do_post_synchronize_evaluate(
1717 cls, session, statement, result, update_options
1718 ):
1719 matched_objects = cls._get_matched_objects_on_criteria(
1720 update_options,
1721 session.identity_map.all_states(),
1722 )
1723
1724 cls._apply_update_set_values_to_objects(
1725 session,
1726 update_options,
1727 statement,
1728 [(obj, state, dict_) for obj, state, dict_, _ in matched_objects],
1729 )
1730
1731 @classmethod
1732 def _do_post_synchronize_fetch(
1733 cls, session, statement, result, update_options
1734 ):
1735 target_mapper = update_options._subject_mapper
1736
1737 returned_defaults_rows = result.returned_defaults_rows
1738 if returned_defaults_rows:
1739 pk_rows = cls._interpret_returning_rows(
1740 target_mapper, returned_defaults_rows
1741 )
1742
1743 matched_rows = [
1744 tuple(row) + (update_options._identity_token,)
1745 for row in pk_rows
1746 ]
1747 else:
1748 matched_rows = update_options._matched_rows
1749
1750 objs = [
1751 session.identity_map[identity_key]
1752 for identity_key in [
1753 target_mapper.identity_key_from_primary_key(
1754 list(primary_key),
1755 identity_token=identity_token,
1756 )
1757 for primary_key, identity_token in [
1758 (row[0:-1], row[-1]) for row in matched_rows
1759 ]
1760 if update_options._identity_token is None
1761 or identity_token == update_options._identity_token
1762 ]
1763 if identity_key in session.identity_map
1764 ]
1765
1766 if not objs:
1767 return
1768
1769 cls._apply_update_set_values_to_objects(
1770 session,
1771 update_options,
1772 statement,
1773 [
1774 (
1775 obj,
1776 attributes.instance_state(obj),
1777 attributes.instance_dict(obj),
1778 )
1779 for obj in objs
1780 ],
1781 )
1782
1783 @classmethod
1784 def _apply_update_set_values_to_objects(
1785 cls, session, update_options, statement, matched_objects
1786 ):
1787 """apply values to objects derived from an update statement, e.g.
1788 UPDATE..SET <values>
1789
1790 """
1791 mapper = update_options._subject_mapper
1792 target_cls = mapper.class_
1793 evaluator_compiler = evaluator._EvaluatorCompiler(target_cls)
1794 resolved_values = cls._get_resolved_values(mapper, statement)
1795 resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
1796 mapper, resolved_values
1797 )
1798 value_evaluators = {}
1799 for key, value in resolved_keys_as_propnames:
1800 try:
1801 _evaluator = evaluator_compiler.process(
1802 coercions.expect(roles.ExpressionElementRole, value)
1803 )
1804 except evaluator.UnevaluatableError:
1805 pass
1806 else:
1807 value_evaluators[key] = _evaluator
1808
1809 evaluated_keys = list(value_evaluators.keys())
1810 attrib = {k for k, v in resolved_keys_as_propnames}
1811
1812 states = set()
1813 for obj, state, dict_ in matched_objects:
1814 to_evaluate = state.unmodified.intersection(evaluated_keys)
1815
1816 for key in to_evaluate:
1817 if key in dict_:
1818 # only run eval for attributes that are present.
1819 dict_[key] = value_evaluators[key](obj)
1820
1821 state.manager.dispatch.refresh(state, None, to_evaluate)
1822
1823 state._commit(dict_, list(to_evaluate))
1824
1825 # attributes that were formerly modified instead get expired.
1826 # this only gets hit if the session had pending changes
1827 # and autoflush were set to False.
1828 to_expire = attrib.intersection(dict_).difference(to_evaluate)
1829 if to_expire:
1830 state._expire_attributes(dict_, to_expire)
1831
1832 states.add(state)
1833 session._register_altered(states)
1834
1835
1836@CompileState.plugin_for("orm", "delete")
1837class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
1838 @classmethod
1839 def create_for_statement(cls, statement, compiler, **kw):
1840 self = cls.__new__(cls)
1841
1842 dml_strategy = statement._annotations.get(
1843 "dml_strategy", "unspecified"
1844 )
1845
1846 if (
1847 dml_strategy == "core_only"
1848 or dml_strategy == "unspecified"
1849 and "parententity" not in statement.table._annotations
1850 ):
1851 DeleteDMLState.__init__(self, statement, compiler, **kw)
1852 return self
1853
1854 toplevel = not compiler.stack
1855
1856 orm_level_statement = statement
1857
1858 ext_info = statement.table._annotations["parententity"]
1859 self.mapper = mapper = ext_info.mapper
1860
1861 self._init_global_attributes(
1862 statement,
1863 compiler,
1864 toplevel=toplevel,
1865 process_criteria_for_toplevel=toplevel,
1866 )
1867
1868 new_stmt = statement._clone()
1869
1870 new_crit = cls._adjust_for_extra_criteria(
1871 self.global_attributes, mapper
1872 )
1873 if new_crit:
1874 new_stmt = new_stmt.where(*new_crit)
1875
1876 # do this first as we need to determine if there is
1877 # DELETE..FROM
1878 DeleteDMLState.__init__(self, new_stmt, compiler, **kw)
1879
1880 use_supplemental_cols = False
1881
1882 if not toplevel:
1883 synchronize_session = None
1884 else:
1885 synchronize_session = compiler._annotations.get(
1886 "synchronize_session", None
1887 )
1888 can_use_returning = compiler._annotations.get(
1889 "can_use_returning", None
1890 )
1891 if can_use_returning is not False:
1892 # even though pre_exec has determined basic
1893 # can_use_returning for the dialect, if we are to use
1894 # RETURNING we need to run can_use_returning() at this level
1895 # unconditionally because is_delete_using was not known
1896 # at the pre_exec level
1897 can_use_returning = (
1898 synchronize_session == "fetch"
1899 and self.can_use_returning(
1900 compiler.dialect,
1901 mapper,
1902 is_multitable=self.is_multitable,
1903 is_delete_using=compiler._annotations.get(
1904 "is_delete_using", False
1905 ),
1906 )
1907 )
1908
1909 if can_use_returning:
1910 use_supplemental_cols = True
1911
1912 new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key)
1913
1914 if toplevel:
1915 new_stmt = self._setup_orm_returning(
1916 compiler,
1917 orm_level_statement,
1918 new_stmt,
1919 dml_mapper=mapper,
1920 use_supplemental_cols=use_supplemental_cols,
1921 )
1922
1923 self.statement = new_stmt
1924
1925 return self
1926
1927 @classmethod
1928 def orm_execute_statement(
1929 cls,
1930 session: Session,
1931 statement: dml.Delete,
1932 params: _CoreAnyExecuteParams,
1933 execution_options: OrmExecuteOptionsParameter,
1934 bind_arguments: _BindArguments,
1935 conn: Connection,
1936 ) -> _result.Result:
1937 update_options = execution_options.get(
1938 "_sa_orm_update_options", cls.default_update_options
1939 )
1940
1941 if update_options._dml_strategy == "bulk":
1942 raise sa_exc.InvalidRequestError(
1943 "Bulk ORM DELETE not supported right now. "
1944 "Statement may be invoked at the "
1945 "Core level using "
1946 "session.connection().execute(stmt, parameters)"
1947 )
1948
1949 if update_options._dml_strategy not in ("orm", "auto", "core_only"):
1950 raise sa_exc.ArgumentError(
1951 "Valid strategies for ORM DELETE strategy are 'orm', 'auto', "
1952 "'core_only'"
1953 )
1954
1955 return super().orm_execute_statement(
1956 session, statement, params, execution_options, bind_arguments, conn
1957 )
1958
1959 @classmethod
1960 def can_use_returning(
1961 cls,
1962 dialect: Dialect,
1963 mapper: Mapper[Any],
1964 *,
1965 is_multitable: bool = False,
1966 is_update_from: bool = False,
1967 is_delete_using: bool = False,
1968 is_executemany: bool = False,
1969 ) -> bool:
1970 # normal answer for "should we use RETURNING" at all.
1971 normal_answer = (
1972 dialect.delete_returning and mapper.local_table.implicit_returning
1973 )
1974 if not normal_answer:
1975 return False
1976
1977 # now get into special workarounds because MariaDB supports
1978 # DELETE...RETURNING but not DELETE...USING...RETURNING.
1979 if is_delete_using:
1980 # is_delete_using hint was passed. use
1981 # additional dialect feature (True for PG, False for MariaDB)
1982 return dialect.delete_returning_multifrom
1983
1984 elif is_multitable and not dialect.delete_returning_multifrom:
1985 # is_delete_using hint was not passed, but we determined
1986 # at compile time that this is in fact a DELETE..USING.
1987 # it's too late to continue since we did not pre-SELECT.
1988 # raise that we need that hint up front.
1989
1990 raise sa_exc.CompileError(
1991 f'Dialect "{dialect.name}" does not support RETURNING '
1992 "with DELETE..USING; for synchronize_session='fetch', "
1993 "please add the additional execution option "
1994 "'is_delete_using=True' to the statement to indicate that "
1995 "a separate SELECT should be used for this backend."
1996 )
1997
1998 return True
1999
2000 @classmethod
2001 def _do_post_synchronize_evaluate(
2002 cls, session, statement, result, update_options
2003 ):
2004 matched_objects = cls._get_matched_objects_on_criteria(
2005 update_options,
2006 session.identity_map.all_states(),
2007 )
2008
2009 to_delete = []
2010
2011 for _, state, dict_, is_partially_expired in matched_objects:
2012 if is_partially_expired:
2013 state._expire(dict_, session.identity_map._modified)
2014 else:
2015 to_delete.append(state)
2016
2017 if to_delete:
2018 session._remove_newly_deleted(to_delete)
2019
2020 @classmethod
2021 def _do_post_synchronize_fetch(
2022 cls, session, statement, result, update_options
2023 ):
2024 target_mapper = update_options._subject_mapper
2025
2026 returned_defaults_rows = result.returned_defaults_rows
2027
2028 if returned_defaults_rows:
2029 pk_rows = cls._interpret_returning_rows(
2030 target_mapper, returned_defaults_rows
2031 )
2032
2033 matched_rows = [
2034 tuple(row) + (update_options._identity_token,)
2035 for row in pk_rows
2036 ]
2037 else:
2038 matched_rows = update_options._matched_rows
2039
2040 for row in matched_rows:
2041 primary_key = row[0:-1]
2042 identity_token = row[-1]
2043
2044 # TODO: inline this and call remove_newly_deleted
2045 # once
2046 identity_key = target_mapper.identity_key_from_primary_key(
2047 list(primary_key),
2048 identity_token=identity_token,
2049 )
2050 if identity_key in session.identity_map:
2051 session._remove_newly_deleted(
2052 [
2053 attributes.instance_state(
2054 session.identity_map[identity_key]
2055 )
2056 ]
2057 )