1# orm/persistence.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: ignore-errors
8
9
10"""private module containing functions used to emit INSERT, UPDATE
11and DELETE statements on behalf of a :class:`_orm.Mapper` and its descending
12mappers.
13
14The functions here are called only by the unit of work functions
15in unitofwork.py.
16
17"""
18from __future__ import annotations
19
20from itertools import chain
21from itertools import groupby
22from itertools import zip_longest
23import operator
24
25from . import attributes
26from . import exc as orm_exc
27from . import loading
28from . import sync
29from .base import state_str
30from .. import exc as sa_exc
31from .. import future
32from .. import sql
33from .. import util
34from ..engine import cursor as _cursor
35from ..sql import operators
36from ..sql.elements import BooleanClauseList
37from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
38
39
40def _save_obj(base_mapper, states, uowtransaction, single=False):
41 """Issue ``INSERT`` and/or ``UPDATE`` statements for a list
42 of objects.
43
44 This is called within the context of a UOWTransaction during a
45 flush operation, given a list of states to be flushed. The
46 base mapper in an inheritance hierarchy handles the inserts/
47 updates for all descendant mappers.
48
49 """
50
51 # if batch=false, call _save_obj separately for each object
52 if not single and not base_mapper.batch:
53 for state in _sort_states(base_mapper, states):
54 _save_obj(base_mapper, [state], uowtransaction, single=True)
55 return
56
57 states_to_update = []
58 states_to_insert = []
59
60 for (
61 state,
62 dict_,
63 mapper,
64 connection,
65 has_identity,
66 row_switch,
67 update_version_id,
68 ) in _organize_states_for_save(base_mapper, states, uowtransaction):
69 if has_identity or row_switch:
70 states_to_update.append(
71 (state, dict_, mapper, connection, update_version_id)
72 )
73 else:
74 states_to_insert.append((state, dict_, mapper, connection))
75
76 for table, mapper in base_mapper._sorted_tables.items():
77 if table not in mapper._pks_by_table:
78 continue
79 insert = _collect_insert_commands(table, states_to_insert)
80
81 update = _collect_update_commands(
82 uowtransaction, table, states_to_update
83 )
84
85 _emit_update_statements(
86 base_mapper,
87 uowtransaction,
88 mapper,
89 table,
90 update,
91 )
92
93 _emit_insert_statements(
94 base_mapper,
95 uowtransaction,
96 mapper,
97 table,
98 insert,
99 )
100
101 _finalize_insert_update_commands(
102 base_mapper,
103 uowtransaction,
104 chain(
105 (
106 (state, state_dict, mapper, connection, False)
107 for (state, state_dict, mapper, connection) in states_to_insert
108 ),
109 (
110 (state, state_dict, mapper, connection, True)
111 for (
112 state,
113 state_dict,
114 mapper,
115 connection,
116 update_version_id,
117 ) in states_to_update
118 ),
119 ),
120 )
121
122
123def _post_update(base_mapper, states, uowtransaction, post_update_cols):
124 """Issue UPDATE statements on behalf of a relationship() which
125 specifies post_update.
126
127 """
128
129 states_to_update = list(
130 _organize_states_for_post_update(base_mapper, states, uowtransaction)
131 )
132
133 for table, mapper in base_mapper._sorted_tables.items():
134 if table not in mapper._pks_by_table:
135 continue
136
137 update = (
138 (
139 state,
140 state_dict,
141 sub_mapper,
142 connection,
143 (
144 mapper._get_committed_state_attr_by_column(
145 state, state_dict, mapper.version_id_col
146 )
147 if mapper.version_id_col is not None
148 else None
149 ),
150 )
151 for state, state_dict, sub_mapper, connection in states_to_update
152 if table in sub_mapper._pks_by_table
153 )
154
155 update = _collect_post_update_commands(
156 base_mapper, uowtransaction, table, update, post_update_cols
157 )
158
159 _emit_post_update_statements(
160 base_mapper,
161 uowtransaction,
162 mapper,
163 table,
164 update,
165 )
166
167
168def _delete_obj(base_mapper, states, uowtransaction):
169 """Issue ``DELETE`` statements for a list of objects.
170
171 This is called within the context of a UOWTransaction during a
172 flush operation.
173
174 """
175
176 states_to_delete = list(
177 _organize_states_for_delete(base_mapper, states, uowtransaction)
178 )
179
180 table_to_mapper = base_mapper._sorted_tables
181
182 for table in reversed(list(table_to_mapper.keys())):
183 mapper = table_to_mapper[table]
184 if table not in mapper._pks_by_table:
185 continue
186 elif mapper.inherits and mapper.passive_deletes:
187 continue
188
189 delete = _collect_delete_commands(
190 base_mapper, uowtransaction, table, states_to_delete
191 )
192
193 _emit_delete_statements(
194 base_mapper,
195 uowtransaction,
196 mapper,
197 table,
198 delete,
199 )
200
201 for (
202 state,
203 state_dict,
204 mapper,
205 connection,
206 update_version_id,
207 ) in states_to_delete:
208 mapper.dispatch.after_delete(mapper, connection, state)
209
210
211def _organize_states_for_save(base_mapper, states, uowtransaction):
212 """Make an initial pass across a set of states for INSERT or
213 UPDATE.
214
215 This includes splitting out into distinct lists for
216 each, calling before_insert/before_update, obtaining
217 key information for each state including its dictionary,
218 mapper, the connection to use for the execution per state,
219 and the identity flag.
220
221 """
222
223 for state, dict_, mapper, connection in _connections_for_states(
224 base_mapper, uowtransaction, states
225 ):
226 has_identity = bool(state.key)
227
228 instance_key = state.key or mapper._identity_key_from_state(state)
229
230 row_switch = update_version_id = None
231
232 # call before_XXX extensions
233 if not has_identity:
234 mapper.dispatch.before_insert(mapper, connection, state)
235 else:
236 mapper.dispatch.before_update(mapper, connection, state)
237
238 if mapper._validate_polymorphic_identity:
239 mapper._validate_polymorphic_identity(mapper, state, dict_)
240
241 # detect if we have a "pending" instance (i.e. has
242 # no instance_key attached to it), and another instance
243 # with the same identity key already exists as persistent.
244 # convert to an UPDATE if so.
245 if (
246 not has_identity
247 and instance_key in uowtransaction.session.identity_map
248 ):
249 instance = uowtransaction.session.identity_map[instance_key]
250 existing = attributes.instance_state(instance)
251
252 if not uowtransaction.was_already_deleted(existing):
253 if not uowtransaction.is_deleted(existing):
254 util.warn(
255 "New instance %s with identity key %s conflicts "
256 "with persistent instance %s"
257 % (state_str(state), instance_key, state_str(existing))
258 )
259 else:
260 base_mapper._log_debug(
261 "detected row switch for identity %s. "
262 "will update %s, remove %s from "
263 "transaction",
264 instance_key,
265 state_str(state),
266 state_str(existing),
267 )
268
269 # remove the "delete" flag from the existing element
270 uowtransaction.remove_state_actions(existing)
271 row_switch = existing
272
273 if (has_identity or row_switch) and mapper.version_id_col is not None:
274 update_version_id = mapper._get_committed_state_attr_by_column(
275 row_switch if row_switch else state,
276 row_switch.dict if row_switch else dict_,
277 mapper.version_id_col,
278 )
279
280 yield (
281 state,
282 dict_,
283 mapper,
284 connection,
285 has_identity,
286 row_switch,
287 update_version_id,
288 )
289
290
291def _organize_states_for_post_update(base_mapper, states, uowtransaction):
292 """Make an initial pass across a set of states for UPDATE
293 corresponding to post_update.
294
295 This includes obtaining key information for each state
296 including its dictionary, mapper, the connection to use for
297 the execution per state.
298
299 """
300 return _connections_for_states(base_mapper, uowtransaction, states)
301
302
303def _organize_states_for_delete(base_mapper, states, uowtransaction):
304 """Make an initial pass across a set of states for DELETE.
305
306 This includes calling out before_delete and obtaining
307 key information for each state including its dictionary,
308 mapper, the connection to use for the execution per state.
309
310 """
311 for state, dict_, mapper, connection in _connections_for_states(
312 base_mapper, uowtransaction, states
313 ):
314 mapper.dispatch.before_delete(mapper, connection, state)
315
316 if mapper.version_id_col is not None:
317 update_version_id = mapper._get_committed_state_attr_by_column(
318 state, dict_, mapper.version_id_col
319 )
320 else:
321 update_version_id = None
322
323 yield (state, dict_, mapper, connection, update_version_id)
324
325
326def _collect_insert_commands(
327 table,
328 states_to_insert,
329 *,
330 bulk=False,
331 return_defaults=False,
332 render_nulls=False,
333 include_bulk_keys=(),
334):
335 """Identify sets of values to use in INSERT statements for a
336 list of states.
337
338 """
339 for state, state_dict, mapper, connection in states_to_insert:
340 if table not in mapper._pks_by_table:
341 continue
342
343 params = {}
344 value_params = {}
345
346 propkey_to_col = mapper._propkey_to_col[table]
347
348 eval_none = mapper._insert_cols_evaluating_none[table]
349
350 for propkey in set(propkey_to_col).intersection(state_dict):
351 value = state_dict[propkey]
352 col = propkey_to_col[propkey]
353 if value is None and col not in eval_none and not render_nulls:
354 continue
355 elif not bulk and (
356 hasattr(value, "__clause_element__")
357 or isinstance(value, sql.ClauseElement)
358 ):
359 value_params[col] = (
360 value.__clause_element__()
361 if hasattr(value, "__clause_element__")
362 else value
363 )
364 else:
365 params[col.key] = value
366
367 if not bulk:
368 # for all the columns that have no default and we don't have
369 # a value and where "None" is not a special value, add
370 # explicit None to the INSERT. This is a legacy behavior
371 # which might be worth removing, as it should not be necessary
372 # and also produces confusion, given that "missing" and None
373 # now have distinct meanings
374 for colkey in (
375 mapper._insert_cols_as_none[table]
376 .difference(params)
377 .difference([c.key for c in value_params])
378 ):
379 params[colkey] = None
380
381 if not bulk or return_defaults:
382 # params are in terms of Column key objects, so
383 # compare to pk_keys_by_table
384 has_all_pks = mapper._pk_keys_by_table[table].issubset(params)
385
386 if mapper.base_mapper._prefer_eager_defaults(
387 connection.dialect, table
388 ):
389 has_all_defaults = mapper._server_default_col_keys[
390 table
391 ].issubset(params)
392 else:
393 has_all_defaults = True
394 else:
395 has_all_defaults = has_all_pks = True
396
397 if (
398 mapper.version_id_generator is not False
399 and mapper.version_id_col is not None
400 and mapper.version_id_col in mapper._cols_by_table[table]
401 ):
402 params[mapper.version_id_col.key] = mapper.version_id_generator(
403 None
404 )
405
406 if bulk:
407 if mapper._set_polymorphic_identity:
408 params.setdefault(
409 mapper._polymorphic_attr_key, mapper.polymorphic_identity
410 )
411
412 if include_bulk_keys:
413 params.update((k, state_dict[k]) for k in include_bulk_keys)
414
415 yield (
416 state,
417 state_dict,
418 params,
419 mapper,
420 connection,
421 value_params,
422 has_all_pks,
423 has_all_defaults,
424 )
425
426
427def _collect_update_commands(
428 uowtransaction,
429 table,
430 states_to_update,
431 *,
432 bulk=False,
433 use_orm_update_stmt=None,
434 include_bulk_keys=(),
435):
436 """Identify sets of values to use in UPDATE statements for a
437 list of states.
438
439 This function works intricately with the history system
440 to determine exactly what values should be updated
441 as well as how the row should be matched within an UPDATE
442 statement. Includes some tricky scenarios where the primary
443 key of an object might have been changed.
444
445 """
446
447 for (
448 state,
449 state_dict,
450 mapper,
451 connection,
452 update_version_id,
453 ) in states_to_update:
454 if table not in mapper._pks_by_table:
455 continue
456
457 pks = mapper._pks_by_table[table]
458
459 if (
460 use_orm_update_stmt is not None
461 and not use_orm_update_stmt._maintain_values_ordering
462 ):
463 # TODO: ordered values, etc
464 # ORM bulk_persistence will raise for the maintain_values_ordering
465 # case right now
466 value_params = use_orm_update_stmt._values
467 else:
468 value_params = {}
469
470 propkey_to_col = mapper._propkey_to_col[table]
471
472 if bulk:
473 # keys here are mapped attribute keys, so
474 # look at mapper attribute keys for pk
475 params = {
476 propkey_to_col[propkey].key: state_dict[propkey]
477 for propkey in set(propkey_to_col)
478 .intersection(state_dict)
479 .difference(mapper._pk_attr_keys_by_table[table])
480 }
481 has_all_defaults = True
482 else:
483 params = {}
484 for propkey in set(propkey_to_col).intersection(
485 state.committed_state
486 ):
487 value = state_dict[propkey]
488 col = propkey_to_col[propkey]
489
490 if hasattr(value, "__clause_element__") or isinstance(
491 value, sql.ClauseElement
492 ):
493 value_params[col] = (
494 value.__clause_element__()
495 if hasattr(value, "__clause_element__")
496 else value
497 )
498 # guard against values that generate non-__nonzero__
499 # objects for __eq__()
500 elif (
501 state.manager[propkey].impl.is_equal(
502 value, state.committed_state[propkey]
503 )
504 is not True
505 ):
506 params[col.key] = value
507
508 if mapper.base_mapper.eager_defaults is True:
509 has_all_defaults = (
510 mapper._server_onupdate_default_col_keys[table]
511 ).issubset(params)
512 else:
513 has_all_defaults = True
514
515 if (
516 update_version_id is not None
517 and mapper.version_id_col in mapper._cols_by_table[table]
518 ):
519 if not bulk and not (params or value_params):
520 # HACK: check for history in other tables, in case the
521 # history is only in a different table than the one
522 # where the version_id_col is. This logic was lost
523 # from 0.9 -> 1.0.0 and restored in 1.0.6.
524 for prop in mapper._columntoproperty.values():
525 history = state.manager[prop.key].impl.get_history(
526 state, state_dict, attributes.PASSIVE_NO_INITIALIZE
527 )
528 if history.added:
529 break
530 else:
531 # no net change, break
532 continue
533
534 col = mapper.version_id_col
535 no_params = not params and not value_params
536 params[col._label] = update_version_id
537
538 if (
539 bulk or col.key not in params
540 ) and mapper.version_id_generator is not False:
541 val = mapper.version_id_generator(update_version_id)
542 params[col.key] = val
543 elif mapper.version_id_generator is False and no_params:
544 # no version id generator, no values set on the table,
545 # and version id wasn't manually incremented.
546 # set version id to itself so we get an UPDATE
547 # statement
548 params[col.key] = update_version_id
549
550 elif not (params or value_params):
551 continue
552
553 has_all_pks = True
554 expect_pk_cascaded = False
555 if bulk:
556 # keys here are mapped attribute keys, so
557 # look at mapper attribute keys for pk
558 pk_params = {
559 propkey_to_col[propkey]._label: state_dict.get(propkey)
560 for propkey in set(propkey_to_col).intersection(
561 mapper._pk_attr_keys_by_table[table]
562 )
563 }
564 if util.NONE_SET.intersection(pk_params.values()):
565 raise sa_exc.InvalidRequestError(
566 f"No primary key value supplied for column(s) "
567 f"""{
568 ', '.join(
569 str(c) for c in pks if pk_params[c._label] is None
570 )
571 }; """
572 "per-row ORM Bulk UPDATE by Primary Key requires that "
573 "records contain primary key values",
574 code="bupq",
575 )
576
577 else:
578 pk_params = {}
579 for col in pks:
580 propkey = mapper._columntoproperty[col].key
581
582 history = state.manager[propkey].impl.get_history(
583 state, state_dict, attributes.PASSIVE_OFF
584 )
585
586 if history.added:
587 if (
588 not history.deleted
589 or ("pk_cascaded", state, col)
590 in uowtransaction.attributes
591 ):
592 expect_pk_cascaded = True
593 pk_params[col._label] = history.added[0]
594 params.pop(col.key, None)
595 else:
596 # else, use the old value to locate the row
597 pk_params[col._label] = history.deleted[0]
598 if col in value_params:
599 has_all_pks = False
600 else:
601 pk_params[col._label] = history.unchanged[0]
602 if pk_params[col._label] is None:
603 raise orm_exc.FlushError(
604 "Can't update table %s using NULL for primary "
605 "key value on column %s" % (table, col)
606 )
607
608 if include_bulk_keys:
609 params.update((k, state_dict[k]) for k in include_bulk_keys)
610
611 if params or value_params:
612 params.update(pk_params)
613 yield (
614 state,
615 state_dict,
616 params,
617 mapper,
618 connection,
619 value_params,
620 has_all_defaults,
621 has_all_pks,
622 )
623 elif expect_pk_cascaded:
624 # no UPDATE occurs on this table, but we expect that CASCADE rules
625 # have changed the primary key of the row; propagate this event to
626 # other columns that expect to have been modified. this normally
627 # occurs after the UPDATE is emitted however we invoke it here
628 # explicitly in the absence of our invoking an UPDATE
629 for m, equated_pairs in mapper._table_to_equated[table]:
630 sync._populate(
631 state,
632 m,
633 state,
634 m,
635 equated_pairs,
636 uowtransaction,
637 mapper.passive_updates,
638 )
639
640
641def _collect_post_update_commands(
642 base_mapper, uowtransaction, table, states_to_update, post_update_cols
643):
644 """Identify sets of values to use in UPDATE statements for a
645 list of states within a post_update operation.
646
647 """
648
649 for (
650 state,
651 state_dict,
652 mapper,
653 connection,
654 update_version_id,
655 ) in states_to_update:
656 # assert table in mapper._pks_by_table
657
658 pks = mapper._pks_by_table[table]
659 params = {}
660 hasdata = False
661
662 for col in mapper._cols_by_table[table]:
663 if col in pks:
664 params[col._label] = mapper._get_state_attr_by_column(
665 state, state_dict, col, passive=attributes.PASSIVE_OFF
666 )
667
668 elif col in post_update_cols or col.onupdate is not None:
669 prop = mapper._columntoproperty[col]
670 history = state.manager[prop.key].impl.get_history(
671 state, state_dict, attributes.PASSIVE_NO_INITIALIZE
672 )
673 if history.added:
674 value = history.added[0]
675 params[col.key] = value
676 hasdata = True
677 if hasdata:
678 if (
679 update_version_id is not None
680 and mapper.version_id_col in mapper._cols_by_table[table]
681 ):
682 col = mapper.version_id_col
683 params[col._label] = update_version_id
684
685 if (
686 bool(state.key)
687 and col.key not in params
688 and mapper.version_id_generator is not False
689 ):
690 val = mapper.version_id_generator(update_version_id)
691 params[col.key] = val
692 yield state, state_dict, mapper, connection, params
693
694
695def _collect_delete_commands(
696 base_mapper, uowtransaction, table, states_to_delete
697):
698 """Identify values to use in DELETE statements for a list of
699 states to be deleted."""
700
701 for (
702 state,
703 state_dict,
704 mapper,
705 connection,
706 update_version_id,
707 ) in states_to_delete:
708 if table not in mapper._pks_by_table:
709 continue
710
711 params = {}
712 for col in mapper._pks_by_table[table]:
713 params[col.key] = value = (
714 mapper._get_committed_state_attr_by_column(
715 state, state_dict, col
716 )
717 )
718 if value is None:
719 raise orm_exc.FlushError(
720 "Can't delete from table %s "
721 "using NULL for primary "
722 "key value on column %s" % (table, col)
723 )
724
725 if (
726 update_version_id is not None
727 and mapper.version_id_col in mapper._cols_by_table[table]
728 ):
729 params[mapper.version_id_col.key] = update_version_id
730 yield params, connection
731
732
733def _emit_update_statements(
734 base_mapper,
735 uowtransaction,
736 mapper,
737 table,
738 update,
739 *,
740 bookkeeping=True,
741 use_orm_update_stmt=None,
742 enable_check_rowcount=True,
743):
744 """Emit UPDATE statements corresponding to value lists collected
745 by _collect_update_commands()."""
746
747 needs_version_id = (
748 mapper.version_id_col is not None
749 and mapper.version_id_col in mapper._cols_by_table[table]
750 )
751
752 execution_options = {"compiled_cache": base_mapper._compiled_cache}
753
754 def update_stmt(existing_stmt=None):
755 clauses = BooleanClauseList._construct_raw(operators.and_)
756
757 for col in mapper._pks_by_table[table]:
758 clauses._append_inplace(
759 col == sql.bindparam(col._label, type_=col.type)
760 )
761
762 if needs_version_id:
763 clauses._append_inplace(
764 mapper.version_id_col
765 == sql.bindparam(
766 mapper.version_id_col._label,
767 type_=mapper.version_id_col.type,
768 )
769 )
770
771 if existing_stmt is not None:
772 stmt = existing_stmt.where(clauses)
773 else:
774 stmt = table.update().where(clauses)
775 return stmt
776
777 if use_orm_update_stmt is not None:
778 cached_stmt = update_stmt(use_orm_update_stmt)
779
780 else:
781 cached_stmt = base_mapper._memo(("update", table), update_stmt)
782
783 for (
784 (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks),
785 records,
786 ) in groupby(
787 update,
788 lambda rec: (
789 rec[4], # connection
790 set(rec[2]), # set of parameter keys
791 bool(rec[5]), # whether or not we have "value" parameters
792 rec[6], # has_all_defaults
793 rec[7], # has all pks
794 ),
795 ):
796 rows = 0
797 records = list(records)
798
799 statement = cached_stmt
800
801 if use_orm_update_stmt is not None:
802 statement = statement._annotate(
803 {
804 "_emit_update_table": table,
805 "_emit_update_mapper": mapper,
806 }
807 )
808
809 return_defaults = False
810
811 if not has_all_pks:
812 statement = statement.return_defaults(*mapper._pks_by_table[table])
813 return_defaults = True
814
815 if (
816 bookkeeping
817 and not has_all_defaults
818 and mapper.base_mapper.eager_defaults is True
819 # change as of #8889 - if RETURNING is not going to be used anyway,
820 # (applies to MySQL, MariaDB which lack UPDATE RETURNING) ensure
821 # we can do an executemany UPDATE which is more efficient
822 and table.implicit_returning
823 and connection.dialect.update_returning
824 ):
825 statement = statement.return_defaults(
826 *mapper._server_onupdate_default_cols[table]
827 )
828 return_defaults = True
829
830 if mapper._version_id_has_server_side_value:
831 statement = statement.return_defaults(mapper.version_id_col)
832 return_defaults = True
833
834 assert_singlerow = connection.dialect.supports_sane_rowcount
835
836 assert_multirow = (
837 assert_singlerow
838 and connection.dialect.supports_sane_multi_rowcount
839 )
840
841 # change as of #8889 - if RETURNING is not going to be used anyway,
842 # (applies to MySQL, MariaDB which lack UPDATE RETURNING) ensure
843 # we can do an executemany UPDATE which is more efficient
844 allow_executemany = not return_defaults and not needs_version_id
845
846 if hasvalue:
847 for (
848 state,
849 state_dict,
850 params,
851 mapper,
852 connection,
853 value_params,
854 has_all_defaults,
855 has_all_pks,
856 ) in records:
857 c = connection.execute(
858 statement.values(value_params),
859 params,
860 execution_options=execution_options,
861 )
862 if bookkeeping:
863 _postfetch(
864 mapper,
865 uowtransaction,
866 table,
867 state,
868 state_dict,
869 c,
870 c.context.compiled_parameters[0],
871 value_params,
872 True,
873 c.returned_defaults,
874 )
875 rows += c.rowcount
876 check_rowcount = enable_check_rowcount and assert_singlerow
877 else:
878 if not allow_executemany:
879 check_rowcount = enable_check_rowcount and assert_singlerow
880 for (
881 state,
882 state_dict,
883 params,
884 mapper,
885 connection,
886 value_params,
887 has_all_defaults,
888 has_all_pks,
889 ) in records:
890 c = connection.execute(
891 statement, params, execution_options=execution_options
892 )
893
894 # TODO: why with bookkeeping=False?
895 if bookkeeping:
896 _postfetch(
897 mapper,
898 uowtransaction,
899 table,
900 state,
901 state_dict,
902 c,
903 c.context.compiled_parameters[0],
904 value_params,
905 True,
906 c.returned_defaults,
907 )
908 rows += c.rowcount
909 else:
910 multiparams = [rec[2] for rec in records]
911
912 check_rowcount = enable_check_rowcount and (
913 assert_multirow
914 or (assert_singlerow and len(multiparams) == 1)
915 )
916
917 c = connection.execute(
918 statement, multiparams, execution_options=execution_options
919 )
920
921 rows += c.rowcount
922
923 for (
924 state,
925 state_dict,
926 params,
927 mapper,
928 connection,
929 value_params,
930 has_all_defaults,
931 has_all_pks,
932 ) in records:
933 if bookkeeping:
934 _postfetch(
935 mapper,
936 uowtransaction,
937 table,
938 state,
939 state_dict,
940 c,
941 c.context.compiled_parameters[0],
942 value_params,
943 True,
944 (
945 c.returned_defaults
946 if not c.context.executemany
947 else None
948 ),
949 )
950
951 if check_rowcount:
952 if rows != len(records):
953 raise orm_exc.StaleDataError(
954 "UPDATE statement on table '%s' expected to "
955 "update %d row(s); %d were matched."
956 % (table.description, len(records), rows)
957 )
958
959 elif needs_version_id:
960 util.warn(
961 "Dialect %s does not support updated rowcount "
962 "- versioning cannot be verified."
963 % c.dialect.dialect_description
964 )
965
966
967def _emit_insert_statements(
968 base_mapper,
969 uowtransaction,
970 mapper,
971 table,
972 insert,
973 *,
974 bookkeeping=True,
975 use_orm_insert_stmt=None,
976 execution_options=None,
977):
978 """Emit INSERT statements corresponding to value lists collected
979 by _collect_insert_commands()."""
980
981 if use_orm_insert_stmt is not None:
982 cached_stmt = use_orm_insert_stmt
983 exec_opt = util.EMPTY_DICT
984
985 # if a user query with RETURNING was passed, we definitely need
986 # to use RETURNING.
987 returning_is_required_anyway = bool(use_orm_insert_stmt._returning)
988 deterministic_results_reqd = (
989 returning_is_required_anyway
990 and use_orm_insert_stmt._sort_by_parameter_order
991 ) or bookkeeping
992 else:
993 returning_is_required_anyway = False
994 deterministic_results_reqd = bookkeeping
995 cached_stmt = base_mapper._memo(("insert", table), table.insert)
996 exec_opt = {"compiled_cache": base_mapper._compiled_cache}
997
998 if execution_options:
999 execution_options = util.EMPTY_DICT.merge_with(
1000 exec_opt, execution_options
1001 )
1002 else:
1003 execution_options = exec_opt
1004
1005 return_result = None
1006
1007 for (
1008 (connection, _, hasvalue, has_all_pks, has_all_defaults),
1009 records,
1010 ) in groupby(
1011 insert,
1012 lambda rec: (
1013 rec[4], # connection
1014 set(rec[2]), # parameter keys
1015 bool(rec[5]), # whether we have "value" parameters
1016 rec[6],
1017 rec[7],
1018 ),
1019 ):
1020 statement = cached_stmt
1021
1022 if use_orm_insert_stmt is not None:
1023 statement = statement._annotate(
1024 {
1025 "_emit_insert_table": table,
1026 "_emit_insert_mapper": mapper,
1027 }
1028 )
1029
1030 if (
1031 (
1032 not bookkeeping
1033 or (
1034 has_all_defaults
1035 or not base_mapper._prefer_eager_defaults(
1036 connection.dialect, table
1037 )
1038 or not table.implicit_returning
1039 or not connection.dialect.insert_returning
1040 )
1041 )
1042 and not returning_is_required_anyway
1043 and has_all_pks
1044 and not hasvalue
1045 ):
1046 # the "we don't need newly generated values back" section.
1047 # here we have all the PKs, all the defaults or we don't want
1048 # to fetch them, or the dialect doesn't support RETURNING at all
1049 # so we have to post-fetch / use lastrowid anyway.
1050 records = list(records)
1051 multiparams = [rec[2] for rec in records]
1052
1053 result = connection.execute(
1054 statement, multiparams, execution_options=execution_options
1055 )
1056 if bookkeeping:
1057 for (
1058 (
1059 state,
1060 state_dict,
1061 params,
1062 mapper_rec,
1063 conn,
1064 value_params,
1065 has_all_pks,
1066 has_all_defaults,
1067 ),
1068 last_inserted_params,
1069 ) in zip(records, result.context.compiled_parameters):
1070 if state:
1071 _postfetch(
1072 mapper_rec,
1073 uowtransaction,
1074 table,
1075 state,
1076 state_dict,
1077 result,
1078 last_inserted_params,
1079 value_params,
1080 False,
1081 (
1082 result.returned_defaults
1083 if not result.context.executemany
1084 else None
1085 ),
1086 )
1087 else:
1088 _postfetch_bulk_save(mapper_rec, state_dict, table)
1089
1090 else:
1091 # here, we need defaults and/or pk values back or we otherwise
1092 # know that we are using RETURNING in any case
1093
1094 records = list(records)
1095
1096 if returning_is_required_anyway or (
1097 table.implicit_returning and not hasvalue and len(records) > 1
1098 ):
1099 if (
1100 deterministic_results_reqd
1101 and connection.dialect.insert_executemany_returning_sort_by_parameter_order # noqa: E501
1102 ) or (
1103 not deterministic_results_reqd
1104 and connection.dialect.insert_executemany_returning
1105 ):
1106 do_executemany = True
1107 elif returning_is_required_anyway:
1108 if deterministic_results_reqd:
1109 dt = " with RETURNING and sort by parameter order"
1110 else:
1111 dt = " with RETURNING"
1112 raise sa_exc.InvalidRequestError(
1113 f"Can't use explicit RETURNING for bulk INSERT "
1114 f"operation with "
1115 f"{connection.dialect.dialect_description} backend; "
1116 f"executemany{dt} is not enabled for this dialect."
1117 )
1118 else:
1119 do_executemany = False
1120 else:
1121 do_executemany = False
1122
1123 if use_orm_insert_stmt is None:
1124 if (
1125 not has_all_defaults
1126 and base_mapper._prefer_eager_defaults(
1127 connection.dialect, table
1128 )
1129 ):
1130 statement = statement.return_defaults(
1131 *mapper._server_default_cols[table],
1132 sort_by_parameter_order=bookkeeping,
1133 )
1134
1135 if mapper.version_id_col is not None:
1136 statement = statement.return_defaults(
1137 mapper.version_id_col,
1138 sort_by_parameter_order=bookkeeping,
1139 )
1140 elif do_executemany:
1141 statement = statement.return_defaults(
1142 *table.primary_key, sort_by_parameter_order=bookkeeping
1143 )
1144
1145 if do_executemany:
1146 multiparams = [rec[2] for rec in records]
1147
1148 result = connection.execute(
1149 statement, multiparams, execution_options=execution_options
1150 )
1151
1152 if use_orm_insert_stmt is not None:
1153 if return_result is None:
1154 return_result = result
1155 else:
1156 return_result = return_result.splice_vertically(result)
1157
1158 if bookkeeping:
1159 for (
1160 (
1161 state,
1162 state_dict,
1163 params,
1164 mapper_rec,
1165 conn,
1166 value_params,
1167 has_all_pks,
1168 has_all_defaults,
1169 ),
1170 last_inserted_params,
1171 inserted_primary_key,
1172 returned_defaults,
1173 ) in zip_longest(
1174 records,
1175 result.context.compiled_parameters,
1176 result.inserted_primary_key_rows,
1177 result.returned_defaults_rows or (),
1178 ):
1179 if inserted_primary_key is None:
1180 # this is a real problem and means that we didn't
1181 # get back as many PK rows. we can't continue
1182 # since this indicates PK rows were missing, which
1183 # means we likely mis-populated records starting
1184 # at that point with incorrectly matched PK
1185 # values.
1186 raise orm_exc.FlushError(
1187 "Multi-row INSERT statement for %s did not "
1188 "produce "
1189 "the correct number of INSERTed rows for "
1190 "RETURNING. Ensure there are no triggers or "
1191 "special driver issues preventing INSERT from "
1192 "functioning properly." % mapper_rec
1193 )
1194
1195 for pk, col in zip(
1196 inserted_primary_key,
1197 mapper._pks_by_table[table],
1198 ):
1199 prop = mapper_rec._columntoproperty[col]
1200 if state_dict.get(prop.key) is None:
1201 state_dict[prop.key] = pk
1202
1203 if state:
1204 _postfetch(
1205 mapper_rec,
1206 uowtransaction,
1207 table,
1208 state,
1209 state_dict,
1210 result,
1211 last_inserted_params,
1212 value_params,
1213 False,
1214 returned_defaults,
1215 )
1216 else:
1217 _postfetch_bulk_save(mapper_rec, state_dict, table)
1218 else:
1219 assert not returning_is_required_anyway
1220
1221 for (
1222 state,
1223 state_dict,
1224 params,
1225 mapper_rec,
1226 connection,
1227 value_params,
1228 has_all_pks,
1229 has_all_defaults,
1230 ) in records:
1231 if value_params:
1232 result = connection.execute(
1233 statement.values(value_params),
1234 params,
1235 execution_options=execution_options,
1236 )
1237 else:
1238 result = connection.execute(
1239 statement,
1240 params,
1241 execution_options=execution_options,
1242 )
1243
1244 primary_key = result.inserted_primary_key
1245 if primary_key is None:
1246 raise orm_exc.FlushError(
1247 "Single-row INSERT statement for %s "
1248 "did not produce a "
1249 "new primary key result "
1250 "being invoked. Ensure there are no triggers or "
1251 "special driver issues preventing INSERT from "
1252 "functioning properly." % (mapper_rec,)
1253 )
1254 for pk, col in zip(
1255 primary_key, mapper._pks_by_table[table]
1256 ):
1257 prop = mapper_rec._columntoproperty[col]
1258 if (
1259 col in value_params
1260 or state_dict.get(prop.key) is None
1261 ):
1262 state_dict[prop.key] = pk
1263 if bookkeeping:
1264 if state:
1265 _postfetch(
1266 mapper_rec,
1267 uowtransaction,
1268 table,
1269 state,
1270 state_dict,
1271 result,
1272 result.context.compiled_parameters[0],
1273 value_params,
1274 False,
1275 (
1276 result.returned_defaults
1277 if not result.context.executemany
1278 else None
1279 ),
1280 )
1281 else:
1282 _postfetch_bulk_save(mapper_rec, state_dict, table)
1283
1284 if use_orm_insert_stmt is not None:
1285 if return_result is None:
1286 return _cursor.null_dml_result()
1287 else:
1288 return return_result
1289
1290
1291def _emit_post_update_statements(
1292 base_mapper, uowtransaction, mapper, table, update
1293):
1294 """Emit UPDATE statements corresponding to value lists collected
1295 by _collect_post_update_commands()."""
1296
1297 execution_options = {"compiled_cache": base_mapper._compiled_cache}
1298
1299 needs_version_id = (
1300 mapper.version_id_col is not None
1301 and mapper.version_id_col in mapper._cols_by_table[table]
1302 )
1303
1304 def update_stmt():
1305 clauses = BooleanClauseList._construct_raw(operators.and_)
1306
1307 for col in mapper._pks_by_table[table]:
1308 clauses._append_inplace(
1309 col == sql.bindparam(col._label, type_=col.type)
1310 )
1311
1312 if needs_version_id:
1313 clauses._append_inplace(
1314 mapper.version_id_col
1315 == sql.bindparam(
1316 mapper.version_id_col._label,
1317 type_=mapper.version_id_col.type,
1318 )
1319 )
1320
1321 stmt = table.update().where(clauses)
1322
1323 return stmt
1324
1325 statement = base_mapper._memo(("post_update", table), update_stmt)
1326
1327 if mapper._version_id_has_server_side_value:
1328 statement = statement.return_defaults(mapper.version_id_col)
1329
1330 # execute each UPDATE in the order according to the original
1331 # list of states to guarantee row access order, but
1332 # also group them into common (connection, cols) sets
1333 # to support executemany().
1334 for key, records in groupby(
1335 update,
1336 lambda rec: (rec[3], set(rec[4])), # connection # parameter keys
1337 ):
1338 rows = 0
1339
1340 records = list(records)
1341 connection = key[0]
1342
1343 assert_singlerow = connection.dialect.supports_sane_rowcount
1344 assert_multirow = (
1345 assert_singlerow
1346 and connection.dialect.supports_sane_multi_rowcount
1347 )
1348 allow_executemany = not needs_version_id or assert_multirow
1349
1350 if not allow_executemany:
1351 check_rowcount = assert_singlerow
1352 for state, state_dict, mapper_rec, connection, params in records:
1353 c = connection.execute(
1354 statement, params, execution_options=execution_options
1355 )
1356
1357 _postfetch_post_update(
1358 mapper_rec,
1359 uowtransaction,
1360 table,
1361 state,
1362 state_dict,
1363 c,
1364 c.context.compiled_parameters[0],
1365 )
1366 rows += c.rowcount
1367 else:
1368 multiparams = [
1369 params
1370 for state, state_dict, mapper_rec, conn, params in records
1371 ]
1372
1373 check_rowcount = assert_multirow or (
1374 assert_singlerow and len(multiparams) == 1
1375 )
1376
1377 c = connection.execute(
1378 statement, multiparams, execution_options=execution_options
1379 )
1380
1381 rows += c.rowcount
1382 for state, state_dict, mapper_rec, connection, params in records:
1383 _postfetch_post_update(
1384 mapper_rec,
1385 uowtransaction,
1386 table,
1387 state,
1388 state_dict,
1389 c,
1390 c.context.compiled_parameters[0],
1391 )
1392
1393 if check_rowcount:
1394 if rows != len(records):
1395 raise orm_exc.StaleDataError(
1396 "UPDATE statement on table '%s' expected to "
1397 "update %d row(s); %d were matched."
1398 % (table.description, len(records), rows)
1399 )
1400
1401 elif needs_version_id:
1402 util.warn(
1403 "Dialect %s does not support updated rowcount "
1404 "- versioning cannot be verified."
1405 % c.dialect.dialect_description
1406 )
1407
1408
1409def _emit_delete_statements(
1410 base_mapper, uowtransaction, mapper, table, delete
1411):
1412 """Emit DELETE statements corresponding to value lists collected
1413 by _collect_delete_commands()."""
1414
1415 need_version_id = (
1416 mapper.version_id_col is not None
1417 and mapper.version_id_col in mapper._cols_by_table[table]
1418 )
1419
1420 def delete_stmt():
1421 clauses = BooleanClauseList._construct_raw(operators.and_)
1422
1423 for col in mapper._pks_by_table[table]:
1424 clauses._append_inplace(
1425 col == sql.bindparam(col.key, type_=col.type)
1426 )
1427
1428 if need_version_id:
1429 clauses._append_inplace(
1430 mapper.version_id_col
1431 == sql.bindparam(
1432 mapper.version_id_col.key, type_=mapper.version_id_col.type
1433 )
1434 )
1435
1436 return table.delete().where(clauses)
1437
1438 statement = base_mapper._memo(("delete", table), delete_stmt)
1439 for connection, recs in groupby(delete, lambda rec: rec[1]): # connection
1440 del_objects = [params for params, connection in recs]
1441
1442 execution_options = {"compiled_cache": base_mapper._compiled_cache}
1443 expected = len(del_objects)
1444 rows_matched = -1
1445 only_warn = False
1446
1447 if (
1448 need_version_id
1449 and not connection.dialect.supports_sane_multi_rowcount
1450 ):
1451 if connection.dialect.supports_sane_rowcount:
1452 rows_matched = 0
1453 # execute deletes individually so that versioned
1454 # rows can be verified
1455 for params in del_objects:
1456 c = connection.execute(
1457 statement, params, execution_options=execution_options
1458 )
1459 rows_matched += c.rowcount
1460 else:
1461 util.warn(
1462 "Dialect %s does not support deleted rowcount "
1463 "- versioning cannot be verified."
1464 % connection.dialect.dialect_description
1465 )
1466 connection.execute(
1467 statement, del_objects, execution_options=execution_options
1468 )
1469 else:
1470 c = connection.execute(
1471 statement, del_objects, execution_options=execution_options
1472 )
1473
1474 if not need_version_id:
1475 only_warn = True
1476
1477 rows_matched = c.rowcount
1478
1479 if (
1480 base_mapper.confirm_deleted_rows
1481 and rows_matched > -1
1482 and expected != rows_matched
1483 and (
1484 connection.dialect.supports_sane_multi_rowcount
1485 or len(del_objects) == 1
1486 )
1487 ):
1488 # TODO: why does this "only warn" if versioning is turned off,
1489 # whereas the UPDATE raises?
1490 if only_warn:
1491 util.warn(
1492 "DELETE statement on table '%s' expected to "
1493 "delete %d row(s); %d were matched. Please set "
1494 "confirm_deleted_rows=False within the mapper "
1495 "configuration to prevent this warning."
1496 % (table.description, expected, rows_matched)
1497 )
1498 else:
1499 raise orm_exc.StaleDataError(
1500 "DELETE statement on table '%s' expected to "
1501 "delete %d row(s); %d were matched. Please set "
1502 "confirm_deleted_rows=False within the mapper "
1503 "configuration to prevent this warning."
1504 % (table.description, expected, rows_matched)
1505 )
1506
1507
1508def _finalize_insert_update_commands(base_mapper, uowtransaction, states):
1509 """finalize state on states that have been inserted or updated,
1510 including calling after_insert/after_update events.
1511
1512 """
1513 for state, state_dict, mapper, connection, has_identity in states:
1514 if mapper._readonly_props:
1515 readonly = state.unmodified_intersection(
1516 [
1517 p.key
1518 for p in mapper._readonly_props
1519 if (
1520 p.expire_on_flush
1521 and (not p.deferred or p.key in state.dict)
1522 )
1523 or (
1524 not p.expire_on_flush
1525 and not p.deferred
1526 and p.key not in state.dict
1527 )
1528 ]
1529 )
1530 if readonly:
1531 state._expire_attributes(state.dict, readonly)
1532
1533 # if eager_defaults option is enabled, load
1534 # all expired cols. Else if we have a version_id_col, make sure
1535 # it isn't expired.
1536 toload_now = []
1537
1538 # this is specifically to emit a second SELECT for eager_defaults,
1539 # so only if it's set to True, not "auto"
1540 if base_mapper.eager_defaults is True:
1541 toload_now.extend(
1542 state._unloaded_non_object.intersection(
1543 mapper._server_default_plus_onupdate_propkeys
1544 )
1545 )
1546
1547 if (
1548 mapper.version_id_col is not None
1549 and mapper.version_id_generator is False
1550 ):
1551 if mapper._version_id_prop.key in state.unloaded:
1552 toload_now.extend([mapper._version_id_prop.key])
1553
1554 if toload_now:
1555 state.key = base_mapper._identity_key_from_state(state)
1556 stmt = future.select(mapper).set_label_style(
1557 LABEL_STYLE_TABLENAME_PLUS_COL
1558 )
1559 loading._load_on_ident(
1560 uowtransaction.session,
1561 stmt,
1562 state.key,
1563 refresh_state=state,
1564 only_load_props=toload_now,
1565 )
1566
1567 # call after_XXX extensions
1568 if not has_identity:
1569 mapper.dispatch.after_insert(mapper, connection, state)
1570 else:
1571 mapper.dispatch.after_update(mapper, connection, state)
1572
1573 if (
1574 mapper.version_id_generator is False
1575 and mapper.version_id_col is not None
1576 ):
1577 if state_dict[mapper._version_id_prop.key] is None:
1578 raise orm_exc.FlushError(
1579 "Instance does not contain a non-NULL version value"
1580 )
1581
1582
1583def _postfetch_post_update(
1584 mapper, uowtransaction, table, state, dict_, result, params
1585):
1586 needs_version_id = (
1587 mapper.version_id_col is not None
1588 and mapper.version_id_col in mapper._cols_by_table[table]
1589 )
1590
1591 if not uowtransaction.is_deleted(state):
1592 # post updating after a regular INSERT or UPDATE, do a full postfetch
1593 prefetch_cols = result.context.compiled.prefetch
1594 postfetch_cols = result.context.compiled.postfetch
1595 elif needs_version_id:
1596 # post updating before a DELETE with a version_id_col, need to
1597 # postfetch just version_id_col
1598 prefetch_cols = postfetch_cols = ()
1599 else:
1600 # post updating before a DELETE without a version_id_col,
1601 # don't need to postfetch
1602 return
1603
1604 if needs_version_id:
1605 prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
1606
1607 refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
1608 if refresh_flush:
1609 load_evt_attrs = []
1610
1611 for c in prefetch_cols:
1612 if c.key in params and c in mapper._columntoproperty:
1613 dict_[mapper._columntoproperty[c].key] = params[c.key]
1614 if refresh_flush:
1615 load_evt_attrs.append(mapper._columntoproperty[c].key)
1616
1617 if refresh_flush and load_evt_attrs:
1618 mapper.class_manager.dispatch.refresh_flush(
1619 state, uowtransaction, load_evt_attrs
1620 )
1621
1622 if postfetch_cols:
1623 state._expire_attributes(
1624 state.dict,
1625 [
1626 mapper._columntoproperty[c].key
1627 for c in postfetch_cols
1628 if c in mapper._columntoproperty
1629 ],
1630 )
1631
1632
1633def _postfetch(
1634 mapper,
1635 uowtransaction,
1636 table,
1637 state,
1638 dict_,
1639 result,
1640 params,
1641 value_params,
1642 isupdate,
1643 returned_defaults,
1644):
1645 """Expire attributes in need of newly persisted database state,
1646 after an INSERT or UPDATE statement has proceeded for that
1647 state."""
1648
1649 prefetch_cols = result.context.compiled.prefetch
1650 postfetch_cols = result.context.compiled.postfetch
1651 returning_cols = result.context.compiled.effective_returning
1652
1653 if (
1654 mapper.version_id_col is not None
1655 and mapper.version_id_col in mapper._cols_by_table[table]
1656 ):
1657 prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
1658
1659 refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
1660 if refresh_flush:
1661 load_evt_attrs = []
1662
1663 if returning_cols:
1664 row = returned_defaults
1665 if row is not None:
1666 for row_value, col in zip(row, returning_cols):
1667 # pk cols returned from insert are handled
1668 # distinctly, don't step on the values here
1669 if col.primary_key and result.context.isinsert:
1670 continue
1671
1672 # note that columns can be in the "return defaults" that are
1673 # not mapped to this mapper, typically because they are
1674 # "excluded", which can be specified directly or also occurs
1675 # when using declarative w/ single table inheritance
1676 prop = mapper._columntoproperty.get(col)
1677 if prop:
1678 dict_[prop.key] = row_value
1679 if refresh_flush:
1680 load_evt_attrs.append(prop.key)
1681
1682 for c in prefetch_cols:
1683 if c.key in params and c in mapper._columntoproperty:
1684 pkey = mapper._columntoproperty[c].key
1685
1686 # set prefetched value in dict and also pop from committed_state,
1687 # since this is new database state that replaces whatever might
1688 # have previously been fetched (see #10800). this is essentially a
1689 # shorthand version of set_committed_value(), which could also be
1690 # used here directly (with more overhead)
1691 dict_[pkey] = params[c.key]
1692 state.committed_state.pop(pkey, None)
1693
1694 if refresh_flush:
1695 load_evt_attrs.append(pkey)
1696
1697 if refresh_flush and load_evt_attrs:
1698 mapper.class_manager.dispatch.refresh_flush(
1699 state, uowtransaction, load_evt_attrs
1700 )
1701
1702 if isupdate and value_params:
1703 # explicitly suit the use case specified by
1704 # [ticket:3801], PK SQL expressions for UPDATE on non-RETURNING
1705 # database which are set to themselves in order to do a version bump.
1706 postfetch_cols.extend(
1707 [
1708 col
1709 for col in value_params
1710 if col.primary_key and col not in returning_cols
1711 ]
1712 )
1713
1714 if postfetch_cols:
1715 state._expire_attributes(
1716 state.dict,
1717 [
1718 mapper._columntoproperty[c].key
1719 for c in postfetch_cols
1720 if c in mapper._columntoproperty
1721 ],
1722 )
1723
1724 # synchronize newly inserted ids from one table to the next
1725 # TODO: this still goes a little too often. would be nice to
1726 # have definitive list of "columns that changed" here
1727 for m, equated_pairs in mapper._table_to_equated[table]:
1728 sync._populate(
1729 state,
1730 m,
1731 state,
1732 m,
1733 equated_pairs,
1734 uowtransaction,
1735 mapper.passive_updates,
1736 )
1737
1738
1739def _postfetch_bulk_save(mapper, dict_, table):
1740 for m, equated_pairs in mapper._table_to_equated[table]:
1741 sync._bulk_populate_inherit_keys(dict_, m, equated_pairs)
1742
1743
1744def _connections_for_states(base_mapper, uowtransaction, states):
1745 """Return an iterator of (state, state.dict, mapper, connection).
1746
1747 The states are sorted according to _sort_states, then paired
1748 with the connection they should be using for the given
1749 unit of work transaction.
1750
1751 """
1752 # if session has a connection callable,
1753 # organize individual states with the connection
1754 # to use for update
1755 if uowtransaction.session.connection_callable:
1756 connection_callable = uowtransaction.session.connection_callable
1757 else:
1758 connection = uowtransaction.transaction.connection(base_mapper)
1759 connection_callable = None
1760
1761 for state in _sort_states(base_mapper, states):
1762 if connection_callable:
1763 connection = connection_callable(base_mapper, state.obj())
1764
1765 mapper = state.manager.mapper
1766
1767 yield state, state.dict, mapper, connection
1768
1769
1770def _sort_states(mapper, states):
1771 pending = set(states)
1772 persistent = {s for s in pending if s.key is not None}
1773 pending.difference_update(persistent)
1774
1775 try:
1776 persistent_sorted = sorted(
1777 persistent, key=mapper._persistent_sortkey_fn
1778 )
1779 except TypeError as err:
1780 raise sa_exc.InvalidRequestError(
1781 "Could not sort objects by primary key; primary key "
1782 "values must be sortable in Python (was: %s)" % err
1783 ) from err
1784 return (
1785 sorted(pending, key=operator.attrgetter("insert_order"))
1786 + persistent_sorted
1787 )