1# orm/persistence.py
2# Copyright (C) 2005-2026 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: ignore-errors
8
9
10"""private module containing functions used to emit INSERT, UPDATE
11and DELETE statements on behalf of a :class:`_orm.Mapper` and its descending
12mappers.
13
14The functions here are called only by the unit of work functions
15in unitofwork.py.
16
17"""
18
19from __future__ import annotations
20
21from itertools import chain
22from itertools import groupby
23from itertools import zip_longest
24import operator
25
26from . import attributes
27from . import exc as orm_exc
28from . import loading
29from . import sync
30from .base import state_str
31from .. import exc as sa_exc
32from .. import sql
33from .. import util
34from ..engine import cursor as _cursor
35from ..sql import operators
36from ..sql.elements import BooleanClauseList
37
38
39def _save_obj(base_mapper, states, uowtransaction, single=False):
40 """Issue ``INSERT`` and/or ``UPDATE`` statements for a list
41 of objects.
42
43 This is called within the context of a UOWTransaction during a
44 flush operation, given a list of states to be flushed. The
45 base mapper in an inheritance hierarchy handles the inserts/
46 updates for all descendant mappers.
47
48 """
49
50 # if batch=false, call _save_obj separately for each object
51 if not single and not base_mapper.batch:
52 for state in _sort_states(base_mapper, states):
53 _save_obj(base_mapper, [state], uowtransaction, single=True)
54 return
55
56 states_to_update = []
57 states_to_insert = []
58
59 for (
60 state,
61 dict_,
62 mapper,
63 connection,
64 has_identity,
65 row_switch,
66 update_version_id,
67 ) in _organize_states_for_save(base_mapper, states, uowtransaction):
68 if has_identity or row_switch:
69 states_to_update.append(
70 (state, dict_, mapper, connection, update_version_id)
71 )
72 else:
73 states_to_insert.append((state, dict_, mapper, connection))
74
75 for table, mapper in base_mapper._sorted_tables.items():
76 if table not in mapper._pks_by_table:
77 continue
78 insert = _collect_insert_commands(table, states_to_insert)
79
80 update = _collect_update_commands(
81 uowtransaction, table, states_to_update
82 )
83
84 _emit_update_statements(
85 base_mapper,
86 uowtransaction,
87 mapper,
88 table,
89 update,
90 )
91
92 _emit_insert_statements(
93 base_mapper,
94 uowtransaction,
95 mapper,
96 table,
97 insert,
98 )
99
100 _finalize_insert_update_commands(
101 base_mapper,
102 uowtransaction,
103 chain(
104 (
105 (state, state_dict, mapper, connection, False)
106 for (state, state_dict, mapper, connection) in states_to_insert
107 ),
108 (
109 (state, state_dict, mapper, connection, True)
110 for (
111 state,
112 state_dict,
113 mapper,
114 connection,
115 update_version_id,
116 ) in states_to_update
117 ),
118 ),
119 )
120
121
122def _post_update(base_mapper, states, uowtransaction, post_update_cols):
123 """Issue UPDATE statements on behalf of a relationship() which
124 specifies post_update.
125
126 """
127
128 states_to_update = list(
129 _organize_states_for_post_update(base_mapper, states, uowtransaction)
130 )
131
132 for table, mapper in base_mapper._sorted_tables.items():
133 if table not in mapper._pks_by_table:
134 continue
135
136 update = (
137 (
138 state,
139 state_dict,
140 sub_mapper,
141 connection,
142 (
143 mapper._get_committed_state_attr_by_column(
144 state, state_dict, mapper.version_id_col
145 )
146 if mapper.version_id_col is not None
147 else None
148 ),
149 )
150 for state, state_dict, sub_mapper, connection in states_to_update
151 if table in sub_mapper._pks_by_table
152 )
153
154 update = _collect_post_update_commands(
155 base_mapper, uowtransaction, table, update, post_update_cols
156 )
157
158 _emit_post_update_statements(
159 base_mapper,
160 uowtransaction,
161 mapper,
162 table,
163 update,
164 )
165
166
167def _delete_obj(base_mapper, states, uowtransaction):
168 """Issue ``DELETE`` statements for a list of objects.
169
170 This is called within the context of a UOWTransaction during a
171 flush operation.
172
173 """
174
175 states_to_delete = list(
176 _organize_states_for_delete(base_mapper, states, uowtransaction)
177 )
178
179 table_to_mapper = base_mapper._sorted_tables
180
181 for table in reversed(list(table_to_mapper.keys())):
182 mapper = table_to_mapper[table]
183 if table not in mapper._pks_by_table:
184 continue
185 elif mapper.inherits and mapper.passive_deletes:
186 continue
187
188 delete = _collect_delete_commands(
189 base_mapper, uowtransaction, table, states_to_delete
190 )
191
192 _emit_delete_statements(
193 base_mapper,
194 uowtransaction,
195 mapper,
196 table,
197 delete,
198 )
199
200 for (
201 state,
202 state_dict,
203 mapper,
204 connection,
205 update_version_id,
206 ) in states_to_delete:
207 mapper.dispatch.after_delete(mapper, connection, state)
208
209
210def _organize_states_for_save(base_mapper, states, uowtransaction):
211 """Make an initial pass across a set of states for INSERT or
212 UPDATE.
213
214 This includes splitting out into distinct lists for
215 each, calling before_insert/before_update, obtaining
216 key information for each state including its dictionary,
217 mapper, the connection to use for the execution per state,
218 and the identity flag.
219
220 """
221
222 for state, dict_, mapper, connection in _connections_for_states(
223 base_mapper, uowtransaction, states
224 ):
225 has_identity = bool(state.key)
226
227 instance_key = state.key or mapper._identity_key_from_state(state)
228
229 row_switch = update_version_id = None
230
231 # call before_XXX extensions
232 if not has_identity:
233 mapper.dispatch.before_insert(mapper, connection, state)
234 else:
235 mapper.dispatch.before_update(mapper, connection, state)
236
237 if mapper._validate_polymorphic_identity:
238 mapper._validate_polymorphic_identity(mapper, state, dict_)
239
240 # detect if we have a "pending" instance (i.e. has
241 # no instance_key attached to it), and another instance
242 # with the same identity key already exists as persistent.
243 # convert to an UPDATE if so.
244 if (
245 not has_identity
246 and instance_key in uowtransaction.session.identity_map
247 ):
248 instance = uowtransaction.session.identity_map[instance_key]
249 existing = attributes.instance_state(instance)
250
251 if not uowtransaction.was_already_deleted(existing):
252 if not uowtransaction.is_deleted(existing):
253 util.warn(
254 "New instance %s with identity key %s conflicts "
255 "with persistent instance %s"
256 % (state_str(state), instance_key, state_str(existing))
257 )
258 else:
259 base_mapper._log_debug(
260 "detected row switch for identity %s. "
261 "will update %s, remove %s from "
262 "transaction",
263 instance_key,
264 state_str(state),
265 state_str(existing),
266 )
267
268 # remove the "delete" flag from the existing element
269 uowtransaction.remove_state_actions(existing)
270 row_switch = existing
271
272 if (has_identity or row_switch) and mapper.version_id_col is not None:
273 update_version_id = mapper._get_committed_state_attr_by_column(
274 row_switch if row_switch else state,
275 row_switch.dict if row_switch else dict_,
276 mapper.version_id_col,
277 )
278
279 yield (
280 state,
281 dict_,
282 mapper,
283 connection,
284 has_identity,
285 row_switch,
286 update_version_id,
287 )
288
289
290def _organize_states_for_post_update(base_mapper, states, uowtransaction):
291 """Make an initial pass across a set of states for UPDATE
292 corresponding to post_update.
293
294 This includes obtaining key information for each state
295 including its dictionary, mapper, the connection to use for
296 the execution per state.
297
298 """
299 return _connections_for_states(base_mapper, uowtransaction, states)
300
301
302def _organize_states_for_delete(base_mapper, states, uowtransaction):
303 """Make an initial pass across a set of states for DELETE.
304
305 This includes calling out before_delete and obtaining
306 key information for each state including its dictionary,
307 mapper, the connection to use for the execution per state.
308
309 """
310 for state, dict_, mapper, connection in _connections_for_states(
311 base_mapper, uowtransaction, states
312 ):
313 mapper.dispatch.before_delete(mapper, connection, state)
314
315 if mapper.version_id_col is not None:
316 update_version_id = mapper._get_committed_state_attr_by_column(
317 state, dict_, mapper.version_id_col
318 )
319 else:
320 update_version_id = None
321
322 yield (state, dict_, mapper, connection, update_version_id)
323
324
325def _collect_insert_commands(
326 table,
327 states_to_insert,
328 *,
329 bulk=False,
330 return_defaults=False,
331 render_nulls=False,
332 include_bulk_keys=(),
333):
334 """Identify sets of values to use in INSERT statements for a
335 list of states.
336
337 """
338 for state, state_dict, mapper, connection in states_to_insert:
339 if table not in mapper._pks_by_table:
340 continue
341
342 params = {}
343 value_params = {}
344
345 propkey_to_col = mapper._propkey_to_col[table]
346
347 eval_none = mapper._insert_cols_evaluating_none[table]
348
349 for propkey in set(propkey_to_col).intersection(state_dict):
350 value = state_dict[propkey]
351 col = propkey_to_col[propkey]
352 if value is None and col not in eval_none and not render_nulls:
353 continue
354 elif not bulk and (
355 hasattr(value, "__clause_element__")
356 or isinstance(value, sql.ClauseElement)
357 ):
358 value_params[col] = (
359 value.__clause_element__()
360 if hasattr(value, "__clause_element__")
361 else value
362 )
363 else:
364 params[col.key] = value
365
366 if not bulk:
367 # for all the columns that have no default and we don't have
368 # a value and where "None" is not a special value, add
369 # explicit None to the INSERT. This is a legacy behavior
370 # which might be worth removing, as it should not be necessary
371 # and also produces confusion, given that "missing" and None
372 # now have distinct meanings
373 for colkey in (
374 mapper._insert_cols_as_none[table]
375 .difference(params)
376 .difference([c.key for c in value_params])
377 ):
378 params[colkey] = None
379
380 if not bulk or return_defaults:
381 # params are in terms of Column key objects, so
382 # compare to pk_keys_by_table
383 has_all_pks = mapper._pk_keys_by_table[table].issubset(params)
384
385 if mapper.base_mapper._prefer_eager_defaults(
386 connection.dialect, table
387 ):
388 has_all_defaults = mapper._server_default_col_keys[
389 table
390 ].issubset(params)
391 else:
392 has_all_defaults = True
393 else:
394 has_all_defaults = has_all_pks = True
395
396 if (
397 mapper.version_id_generator is not False
398 and mapper.version_id_col is not None
399 and mapper.version_id_col in mapper._cols_by_table[table]
400 ):
401 params[mapper.version_id_col.key] = mapper.version_id_generator(
402 None
403 )
404
405 if bulk:
406 if mapper._set_polymorphic_identity:
407 params.setdefault(
408 mapper._polymorphic_attr_key, mapper.polymorphic_identity
409 )
410
411 if include_bulk_keys:
412 params.update((k, state_dict[k]) for k in include_bulk_keys)
413
414 yield (
415 state,
416 state_dict,
417 params,
418 mapper,
419 connection,
420 value_params,
421 has_all_pks,
422 has_all_defaults,
423 )
424
425
426def _collect_update_commands(
427 uowtransaction,
428 table,
429 states_to_update,
430 *,
431 bulk=False,
432 use_orm_update_stmt=None,
433 include_bulk_keys=(),
434):
435 """Identify sets of values to use in UPDATE statements for a
436 list of states.
437
438 This function works intricately with the history system
439 to determine exactly what values should be updated
440 as well as how the row should be matched within an UPDATE
441 statement. Includes some tricky scenarios where the primary
442 key of an object might have been changed.
443
444 """
445
446 for (
447 state,
448 state_dict,
449 mapper,
450 connection,
451 update_version_id,
452 ) in states_to_update:
453 if table not in mapper._pks_by_table:
454 continue
455
456 pks = mapper._pks_by_table[table]
457
458 if (
459 use_orm_update_stmt is not None
460 and not use_orm_update_stmt._maintain_values_ordering
461 ):
462 # TODO: ordered values, etc
463 # ORM bulk_persistence will raise for the maintain_values_ordering
464 # case right now
465 value_params = use_orm_update_stmt._values
466 else:
467 value_params = {}
468
469 propkey_to_col = mapper._propkey_to_col[table]
470
471 if bulk:
472 # keys here are mapped attribute keys, so
473 # look at mapper attribute keys for pk
474 params = {
475 propkey_to_col[propkey].key: state_dict[propkey]
476 for propkey in set(propkey_to_col)
477 .intersection(state_dict)
478 .difference(mapper._pk_attr_keys_by_table[table])
479 }
480 has_all_defaults = True
481 else:
482 params = {}
483 for propkey in set(propkey_to_col).intersection(
484 state.committed_state
485 ):
486 value = state_dict[propkey]
487 col = propkey_to_col[propkey]
488
489 if hasattr(value, "__clause_element__") or isinstance(
490 value, sql.ClauseElement
491 ):
492 value_params[col] = (
493 value.__clause_element__()
494 if hasattr(value, "__clause_element__")
495 else value
496 )
497 # guard against values that generate non-__nonzero__
498 # objects for __eq__()
499 elif (
500 state.manager[propkey].impl.is_equal(
501 value, state.committed_state[propkey]
502 )
503 is not True
504 ):
505 params[col.key] = value
506
507 if mapper.base_mapper.eager_defaults is True:
508 has_all_defaults = (
509 mapper._server_onupdate_default_col_keys[table]
510 ).issubset(params)
511 else:
512 has_all_defaults = True
513
514 if (
515 update_version_id is not None
516 and mapper.version_id_col in mapper._cols_by_table[table]
517 ):
518 if not bulk and not (params or value_params):
519 # HACK: check for history in other tables, in case the
520 # history is only in a different table than the one
521 # where the version_id_col is. This logic was lost
522 # from 0.9 -> 1.0.0 and restored in 1.0.6.
523 for prop in mapper._columntoproperty.values():
524 history = state.manager[prop.key].impl.get_history(
525 state, state_dict, attributes.PASSIVE_NO_INITIALIZE
526 )
527 if history.added:
528 break
529 else:
530 # no net change, break
531 continue
532
533 col = mapper.version_id_col
534 no_params = not params and not value_params
535 params[col._label] = update_version_id
536
537 if (
538 bulk or col.key not in params
539 ) and mapper.version_id_generator is not False:
540 val = mapper.version_id_generator(update_version_id)
541 params[col.key] = val
542 elif mapper.version_id_generator is False and no_params:
543 # no version id generator, no values set on the table,
544 # and version id wasn't manually incremented.
545 # set version id to itself so we get an UPDATE
546 # statement
547 params[col.key] = update_version_id
548
549 elif not (params or value_params):
550 continue
551
552 has_all_pks = True
553 expect_pk_cascaded = False
554 if bulk:
555 # keys here are mapped attribute keys, so
556 # look at mapper attribute keys for pk
557 pk_params = {
558 propkey_to_col[propkey]._label: state_dict.get(propkey)
559 for propkey in set(propkey_to_col).intersection(
560 mapper._pk_attr_keys_by_table[table]
561 )
562 }
563 if util.NONE_SET.intersection(pk_params.values()):
564 raise sa_exc.InvalidRequestError(
565 f"No primary key value supplied for column(s) "
566 f"""{
567 ', '.join(
568 str(c) for c in pks if pk_params[c._label] is None
569 )
570 }; """
571 "per-row ORM Bulk UPDATE by Primary Key requires that "
572 "records contain primary key values",
573 code="bupq",
574 )
575
576 else:
577 pk_params = {}
578 for col in pks:
579 propkey = mapper._columntoproperty[col].key
580
581 history = state.manager[propkey].impl.get_history(
582 state, state_dict, attributes.PASSIVE_OFF
583 )
584
585 if history.added:
586 if (
587 not history.deleted
588 or ("pk_cascaded", state, col)
589 in uowtransaction.attributes
590 ):
591 expect_pk_cascaded = True
592 pk_params[col._label] = history.added[0]
593 params.pop(col.key, None)
594 else:
595 # else, use the old value to locate the row
596 pk_params[col._label] = history.deleted[0]
597 if col in value_params:
598 has_all_pks = False
599 else:
600 pk_params[col._label] = history.unchanged[0]
601 if pk_params[col._label] is None:
602 raise orm_exc.FlushError(
603 "Can't update table %s using NULL for primary "
604 "key value on column %s" % (table, col)
605 )
606
607 if include_bulk_keys:
608 params.update((k, state_dict[k]) for k in include_bulk_keys)
609
610 if params or value_params:
611 params.update(pk_params)
612 yield (
613 state,
614 state_dict,
615 params,
616 mapper,
617 connection,
618 value_params,
619 has_all_defaults,
620 has_all_pks,
621 )
622 elif expect_pk_cascaded:
623 # no UPDATE occurs on this table, but we expect that CASCADE rules
624 # have changed the primary key of the row; propagate this event to
625 # other columns that expect to have been modified. this normally
626 # occurs after the UPDATE is emitted however we invoke it here
627 # explicitly in the absence of our invoking an UPDATE
628 for m, equated_pairs in mapper._table_to_equated[table]:
629 sync._populate(
630 state,
631 m,
632 state,
633 m,
634 equated_pairs,
635 uowtransaction,
636 mapper.passive_updates,
637 )
638
639
640def _collect_post_update_commands(
641 base_mapper, uowtransaction, table, states_to_update, post_update_cols
642):
643 """Identify sets of values to use in UPDATE statements for a
644 list of states within a post_update operation.
645
646 """
647
648 for (
649 state,
650 state_dict,
651 mapper,
652 connection,
653 update_version_id,
654 ) in states_to_update:
655 # assert table in mapper._pks_by_table
656
657 pks = mapper._pks_by_table[table]
658 params = {}
659 hasdata = False
660
661 for col in mapper._cols_by_table[table]:
662 if col in pks:
663 params[col._label] = mapper._get_state_attr_by_column(
664 state, state_dict, col, passive=attributes.PASSIVE_OFF
665 )
666
667 elif col in post_update_cols or col.onupdate is not None:
668 prop = mapper._columntoproperty[col]
669 history = state.manager[prop.key].impl.get_history(
670 state, state_dict, attributes.PASSIVE_NO_INITIALIZE
671 )
672 if history.added:
673 value = history.added[0]
674 params[col.key] = value
675 hasdata = True
676 if hasdata:
677 if (
678 update_version_id is not None
679 and mapper.version_id_col in mapper._cols_by_table[table]
680 ):
681 col = mapper.version_id_col
682 params[col._label] = update_version_id
683
684 if (
685 bool(state.key)
686 and col.key not in params
687 and mapper.version_id_generator is not False
688 ):
689 val = mapper.version_id_generator(update_version_id)
690 params[col.key] = val
691 yield state, state_dict, mapper, connection, params
692
693
694def _collect_delete_commands(
695 base_mapper, uowtransaction, table, states_to_delete
696):
697 """Identify values to use in DELETE statements for a list of
698 states to be deleted."""
699
700 for (
701 state,
702 state_dict,
703 mapper,
704 connection,
705 update_version_id,
706 ) in states_to_delete:
707 if table not in mapper._pks_by_table:
708 continue
709
710 params = {}
711 for col in mapper._pks_by_table[table]:
712 params[col.key] = value = (
713 mapper._get_committed_state_attr_by_column(
714 state, state_dict, col
715 )
716 )
717 if value is None:
718 raise orm_exc.FlushError(
719 "Can't delete from table %s "
720 "using NULL for primary "
721 "key value on column %s" % (table, col)
722 )
723
724 if (
725 update_version_id is not None
726 and mapper.version_id_col in mapper._cols_by_table[table]
727 ):
728 params[mapper.version_id_col.key] = update_version_id
729 yield params, connection
730
731
732def _emit_update_statements(
733 base_mapper,
734 uowtransaction,
735 mapper,
736 table,
737 update,
738 *,
739 bookkeeping=True,
740 use_orm_update_stmt=None,
741 enable_check_rowcount=True,
742):
743 """Emit UPDATE statements corresponding to value lists collected
744 by _collect_update_commands()."""
745
746 needs_version_id = (
747 mapper.version_id_col is not None
748 and mapper.version_id_col in mapper._cols_by_table[table]
749 )
750
751 execution_options = {"compiled_cache": base_mapper._compiled_cache}
752
753 def update_stmt(existing_stmt=None):
754 clauses = BooleanClauseList._construct_raw(operators.and_)
755
756 for col in mapper._pks_by_table[table]:
757 clauses._append_inplace(
758 col == sql.bindparam(col._label, type_=col.type)
759 )
760
761 if needs_version_id:
762 clauses._append_inplace(
763 mapper.version_id_col
764 == sql.bindparam(
765 mapper.version_id_col._label,
766 type_=mapper.version_id_col.type,
767 )
768 )
769
770 if existing_stmt is not None:
771 stmt = existing_stmt.where(clauses)
772 else:
773 stmt = table.update().where(clauses)
774 return stmt
775
776 if use_orm_update_stmt is not None:
777 cached_stmt = update_stmt(use_orm_update_stmt)
778
779 else:
780 cached_stmt = base_mapper._memo(("update", table), update_stmt)
781
782 for (
783 (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks),
784 records,
785 ) in groupby(
786 update,
787 lambda rec: (
788 rec[4], # connection
789 set(rec[2]), # set of parameter keys
790 bool(rec[5]), # whether or not we have "value" parameters
791 rec[6], # has_all_defaults
792 rec[7], # has all pks
793 ),
794 ):
795 rows = 0
796 records = list(records)
797
798 statement = cached_stmt
799
800 if use_orm_update_stmt is not None:
801 statement = statement._annotate(
802 {
803 "_emit_update_table": table,
804 "_emit_update_mapper": mapper,
805 }
806 )
807
808 return_defaults = False
809
810 if not has_all_pks:
811 statement = statement.return_defaults(*mapper._pks_by_table[table])
812 return_defaults = True
813
814 if (
815 bookkeeping
816 and not has_all_defaults
817 and mapper.base_mapper.eager_defaults is True
818 # change as of #8889 - if RETURNING is not going to be used anyway,
819 # (applies to MySQL, MariaDB which lack UPDATE RETURNING) ensure
820 # we can do an executemany UPDATE which is more efficient
821 and table.implicit_returning
822 and connection.dialect.update_returning
823 ):
824 statement = statement.return_defaults(
825 *mapper._server_onupdate_default_cols[table]
826 )
827 return_defaults = True
828
829 if mapper._version_id_has_server_side_value:
830 statement = statement.return_defaults(mapper.version_id_col)
831 return_defaults = True
832
833 assert_singlerow = connection.dialect.supports_sane_rowcount
834
835 assert_multirow = (
836 assert_singlerow
837 and connection.dialect.supports_sane_multi_rowcount
838 )
839
840 # change as of #8889 - if RETURNING is not going to be used anyway,
841 # (applies to MySQL, MariaDB which lack UPDATE RETURNING) ensure
842 # we can do an executemany UPDATE which is more efficient
843 allow_executemany = not return_defaults and not needs_version_id
844
845 if hasvalue:
846 for (
847 state,
848 state_dict,
849 params,
850 mapper,
851 connection,
852 value_params,
853 has_all_defaults,
854 has_all_pks,
855 ) in records:
856 c = connection.execute(
857 statement.values(value_params),
858 params,
859 execution_options=execution_options,
860 )
861 if bookkeeping:
862 _postfetch(
863 mapper,
864 uowtransaction,
865 table,
866 state,
867 state_dict,
868 c,
869 c.context.compiled_parameters[0],
870 value_params,
871 True,
872 c.returned_defaults,
873 )
874 rows += c.rowcount
875 check_rowcount = enable_check_rowcount and assert_singlerow
876 else:
877 if not allow_executemany:
878 check_rowcount = enable_check_rowcount and assert_singlerow
879 for (
880 state,
881 state_dict,
882 params,
883 mapper,
884 connection,
885 value_params,
886 has_all_defaults,
887 has_all_pks,
888 ) in records:
889 c = connection.execute(
890 statement, params, execution_options=execution_options
891 )
892
893 # TODO: why with bookkeeping=False?
894 if bookkeeping:
895 _postfetch(
896 mapper,
897 uowtransaction,
898 table,
899 state,
900 state_dict,
901 c,
902 c.context.compiled_parameters[0],
903 value_params,
904 True,
905 c.returned_defaults,
906 )
907 rows += c.rowcount
908 else:
909 multiparams = [rec[2] for rec in records]
910
911 check_rowcount = enable_check_rowcount and (
912 assert_multirow
913 or (assert_singlerow and len(multiparams) == 1)
914 )
915
916 c = connection.execute(
917 statement, multiparams, execution_options=execution_options
918 )
919
920 rows += c.rowcount
921
922 for (
923 state,
924 state_dict,
925 params,
926 mapper,
927 connection,
928 value_params,
929 has_all_defaults,
930 has_all_pks,
931 ) in records:
932 if bookkeeping:
933 _postfetch(
934 mapper,
935 uowtransaction,
936 table,
937 state,
938 state_dict,
939 c,
940 c.context.compiled_parameters[0],
941 value_params,
942 True,
943 (
944 c.returned_defaults
945 if not c.context.executemany
946 else None
947 ),
948 )
949
950 if check_rowcount:
951 if rows != len(records):
952 raise orm_exc.StaleDataError(
953 "UPDATE statement on table '%s' expected to "
954 "update %d row(s); %d were matched."
955 % (table.description, len(records), rows)
956 )
957
958 elif needs_version_id:
959 util.warn(
960 "Dialect %s does not support updated rowcount "
961 "- versioning cannot be verified."
962 % c.dialect.dialect_description
963 )
964
965
966def _emit_insert_statements(
967 base_mapper,
968 uowtransaction,
969 mapper,
970 table,
971 insert,
972 *,
973 bookkeeping=True,
974 use_orm_insert_stmt=None,
975 execution_options=None,
976):
977 """Emit INSERT statements corresponding to value lists collected
978 by _collect_insert_commands()."""
979
980 if use_orm_insert_stmt is not None:
981 cached_stmt = use_orm_insert_stmt
982 exec_opt = util.EMPTY_DICT
983
984 # if a user query with RETURNING was passed, we definitely need
985 # to use RETURNING.
986 returning_is_required_anyway = bool(use_orm_insert_stmt._returning)
987 deterministic_results_reqd = (
988 returning_is_required_anyway
989 and use_orm_insert_stmt._sort_by_parameter_order
990 ) or bookkeeping
991 else:
992 returning_is_required_anyway = False
993 deterministic_results_reqd = bookkeeping
994 cached_stmt = base_mapper._memo(("insert", table), table.insert)
995 exec_opt = {"compiled_cache": base_mapper._compiled_cache}
996
997 if execution_options:
998 execution_options = util.EMPTY_DICT.merge_with(
999 exec_opt, execution_options
1000 )
1001 else:
1002 execution_options = exec_opt
1003
1004 return_result = None
1005
1006 for (
1007 (connection, _, hasvalue, has_all_pks, has_all_defaults),
1008 records,
1009 ) in groupby(
1010 insert,
1011 lambda rec: (
1012 rec[4], # connection
1013 set(rec[2]), # parameter keys
1014 bool(rec[5]), # whether we have "value" parameters
1015 rec[6],
1016 rec[7],
1017 ),
1018 ):
1019 statement = cached_stmt
1020
1021 if use_orm_insert_stmt is not None:
1022 statement = statement._annotate(
1023 {
1024 "_emit_insert_table": table,
1025 "_emit_insert_mapper": mapper,
1026 }
1027 )
1028
1029 if (
1030 (
1031 not bookkeeping
1032 or (
1033 has_all_defaults
1034 or not base_mapper._prefer_eager_defaults(
1035 connection.dialect, table
1036 )
1037 or not table.implicit_returning
1038 or not connection.dialect.insert_returning
1039 )
1040 )
1041 and not returning_is_required_anyway
1042 and has_all_pks
1043 and not hasvalue
1044 ):
1045 # the "we don't need newly generated values back" section.
1046 # here we have all the PKs, all the defaults or we don't want
1047 # to fetch them, or the dialect doesn't support RETURNING at all
1048 # so we have to post-fetch / use lastrowid anyway.
1049 records = list(records)
1050 multiparams = [rec[2] for rec in records]
1051
1052 result = connection.execute(
1053 statement, multiparams, execution_options=execution_options
1054 )
1055 if bookkeeping:
1056 for (
1057 (
1058 state,
1059 state_dict,
1060 params,
1061 mapper_rec,
1062 conn,
1063 value_params,
1064 has_all_pks,
1065 has_all_defaults,
1066 ),
1067 last_inserted_params,
1068 ) in zip(records, result.context.compiled_parameters):
1069 if state:
1070 _postfetch(
1071 mapper_rec,
1072 uowtransaction,
1073 table,
1074 state,
1075 state_dict,
1076 result,
1077 last_inserted_params,
1078 value_params,
1079 False,
1080 (
1081 result.returned_defaults
1082 if not result.context.executemany
1083 else None
1084 ),
1085 )
1086 else:
1087 _postfetch_bulk_save(mapper_rec, state_dict, table)
1088
1089 else:
1090 # here, we need defaults and/or pk values back or we otherwise
1091 # know that we are using RETURNING in any case
1092
1093 records = list(records)
1094
1095 if returning_is_required_anyway or (
1096 table.implicit_returning and not hasvalue and len(records) > 1
1097 ):
1098 if (
1099 deterministic_results_reqd
1100 and connection.dialect.insert_executemany_returning_sort_by_parameter_order # noqa: E501
1101 ) or (
1102 not deterministic_results_reqd
1103 and connection.dialect.insert_executemany_returning
1104 ):
1105 do_executemany = True
1106 elif returning_is_required_anyway:
1107 if deterministic_results_reqd:
1108 dt = " with RETURNING and sort by parameter order"
1109 else:
1110 dt = " with RETURNING"
1111 raise sa_exc.InvalidRequestError(
1112 f"Can't use explicit RETURNING for bulk INSERT "
1113 f"operation with "
1114 f"{connection.dialect.dialect_description} backend; "
1115 f"executemany{dt} is not enabled for this dialect."
1116 )
1117 else:
1118 do_executemany = False
1119 else:
1120 do_executemany = False
1121
1122 if use_orm_insert_stmt is None:
1123 if (
1124 not has_all_defaults
1125 and base_mapper._prefer_eager_defaults(
1126 connection.dialect, table
1127 )
1128 ):
1129 statement = statement.return_defaults(
1130 *mapper._server_default_cols[table],
1131 sort_by_parameter_order=bookkeeping,
1132 )
1133
1134 if mapper.version_id_col is not None:
1135 statement = statement.return_defaults(
1136 mapper.version_id_col,
1137 sort_by_parameter_order=bookkeeping,
1138 )
1139 elif do_executemany:
1140 statement = statement.return_defaults(
1141 *table.primary_key, sort_by_parameter_order=bookkeeping
1142 )
1143
1144 if do_executemany:
1145 multiparams = [rec[2] for rec in records]
1146
1147 result = connection.execute(
1148 statement, multiparams, execution_options=execution_options
1149 )
1150
1151 if use_orm_insert_stmt is not None:
1152 if return_result is None:
1153 return_result = result
1154 else:
1155 return_result = return_result.splice_vertically(result)
1156
1157 if bookkeeping:
1158 for (
1159 (
1160 state,
1161 state_dict,
1162 params,
1163 mapper_rec,
1164 conn,
1165 value_params,
1166 has_all_pks,
1167 has_all_defaults,
1168 ),
1169 last_inserted_params,
1170 inserted_primary_key,
1171 returned_defaults,
1172 ) in zip_longest(
1173 records,
1174 result.context.compiled_parameters,
1175 result.inserted_primary_key_rows,
1176 result.returned_defaults_rows or (),
1177 ):
1178 if inserted_primary_key is None:
1179 # this is a real problem and means that we didn't
1180 # get back as many PK rows. we can't continue
1181 # since this indicates PK rows were missing, which
1182 # means we likely mis-populated records starting
1183 # at that point with incorrectly matched PK
1184 # values.
1185 raise orm_exc.FlushError(
1186 "Multi-row INSERT statement for %s did not "
1187 "produce "
1188 "the correct number of INSERTed rows for "
1189 "RETURNING. Ensure there are no triggers or "
1190 "special driver issues preventing INSERT from "
1191 "functioning properly." % mapper_rec
1192 )
1193
1194 for pk, col in zip(
1195 inserted_primary_key,
1196 mapper._pks_by_table[table],
1197 ):
1198 prop = mapper_rec._columntoproperty[col]
1199 if state_dict.get(prop.key) is None:
1200 state_dict[prop.key] = pk
1201
1202 if state:
1203 _postfetch(
1204 mapper_rec,
1205 uowtransaction,
1206 table,
1207 state,
1208 state_dict,
1209 result,
1210 last_inserted_params,
1211 value_params,
1212 False,
1213 returned_defaults,
1214 )
1215 else:
1216 _postfetch_bulk_save(mapper_rec, state_dict, table)
1217 else:
1218 assert not returning_is_required_anyway
1219
1220 for (
1221 state,
1222 state_dict,
1223 params,
1224 mapper_rec,
1225 connection,
1226 value_params,
1227 has_all_pks,
1228 has_all_defaults,
1229 ) in records:
1230 if value_params:
1231 result = connection.execute(
1232 statement.values(value_params),
1233 params,
1234 execution_options=execution_options,
1235 )
1236 else:
1237 result = connection.execute(
1238 statement,
1239 params,
1240 execution_options=execution_options,
1241 )
1242
1243 primary_key = result.inserted_primary_key
1244 if primary_key is None:
1245 raise orm_exc.FlushError(
1246 "Single-row INSERT statement for %s "
1247 "did not produce a "
1248 "new primary key result "
1249 "being invoked. Ensure there are no triggers or "
1250 "special driver issues preventing INSERT from "
1251 "functioning properly." % (mapper_rec,)
1252 )
1253 for pk, col in zip(
1254 primary_key, mapper._pks_by_table[table]
1255 ):
1256 prop = mapper_rec._columntoproperty[col]
1257 if (
1258 col in value_params
1259 or state_dict.get(prop.key) is None
1260 ):
1261 state_dict[prop.key] = pk
1262 if bookkeeping:
1263 if state:
1264 _postfetch(
1265 mapper_rec,
1266 uowtransaction,
1267 table,
1268 state,
1269 state_dict,
1270 result,
1271 result.context.compiled_parameters[0],
1272 value_params,
1273 False,
1274 (
1275 result.returned_defaults
1276 if not result.context.executemany
1277 else None
1278 ),
1279 )
1280 else:
1281 _postfetch_bulk_save(mapper_rec, state_dict, table)
1282
1283 if use_orm_insert_stmt is not None:
1284 if return_result is None:
1285 return _cursor.null_dml_result()
1286 else:
1287 return return_result
1288
1289
1290def _emit_post_update_statements(
1291 base_mapper, uowtransaction, mapper, table, update
1292):
1293 """Emit UPDATE statements corresponding to value lists collected
1294 by _collect_post_update_commands()."""
1295
1296 execution_options = {"compiled_cache": base_mapper._compiled_cache}
1297
1298 needs_version_id = (
1299 mapper.version_id_col is not None
1300 and mapper.version_id_col in mapper._cols_by_table[table]
1301 )
1302
1303 def update_stmt():
1304 clauses = BooleanClauseList._construct_raw(operators.and_)
1305
1306 for col in mapper._pks_by_table[table]:
1307 clauses._append_inplace(
1308 col == sql.bindparam(col._label, type_=col.type)
1309 )
1310
1311 if needs_version_id:
1312 clauses._append_inplace(
1313 mapper.version_id_col
1314 == sql.bindparam(
1315 mapper.version_id_col._label,
1316 type_=mapper.version_id_col.type,
1317 )
1318 )
1319
1320 stmt = table.update().where(clauses)
1321
1322 return stmt
1323
1324 statement = base_mapper._memo(("post_update", table), update_stmt)
1325
1326 if mapper._version_id_has_server_side_value:
1327 statement = statement.return_defaults(mapper.version_id_col)
1328
1329 # execute each UPDATE in the order according to the original
1330 # list of states to guarantee row access order, but
1331 # also group them into common (connection, cols) sets
1332 # to support executemany().
1333 for key, records in groupby(
1334 update,
1335 lambda rec: (rec[3], set(rec[4])), # connection # parameter keys
1336 ):
1337 rows = 0
1338
1339 records = list(records)
1340 connection = key[0]
1341
1342 assert_singlerow = connection.dialect.supports_sane_rowcount
1343 assert_multirow = (
1344 assert_singlerow
1345 and connection.dialect.supports_sane_multi_rowcount
1346 )
1347 allow_executemany = not needs_version_id or assert_multirow
1348
1349 if not allow_executemany:
1350 check_rowcount = assert_singlerow
1351 for state, state_dict, mapper_rec, connection, params in records:
1352 c = connection.execute(
1353 statement, params, execution_options=execution_options
1354 )
1355
1356 _postfetch_post_update(
1357 mapper_rec,
1358 uowtransaction,
1359 table,
1360 state,
1361 state_dict,
1362 c,
1363 c.context.compiled_parameters[0],
1364 )
1365 rows += c.rowcount
1366 else:
1367 multiparams = [
1368 params
1369 for state, state_dict, mapper_rec, conn, params in records
1370 ]
1371
1372 check_rowcount = assert_multirow or (
1373 assert_singlerow and len(multiparams) == 1
1374 )
1375
1376 c = connection.execute(
1377 statement, multiparams, execution_options=execution_options
1378 )
1379
1380 rows += c.rowcount
1381 for i, (
1382 state,
1383 state_dict,
1384 mapper_rec,
1385 connection,
1386 params,
1387 ) in enumerate(records):
1388 _postfetch_post_update(
1389 mapper_rec,
1390 uowtransaction,
1391 table,
1392 state,
1393 state_dict,
1394 c,
1395 c.context.compiled_parameters[i],
1396 )
1397
1398 if check_rowcount:
1399 if rows != len(records):
1400 raise orm_exc.StaleDataError(
1401 "UPDATE statement on table '%s' expected to "
1402 "update %d row(s); %d were matched."
1403 % (table.description, len(records), rows)
1404 )
1405
1406 elif needs_version_id:
1407 util.warn(
1408 "Dialect %s does not support updated rowcount "
1409 "- versioning cannot be verified."
1410 % c.dialect.dialect_description
1411 )
1412
1413
1414def _emit_delete_statements(
1415 base_mapper, uowtransaction, mapper, table, delete
1416):
1417 """Emit DELETE statements corresponding to value lists collected
1418 by _collect_delete_commands()."""
1419
1420 need_version_id = (
1421 mapper.version_id_col is not None
1422 and mapper.version_id_col in mapper._cols_by_table[table]
1423 )
1424
1425 def delete_stmt():
1426 clauses = BooleanClauseList._construct_raw(operators.and_)
1427
1428 for col in mapper._pks_by_table[table]:
1429 clauses._append_inplace(
1430 col == sql.bindparam(col.key, type_=col.type)
1431 )
1432
1433 if need_version_id:
1434 clauses._append_inplace(
1435 mapper.version_id_col
1436 == sql.bindparam(
1437 mapper.version_id_col.key, type_=mapper.version_id_col.type
1438 )
1439 )
1440
1441 return table.delete().where(clauses)
1442
1443 statement = base_mapper._memo(("delete", table), delete_stmt)
1444 for connection, recs in groupby(delete, lambda rec: rec[1]): # connection
1445 del_objects = [params for params, connection in recs]
1446
1447 execution_options = {"compiled_cache": base_mapper._compiled_cache}
1448 expected = len(del_objects)
1449 rows_matched = -1
1450 only_warn = False
1451
1452 if (
1453 need_version_id
1454 and not connection.dialect.supports_sane_multi_rowcount
1455 ):
1456 if connection.dialect.supports_sane_rowcount:
1457 rows_matched = 0
1458 # execute deletes individually so that versioned
1459 # rows can be verified
1460 for params in del_objects:
1461 c = connection.execute(
1462 statement, params, execution_options=execution_options
1463 )
1464 rows_matched += c.rowcount
1465 else:
1466 util.warn(
1467 "Dialect %s does not support deleted rowcount "
1468 "- versioning cannot be verified."
1469 % connection.dialect.dialect_description
1470 )
1471 connection.execute(
1472 statement, del_objects, execution_options=execution_options
1473 )
1474 else:
1475 c = connection.execute(
1476 statement, del_objects, execution_options=execution_options
1477 )
1478
1479 if not need_version_id:
1480 only_warn = True
1481
1482 rows_matched = c.rowcount
1483
1484 if (
1485 base_mapper.confirm_deleted_rows
1486 and rows_matched > -1
1487 and expected != rows_matched
1488 and (
1489 connection.dialect.supports_sane_multi_rowcount
1490 or len(del_objects) == 1
1491 )
1492 ):
1493 # TODO: why does this "only warn" if versioning is turned off,
1494 # whereas the UPDATE raises?
1495 if only_warn:
1496 util.warn(
1497 "DELETE statement on table '%s' expected to "
1498 "delete %d row(s); %d were matched. Please set "
1499 "confirm_deleted_rows=False within the mapper "
1500 "configuration to prevent this warning."
1501 % (table.description, expected, rows_matched)
1502 )
1503 else:
1504 raise orm_exc.StaleDataError(
1505 "DELETE statement on table '%s' expected to "
1506 "delete %d row(s); %d were matched. Please set "
1507 "confirm_deleted_rows=False within the mapper "
1508 "configuration to prevent this warning."
1509 % (table.description, expected, rows_matched)
1510 )
1511
1512
1513def _finalize_insert_update_commands(base_mapper, uowtransaction, states):
1514 """finalize state on states that have been inserted or updated,
1515 including calling after_insert/after_update events.
1516
1517 """
1518 for state, state_dict, mapper, connection, has_identity in states:
1519 if mapper._readonly_props:
1520 readonly = state.unmodified_intersection(
1521 [
1522 p.key
1523 for p in mapper._readonly_props
1524 if (
1525 p.expire_on_flush
1526 and (not p.deferred or p.key in state.dict)
1527 )
1528 or (
1529 not p.expire_on_flush
1530 and not p.deferred
1531 and p.key not in state.dict
1532 )
1533 ]
1534 )
1535 if readonly:
1536 state._expire_attributes(state.dict, readonly)
1537
1538 # if eager_defaults option is enabled, load
1539 # all expired cols. Else if we have a version_id_col, make sure
1540 # it isn't expired.
1541 toload_now = []
1542
1543 # this is specifically to emit a second SELECT for eager_defaults,
1544 # so only if it's set to True, not "auto"
1545 if base_mapper.eager_defaults is True:
1546 toload_now.extend(
1547 state._unloaded_non_object.intersection(
1548 mapper._server_default_plus_onupdate_propkeys
1549 )
1550 )
1551
1552 if (
1553 mapper.version_id_col is not None
1554 and mapper.version_id_generator is False
1555 ):
1556 if mapper._version_id_prop.key in state.unloaded:
1557 toload_now.extend([mapper._version_id_prop.key])
1558
1559 if toload_now:
1560 state.key = base_mapper._identity_key_from_state(state)
1561 stmt = sql.select(mapper)
1562 loading._load_on_ident(
1563 uowtransaction.session,
1564 stmt,
1565 state.key,
1566 refresh_state=state,
1567 only_load_props=toload_now,
1568 )
1569
1570 # call after_XXX extensions
1571 if not has_identity:
1572 mapper.dispatch.after_insert(mapper, connection, state)
1573 else:
1574 mapper.dispatch.after_update(mapper, connection, state)
1575
1576 if (
1577 mapper.version_id_generator is False
1578 and mapper.version_id_col is not None
1579 ):
1580 if state_dict[mapper._version_id_prop.key] is None:
1581 raise orm_exc.FlushError(
1582 "Instance does not contain a non-NULL version value"
1583 )
1584
1585
1586def _postfetch_post_update(
1587 mapper, uowtransaction, table, state, dict_, result, params
1588):
1589 needs_version_id = (
1590 mapper.version_id_col is not None
1591 and mapper.version_id_col in mapper._cols_by_table[table]
1592 )
1593
1594 if not uowtransaction.is_deleted(state):
1595 # post updating after a regular INSERT or UPDATE, do a full postfetch
1596 prefetch_cols = result.context.compiled.prefetch
1597 postfetch_cols = result.context.compiled.postfetch
1598 elif needs_version_id:
1599 # post updating before a DELETE with a version_id_col, need to
1600 # postfetch just version_id_col
1601 prefetch_cols = postfetch_cols = ()
1602 else:
1603 # post updating before a DELETE without a version_id_col,
1604 # don't need to postfetch
1605 return
1606
1607 if needs_version_id:
1608 prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
1609
1610 refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
1611 if refresh_flush:
1612 load_evt_attrs = []
1613
1614 for c in prefetch_cols:
1615 if c.key in params and c in mapper._columntoproperty:
1616 dict_[mapper._columntoproperty[c].key] = params[c.key]
1617 if refresh_flush:
1618 load_evt_attrs.append(mapper._columntoproperty[c].key)
1619
1620 if refresh_flush and load_evt_attrs:
1621 mapper.class_manager.dispatch.refresh_flush(
1622 state, uowtransaction, load_evt_attrs
1623 )
1624
1625 if postfetch_cols:
1626 state._expire_attributes(
1627 state.dict,
1628 [
1629 mapper._columntoproperty[c].key
1630 for c in postfetch_cols
1631 if c in mapper._columntoproperty
1632 ],
1633 )
1634
1635
1636def _postfetch(
1637 mapper,
1638 uowtransaction,
1639 table,
1640 state,
1641 dict_,
1642 result,
1643 params,
1644 value_params,
1645 isupdate,
1646 returned_defaults,
1647):
1648 """Expire attributes in need of newly persisted database state,
1649 after an INSERT or UPDATE statement has proceeded for that
1650 state."""
1651
1652 prefetch_cols = result.context.compiled.prefetch
1653 postfetch_cols = result.context.compiled.postfetch
1654 returning_cols = result.context.compiled.effective_returning
1655
1656 if (
1657 mapper.version_id_col is not None
1658 and mapper.version_id_col in mapper._cols_by_table[table]
1659 ):
1660 prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
1661
1662 refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
1663 if refresh_flush:
1664 load_evt_attrs = []
1665
1666 if returning_cols:
1667 row = returned_defaults
1668 if row is not None:
1669 for row_value, col in zip(row, returning_cols):
1670 # pk cols returned from insert are handled
1671 # distinctly, don't step on the values here
1672 if col.primary_key and result.context.isinsert:
1673 continue
1674
1675 # note that columns can be in the "return defaults" that are
1676 # not mapped to this mapper, typically because they are
1677 # "excluded", which can be specified directly or also occurs
1678 # when using declarative w/ single table inheritance
1679 prop = mapper._columntoproperty.get(col)
1680 if prop:
1681 dict_[prop.key] = row_value
1682 if refresh_flush:
1683 load_evt_attrs.append(prop.key)
1684
1685 for c in prefetch_cols:
1686 if c.key in params and c in mapper._columntoproperty:
1687 pkey = mapper._columntoproperty[c].key
1688
1689 # set prefetched value in dict and also pop from committed_state,
1690 # since this is new database state that replaces whatever might
1691 # have previously been fetched (see #10800). this is essentially a
1692 # shorthand version of set_committed_value(), which could also be
1693 # used here directly (with more overhead)
1694 dict_[pkey] = params[c.key]
1695 state.committed_state.pop(pkey, None)
1696
1697 if refresh_flush:
1698 load_evt_attrs.append(pkey)
1699
1700 if refresh_flush and load_evt_attrs:
1701 mapper.class_manager.dispatch.refresh_flush(
1702 state, uowtransaction, load_evt_attrs
1703 )
1704
1705 if isupdate and value_params:
1706 # explicitly suit the use case specified by
1707 # [ticket:3801], PK SQL expressions for UPDATE on non-RETURNING
1708 # database which are set to themselves in order to do a version bump.
1709 postfetch_cols.extend(
1710 [
1711 col
1712 for col in value_params
1713 if col.primary_key and col not in returning_cols
1714 ]
1715 )
1716
1717 if postfetch_cols:
1718 state._expire_attributes(
1719 state.dict,
1720 [
1721 mapper._columntoproperty[c].key
1722 for c in postfetch_cols
1723 if c in mapper._columntoproperty
1724 ],
1725 )
1726
1727 # synchronize newly inserted ids from one table to the next
1728 # TODO: this still goes a little too often. would be nice to
1729 # have definitive list of "columns that changed" here
1730 for m, equated_pairs in mapper._table_to_equated[table]:
1731 sync._populate(
1732 state,
1733 m,
1734 state,
1735 m,
1736 equated_pairs,
1737 uowtransaction,
1738 mapper.passive_updates,
1739 )
1740
1741
1742def _postfetch_bulk_save(mapper, dict_, table):
1743 for m, equated_pairs in mapper._table_to_equated[table]:
1744 sync._bulk_populate_inherit_keys(dict_, m, equated_pairs)
1745
1746
1747def _connections_for_states(base_mapper, uowtransaction, states):
1748 """Return an iterator of (state, state.dict, mapper, connection).
1749
1750 The states are sorted according to _sort_states, then paired
1751 with the connection they should be using for the given
1752 unit of work transaction.
1753
1754 """
1755 # if session has a connection callable,
1756 # organize individual states with the connection
1757 # to use for update
1758 if uowtransaction.session.connection_callable:
1759 connection_callable = uowtransaction.session.connection_callable
1760 else:
1761 connection = uowtransaction.transaction.connection(base_mapper)
1762 connection_callable = None
1763
1764 for state in _sort_states(base_mapper, states):
1765 if connection_callable:
1766 connection = connection_callable(base_mapper, state.obj())
1767
1768 mapper = state.manager.mapper
1769
1770 yield state, state.dict, mapper, connection
1771
1772
1773def _sort_states(mapper, states):
1774 pending = set(states)
1775 persistent = {s for s in pending if s.key is not None}
1776 pending.difference_update(persistent)
1777
1778 try:
1779 persistent_sorted = sorted(
1780 persistent, key=mapper._persistent_sortkey_fn
1781 )
1782 except TypeError as err:
1783 raise sa_exc.InvalidRequestError(
1784 "Could not sort objects by primary key; primary key "
1785 "values must be sortable in Python (was: %s)" % err
1786 ) from err
1787 return (
1788 sorted(pending, key=operator.attrgetter("insert_order"))
1789 + persistent_sorted
1790 )