Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/sqlalchemy/orm/persistence.py: 12%
857 statements
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
1# orm/persistence.py
2# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
8"""private module containing functions used to emit INSERT, UPDATE
9and DELETE statements on behalf of a :class:`_orm.Mapper` and its descending
10mappers.
12The functions here are called only by the unit of work functions
13in unitofwork.py.
15"""
17from itertools import chain
18from itertools import groupby
19import operator
21from . import attributes
22from . import evaluator
23from . import exc as orm_exc
24from . import loading
25from . import sync
26from .base import NO_VALUE
27from .base import state_str
28from .. import exc as sa_exc
29from .. import future
30from .. import sql
31from .. import util
32from ..engine import result as _result
33from ..sql import coercions
34from ..sql import expression
35from ..sql import operators
36from ..sql import roles
37from ..sql import select
38from ..sql import sqltypes
39from ..sql.base import _entity_namespace_key
40from ..sql.base import CompileState
41from ..sql.base import Options
42from ..sql.dml import DeleteDMLState
43from ..sql.dml import InsertDMLState
44from ..sql.dml import UpdateDMLState
45from ..sql.elements import BooleanClauseList
46from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
49def _bulk_insert(
50 mapper,
51 mappings,
52 session_transaction,
53 isstates,
54 return_defaults,
55 render_nulls,
56):
57 base_mapper = mapper.base_mapper
59 if session_transaction.session.connection_callable:
60 raise NotImplementedError(
61 "connection_callable / per-instance sharding "
62 "not supported in bulk_insert()"
63 )
65 if isstates:
66 if return_defaults:
67 states = [(state, state.dict) for state in mappings]
68 mappings = [dict_ for (state, dict_) in states]
69 else:
70 mappings = [state.dict for state in mappings]
71 else:
72 mappings = list(mappings)
74 connection = session_transaction.connection(base_mapper)
75 for table, super_mapper in base_mapper._sorted_tables.items():
76 if not mapper.isa(super_mapper):
77 continue
79 records = (
80 (
81 None,
82 state_dict,
83 params,
84 mapper,
85 connection,
86 value_params,
87 has_all_pks,
88 has_all_defaults,
89 )
90 for (
91 state,
92 state_dict,
93 params,
94 mp,
95 conn,
96 value_params,
97 has_all_pks,
98 has_all_defaults,
99 ) in _collect_insert_commands(
100 table,
101 ((None, mapping, mapper, connection) for mapping in mappings),
102 bulk=True,
103 return_defaults=return_defaults,
104 render_nulls=render_nulls,
105 )
106 )
107 _emit_insert_statements(
108 base_mapper,
109 None,
110 super_mapper,
111 table,
112 records,
113 bookkeeping=return_defaults,
114 )
116 if return_defaults and isstates:
117 identity_cls = mapper._identity_class
118 identity_props = [p.key for p in mapper._identity_key_props]
119 for state, dict_ in states:
120 state.key = (
121 identity_cls,
122 tuple([dict_[key] for key in identity_props]),
123 )
126def _bulk_update(
127 mapper, mappings, session_transaction, isstates, update_changed_only
128):
129 base_mapper = mapper.base_mapper
131 search_keys = mapper._primary_key_propkeys
132 if mapper._version_id_prop:
133 search_keys = {mapper._version_id_prop.key}.union(search_keys)
135 def _changed_dict(mapper, state):
136 return dict(
137 (k, v)
138 for k, v in state.dict.items()
139 if k in state.committed_state or k in search_keys
140 )
142 if isstates:
143 if update_changed_only:
144 mappings = [_changed_dict(mapper, state) for state in mappings]
145 else:
146 mappings = [state.dict for state in mappings]
147 else:
148 mappings = list(mappings)
150 if session_transaction.session.connection_callable:
151 raise NotImplementedError(
152 "connection_callable / per-instance sharding "
153 "not supported in bulk_update()"
154 )
156 connection = session_transaction.connection(base_mapper)
158 for table, super_mapper in base_mapper._sorted_tables.items():
159 if not mapper.isa(super_mapper):
160 continue
162 records = _collect_update_commands(
163 None,
164 table,
165 (
166 (
167 None,
168 mapping,
169 mapper,
170 connection,
171 (
172 mapping[mapper._version_id_prop.key]
173 if mapper._version_id_prop
174 else None
175 ),
176 )
177 for mapping in mappings
178 ),
179 bulk=True,
180 )
182 _emit_update_statements(
183 base_mapper,
184 None,
185 super_mapper,
186 table,
187 records,
188 bookkeeping=False,
189 )
192def save_obj(base_mapper, states, uowtransaction, single=False):
193 """Issue ``INSERT`` and/or ``UPDATE`` statements for a list
194 of objects.
196 This is called within the context of a UOWTransaction during a
197 flush operation, given a list of states to be flushed. The
198 base mapper in an inheritance hierarchy handles the inserts/
199 updates for all descendant mappers.
201 """
203 # if batch=false, call _save_obj separately for each object
204 if not single and not base_mapper.batch:
205 for state in _sort_states(base_mapper, states):
206 save_obj(base_mapper, [state], uowtransaction, single=True)
207 return
209 states_to_update = []
210 states_to_insert = []
212 for (
213 state,
214 dict_,
215 mapper,
216 connection,
217 has_identity,
218 row_switch,
219 update_version_id,
220 ) in _organize_states_for_save(base_mapper, states, uowtransaction):
221 if has_identity or row_switch:
222 states_to_update.append(
223 (state, dict_, mapper, connection, update_version_id)
224 )
225 else:
226 states_to_insert.append((state, dict_, mapper, connection))
228 for table, mapper in base_mapper._sorted_tables.items():
229 if table not in mapper._pks_by_table:
230 continue
231 insert = _collect_insert_commands(table, states_to_insert)
233 update = _collect_update_commands(
234 uowtransaction, table, states_to_update
235 )
237 _emit_update_statements(
238 base_mapper,
239 uowtransaction,
240 mapper,
241 table,
242 update,
243 )
245 _emit_insert_statements(
246 base_mapper,
247 uowtransaction,
248 mapper,
249 table,
250 insert,
251 )
253 _finalize_insert_update_commands(
254 base_mapper,
255 uowtransaction,
256 chain(
257 (
258 (state, state_dict, mapper, connection, False)
259 for (state, state_dict, mapper, connection) in states_to_insert
260 ),
261 (
262 (state, state_dict, mapper, connection, True)
263 for (
264 state,
265 state_dict,
266 mapper,
267 connection,
268 update_version_id,
269 ) in states_to_update
270 ),
271 ),
272 )
275def post_update(base_mapper, states, uowtransaction, post_update_cols):
276 """Issue UPDATE statements on behalf of a relationship() which
277 specifies post_update.
279 """
281 states_to_update = list(
282 _organize_states_for_post_update(base_mapper, states, uowtransaction)
283 )
285 for table, mapper in base_mapper._sorted_tables.items():
286 if table not in mapper._pks_by_table:
287 continue
289 update = (
290 (
291 state,
292 state_dict,
293 sub_mapper,
294 connection,
295 mapper._get_committed_state_attr_by_column(
296 state, state_dict, mapper.version_id_col
297 )
298 if mapper.version_id_col is not None
299 else None,
300 )
301 for state, state_dict, sub_mapper, connection in states_to_update
302 if table in sub_mapper._pks_by_table
303 )
305 update = _collect_post_update_commands(
306 base_mapper, uowtransaction, table, update, post_update_cols
307 )
309 _emit_post_update_statements(
310 base_mapper,
311 uowtransaction,
312 mapper,
313 table,
314 update,
315 )
318def delete_obj(base_mapper, states, uowtransaction):
319 """Issue ``DELETE`` statements for a list of objects.
321 This is called within the context of a UOWTransaction during a
322 flush operation.
324 """
326 states_to_delete = list(
327 _organize_states_for_delete(base_mapper, states, uowtransaction)
328 )
330 table_to_mapper = base_mapper._sorted_tables
332 for table in reversed(list(table_to_mapper.keys())):
333 mapper = table_to_mapper[table]
334 if table not in mapper._pks_by_table:
335 continue
336 elif mapper.inherits and mapper.passive_deletes:
337 continue
339 delete = _collect_delete_commands(
340 base_mapper, uowtransaction, table, states_to_delete
341 )
343 _emit_delete_statements(
344 base_mapper,
345 uowtransaction,
346 mapper,
347 table,
348 delete,
349 )
351 for (
352 state,
353 state_dict,
354 mapper,
355 connection,
356 update_version_id,
357 ) in states_to_delete:
358 mapper.dispatch.after_delete(mapper, connection, state)
361def _organize_states_for_save(base_mapper, states, uowtransaction):
362 """Make an initial pass across a set of states for INSERT or
363 UPDATE.
365 This includes splitting out into distinct lists for
366 each, calling before_insert/before_update, obtaining
367 key information for each state including its dictionary,
368 mapper, the connection to use for the execution per state,
369 and the identity flag.
371 """
373 for state, dict_, mapper, connection in _connections_for_states(
374 base_mapper, uowtransaction, states
375 ):
377 has_identity = bool(state.key)
379 instance_key = state.key or mapper._identity_key_from_state(state)
381 row_switch = update_version_id = None
383 # call before_XXX extensions
384 if not has_identity:
385 mapper.dispatch.before_insert(mapper, connection, state)
386 else:
387 mapper.dispatch.before_update(mapper, connection, state)
389 if mapper._validate_polymorphic_identity:
390 mapper._validate_polymorphic_identity(mapper, state, dict_)
392 # detect if we have a "pending" instance (i.e. has
393 # no instance_key attached to it), and another instance
394 # with the same identity key already exists as persistent.
395 # convert to an UPDATE if so.
396 if (
397 not has_identity
398 and instance_key in uowtransaction.session.identity_map
399 ):
400 instance = uowtransaction.session.identity_map[instance_key]
401 existing = attributes.instance_state(instance)
403 if not uowtransaction.was_already_deleted(existing):
404 if not uowtransaction.is_deleted(existing):
405 util.warn(
406 "New instance %s with identity key %s conflicts "
407 "with persistent instance %s"
408 % (state_str(state), instance_key, state_str(existing))
409 )
410 else:
411 base_mapper._log_debug(
412 "detected row switch for identity %s. "
413 "will update %s, remove %s from "
414 "transaction",
415 instance_key,
416 state_str(state),
417 state_str(existing),
418 )
420 # remove the "delete" flag from the existing element
421 uowtransaction.remove_state_actions(existing)
422 row_switch = existing
424 if (has_identity or row_switch) and mapper.version_id_col is not None:
425 update_version_id = mapper._get_committed_state_attr_by_column(
426 row_switch if row_switch else state,
427 row_switch.dict if row_switch else dict_,
428 mapper.version_id_col,
429 )
431 yield (
432 state,
433 dict_,
434 mapper,
435 connection,
436 has_identity,
437 row_switch,
438 update_version_id,
439 )
442def _organize_states_for_post_update(base_mapper, states, uowtransaction):
443 """Make an initial pass across a set of states for UPDATE
444 corresponding to post_update.
446 This includes obtaining key information for each state
447 including its dictionary, mapper, the connection to use for
448 the execution per state.
450 """
451 return _connections_for_states(base_mapper, uowtransaction, states)
454def _organize_states_for_delete(base_mapper, states, uowtransaction):
455 """Make an initial pass across a set of states for DELETE.
457 This includes calling out before_delete and obtaining
458 key information for each state including its dictionary,
459 mapper, the connection to use for the execution per state.
461 """
462 for state, dict_, mapper, connection in _connections_for_states(
463 base_mapper, uowtransaction, states
464 ):
466 mapper.dispatch.before_delete(mapper, connection, state)
468 if mapper.version_id_col is not None:
469 update_version_id = mapper._get_committed_state_attr_by_column(
470 state, dict_, mapper.version_id_col
471 )
472 else:
473 update_version_id = None
475 yield (state, dict_, mapper, connection, update_version_id)
478def _collect_insert_commands(
479 table,
480 states_to_insert,
481 bulk=False,
482 return_defaults=False,
483 render_nulls=False,
484):
485 """Identify sets of values to use in INSERT statements for a
486 list of states.
488 """
489 for state, state_dict, mapper, connection in states_to_insert:
490 if table not in mapper._pks_by_table:
491 continue
493 params = {}
494 value_params = {}
496 propkey_to_col = mapper._propkey_to_col[table]
498 eval_none = mapper._insert_cols_evaluating_none[table]
500 for propkey in set(propkey_to_col).intersection(state_dict):
501 value = state_dict[propkey]
502 col = propkey_to_col[propkey]
503 if value is None and col not in eval_none and not render_nulls:
504 continue
505 elif not bulk and (
506 hasattr(value, "__clause_element__")
507 or isinstance(value, sql.ClauseElement)
508 ):
509 value_params[col] = (
510 value.__clause_element__()
511 if hasattr(value, "__clause_element__")
512 else value
513 )
514 else:
515 params[col.key] = value
517 if not bulk:
518 # for all the columns that have no default and we don't have
519 # a value and where "None" is not a special value, add
520 # explicit None to the INSERT. This is a legacy behavior
521 # which might be worth removing, as it should not be necessary
522 # and also produces confusion, given that "missing" and None
523 # now have distinct meanings
524 for colkey in (
525 mapper._insert_cols_as_none[table]
526 .difference(params)
527 .difference([c.key for c in value_params])
528 ):
529 params[colkey] = None
531 if not bulk or return_defaults:
532 # params are in terms of Column key objects, so
533 # compare to pk_keys_by_table
534 has_all_pks = mapper._pk_keys_by_table[table].issubset(params)
536 if mapper.base_mapper.eager_defaults:
537 has_all_defaults = mapper._server_default_cols[table].issubset(
538 params
539 )
540 else:
541 has_all_defaults = True
542 else:
543 has_all_defaults = has_all_pks = True
545 if (
546 mapper.version_id_generator is not False
547 and mapper.version_id_col is not None
548 and mapper.version_id_col in mapper._cols_by_table[table]
549 ):
550 params[mapper.version_id_col.key] = mapper.version_id_generator(
551 None
552 )
554 yield (
555 state,
556 state_dict,
557 params,
558 mapper,
559 connection,
560 value_params,
561 has_all_pks,
562 has_all_defaults,
563 )
566def _collect_update_commands(
567 uowtransaction, table, states_to_update, bulk=False
568):
569 """Identify sets of values to use in UPDATE statements for a
570 list of states.
572 This function works intricately with the history system
573 to determine exactly what values should be updated
574 as well as how the row should be matched within an UPDATE
575 statement. Includes some tricky scenarios where the primary
576 key of an object might have been changed.
578 """
580 for (
581 state,
582 state_dict,
583 mapper,
584 connection,
585 update_version_id,
586 ) in states_to_update:
588 if table not in mapper._pks_by_table:
589 continue
591 pks = mapper._pks_by_table[table]
593 value_params = {}
595 propkey_to_col = mapper._propkey_to_col[table]
597 if bulk:
598 # keys here are mapped attribute keys, so
599 # look at mapper attribute keys for pk
600 params = dict(
601 (propkey_to_col[propkey].key, state_dict[propkey])
602 for propkey in set(propkey_to_col)
603 .intersection(state_dict)
604 .difference(mapper._pk_attr_keys_by_table[table])
605 )
606 has_all_defaults = True
607 else:
608 params = {}
609 for propkey in set(propkey_to_col).intersection(
610 state.committed_state
611 ):
612 value = state_dict[propkey]
613 col = propkey_to_col[propkey]
615 if hasattr(value, "__clause_element__") or isinstance(
616 value, sql.ClauseElement
617 ):
618 value_params[col] = (
619 value.__clause_element__()
620 if hasattr(value, "__clause_element__")
621 else value
622 )
623 # guard against values that generate non-__nonzero__
624 # objects for __eq__()
625 elif (
626 state.manager[propkey].impl.is_equal(
627 value, state.committed_state[propkey]
628 )
629 is not True
630 ):
631 params[col.key] = value
633 if mapper.base_mapper.eager_defaults:
634 has_all_defaults = (
635 mapper._server_onupdate_default_cols[table]
636 ).issubset(params)
637 else:
638 has_all_defaults = True
640 if (
641 update_version_id is not None
642 and mapper.version_id_col in mapper._cols_by_table[table]
643 ):
645 if not bulk and not (params or value_params):
646 # HACK: check for history in other tables, in case the
647 # history is only in a different table than the one
648 # where the version_id_col is. This logic was lost
649 # from 0.9 -> 1.0.0 and restored in 1.0.6.
650 for prop in mapper._columntoproperty.values():
651 history = state.manager[prop.key].impl.get_history(
652 state, state_dict, attributes.PASSIVE_NO_INITIALIZE
653 )
654 if history.added:
655 break
656 else:
657 # no net change, break
658 continue
660 col = mapper.version_id_col
661 no_params = not params and not value_params
662 params[col._label] = update_version_id
664 if (
665 bulk or col.key not in params
666 ) and mapper.version_id_generator is not False:
667 val = mapper.version_id_generator(update_version_id)
668 params[col.key] = val
669 elif mapper.version_id_generator is False and no_params:
670 # no version id generator, no values set on the table,
671 # and version id wasn't manually incremented.
672 # set version id to itself so we get an UPDATE
673 # statement
674 params[col.key] = update_version_id
676 elif not (params or value_params):
677 continue
679 has_all_pks = True
680 expect_pk_cascaded = False
681 if bulk:
682 # keys here are mapped attribute keys, so
683 # look at mapper attribute keys for pk
684 pk_params = dict(
685 (propkey_to_col[propkey]._label, state_dict.get(propkey))
686 for propkey in set(propkey_to_col).intersection(
687 mapper._pk_attr_keys_by_table[table]
688 )
689 )
690 else:
691 pk_params = {}
692 for col in pks:
693 propkey = mapper._columntoproperty[col].key
695 history = state.manager[propkey].impl.get_history(
696 state, state_dict, attributes.PASSIVE_OFF
697 )
699 if history.added:
700 if (
701 not history.deleted
702 or ("pk_cascaded", state, col)
703 in uowtransaction.attributes
704 ):
705 expect_pk_cascaded = True
706 pk_params[col._label] = history.added[0]
707 params.pop(col.key, None)
708 else:
709 # else, use the old value to locate the row
710 pk_params[col._label] = history.deleted[0]
711 if col in value_params:
712 has_all_pks = False
713 else:
714 pk_params[col._label] = history.unchanged[0]
715 if pk_params[col._label] is None:
716 raise orm_exc.FlushError(
717 "Can't update table %s using NULL for primary "
718 "key value on column %s" % (table, col)
719 )
721 if params or value_params:
722 params.update(pk_params)
723 yield (
724 state,
725 state_dict,
726 params,
727 mapper,
728 connection,
729 value_params,
730 has_all_defaults,
731 has_all_pks,
732 )
733 elif expect_pk_cascaded:
734 # no UPDATE occurs on this table, but we expect that CASCADE rules
735 # have changed the primary key of the row; propagate this event to
736 # other columns that expect to have been modified. this normally
737 # occurs after the UPDATE is emitted however we invoke it here
738 # explicitly in the absence of our invoking an UPDATE
739 for m, equated_pairs in mapper._table_to_equated[table]:
740 sync.populate(
741 state,
742 m,
743 state,
744 m,
745 equated_pairs,
746 uowtransaction,
747 mapper.passive_updates,
748 )
751def _collect_post_update_commands(
752 base_mapper, uowtransaction, table, states_to_update, post_update_cols
753):
754 """Identify sets of values to use in UPDATE statements for a
755 list of states within a post_update operation.
757 """
759 for (
760 state,
761 state_dict,
762 mapper,
763 connection,
764 update_version_id,
765 ) in states_to_update:
767 # assert table in mapper._pks_by_table
769 pks = mapper._pks_by_table[table]
770 params = {}
771 hasdata = False
773 for col in mapper._cols_by_table[table]:
774 if col in pks:
775 params[col._label] = mapper._get_state_attr_by_column(
776 state, state_dict, col, passive=attributes.PASSIVE_OFF
777 )
779 elif col in post_update_cols or col.onupdate is not None:
780 prop = mapper._columntoproperty[col]
781 history = state.manager[prop.key].impl.get_history(
782 state, state_dict, attributes.PASSIVE_NO_INITIALIZE
783 )
784 if history.added:
785 value = history.added[0]
786 params[col.key] = value
787 hasdata = True
788 if hasdata:
789 if (
790 update_version_id is not None
791 and mapper.version_id_col in mapper._cols_by_table[table]
792 ):
794 col = mapper.version_id_col
795 params[col._label] = update_version_id
797 if (
798 bool(state.key)
799 and col.key not in params
800 and mapper.version_id_generator is not False
801 ):
802 val = mapper.version_id_generator(update_version_id)
803 params[col.key] = val
804 yield state, state_dict, mapper, connection, params
807def _collect_delete_commands(
808 base_mapper, uowtransaction, table, states_to_delete
809):
810 """Identify values to use in DELETE statements for a list of
811 states to be deleted."""
813 for (
814 state,
815 state_dict,
816 mapper,
817 connection,
818 update_version_id,
819 ) in states_to_delete:
821 if table not in mapper._pks_by_table:
822 continue
824 params = {}
825 for col in mapper._pks_by_table[table]:
826 params[
827 col.key
828 ] = value = mapper._get_committed_state_attr_by_column(
829 state, state_dict, col
830 )
831 if value is None:
832 raise orm_exc.FlushError(
833 "Can't delete from table %s "
834 "using NULL for primary "
835 "key value on column %s" % (table, col)
836 )
838 if (
839 update_version_id is not None
840 and mapper.version_id_col in mapper._cols_by_table[table]
841 ):
842 params[mapper.version_id_col.key] = update_version_id
843 yield params, connection
846def _emit_update_statements(
847 base_mapper,
848 uowtransaction,
849 mapper,
850 table,
851 update,
852 bookkeeping=True,
853):
854 """Emit UPDATE statements corresponding to value lists collected
855 by _collect_update_commands()."""
857 needs_version_id = (
858 mapper.version_id_col is not None
859 and mapper.version_id_col in mapper._cols_by_table[table]
860 )
862 execution_options = {"compiled_cache": base_mapper._compiled_cache}
864 def update_stmt():
865 clauses = BooleanClauseList._construct_raw(operators.and_)
867 for col in mapper._pks_by_table[table]:
868 clauses.clauses.append(
869 col == sql.bindparam(col._label, type_=col.type)
870 )
872 if needs_version_id:
873 clauses.clauses.append(
874 mapper.version_id_col
875 == sql.bindparam(
876 mapper.version_id_col._label,
877 type_=mapper.version_id_col.type,
878 )
879 )
881 stmt = table.update().where(clauses)
882 return stmt
884 cached_stmt = base_mapper._memo(("update", table), update_stmt)
886 for (
887 (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks),
888 records,
889 ) in groupby(
890 update,
891 lambda rec: (
892 rec[4], # connection
893 set(rec[2]), # set of parameter keys
894 bool(rec[5]), # whether or not we have "value" parameters
895 rec[6], # has_all_defaults
896 rec[7], # has all pks
897 ),
898 ):
899 rows = 0
900 records = list(records)
902 statement = cached_stmt
903 return_defaults = False
905 if not has_all_pks:
906 statement = statement.return_defaults()
907 return_defaults = True
908 elif (
909 bookkeeping
910 and not has_all_defaults
911 and mapper.base_mapper.eager_defaults
912 ):
913 statement = statement.return_defaults()
914 return_defaults = True
915 elif mapper.version_id_col is not None:
916 statement = statement.return_defaults(mapper.version_id_col)
917 return_defaults = True
919 assert_singlerow = (
920 connection.dialect.supports_sane_rowcount
921 if not return_defaults
922 else connection.dialect.supports_sane_rowcount_returning
923 )
925 assert_multirow = (
926 assert_singlerow
927 and connection.dialect.supports_sane_multi_rowcount
928 )
929 allow_multirow = has_all_defaults and not needs_version_id
931 if hasvalue:
932 for (
933 state,
934 state_dict,
935 params,
936 mapper,
937 connection,
938 value_params,
939 has_all_defaults,
940 has_all_pks,
941 ) in records:
942 c = connection._execute_20(
943 statement.values(value_params),
944 params,
945 execution_options=execution_options,
946 )
947 if bookkeeping:
948 _postfetch(
949 mapper,
950 uowtransaction,
951 table,
952 state,
953 state_dict,
954 c,
955 c.context.compiled_parameters[0],
956 value_params,
957 True,
958 c.returned_defaults,
959 )
960 rows += c.rowcount
961 check_rowcount = assert_singlerow
962 else:
963 if not allow_multirow:
964 check_rowcount = assert_singlerow
965 for (
966 state,
967 state_dict,
968 params,
969 mapper,
970 connection,
971 value_params,
972 has_all_defaults,
973 has_all_pks,
974 ) in records:
975 c = connection._execute_20(
976 statement, params, execution_options=execution_options
977 )
979 # TODO: why with bookkeeping=False?
980 if bookkeeping:
981 _postfetch(
982 mapper,
983 uowtransaction,
984 table,
985 state,
986 state_dict,
987 c,
988 c.context.compiled_parameters[0],
989 value_params,
990 True,
991 c.returned_defaults,
992 )
993 rows += c.rowcount
994 else:
995 multiparams = [rec[2] for rec in records]
997 check_rowcount = assert_multirow or (
998 assert_singlerow and len(multiparams) == 1
999 )
1001 c = connection._execute_20(
1002 statement, multiparams, execution_options=execution_options
1003 )
1005 rows += c.rowcount
1007 for (
1008 state,
1009 state_dict,
1010 params,
1011 mapper,
1012 connection,
1013 value_params,
1014 has_all_defaults,
1015 has_all_pks,
1016 ) in records:
1017 if bookkeeping:
1018 _postfetch(
1019 mapper,
1020 uowtransaction,
1021 table,
1022 state,
1023 state_dict,
1024 c,
1025 c.context.compiled_parameters[0],
1026 value_params,
1027 True,
1028 c.returned_defaults
1029 if not c.context.executemany
1030 else None,
1031 )
1033 if check_rowcount:
1034 if rows != len(records):
1035 raise orm_exc.StaleDataError(
1036 "UPDATE statement on table '%s' expected to "
1037 "update %d row(s); %d were matched."
1038 % (table.description, len(records), rows)
1039 )
1041 elif needs_version_id:
1042 util.warn(
1043 "Dialect %s does not support updated rowcount "
1044 "- versioning cannot be verified."
1045 % c.dialect.dialect_description
1046 )
1049def _emit_insert_statements(
1050 base_mapper,
1051 uowtransaction,
1052 mapper,
1053 table,
1054 insert,
1055 bookkeeping=True,
1056):
1057 """Emit INSERT statements corresponding to value lists collected
1058 by _collect_insert_commands()."""
1060 cached_stmt = base_mapper._memo(("insert", table), table.insert)
1062 execution_options = {"compiled_cache": base_mapper._compiled_cache}
1064 for (
1065 (connection, pkeys, hasvalue, has_all_pks, has_all_defaults),
1066 records,
1067 ) in groupby(
1068 insert,
1069 lambda rec: (
1070 rec[4], # connection
1071 set(rec[2]), # parameter keys
1072 bool(rec[5]), # whether we have "value" parameters
1073 rec[6],
1074 rec[7],
1075 ),
1076 ):
1078 statement = cached_stmt
1080 if (
1081 not bookkeeping
1082 or (
1083 has_all_defaults
1084 or not base_mapper.eager_defaults
1085 or not connection.dialect.implicit_returning
1086 )
1087 and has_all_pks
1088 and not hasvalue
1089 ):
1090 # the "we don't need newly generated values back" section.
1091 # here we have all the PKs, all the defaults or we don't want
1092 # to fetch them, or the dialect doesn't support RETURNING at all
1093 # so we have to post-fetch / use lastrowid anyway.
1094 records = list(records)
1095 multiparams = [rec[2] for rec in records]
1097 c = connection._execute_20(
1098 statement, multiparams, execution_options=execution_options
1099 )
1101 if bookkeeping:
1102 for (
1103 (
1104 state,
1105 state_dict,
1106 params,
1107 mapper_rec,
1108 conn,
1109 value_params,
1110 has_all_pks,
1111 has_all_defaults,
1112 ),
1113 last_inserted_params,
1114 ) in zip(records, c.context.compiled_parameters):
1115 if state:
1116 _postfetch(
1117 mapper_rec,
1118 uowtransaction,
1119 table,
1120 state,
1121 state_dict,
1122 c,
1123 last_inserted_params,
1124 value_params,
1125 False,
1126 c.returned_defaults
1127 if not c.context.executemany
1128 else None,
1129 )
1130 else:
1131 _postfetch_bulk_save(mapper_rec, state_dict, table)
1133 else:
1134 # here, we need defaults and/or pk values back.
1136 records = list(records)
1137 if (
1138 not hasvalue
1139 and connection.dialect.insert_executemany_returning
1140 and len(records) > 1
1141 ):
1142 do_executemany = True
1143 else:
1144 do_executemany = False
1146 if not has_all_defaults and base_mapper.eager_defaults:
1147 statement = statement.return_defaults()
1148 elif mapper.version_id_col is not None:
1149 statement = statement.return_defaults(mapper.version_id_col)
1150 elif do_executemany:
1151 statement = statement.return_defaults(*table.primary_key)
1153 if do_executemany:
1154 multiparams = [rec[2] for rec in records]
1156 c = connection._execute_20(
1157 statement, multiparams, execution_options=execution_options
1158 )
1160 if bookkeeping:
1161 for (
1162 (
1163 state,
1164 state_dict,
1165 params,
1166 mapper_rec,
1167 conn,
1168 value_params,
1169 has_all_pks,
1170 has_all_defaults,
1171 ),
1172 last_inserted_params,
1173 inserted_primary_key,
1174 returned_defaults,
1175 ) in util.zip_longest(
1176 records,
1177 c.context.compiled_parameters,
1178 c.inserted_primary_key_rows,
1179 c.returned_defaults_rows or (),
1180 ):
1181 if inserted_primary_key is None:
1182 # this is a real problem and means that we didn't
1183 # get back as many PK rows. we can't continue
1184 # since this indicates PK rows were missing, which
1185 # means we likely mis-populated records starting
1186 # at that point with incorrectly matched PK
1187 # values.
1188 raise orm_exc.FlushError(
1189 "Multi-row INSERT statement for %s did not "
1190 "produce "
1191 "the correct number of INSERTed rows for "
1192 "RETURNING. Ensure there are no triggers or "
1193 "special driver issues preventing INSERT from "
1194 "functioning properly." % mapper_rec
1195 )
1197 for pk, col in zip(
1198 inserted_primary_key,
1199 mapper._pks_by_table[table],
1200 ):
1201 prop = mapper_rec._columntoproperty[col]
1202 if state_dict.get(prop.key) is None:
1203 state_dict[prop.key] = pk
1205 if state:
1206 _postfetch(
1207 mapper_rec,
1208 uowtransaction,
1209 table,
1210 state,
1211 state_dict,
1212 c,
1213 last_inserted_params,
1214 value_params,
1215 False,
1216 returned_defaults,
1217 )
1218 else:
1219 _postfetch_bulk_save(mapper_rec, state_dict, table)
1220 else:
1221 for (
1222 state,
1223 state_dict,
1224 params,
1225 mapper_rec,
1226 connection,
1227 value_params,
1228 has_all_pks,
1229 has_all_defaults,
1230 ) in records:
1231 if value_params:
1232 result = connection._execute_20(
1233 statement.values(value_params),
1234 params,
1235 execution_options=execution_options,
1236 )
1237 else:
1238 result = connection._execute_20(
1239 statement,
1240 params,
1241 execution_options=execution_options,
1242 )
1244 primary_key = result.inserted_primary_key
1245 if primary_key is None:
1246 raise orm_exc.FlushError(
1247 "Single-row INSERT statement for %s "
1248 "did not produce a "
1249 "new primary key result "
1250 "being invoked. Ensure there are no triggers or "
1251 "special driver issues preventing INSERT from "
1252 "functioning properly." % (mapper_rec,)
1253 )
1254 for pk, col in zip(
1255 primary_key, mapper._pks_by_table[table]
1256 ):
1257 prop = mapper_rec._columntoproperty[col]
1258 if (
1259 col in value_params
1260 or state_dict.get(prop.key) is None
1261 ):
1262 state_dict[prop.key] = pk
1263 if bookkeeping:
1264 if state:
1265 _postfetch(
1266 mapper_rec,
1267 uowtransaction,
1268 table,
1269 state,
1270 state_dict,
1271 result,
1272 result.context.compiled_parameters[0],
1273 value_params,
1274 False,
1275 result.returned_defaults
1276 if not result.context.executemany
1277 else None,
1278 )
1279 else:
1280 _postfetch_bulk_save(mapper_rec, state_dict, table)
1283def _emit_post_update_statements(
1284 base_mapper, uowtransaction, mapper, table, update
1285):
1286 """Emit UPDATE statements corresponding to value lists collected
1287 by _collect_post_update_commands()."""
1289 execution_options = {"compiled_cache": base_mapper._compiled_cache}
1291 needs_version_id = (
1292 mapper.version_id_col is not None
1293 and mapper.version_id_col in mapper._cols_by_table[table]
1294 )
1296 def update_stmt():
1297 clauses = BooleanClauseList._construct_raw(operators.and_)
1299 for col in mapper._pks_by_table[table]:
1300 clauses.clauses.append(
1301 col == sql.bindparam(col._label, type_=col.type)
1302 )
1304 if needs_version_id:
1305 clauses.clauses.append(
1306 mapper.version_id_col
1307 == sql.bindparam(
1308 mapper.version_id_col._label,
1309 type_=mapper.version_id_col.type,
1310 )
1311 )
1313 stmt = table.update().where(clauses)
1315 if mapper.version_id_col is not None:
1316 stmt = stmt.return_defaults(mapper.version_id_col)
1318 return stmt
1320 statement = base_mapper._memo(("post_update", table), update_stmt)
1322 # execute each UPDATE in the order according to the original
1323 # list of states to guarantee row access order, but
1324 # also group them into common (connection, cols) sets
1325 # to support executemany().
1326 for key, records in groupby(
1327 update,
1328 lambda rec: (rec[3], set(rec[4])), # connection # parameter keys
1329 ):
1330 rows = 0
1332 records = list(records)
1333 connection = key[0]
1335 assert_singlerow = (
1336 connection.dialect.supports_sane_rowcount
1337 if mapper.version_id_col is None
1338 else connection.dialect.supports_sane_rowcount_returning
1339 )
1340 assert_multirow = (
1341 assert_singlerow
1342 and connection.dialect.supports_sane_multi_rowcount
1343 )
1344 allow_multirow = not needs_version_id or assert_multirow
1346 if not allow_multirow:
1347 check_rowcount = assert_singlerow
1348 for state, state_dict, mapper_rec, connection, params in records:
1350 c = connection._execute_20(
1351 statement, params, execution_options=execution_options
1352 )
1354 _postfetch_post_update(
1355 mapper_rec,
1356 uowtransaction,
1357 table,
1358 state,
1359 state_dict,
1360 c,
1361 c.context.compiled_parameters[0],
1362 )
1363 rows += c.rowcount
1364 else:
1365 multiparams = [
1366 params
1367 for state, state_dict, mapper_rec, conn, params in records
1368 ]
1370 check_rowcount = assert_multirow or (
1371 assert_singlerow and len(multiparams) == 1
1372 )
1374 c = connection._execute_20(
1375 statement, multiparams, execution_options=execution_options
1376 )
1378 rows += c.rowcount
1379 for state, state_dict, mapper_rec, connection, params in records:
1380 _postfetch_post_update(
1381 mapper_rec,
1382 uowtransaction,
1383 table,
1384 state,
1385 state_dict,
1386 c,
1387 c.context.compiled_parameters[0],
1388 )
1390 if check_rowcount:
1391 if rows != len(records):
1392 raise orm_exc.StaleDataError(
1393 "UPDATE statement on table '%s' expected to "
1394 "update %d row(s); %d were matched."
1395 % (table.description, len(records), rows)
1396 )
1398 elif needs_version_id:
1399 util.warn(
1400 "Dialect %s does not support updated rowcount "
1401 "- versioning cannot be verified."
1402 % c.dialect.dialect_description
1403 )
1406def _emit_delete_statements(
1407 base_mapper, uowtransaction, mapper, table, delete
1408):
1409 """Emit DELETE statements corresponding to value lists collected
1410 by _collect_delete_commands()."""
1412 need_version_id = (
1413 mapper.version_id_col is not None
1414 and mapper.version_id_col in mapper._cols_by_table[table]
1415 )
1417 def delete_stmt():
1418 clauses = BooleanClauseList._construct_raw(operators.and_)
1420 for col in mapper._pks_by_table[table]:
1421 clauses.clauses.append(
1422 col == sql.bindparam(col.key, type_=col.type)
1423 )
1425 if need_version_id:
1426 clauses.clauses.append(
1427 mapper.version_id_col
1428 == sql.bindparam(
1429 mapper.version_id_col.key, type_=mapper.version_id_col.type
1430 )
1431 )
1433 return table.delete().where(clauses)
1435 statement = base_mapper._memo(("delete", table), delete_stmt)
1436 for connection, recs in groupby(delete, lambda rec: rec[1]): # connection
1437 del_objects = [params for params, connection in recs]
1439 execution_options = {"compiled_cache": base_mapper._compiled_cache}
1440 expected = len(del_objects)
1441 rows_matched = -1
1442 only_warn = False
1444 if (
1445 need_version_id
1446 and not connection.dialect.supports_sane_multi_rowcount
1447 ):
1448 if connection.dialect.supports_sane_rowcount:
1449 rows_matched = 0
1450 # execute deletes individually so that versioned
1451 # rows can be verified
1452 for params in del_objects:
1454 c = connection._execute_20(
1455 statement, params, execution_options=execution_options
1456 )
1457 rows_matched += c.rowcount
1458 else:
1459 util.warn(
1460 "Dialect %s does not support deleted rowcount "
1461 "- versioning cannot be verified."
1462 % connection.dialect.dialect_description
1463 )
1464 connection._execute_20(
1465 statement, del_objects, execution_options=execution_options
1466 )
1467 else:
1468 c = connection._execute_20(
1469 statement, del_objects, execution_options=execution_options
1470 )
1472 if not need_version_id:
1473 only_warn = True
1475 rows_matched = c.rowcount
1477 if (
1478 base_mapper.confirm_deleted_rows
1479 and rows_matched > -1
1480 and expected != rows_matched
1481 and (
1482 connection.dialect.supports_sane_multi_rowcount
1483 or len(del_objects) == 1
1484 )
1485 ):
1486 # TODO: why does this "only warn" if versioning is turned off,
1487 # whereas the UPDATE raises?
1488 if only_warn:
1489 util.warn(
1490 "DELETE statement on table '%s' expected to "
1491 "delete %d row(s); %d were matched. Please set "
1492 "confirm_deleted_rows=False within the mapper "
1493 "configuration to prevent this warning."
1494 % (table.description, expected, rows_matched)
1495 )
1496 else:
1497 raise orm_exc.StaleDataError(
1498 "DELETE statement on table '%s' expected to "
1499 "delete %d row(s); %d were matched. Please set "
1500 "confirm_deleted_rows=False within the mapper "
1501 "configuration to prevent this warning."
1502 % (table.description, expected, rows_matched)
1503 )
1506def _finalize_insert_update_commands(base_mapper, uowtransaction, states):
1507 """finalize state on states that have been inserted or updated,
1508 including calling after_insert/after_update events.
1510 """
1511 for state, state_dict, mapper, connection, has_identity in states:
1513 if mapper._readonly_props:
1514 readonly = state.unmodified_intersection(
1515 [
1516 p.key
1517 for p in mapper._readonly_props
1518 if (
1519 p.expire_on_flush
1520 and (not p.deferred or p.key in state.dict)
1521 )
1522 or (
1523 not p.expire_on_flush
1524 and not p.deferred
1525 and p.key not in state.dict
1526 )
1527 ]
1528 )
1529 if readonly:
1530 state._expire_attributes(state.dict, readonly)
1532 # if eager_defaults option is enabled, load
1533 # all expired cols. Else if we have a version_id_col, make sure
1534 # it isn't expired.
1535 toload_now = []
1537 if base_mapper.eager_defaults:
1538 toload_now.extend(
1539 state._unloaded_non_object.intersection(
1540 mapper._server_default_plus_onupdate_propkeys
1541 )
1542 )
1544 if (
1545 mapper.version_id_col is not None
1546 and mapper.version_id_generator is False
1547 ):
1548 if mapper._version_id_prop.key in state.unloaded:
1549 toload_now.extend([mapper._version_id_prop.key])
1551 if toload_now:
1552 state.key = base_mapper._identity_key_from_state(state)
1553 stmt = future.select(mapper).set_label_style(
1554 LABEL_STYLE_TABLENAME_PLUS_COL
1555 )
1556 loading.load_on_ident(
1557 uowtransaction.session,
1558 stmt,
1559 state.key,
1560 refresh_state=state,
1561 only_load_props=toload_now,
1562 )
1564 # call after_XXX extensions
1565 if not has_identity:
1566 mapper.dispatch.after_insert(mapper, connection, state)
1567 else:
1568 mapper.dispatch.after_update(mapper, connection, state)
1570 if (
1571 mapper.version_id_generator is False
1572 and mapper.version_id_col is not None
1573 ):
1574 if state_dict[mapper._version_id_prop.key] is None:
1575 raise orm_exc.FlushError(
1576 "Instance does not contain a non-NULL version value"
1577 )
1580def _postfetch_post_update(
1581 mapper, uowtransaction, table, state, dict_, result, params
1582):
1583 if uowtransaction.is_deleted(state):
1584 return
1586 prefetch_cols = result.context.compiled.prefetch
1587 postfetch_cols = result.context.compiled.postfetch
1589 if (
1590 mapper.version_id_col is not None
1591 and mapper.version_id_col in mapper._cols_by_table[table]
1592 ):
1593 prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
1595 refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
1596 if refresh_flush:
1597 load_evt_attrs = []
1599 for c in prefetch_cols:
1600 if c.key in params and c in mapper._columntoproperty:
1601 dict_[mapper._columntoproperty[c].key] = params[c.key]
1602 if refresh_flush:
1603 load_evt_attrs.append(mapper._columntoproperty[c].key)
1605 if refresh_flush and load_evt_attrs:
1606 mapper.class_manager.dispatch.refresh_flush(
1607 state, uowtransaction, load_evt_attrs
1608 )
1610 if postfetch_cols:
1611 state._expire_attributes(
1612 state.dict,
1613 [
1614 mapper._columntoproperty[c].key
1615 for c in postfetch_cols
1616 if c in mapper._columntoproperty
1617 ],
1618 )
1621def _postfetch(
1622 mapper,
1623 uowtransaction,
1624 table,
1625 state,
1626 dict_,
1627 result,
1628 params,
1629 value_params,
1630 isupdate,
1631 returned_defaults,
1632):
1633 """Expire attributes in need of newly persisted database state,
1634 after an INSERT or UPDATE statement has proceeded for that
1635 state."""
1637 prefetch_cols = result.context.compiled.prefetch
1638 postfetch_cols = result.context.compiled.postfetch
1639 returning_cols = result.context.compiled.returning
1641 if (
1642 mapper.version_id_col is not None
1643 and mapper.version_id_col in mapper._cols_by_table[table]
1644 ):
1645 prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
1647 refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
1648 if refresh_flush:
1649 load_evt_attrs = []
1651 if returning_cols:
1652 row = returned_defaults
1653 if row is not None:
1654 for row_value, col in zip(row, returning_cols):
1655 # pk cols returned from insert are handled
1656 # distinctly, don't step on the values here
1657 if col.primary_key and result.context.isinsert:
1658 continue
1660 # note that columns can be in the "return defaults" that are
1661 # not mapped to this mapper, typically because they are
1662 # "excluded", which can be specified directly or also occurs
1663 # when using declarative w/ single table inheritance
1664 prop = mapper._columntoproperty.get(col)
1665 if prop:
1666 dict_[prop.key] = row_value
1667 if refresh_flush:
1668 load_evt_attrs.append(prop.key)
1670 for c in prefetch_cols:
1671 if c.key in params and c in mapper._columntoproperty:
1672 dict_[mapper._columntoproperty[c].key] = params[c.key]
1673 if refresh_flush:
1674 load_evt_attrs.append(mapper._columntoproperty[c].key)
1676 if refresh_flush and load_evt_attrs:
1677 mapper.class_manager.dispatch.refresh_flush(
1678 state, uowtransaction, load_evt_attrs
1679 )
1681 if isupdate and value_params:
1682 # explicitly suit the use case specified by
1683 # [ticket:3801], PK SQL expressions for UPDATE on non-RETURNING
1684 # database which are set to themselves in order to do a version bump.
1685 postfetch_cols.extend(
1686 [
1687 col
1688 for col in value_params
1689 if col.primary_key and col not in returning_cols
1690 ]
1691 )
1693 if postfetch_cols:
1694 state._expire_attributes(
1695 state.dict,
1696 [
1697 mapper._columntoproperty[c].key
1698 for c in postfetch_cols
1699 if c in mapper._columntoproperty
1700 ],
1701 )
1703 # synchronize newly inserted ids from one table to the next
1704 # TODO: this still goes a little too often. would be nice to
1705 # have definitive list of "columns that changed" here
1706 for m, equated_pairs in mapper._table_to_equated[table]:
1707 sync.populate(
1708 state,
1709 m,
1710 state,
1711 m,
1712 equated_pairs,
1713 uowtransaction,
1714 mapper.passive_updates,
1715 )
1718def _postfetch_bulk_save(mapper, dict_, table):
1719 for m, equated_pairs in mapper._table_to_equated[table]:
1720 sync.bulk_populate_inherit_keys(dict_, m, equated_pairs)
1723def _connections_for_states(base_mapper, uowtransaction, states):
1724 """Return an iterator of (state, state.dict, mapper, connection).
1726 The states are sorted according to _sort_states, then paired
1727 with the connection they should be using for the given
1728 unit of work transaction.
1730 """
1731 # if session has a connection callable,
1732 # organize individual states with the connection
1733 # to use for update
1734 if uowtransaction.session.connection_callable:
1735 connection_callable = uowtransaction.session.connection_callable
1736 else:
1737 connection = uowtransaction.transaction.connection(base_mapper)
1738 connection_callable = None
1740 for state in _sort_states(base_mapper, states):
1741 if connection_callable:
1742 connection = connection_callable(base_mapper, state.obj())
1744 mapper = state.manager.mapper
1746 yield state, state.dict, mapper, connection
1749def _sort_states(mapper, states):
1750 pending = set(states)
1751 persistent = set(s for s in pending if s.key is not None)
1752 pending.difference_update(persistent)
1754 try:
1755 persistent_sorted = sorted(
1756 persistent, key=mapper._persistent_sortkey_fn
1757 )
1758 except TypeError as err:
1759 util.raise_(
1760 sa_exc.InvalidRequestError(
1761 "Could not sort objects by primary key; primary key "
1762 "values must be sortable in Python (was: %s)" % err
1763 ),
1764 replace_context=err,
1765 )
1766 return (
1767 sorted(pending, key=operator.attrgetter("insert_order"))
1768 + persistent_sorted
1769 )
1772_EMPTY_DICT = util.immutabledict()
1775class BulkUDCompileState(CompileState):
1776 class default_update_options(Options):
1777 _synchronize_session = "evaluate"
1778 _autoflush = True
1779 _subject_mapper = None
1780 _resolved_values = _EMPTY_DICT
1781 _resolved_keys_as_propnames = _EMPTY_DICT
1782 _value_evaluators = _EMPTY_DICT
1783 _matched_objects = None
1784 _matched_rows = None
1785 _refresh_identity_token = None
1787 @classmethod
1788 def orm_pre_session_exec(
1789 cls,
1790 session,
1791 statement,
1792 params,
1793 execution_options,
1794 bind_arguments,
1795 is_reentrant_invoke,
1796 ):
1797 if is_reentrant_invoke:
1798 return statement, execution_options
1800 (
1801 update_options,
1802 execution_options,
1803 ) = BulkUDCompileState.default_update_options.from_execution_options(
1804 "_sa_orm_update_options",
1805 {"synchronize_session"},
1806 execution_options,
1807 statement._execution_options,
1808 )
1810 sync = update_options._synchronize_session
1811 if sync is not None:
1812 if sync not in ("evaluate", "fetch", False):
1813 raise sa_exc.ArgumentError(
1814 "Valid strategies for session synchronization "
1815 "are 'evaluate', 'fetch', False"
1816 )
1818 bind_arguments["clause"] = statement
1819 try:
1820 plugin_subject = statement._propagate_attrs["plugin_subject"]
1821 except KeyError:
1822 assert False, "statement had 'orm' plugin but no plugin_subject"
1823 else:
1824 bind_arguments["mapper"] = plugin_subject.mapper
1826 update_options += {"_subject_mapper": plugin_subject.mapper}
1828 if update_options._autoflush:
1829 session._autoflush()
1831 statement = statement._annotate(
1832 {"synchronize_session": update_options._synchronize_session}
1833 )
1835 # this stage of the execution is called before the do_orm_execute event
1836 # hook. meaning for an extension like horizontal sharding, this step
1837 # happens before the extension splits out into multiple backends and
1838 # runs only once. if we do pre_sync_fetch, we execute a SELECT
1839 # statement, which the horizontal sharding extension splits amongst the
1840 # shards and combines the results together.
1842 if update_options._synchronize_session == "evaluate":
1843 update_options = cls._do_pre_synchronize_evaluate(
1844 session,
1845 statement,
1846 params,
1847 execution_options,
1848 bind_arguments,
1849 update_options,
1850 )
1851 elif update_options._synchronize_session == "fetch":
1852 update_options = cls._do_pre_synchronize_fetch(
1853 session,
1854 statement,
1855 params,
1856 execution_options,
1857 bind_arguments,
1858 update_options,
1859 )
1861 return (
1862 statement,
1863 util.immutabledict(execution_options).union(
1864 {"_sa_orm_update_options": update_options}
1865 ),
1866 )
1868 @classmethod
1869 def orm_setup_cursor_result(
1870 cls,
1871 session,
1872 statement,
1873 params,
1874 execution_options,
1875 bind_arguments,
1876 result,
1877 ):
1879 # this stage of the execution is called after the
1880 # do_orm_execute event hook. meaning for an extension like
1881 # horizontal sharding, this step happens *within* the horizontal
1882 # sharding event handler which calls session.execute() re-entrantly
1883 # and will occur for each backend individually.
1884 # the sharding extension then returns its own merged result from the
1885 # individual ones we return here.
1887 update_options = execution_options["_sa_orm_update_options"]
1888 if update_options._synchronize_session == "evaluate":
1889 cls._do_post_synchronize_evaluate(session, result, update_options)
1890 elif update_options._synchronize_session == "fetch":
1891 cls._do_post_synchronize_fetch(session, result, update_options)
1893 return result
1895 @classmethod
1896 def _adjust_for_extra_criteria(cls, global_attributes, ext_info):
1897 """Apply extra criteria filtering.
1899 For all distinct single-table-inheritance mappers represented in the
1900 table being updated or deleted, produce additional WHERE criteria such
1901 that only the appropriate subtypes are selected from the total results.
1903 Additionally, add WHERE criteria originating from LoaderCriteriaOptions
1904 collected from the statement.
1906 """
1908 return_crit = ()
1910 adapter = ext_info._adapter if ext_info.is_aliased_class else None
1912 if (
1913 "additional_entity_criteria",
1914 ext_info.mapper,
1915 ) in global_attributes:
1916 return_crit += tuple(
1917 ae._resolve_where_criteria(ext_info)
1918 for ae in global_attributes[
1919 ("additional_entity_criteria", ext_info.mapper)
1920 ]
1921 if ae.include_aliases or ae.entity is ext_info
1922 )
1924 if ext_info.mapper._single_table_criterion is not None:
1925 return_crit += (ext_info.mapper._single_table_criterion,)
1927 if adapter:
1928 return_crit = tuple(adapter.traverse(crit) for crit in return_crit)
1930 return return_crit
1932 @classmethod
1933 def _do_pre_synchronize_evaluate(
1934 cls,
1935 session,
1936 statement,
1937 params,
1938 execution_options,
1939 bind_arguments,
1940 update_options,
1941 ):
1942 mapper = update_options._subject_mapper
1943 target_cls = mapper.class_
1945 value_evaluators = resolved_keys_as_propnames = _EMPTY_DICT
1947 try:
1948 evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
1949 crit = ()
1950 if statement._where_criteria:
1951 crit += statement._where_criteria
1953 global_attributes = {}
1954 for opt in statement._with_options:
1955 if opt._is_criteria_option:
1956 opt.get_global_criteria(global_attributes)
1958 if global_attributes:
1959 crit += cls._adjust_for_extra_criteria(
1960 global_attributes, mapper
1961 )
1963 if crit:
1964 eval_condition = evaluator_compiler.process(*crit)
1965 else:
1967 def eval_condition(obj):
1968 return True
1970 except evaluator.UnevaluatableError as err:
1971 util.raise_(
1972 sa_exc.InvalidRequestError(
1973 'Could not evaluate current criteria in Python: "%s". '
1974 "Specify 'fetch' or False for the "
1975 "synchronize_session execution option." % err
1976 ),
1977 from_=err,
1978 )
1980 if statement.__visit_name__ == "lambda_element":
1981 # ._resolved is called on every LambdaElement in order to
1982 # generate the cache key, so this access does not add
1983 # additional expense
1984 effective_statement = statement._resolved
1985 else:
1986 effective_statement = statement
1988 if effective_statement.__visit_name__ == "update":
1989 resolved_values = cls._get_resolved_values(
1990 mapper, effective_statement
1991 )
1992 value_evaluators = {}
1993 resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
1994 mapper, resolved_values
1995 )
1996 for key, value in resolved_keys_as_propnames:
1997 try:
1998 _evaluator = evaluator_compiler.process(
1999 coercions.expect(roles.ExpressionElementRole, value)
2000 )
2001 except evaluator.UnevaluatableError:
2002 pass
2003 else:
2004 value_evaluators[key] = _evaluator
2006 # TODO: detect when the where clause is a trivial primary key match.
2007 matched_objects = [
2008 state.obj()
2009 for state in session.identity_map.all_states()
2010 if state.mapper.isa(mapper)
2011 and not state.expired
2012 and eval_condition(state.obj())
2013 and (
2014 update_options._refresh_identity_token is None
2015 # TODO: coverage for the case where horizontal sharding
2016 # invokes an update() or delete() given an explicit identity
2017 # token up front
2018 or state.identity_token
2019 == update_options._refresh_identity_token
2020 )
2021 ]
2022 return update_options + {
2023 "_matched_objects": matched_objects,
2024 "_value_evaluators": value_evaluators,
2025 "_resolved_keys_as_propnames": resolved_keys_as_propnames,
2026 }
2028 @classmethod
2029 def _get_resolved_values(cls, mapper, statement):
2030 if statement._multi_values:
2031 return []
2032 elif statement._ordered_values:
2033 return list(statement._ordered_values)
2034 elif statement._values:
2035 return list(statement._values.items())
2036 else:
2037 return []
2039 @classmethod
2040 def _resolved_keys_as_propnames(cls, mapper, resolved_values):
2041 values = []
2042 for k, v in resolved_values:
2043 if isinstance(k, attributes.QueryableAttribute):
2044 values.append((k.key, v))
2045 continue
2046 elif hasattr(k, "__clause_element__"):
2047 k = k.__clause_element__()
2049 if mapper and isinstance(k, expression.ColumnElement):
2050 try:
2051 attr = mapper._columntoproperty[k]
2052 except orm_exc.UnmappedColumnError:
2053 pass
2054 else:
2055 values.append((attr.key, v))
2056 else:
2057 raise sa_exc.InvalidRequestError(
2058 "Invalid expression type: %r" % k
2059 )
2060 return values
2062 @classmethod
2063 def _do_pre_synchronize_fetch(
2064 cls,
2065 session,
2066 statement,
2067 params,
2068 execution_options,
2069 bind_arguments,
2070 update_options,
2071 ):
2072 mapper = update_options._subject_mapper
2074 select_stmt = (
2075 select(*(mapper.primary_key + (mapper.select_identity_token,)))
2076 .select_from(mapper)
2077 .options(*statement._with_options)
2078 )
2079 select_stmt._where_criteria = statement._where_criteria
2081 def skip_for_full_returning(orm_context):
2082 bind = orm_context.session.get_bind(**orm_context.bind_arguments)
2083 if bind.dialect.full_returning:
2084 return _result.null_result()
2085 else:
2086 return None
2088 result = session.execute(
2089 select_stmt,
2090 params,
2091 execution_options,
2092 bind_arguments,
2093 _add_event=skip_for_full_returning,
2094 )
2095 matched_rows = result.fetchall()
2097 value_evaluators = _EMPTY_DICT
2099 if statement.__visit_name__ == "lambda_element":
2100 # ._resolved is called on every LambdaElement in order to
2101 # generate the cache key, so this access does not add
2102 # additional expense
2103 effective_statement = statement._resolved
2104 else:
2105 effective_statement = statement
2107 if effective_statement.__visit_name__ == "update":
2108 target_cls = mapper.class_
2109 evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
2110 resolved_values = cls._get_resolved_values(
2111 mapper, effective_statement
2112 )
2113 resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
2114 mapper, resolved_values
2115 )
2117 resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
2118 mapper, resolved_values
2119 )
2120 value_evaluators = {}
2121 for key, value in resolved_keys_as_propnames:
2122 try:
2123 _evaluator = evaluator_compiler.process(
2124 coercions.expect(roles.ExpressionElementRole, value)
2125 )
2126 except evaluator.UnevaluatableError:
2127 pass
2128 else:
2129 value_evaluators[key] = _evaluator
2131 else:
2132 resolved_keys_as_propnames = _EMPTY_DICT
2134 return update_options + {
2135 "_value_evaluators": value_evaluators,
2136 "_matched_rows": matched_rows,
2137 "_resolved_keys_as_propnames": resolved_keys_as_propnames,
2138 }
2141class ORMDMLState:
2142 @classmethod
2143 def get_entity_description(cls, statement):
2144 ext_info = statement.table._annotations["parententity"]
2145 mapper = ext_info.mapper
2146 if ext_info.is_aliased_class:
2147 _label_name = ext_info.name
2148 else:
2149 _label_name = mapper.class_.__name__
2151 return {
2152 "name": _label_name,
2153 "type": mapper.class_,
2154 "expr": ext_info.entity,
2155 "entity": ext_info.entity,
2156 "table": mapper.local_table,
2157 }
2159 @classmethod
2160 def get_returning_column_descriptions(cls, statement):
2161 def _ent_for_col(c):
2162 return c._annotations.get("parententity", None)
2164 def _attr_for_col(c, ent):
2165 if ent is None:
2166 return c
2167 proxy_key = c._annotations.get("proxy_key", None)
2168 if not proxy_key:
2169 return c
2170 else:
2171 return getattr(ent.entity, proxy_key, c)
2173 return [
2174 {
2175 "name": c.key,
2176 "type": c.type,
2177 "expr": _attr_for_col(c, ent),
2178 "aliased": ent.is_aliased_class,
2179 "entity": ent.entity,
2180 }
2181 for c, ent in [
2182 (c, _ent_for_col(c)) for c in statement._all_selected_columns
2183 ]
2184 ]
2187@CompileState.plugin_for("orm", "insert")
2188class ORMInsert(ORMDMLState, InsertDMLState):
2189 @classmethod
2190 def orm_pre_session_exec(
2191 cls,
2192 session,
2193 statement,
2194 params,
2195 execution_options,
2196 bind_arguments,
2197 is_reentrant_invoke,
2198 ):
2199 bind_arguments["clause"] = statement
2200 try:
2201 plugin_subject = statement._propagate_attrs["plugin_subject"]
2202 except KeyError:
2203 assert False, "statement had 'orm' plugin but no plugin_subject"
2204 else:
2205 bind_arguments["mapper"] = plugin_subject.mapper
2207 return (
2208 statement,
2209 util.immutabledict(execution_options),
2210 )
2212 @classmethod
2213 def orm_setup_cursor_result(
2214 cls,
2215 session,
2216 statement,
2217 params,
2218 execution_options,
2219 bind_arguments,
2220 result,
2221 ):
2222 return result
2225@CompileState.plugin_for("orm", "update")
2226class BulkORMUpdate(ORMDMLState, UpdateDMLState, BulkUDCompileState):
2227 @classmethod
2228 def create_for_statement(cls, statement, compiler, **kw):
2230 self = cls.__new__(cls)
2232 ext_info = statement.table._annotations["parententity"]
2234 self.mapper = mapper = ext_info.mapper
2236 self.extra_criteria_entities = {}
2238 self._resolved_values = cls._get_resolved_values(mapper, statement)
2240 extra_criteria_attributes = {}
2242 for opt in statement._with_options:
2243 if opt._is_criteria_option:
2244 opt.get_global_criteria(extra_criteria_attributes)
2246 if not statement._preserve_parameter_order and statement._values:
2247 self._resolved_values = dict(self._resolved_values)
2249 new_stmt = sql.Update.__new__(sql.Update)
2250 new_stmt.__dict__.update(statement.__dict__)
2251 new_stmt.table = mapper.local_table
2253 # note if the statement has _multi_values, these
2254 # are passed through to the new statement, which will then raise
2255 # InvalidRequestError because UPDATE doesn't support multi_values
2256 # right now.
2257 if statement._ordered_values:
2258 new_stmt._ordered_values = self._resolved_values
2259 elif statement._values:
2260 new_stmt._values = self._resolved_values
2262 new_crit = cls._adjust_for_extra_criteria(
2263 extra_criteria_attributes, mapper
2264 )
2265 if new_crit:
2266 new_stmt = new_stmt.where(*new_crit)
2268 # if we are against a lambda statement we might not be the
2269 # topmost object that received per-execute annotations
2271 if (
2272 compiler._annotations.get("synchronize_session", None) == "fetch"
2273 and compiler.dialect.full_returning
2274 ):
2275 if new_stmt._returning:
2276 raise sa_exc.InvalidRequestError(
2277 "Can't use synchronize_session='fetch' "
2278 "with explicit returning()"
2279 )
2280 new_stmt = new_stmt.returning(*mapper.primary_key)
2282 UpdateDMLState.__init__(self, new_stmt, compiler, **kw)
2284 return self
2286 @classmethod
2287 def _get_crud_kv_pairs(cls, statement, kv_iterator):
2288 plugin_subject = statement._propagate_attrs["plugin_subject"]
2290 core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs
2292 if not plugin_subject or not plugin_subject.mapper:
2293 return core_get_crud_kv_pairs(statement, kv_iterator)
2295 mapper = plugin_subject.mapper
2297 values = []
2299 for k, v in kv_iterator:
2300 k = coercions.expect(roles.DMLColumnRole, k)
2302 if isinstance(k, util.string_types):
2303 desc = _entity_namespace_key(mapper, k, default=NO_VALUE)
2304 if desc is NO_VALUE:
2305 values.append(
2306 (
2307 k,
2308 coercions.expect(
2309 roles.ExpressionElementRole,
2310 v,
2311 type_=sqltypes.NullType(),
2312 is_crud=True,
2313 ),
2314 )
2315 )
2316 else:
2317 values.extend(
2318 core_get_crud_kv_pairs(
2319 statement, desc._bulk_update_tuples(v)
2320 )
2321 )
2322 elif "entity_namespace" in k._annotations:
2323 k_anno = k._annotations
2324 attr = _entity_namespace_key(
2325 k_anno["entity_namespace"], k_anno["proxy_key"]
2326 )
2327 values.extend(
2328 core_get_crud_kv_pairs(
2329 statement, attr._bulk_update_tuples(v)
2330 )
2331 )
2332 else:
2333 values.append(
2334 (
2335 k,
2336 coercions.expect(
2337 roles.ExpressionElementRole,
2338 v,
2339 type_=sqltypes.NullType(),
2340 is_crud=True,
2341 ),
2342 )
2343 )
2344 return values
2346 @classmethod
2347 def _do_post_synchronize_evaluate(cls, session, result, update_options):
2349 states = set()
2350 evaluated_keys = list(update_options._value_evaluators.keys())
2351 values = update_options._resolved_keys_as_propnames
2352 attrib = set(k for k, v in values)
2353 for obj in update_options._matched_objects:
2355 state, dict_ = (
2356 attributes.instance_state(obj),
2357 attributes.instance_dict(obj),
2358 )
2360 # the evaluated states were gathered across all identity tokens.
2361 # however the post_sync events are called per identity token,
2362 # so filter.
2363 if (
2364 update_options._refresh_identity_token is not None
2365 and state.identity_token
2366 != update_options._refresh_identity_token
2367 ):
2368 continue
2370 # only evaluate unmodified attributes
2371 to_evaluate = state.unmodified.intersection(evaluated_keys)
2372 for key in to_evaluate:
2373 if key in dict_:
2374 dict_[key] = update_options._value_evaluators[key](obj)
2376 state.manager.dispatch.refresh(state, None, to_evaluate)
2378 state._commit(dict_, list(to_evaluate))
2380 to_expire = attrib.intersection(dict_).difference(to_evaluate)
2381 if to_expire:
2382 state._expire_attributes(dict_, to_expire)
2384 states.add(state)
2385 session._register_altered(states)
2387 @classmethod
2388 def _do_post_synchronize_fetch(cls, session, result, update_options):
2389 target_mapper = update_options._subject_mapper
2391 states = set()
2392 evaluated_keys = list(update_options._value_evaluators.keys())
2394 if result.returns_rows:
2395 matched_rows = [
2396 tuple(row) + (update_options._refresh_identity_token,)
2397 for row in result.all()
2398 ]
2399 else:
2400 matched_rows = update_options._matched_rows
2402 objs = [
2403 session.identity_map[identity_key]
2404 for identity_key in [
2405 target_mapper.identity_key_from_primary_key(
2406 list(primary_key),
2407 identity_token=identity_token,
2408 )
2409 for primary_key, identity_token in [
2410 (row[0:-1], row[-1]) for row in matched_rows
2411 ]
2412 if update_options._refresh_identity_token is None
2413 or identity_token == update_options._refresh_identity_token
2414 ]
2415 if identity_key in session.identity_map
2416 ]
2418 values = update_options._resolved_keys_as_propnames
2419 attrib = set(k for k, v in values)
2421 for obj in objs:
2422 state, dict_ = (
2423 attributes.instance_state(obj),
2424 attributes.instance_dict(obj),
2425 )
2427 to_evaluate = state.unmodified.intersection(evaluated_keys)
2428 for key in to_evaluate:
2429 if key in dict_:
2430 dict_[key] = update_options._value_evaluators[key](obj)
2431 state.manager.dispatch.refresh(state, None, to_evaluate)
2433 state._commit(dict_, list(to_evaluate))
2435 to_expire = attrib.intersection(dict_).difference(to_evaluate)
2436 if to_expire:
2437 state._expire_attributes(dict_, to_expire)
2439 states.add(state)
2440 session._register_altered(states)
2443@CompileState.plugin_for("orm", "delete")
2444class BulkORMDelete(ORMDMLState, DeleteDMLState, BulkUDCompileState):
2445 @classmethod
2446 def create_for_statement(cls, statement, compiler, **kw):
2447 self = cls.__new__(cls)
2449 ext_info = statement.table._annotations["parententity"]
2450 self.mapper = mapper = ext_info.mapper
2452 self.extra_criteria_entities = {}
2454 extra_criteria_attributes = {}
2456 for opt in statement._with_options:
2457 if opt._is_criteria_option:
2458 opt.get_global_criteria(extra_criteria_attributes)
2460 new_crit = cls._adjust_for_extra_criteria(
2461 extra_criteria_attributes, mapper
2462 )
2463 if new_crit:
2464 statement = statement.where(*new_crit)
2466 if (
2467 mapper
2468 and compiler._annotations.get("synchronize_session", None)
2469 == "fetch"
2470 and compiler.dialect.full_returning
2471 ):
2472 statement = statement.returning(*mapper.primary_key)
2474 DeleteDMLState.__init__(self, statement, compiler, **kw)
2476 return self
2478 @classmethod
2479 def _do_post_synchronize_evaluate(cls, session, result, update_options):
2481 session._remove_newly_deleted(
2482 [
2483 attributes.instance_state(obj)
2484 for obj in update_options._matched_objects
2485 ]
2486 )
2488 @classmethod
2489 def _do_post_synchronize_fetch(cls, session, result, update_options):
2490 target_mapper = update_options._subject_mapper
2492 if result.returns_rows:
2493 matched_rows = [
2494 tuple(row) + (update_options._refresh_identity_token,)
2495 for row in result.all()
2496 ]
2497 else:
2498 matched_rows = update_options._matched_rows
2500 for row in matched_rows:
2501 primary_key = row[0:-1]
2502 identity_token = row[-1]
2504 # TODO: inline this and call remove_newly_deleted
2505 # once
2506 identity_key = target_mapper.identity_key_from_primary_key(
2507 list(primary_key),
2508 identity_token=identity_token,
2509 )
2510 if identity_key in session.identity_map:
2511 session._remove_newly_deleted(
2512 [
2513 attributes.instance_state(
2514 session.identity_map[identity_key]
2515 )
2516 ]
2517 )