1# orm/persistence.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: ignore-errors
8
9
10"""private module containing functions used to emit INSERT, UPDATE
11and DELETE statements on behalf of a :class:`_orm.Mapper` and its descending
12mappers.
13
14The functions here are called only by the unit of work functions
15in unitofwork.py.
16
17"""
18from __future__ import annotations
19
20from itertools import chain
21from itertools import groupby
22from itertools import zip_longest
23import operator
24
25from . import attributes
26from . import exc as orm_exc
27from . import loading
28from . import sync
29from .base import state_str
30from .. import exc as sa_exc
31from .. import future
32from .. import sql
33from .. import util
34from ..engine import cursor as _cursor
35from ..sql import operators
36from ..sql.elements import BooleanClauseList
37from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
38
39
40def _save_obj(base_mapper, states, uowtransaction, single=False):
41 """Issue ``INSERT`` and/or ``UPDATE`` statements for a list
42 of objects.
43
44 This is called within the context of a UOWTransaction during a
45 flush operation, given a list of states to be flushed. The
46 base mapper in an inheritance hierarchy handles the inserts/
47 updates for all descendant mappers.
48
49 """
50
51 # if batch=false, call _save_obj separately for each object
52 if not single and not base_mapper.batch:
53 for state in _sort_states(base_mapper, states):
54 _save_obj(base_mapper, [state], uowtransaction, single=True)
55 return
56
57 states_to_update = []
58 states_to_insert = []
59
60 for (
61 state,
62 dict_,
63 mapper,
64 connection,
65 has_identity,
66 row_switch,
67 update_version_id,
68 ) in _organize_states_for_save(base_mapper, states, uowtransaction):
69 if has_identity or row_switch:
70 states_to_update.append(
71 (state, dict_, mapper, connection, update_version_id)
72 )
73 else:
74 states_to_insert.append((state, dict_, mapper, connection))
75
76 for table, mapper in base_mapper._sorted_tables.items():
77 if table not in mapper._pks_by_table:
78 continue
79 insert = _collect_insert_commands(table, states_to_insert)
80
81 update = _collect_update_commands(
82 uowtransaction, table, states_to_update
83 )
84
85 _emit_update_statements(
86 base_mapper,
87 uowtransaction,
88 mapper,
89 table,
90 update,
91 )
92
93 _emit_insert_statements(
94 base_mapper,
95 uowtransaction,
96 mapper,
97 table,
98 insert,
99 )
100
101 _finalize_insert_update_commands(
102 base_mapper,
103 uowtransaction,
104 chain(
105 (
106 (state, state_dict, mapper, connection, False)
107 for (state, state_dict, mapper, connection) in states_to_insert
108 ),
109 (
110 (state, state_dict, mapper, connection, True)
111 for (
112 state,
113 state_dict,
114 mapper,
115 connection,
116 update_version_id,
117 ) in states_to_update
118 ),
119 ),
120 )
121
122
123def _post_update(base_mapper, states, uowtransaction, post_update_cols):
124 """Issue UPDATE statements on behalf of a relationship() which
125 specifies post_update.
126
127 """
128
129 states_to_update = list(
130 _organize_states_for_post_update(base_mapper, states, uowtransaction)
131 )
132
133 for table, mapper in base_mapper._sorted_tables.items():
134 if table not in mapper._pks_by_table:
135 continue
136
137 update = (
138 (
139 state,
140 state_dict,
141 sub_mapper,
142 connection,
143 (
144 mapper._get_committed_state_attr_by_column(
145 state, state_dict, mapper.version_id_col
146 )
147 if mapper.version_id_col is not None
148 else None
149 ),
150 )
151 for state, state_dict, sub_mapper, connection in states_to_update
152 if table in sub_mapper._pks_by_table
153 )
154
155 update = _collect_post_update_commands(
156 base_mapper, uowtransaction, table, update, post_update_cols
157 )
158
159 _emit_post_update_statements(
160 base_mapper,
161 uowtransaction,
162 mapper,
163 table,
164 update,
165 )
166
167
168def _delete_obj(base_mapper, states, uowtransaction):
169 """Issue ``DELETE`` statements for a list of objects.
170
171 This is called within the context of a UOWTransaction during a
172 flush operation.
173
174 """
175
176 states_to_delete = list(
177 _organize_states_for_delete(base_mapper, states, uowtransaction)
178 )
179
180 table_to_mapper = base_mapper._sorted_tables
181
182 for table in reversed(list(table_to_mapper.keys())):
183 mapper = table_to_mapper[table]
184 if table not in mapper._pks_by_table:
185 continue
186 elif mapper.inherits and mapper.passive_deletes:
187 continue
188
189 delete = _collect_delete_commands(
190 base_mapper, uowtransaction, table, states_to_delete
191 )
192
193 _emit_delete_statements(
194 base_mapper,
195 uowtransaction,
196 mapper,
197 table,
198 delete,
199 )
200
201 for (
202 state,
203 state_dict,
204 mapper,
205 connection,
206 update_version_id,
207 ) in states_to_delete:
208 mapper.dispatch.after_delete(mapper, connection, state)
209
210
211def _organize_states_for_save(base_mapper, states, uowtransaction):
212 """Make an initial pass across a set of states for INSERT or
213 UPDATE.
214
215 This includes splitting out into distinct lists for
216 each, calling before_insert/before_update, obtaining
217 key information for each state including its dictionary,
218 mapper, the connection to use for the execution per state,
219 and the identity flag.
220
221 """
222
223 for state, dict_, mapper, connection in _connections_for_states(
224 base_mapper, uowtransaction, states
225 ):
226 has_identity = bool(state.key)
227
228 instance_key = state.key or mapper._identity_key_from_state(state)
229
230 row_switch = update_version_id = None
231
232 # call before_XXX extensions
233 if not has_identity:
234 mapper.dispatch.before_insert(mapper, connection, state)
235 else:
236 mapper.dispatch.before_update(mapper, connection, state)
237
238 if mapper._validate_polymorphic_identity:
239 mapper._validate_polymorphic_identity(mapper, state, dict_)
240
241 # detect if we have a "pending" instance (i.e. has
242 # no instance_key attached to it), and another instance
243 # with the same identity key already exists as persistent.
244 # convert to an UPDATE if so.
245 if (
246 not has_identity
247 and instance_key in uowtransaction.session.identity_map
248 ):
249 instance = uowtransaction.session.identity_map[instance_key]
250 existing = attributes.instance_state(instance)
251
252 if not uowtransaction.was_already_deleted(existing):
253 if not uowtransaction.is_deleted(existing):
254 util.warn(
255 "New instance %s with identity key %s conflicts "
256 "with persistent instance %s"
257 % (state_str(state), instance_key, state_str(existing))
258 )
259 else:
260 base_mapper._log_debug(
261 "detected row switch for identity %s. "
262 "will update %s, remove %s from "
263 "transaction",
264 instance_key,
265 state_str(state),
266 state_str(existing),
267 )
268
269 # remove the "delete" flag from the existing element
270 uowtransaction.remove_state_actions(existing)
271 row_switch = existing
272
273 if (has_identity or row_switch) and mapper.version_id_col is not None:
274 update_version_id = mapper._get_committed_state_attr_by_column(
275 row_switch if row_switch else state,
276 row_switch.dict if row_switch else dict_,
277 mapper.version_id_col,
278 )
279
280 yield (
281 state,
282 dict_,
283 mapper,
284 connection,
285 has_identity,
286 row_switch,
287 update_version_id,
288 )
289
290
291def _organize_states_for_post_update(base_mapper, states, uowtransaction):
292 """Make an initial pass across a set of states for UPDATE
293 corresponding to post_update.
294
295 This includes obtaining key information for each state
296 including its dictionary, mapper, the connection to use for
297 the execution per state.
298
299 """
300 return _connections_for_states(base_mapper, uowtransaction, states)
301
302
303def _organize_states_for_delete(base_mapper, states, uowtransaction):
304 """Make an initial pass across a set of states for DELETE.
305
306 This includes calling out before_delete and obtaining
307 key information for each state including its dictionary,
308 mapper, the connection to use for the execution per state.
309
310 """
311 for state, dict_, mapper, connection in _connections_for_states(
312 base_mapper, uowtransaction, states
313 ):
314 mapper.dispatch.before_delete(mapper, connection, state)
315
316 if mapper.version_id_col is not None:
317 update_version_id = mapper._get_committed_state_attr_by_column(
318 state, dict_, mapper.version_id_col
319 )
320 else:
321 update_version_id = None
322
323 yield (state, dict_, mapper, connection, update_version_id)
324
325
326def _collect_insert_commands(
327 table,
328 states_to_insert,
329 *,
330 bulk=False,
331 return_defaults=False,
332 render_nulls=False,
333 include_bulk_keys=(),
334):
335 """Identify sets of values to use in INSERT statements for a
336 list of states.
337
338 """
339 for state, state_dict, mapper, connection in states_to_insert:
340 if table not in mapper._pks_by_table:
341 continue
342
343 params = {}
344 value_params = {}
345
346 propkey_to_col = mapper._propkey_to_col[table]
347
348 eval_none = mapper._insert_cols_evaluating_none[table]
349
350 for propkey in set(propkey_to_col).intersection(state_dict):
351 value = state_dict[propkey]
352 col = propkey_to_col[propkey]
353 if value is None and col not in eval_none and not render_nulls:
354 continue
355 elif not bulk and (
356 hasattr(value, "__clause_element__")
357 or isinstance(value, sql.ClauseElement)
358 ):
359 value_params[col] = (
360 value.__clause_element__()
361 if hasattr(value, "__clause_element__")
362 else value
363 )
364 else:
365 params[col.key] = value
366
367 if not bulk:
368 # for all the columns that have no default and we don't have
369 # a value and where "None" is not a special value, add
370 # explicit None to the INSERT. This is a legacy behavior
371 # which might be worth removing, as it should not be necessary
372 # and also produces confusion, given that "missing" and None
373 # now have distinct meanings
374 for colkey in (
375 mapper._insert_cols_as_none[table]
376 .difference(params)
377 .difference([c.key for c in value_params])
378 ):
379 params[colkey] = None
380
381 if not bulk or return_defaults:
382 # params are in terms of Column key objects, so
383 # compare to pk_keys_by_table
384 has_all_pks = mapper._pk_keys_by_table[table].issubset(params)
385
386 if mapper.base_mapper._prefer_eager_defaults(
387 connection.dialect, table
388 ):
389 has_all_defaults = mapper._server_default_col_keys[
390 table
391 ].issubset(params)
392 else:
393 has_all_defaults = True
394 else:
395 has_all_defaults = has_all_pks = True
396
397 if (
398 mapper.version_id_generator is not False
399 and mapper.version_id_col is not None
400 and mapper.version_id_col in mapper._cols_by_table[table]
401 ):
402 params[mapper.version_id_col.key] = mapper.version_id_generator(
403 None
404 )
405
406 if bulk:
407 if mapper._set_polymorphic_identity:
408 params.setdefault(
409 mapper._polymorphic_attr_key, mapper.polymorphic_identity
410 )
411
412 if include_bulk_keys:
413 params.update((k, state_dict[k]) for k in include_bulk_keys)
414
415 yield (
416 state,
417 state_dict,
418 params,
419 mapper,
420 connection,
421 value_params,
422 has_all_pks,
423 has_all_defaults,
424 )
425
426
427def _collect_update_commands(
428 uowtransaction,
429 table,
430 states_to_update,
431 *,
432 bulk=False,
433 use_orm_update_stmt=None,
434 include_bulk_keys=(),
435):
436 """Identify sets of values to use in UPDATE statements for a
437 list of states.
438
439 This function works intricately with the history system
440 to determine exactly what values should be updated
441 as well as how the row should be matched within an UPDATE
442 statement. Includes some tricky scenarios where the primary
443 key of an object might have been changed.
444
445 """
446
447 for (
448 state,
449 state_dict,
450 mapper,
451 connection,
452 update_version_id,
453 ) in states_to_update:
454 if table not in mapper._pks_by_table:
455 continue
456
457 pks = mapper._pks_by_table[table]
458
459 if (
460 use_orm_update_stmt is not None
461 and not use_orm_update_stmt._maintain_values_ordering
462 ):
463 # TODO: ordered values, etc
464 # ORM bulk_persistence will raise for the maintain_values_ordering
465 # case right now
466 value_params = use_orm_update_stmt._values
467 else:
468 value_params = {}
469
470 propkey_to_col = mapper._propkey_to_col[table]
471
472 if bulk:
473 # keys here are mapped attribute keys, so
474 # look at mapper attribute keys for pk
475 params = {
476 propkey_to_col[propkey].key: state_dict[propkey]
477 for propkey in set(propkey_to_col)
478 .intersection(state_dict)
479 .difference(mapper._pk_attr_keys_by_table[table])
480 }
481 has_all_defaults = True
482 else:
483 params = {}
484 for propkey in set(propkey_to_col).intersection(
485 state.committed_state
486 ):
487 value = state_dict[propkey]
488 col = propkey_to_col[propkey]
489
490 if hasattr(value, "__clause_element__") or isinstance(
491 value, sql.ClauseElement
492 ):
493 value_params[col] = (
494 value.__clause_element__()
495 if hasattr(value, "__clause_element__")
496 else value
497 )
498 # guard against values that generate non-__nonzero__
499 # objects for __eq__()
500 elif (
501 state.manager[propkey].impl.is_equal(
502 value, state.committed_state[propkey]
503 )
504 is not True
505 ):
506 params[col.key] = value
507
508 if mapper.base_mapper.eager_defaults is True:
509 has_all_defaults = (
510 mapper._server_onupdate_default_col_keys[table]
511 ).issubset(params)
512 else:
513 has_all_defaults = True
514
515 if (
516 update_version_id is not None
517 and mapper.version_id_col in mapper._cols_by_table[table]
518 ):
519 if not bulk and not (params or value_params):
520 # HACK: check for history in other tables, in case the
521 # history is only in a different table than the one
522 # where the version_id_col is. This logic was lost
523 # from 0.9 -> 1.0.0 and restored in 1.0.6.
524 for prop in mapper._columntoproperty.values():
525 history = state.manager[prop.key].impl.get_history(
526 state, state_dict, attributes.PASSIVE_NO_INITIALIZE
527 )
528 if history.added:
529 break
530 else:
531 # no net change, break
532 continue
533
534 col = mapper.version_id_col
535 no_params = not params and not value_params
536 params[col._label] = update_version_id
537
538 if (
539 bulk or col.key not in params
540 ) and mapper.version_id_generator is not False:
541 val = mapper.version_id_generator(update_version_id)
542 params[col.key] = val
543 elif mapper.version_id_generator is False and no_params:
544 # no version id generator, no values set on the table,
545 # and version id wasn't manually incremented.
546 # set version id to itself so we get an UPDATE
547 # statement
548 params[col.key] = update_version_id
549
550 elif not (params or value_params):
551 continue
552
553 has_all_pks = True
554 expect_pk_cascaded = False
555 if bulk:
556 # keys here are mapped attribute keys, so
557 # look at mapper attribute keys for pk
558 pk_params = {
559 propkey_to_col[propkey]._label: state_dict.get(propkey)
560 for propkey in set(propkey_to_col).intersection(
561 mapper._pk_attr_keys_by_table[table]
562 )
563 }
564 if util.NONE_SET.intersection(pk_params.values()):
565 raise sa_exc.InvalidRequestError(
566 f"No primary key value supplied for column(s) "
567 f"""{
568 ', '.join(
569 str(c) for c in pks if pk_params[c._label] is None
570 )
571 }; """
572 "per-row ORM Bulk UPDATE by Primary Key requires that "
573 "records contain primary key values",
574 code="bupq",
575 )
576
577 else:
578 pk_params = {}
579 for col in pks:
580 propkey = mapper._columntoproperty[col].key
581
582 history = state.manager[propkey].impl.get_history(
583 state, state_dict, attributes.PASSIVE_OFF
584 )
585
586 if history.added:
587 if (
588 not history.deleted
589 or ("pk_cascaded", state, col)
590 in uowtransaction.attributes
591 ):
592 expect_pk_cascaded = True
593 pk_params[col._label] = history.added[0]
594 params.pop(col.key, None)
595 else:
596 # else, use the old value to locate the row
597 pk_params[col._label] = history.deleted[0]
598 if col in value_params:
599 has_all_pks = False
600 else:
601 pk_params[col._label] = history.unchanged[0]
602 if pk_params[col._label] is None:
603 raise orm_exc.FlushError(
604 "Can't update table %s using NULL for primary "
605 "key value on column %s" % (table, col)
606 )
607
608 if include_bulk_keys:
609 params.update((k, state_dict[k]) for k in include_bulk_keys)
610
611 if params or value_params:
612 params.update(pk_params)
613 yield (
614 state,
615 state_dict,
616 params,
617 mapper,
618 connection,
619 value_params,
620 has_all_defaults,
621 has_all_pks,
622 )
623 elif expect_pk_cascaded:
624 # no UPDATE occurs on this table, but we expect that CASCADE rules
625 # have changed the primary key of the row; propagate this event to
626 # other columns that expect to have been modified. this normally
627 # occurs after the UPDATE is emitted however we invoke it here
628 # explicitly in the absence of our invoking an UPDATE
629 for m, equated_pairs in mapper._table_to_equated[table]:
630 sync._populate(
631 state,
632 m,
633 state,
634 m,
635 equated_pairs,
636 uowtransaction,
637 mapper.passive_updates,
638 )
639
640
641def _collect_post_update_commands(
642 base_mapper, uowtransaction, table, states_to_update, post_update_cols
643):
644 """Identify sets of values to use in UPDATE statements for a
645 list of states within a post_update operation.
646
647 """
648
649 for (
650 state,
651 state_dict,
652 mapper,
653 connection,
654 update_version_id,
655 ) in states_to_update:
656 # assert table in mapper._pks_by_table
657
658 pks = mapper._pks_by_table[table]
659 params = {}
660 hasdata = False
661
662 for col in mapper._cols_by_table[table]:
663 if col in pks:
664 params[col._label] = mapper._get_state_attr_by_column(
665 state, state_dict, col, passive=attributes.PASSIVE_OFF
666 )
667
668 elif col in post_update_cols or col.onupdate is not None:
669 prop = mapper._columntoproperty[col]
670 history = state.manager[prop.key].impl.get_history(
671 state, state_dict, attributes.PASSIVE_NO_INITIALIZE
672 )
673 if history.added:
674 value = history.added[0]
675 params[col.key] = value
676 hasdata = True
677 if hasdata:
678 if (
679 update_version_id is not None
680 and mapper.version_id_col in mapper._cols_by_table[table]
681 ):
682 col = mapper.version_id_col
683 params[col._label] = update_version_id
684
685 if (
686 bool(state.key)
687 and col.key not in params
688 and mapper.version_id_generator is not False
689 ):
690 val = mapper.version_id_generator(update_version_id)
691 params[col.key] = val
692 yield state, state_dict, mapper, connection, params
693
694
695def _collect_delete_commands(
696 base_mapper, uowtransaction, table, states_to_delete
697):
698 """Identify values to use in DELETE statements for a list of
699 states to be deleted."""
700
701 for (
702 state,
703 state_dict,
704 mapper,
705 connection,
706 update_version_id,
707 ) in states_to_delete:
708 if table not in mapper._pks_by_table:
709 continue
710
711 params = {}
712 for col in mapper._pks_by_table[table]:
713 params[col.key] = value = (
714 mapper._get_committed_state_attr_by_column(
715 state, state_dict, col
716 )
717 )
718 if value is None:
719 raise orm_exc.FlushError(
720 "Can't delete from table %s "
721 "using NULL for primary "
722 "key value on column %s" % (table, col)
723 )
724
725 if (
726 update_version_id is not None
727 and mapper.version_id_col in mapper._cols_by_table[table]
728 ):
729 params[mapper.version_id_col.key] = update_version_id
730 yield params, connection
731
732
733def _emit_update_statements(
734 base_mapper,
735 uowtransaction,
736 mapper,
737 table,
738 update,
739 *,
740 bookkeeping=True,
741 use_orm_update_stmt=None,
742 enable_check_rowcount=True,
743):
744 """Emit UPDATE statements corresponding to value lists collected
745 by _collect_update_commands()."""
746
747 needs_version_id = (
748 mapper.version_id_col is not None
749 and mapper.version_id_col in mapper._cols_by_table[table]
750 )
751
752 execution_options = {"compiled_cache": base_mapper._compiled_cache}
753
754 def update_stmt(existing_stmt=None):
755 clauses = BooleanClauseList._construct_raw(operators.and_)
756
757 for col in mapper._pks_by_table[table]:
758 clauses._append_inplace(
759 col == sql.bindparam(col._label, type_=col.type)
760 )
761
762 if needs_version_id:
763 clauses._append_inplace(
764 mapper.version_id_col
765 == sql.bindparam(
766 mapper.version_id_col._label,
767 type_=mapper.version_id_col.type,
768 )
769 )
770
771 if existing_stmt is not None:
772 stmt = existing_stmt.where(clauses)
773 else:
774 stmt = table.update().where(clauses)
775 return stmt
776
777 if use_orm_update_stmt is not None:
778 cached_stmt = update_stmt(use_orm_update_stmt)
779
780 else:
781 cached_stmt = base_mapper._memo(("update", table), update_stmt)
782
783 for (
784 (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks),
785 records,
786 ) in groupby(
787 update,
788 lambda rec: (
789 rec[4], # connection
790 set(rec[2]), # set of parameter keys
791 bool(rec[5]), # whether or not we have "value" parameters
792 rec[6], # has_all_defaults
793 rec[7], # has all pks
794 ),
795 ):
796 rows = 0
797 records = list(records)
798
799 statement = cached_stmt
800
801 if use_orm_update_stmt is not None:
802 statement = statement._annotate(
803 {
804 "_emit_update_table": table,
805 "_emit_update_mapper": mapper,
806 }
807 )
808
809 return_defaults = False
810
811 if not has_all_pks:
812 statement = statement.return_defaults(*mapper._pks_by_table[table])
813 return_defaults = True
814
815 if (
816 bookkeeping
817 and not has_all_defaults
818 and mapper.base_mapper.eager_defaults is True
819 # change as of #8889 - if RETURNING is not going to be used anyway,
820 # (applies to MySQL, MariaDB which lack UPDATE RETURNING) ensure
821 # we can do an executemany UPDATE which is more efficient
822 and table.implicit_returning
823 and connection.dialect.update_returning
824 ):
825 statement = statement.return_defaults(
826 *mapper._server_onupdate_default_cols[table]
827 )
828 return_defaults = True
829
830 if mapper._version_id_has_server_side_value:
831 statement = statement.return_defaults(mapper.version_id_col)
832 return_defaults = True
833
834 assert_singlerow = connection.dialect.supports_sane_rowcount
835
836 assert_multirow = (
837 assert_singlerow
838 and connection.dialect.supports_sane_multi_rowcount
839 )
840
841 # change as of #8889 - if RETURNING is not going to be used anyway,
842 # (applies to MySQL, MariaDB which lack UPDATE RETURNING) ensure
843 # we can do an executemany UPDATE which is more efficient
844 allow_executemany = not return_defaults and not needs_version_id
845
846 if hasvalue:
847 for (
848 state,
849 state_dict,
850 params,
851 mapper,
852 connection,
853 value_params,
854 has_all_defaults,
855 has_all_pks,
856 ) in records:
857 c = connection.execute(
858 statement.values(value_params),
859 params,
860 execution_options=execution_options,
861 )
862 if bookkeeping:
863 _postfetch(
864 mapper,
865 uowtransaction,
866 table,
867 state,
868 state_dict,
869 c,
870 c.context.compiled_parameters[0],
871 value_params,
872 True,
873 c.returned_defaults,
874 )
875 rows += c.rowcount
876 check_rowcount = enable_check_rowcount and assert_singlerow
877 else:
878 if not allow_executemany:
879 check_rowcount = enable_check_rowcount and assert_singlerow
880 for (
881 state,
882 state_dict,
883 params,
884 mapper,
885 connection,
886 value_params,
887 has_all_defaults,
888 has_all_pks,
889 ) in records:
890 c = connection.execute(
891 statement, params, execution_options=execution_options
892 )
893
894 # TODO: why with bookkeeping=False?
895 if bookkeeping:
896 _postfetch(
897 mapper,
898 uowtransaction,
899 table,
900 state,
901 state_dict,
902 c,
903 c.context.compiled_parameters[0],
904 value_params,
905 True,
906 c.returned_defaults,
907 )
908 rows += c.rowcount
909 else:
910 multiparams = [rec[2] for rec in records]
911
912 check_rowcount = enable_check_rowcount and (
913 assert_multirow
914 or (assert_singlerow and len(multiparams) == 1)
915 )
916
917 c = connection.execute(
918 statement, multiparams, execution_options=execution_options
919 )
920
921 rows += c.rowcount
922
923 for (
924 state,
925 state_dict,
926 params,
927 mapper,
928 connection,
929 value_params,
930 has_all_defaults,
931 has_all_pks,
932 ) in records:
933 if bookkeeping:
934 _postfetch(
935 mapper,
936 uowtransaction,
937 table,
938 state,
939 state_dict,
940 c,
941 c.context.compiled_parameters[0],
942 value_params,
943 True,
944 (
945 c.returned_defaults
946 if not c.context.executemany
947 else None
948 ),
949 )
950
951 if check_rowcount:
952 if rows != len(records):
953 raise orm_exc.StaleDataError(
954 "UPDATE statement on table '%s' expected to "
955 "update %d row(s); %d were matched."
956 % (table.description, len(records), rows)
957 )
958
959 elif needs_version_id:
960 util.warn(
961 "Dialect %s does not support updated rowcount "
962 "- versioning cannot be verified."
963 % c.dialect.dialect_description
964 )
965
966
967def _emit_insert_statements(
968 base_mapper,
969 uowtransaction,
970 mapper,
971 table,
972 insert,
973 *,
974 bookkeeping=True,
975 use_orm_insert_stmt=None,
976 execution_options=None,
977):
978 """Emit INSERT statements corresponding to value lists collected
979 by _collect_insert_commands()."""
980
981 if use_orm_insert_stmt is not None:
982 cached_stmt = use_orm_insert_stmt
983 exec_opt = util.EMPTY_DICT
984
985 # if a user query with RETURNING was passed, we definitely need
986 # to use RETURNING.
987 returning_is_required_anyway = bool(use_orm_insert_stmt._returning)
988 deterministic_results_reqd = (
989 returning_is_required_anyway
990 and use_orm_insert_stmt._sort_by_parameter_order
991 ) or bookkeeping
992 else:
993 returning_is_required_anyway = False
994 deterministic_results_reqd = bookkeeping
995 cached_stmt = base_mapper._memo(("insert", table), table.insert)
996 exec_opt = {"compiled_cache": base_mapper._compiled_cache}
997
998 if execution_options:
999 execution_options = util.EMPTY_DICT.merge_with(
1000 exec_opt, execution_options
1001 )
1002 else:
1003 execution_options = exec_opt
1004
1005 return_result = None
1006
1007 for (
1008 (connection, _, hasvalue, has_all_pks, has_all_defaults),
1009 records,
1010 ) in groupby(
1011 insert,
1012 lambda rec: (
1013 rec[4], # connection
1014 set(rec[2]), # parameter keys
1015 bool(rec[5]), # whether we have "value" parameters
1016 rec[6],
1017 rec[7],
1018 ),
1019 ):
1020 statement = cached_stmt
1021
1022 if use_orm_insert_stmt is not None:
1023 statement = statement._annotate(
1024 {
1025 "_emit_insert_table": table,
1026 "_emit_insert_mapper": mapper,
1027 }
1028 )
1029
1030 if (
1031 (
1032 not bookkeeping
1033 or (
1034 has_all_defaults
1035 or not base_mapper._prefer_eager_defaults(
1036 connection.dialect, table
1037 )
1038 or not table.implicit_returning
1039 or not connection.dialect.insert_returning
1040 )
1041 )
1042 and not returning_is_required_anyway
1043 and has_all_pks
1044 and not hasvalue
1045 ):
1046 # the "we don't need newly generated values back" section.
1047 # here we have all the PKs, all the defaults or we don't want
1048 # to fetch them, or the dialect doesn't support RETURNING at all
1049 # so we have to post-fetch / use lastrowid anyway.
1050 records = list(records)
1051 multiparams = [rec[2] for rec in records]
1052
1053 result = connection.execute(
1054 statement, multiparams, execution_options=execution_options
1055 )
1056 if bookkeeping:
1057 for (
1058 (
1059 state,
1060 state_dict,
1061 params,
1062 mapper_rec,
1063 conn,
1064 value_params,
1065 has_all_pks,
1066 has_all_defaults,
1067 ),
1068 last_inserted_params,
1069 ) in zip(records, result.context.compiled_parameters):
1070 if state:
1071 _postfetch(
1072 mapper_rec,
1073 uowtransaction,
1074 table,
1075 state,
1076 state_dict,
1077 result,
1078 last_inserted_params,
1079 value_params,
1080 False,
1081 (
1082 result.returned_defaults
1083 if not result.context.executemany
1084 else None
1085 ),
1086 )
1087 else:
1088 _postfetch_bulk_save(mapper_rec, state_dict, table)
1089
1090 else:
1091 # here, we need defaults and/or pk values back or we otherwise
1092 # know that we are using RETURNING in any case
1093
1094 records = list(records)
1095
1096 if returning_is_required_anyway or (
1097 table.implicit_returning and not hasvalue and len(records) > 1
1098 ):
1099 if (
1100 deterministic_results_reqd
1101 and connection.dialect.insert_executemany_returning_sort_by_parameter_order # noqa: E501
1102 ) or (
1103 not deterministic_results_reqd
1104 and connection.dialect.insert_executemany_returning
1105 ):
1106 do_executemany = True
1107 elif returning_is_required_anyway:
1108 if deterministic_results_reqd:
1109 dt = " with RETURNING and sort by parameter order"
1110 else:
1111 dt = " with RETURNING"
1112 raise sa_exc.InvalidRequestError(
1113 f"Can't use explicit RETURNING for bulk INSERT "
1114 f"operation with "
1115 f"{connection.dialect.dialect_description} backend; "
1116 f"executemany{dt} is not enabled for this dialect."
1117 )
1118 else:
1119 do_executemany = False
1120 else:
1121 do_executemany = False
1122
1123 if use_orm_insert_stmt is None:
1124 if (
1125 not has_all_defaults
1126 and base_mapper._prefer_eager_defaults(
1127 connection.dialect, table
1128 )
1129 ):
1130 statement = statement.return_defaults(
1131 *mapper._server_default_cols[table],
1132 sort_by_parameter_order=bookkeeping,
1133 )
1134
1135 if mapper.version_id_col is not None:
1136 statement = statement.return_defaults(
1137 mapper.version_id_col,
1138 sort_by_parameter_order=bookkeeping,
1139 )
1140 elif do_executemany:
1141 statement = statement.return_defaults(
1142 *table.primary_key, sort_by_parameter_order=bookkeeping
1143 )
1144
1145 if do_executemany:
1146 multiparams = [rec[2] for rec in records]
1147
1148 result = connection.execute(
1149 statement, multiparams, execution_options=execution_options
1150 )
1151
1152 if use_orm_insert_stmt is not None:
1153 if return_result is None:
1154 return_result = result
1155 else:
1156 return_result = return_result.splice_vertically(result)
1157
1158 if bookkeeping:
1159 for (
1160 (
1161 state,
1162 state_dict,
1163 params,
1164 mapper_rec,
1165 conn,
1166 value_params,
1167 has_all_pks,
1168 has_all_defaults,
1169 ),
1170 last_inserted_params,
1171 inserted_primary_key,
1172 returned_defaults,
1173 ) in zip_longest(
1174 records,
1175 result.context.compiled_parameters,
1176 result.inserted_primary_key_rows,
1177 result.returned_defaults_rows or (),
1178 ):
1179 if inserted_primary_key is None:
1180 # this is a real problem and means that we didn't
1181 # get back as many PK rows. we can't continue
1182 # since this indicates PK rows were missing, which
1183 # means we likely mis-populated records starting
1184 # at that point with incorrectly matched PK
1185 # values.
1186 raise orm_exc.FlushError(
1187 "Multi-row INSERT statement for %s did not "
1188 "produce "
1189 "the correct number of INSERTed rows for "
1190 "RETURNING. Ensure there are no triggers or "
1191 "special driver issues preventing INSERT from "
1192 "functioning properly." % mapper_rec
1193 )
1194
1195 for pk, col in zip(
1196 inserted_primary_key,
1197 mapper._pks_by_table[table],
1198 ):
1199 prop = mapper_rec._columntoproperty[col]
1200 if state_dict.get(prop.key) is None:
1201 state_dict[prop.key] = pk
1202
1203 if state:
1204 _postfetch(
1205 mapper_rec,
1206 uowtransaction,
1207 table,
1208 state,
1209 state_dict,
1210 result,
1211 last_inserted_params,
1212 value_params,
1213 False,
1214 returned_defaults,
1215 )
1216 else:
1217 _postfetch_bulk_save(mapper_rec, state_dict, table)
1218 else:
1219 assert not returning_is_required_anyway
1220
1221 for (
1222 state,
1223 state_dict,
1224 params,
1225 mapper_rec,
1226 connection,
1227 value_params,
1228 has_all_pks,
1229 has_all_defaults,
1230 ) in records:
1231 if value_params:
1232 result = connection.execute(
1233 statement.values(value_params),
1234 params,
1235 execution_options=execution_options,
1236 )
1237 else:
1238 result = connection.execute(
1239 statement,
1240 params,
1241 execution_options=execution_options,
1242 )
1243
1244 primary_key = result.inserted_primary_key
1245 if primary_key is None:
1246 raise orm_exc.FlushError(
1247 "Single-row INSERT statement for %s "
1248 "did not produce a "
1249 "new primary key result "
1250 "being invoked. Ensure there are no triggers or "
1251 "special driver issues preventing INSERT from "
1252 "functioning properly." % (mapper_rec,)
1253 )
1254 for pk, col in zip(
1255 primary_key, mapper._pks_by_table[table]
1256 ):
1257 prop = mapper_rec._columntoproperty[col]
1258 if (
1259 col in value_params
1260 or state_dict.get(prop.key) is None
1261 ):
1262 state_dict[prop.key] = pk
1263 if bookkeeping:
1264 if state:
1265 _postfetch(
1266 mapper_rec,
1267 uowtransaction,
1268 table,
1269 state,
1270 state_dict,
1271 result,
1272 result.context.compiled_parameters[0],
1273 value_params,
1274 False,
1275 (
1276 result.returned_defaults
1277 if not result.context.executemany
1278 else None
1279 ),
1280 )
1281 else:
1282 _postfetch_bulk_save(mapper_rec, state_dict, table)
1283
1284 if use_orm_insert_stmt is not None:
1285 if return_result is None:
1286 return _cursor.null_dml_result()
1287 else:
1288 return return_result
1289
1290
1291def _emit_post_update_statements(
1292 base_mapper, uowtransaction, mapper, table, update
1293):
1294 """Emit UPDATE statements corresponding to value lists collected
1295 by _collect_post_update_commands()."""
1296
1297 execution_options = {"compiled_cache": base_mapper._compiled_cache}
1298
1299 needs_version_id = (
1300 mapper.version_id_col is not None
1301 and mapper.version_id_col in mapper._cols_by_table[table]
1302 )
1303
1304 def update_stmt():
1305 clauses = BooleanClauseList._construct_raw(operators.and_)
1306
1307 for col in mapper._pks_by_table[table]:
1308 clauses._append_inplace(
1309 col == sql.bindparam(col._label, type_=col.type)
1310 )
1311
1312 if needs_version_id:
1313 clauses._append_inplace(
1314 mapper.version_id_col
1315 == sql.bindparam(
1316 mapper.version_id_col._label,
1317 type_=mapper.version_id_col.type,
1318 )
1319 )
1320
1321 stmt = table.update().where(clauses)
1322
1323 return stmt
1324
1325 statement = base_mapper._memo(("post_update", table), update_stmt)
1326
1327 if mapper._version_id_has_server_side_value:
1328 statement = statement.return_defaults(mapper.version_id_col)
1329
1330 # execute each UPDATE in the order according to the original
1331 # list of states to guarantee row access order, but
1332 # also group them into common (connection, cols) sets
1333 # to support executemany().
1334 for key, records in groupby(
1335 update,
1336 lambda rec: (rec[3], set(rec[4])), # connection # parameter keys
1337 ):
1338 rows = 0
1339
1340 records = list(records)
1341 connection = key[0]
1342
1343 assert_singlerow = connection.dialect.supports_sane_rowcount
1344 assert_multirow = (
1345 assert_singlerow
1346 and connection.dialect.supports_sane_multi_rowcount
1347 )
1348 allow_executemany = not needs_version_id or assert_multirow
1349
1350 if not allow_executemany:
1351 check_rowcount = assert_singlerow
1352 for state, state_dict, mapper_rec, connection, params in records:
1353 c = connection.execute(
1354 statement, params, execution_options=execution_options
1355 )
1356
1357 _postfetch_post_update(
1358 mapper_rec,
1359 uowtransaction,
1360 table,
1361 state,
1362 state_dict,
1363 c,
1364 c.context.compiled_parameters[0],
1365 )
1366 rows += c.rowcount
1367 else:
1368 multiparams = [
1369 params
1370 for state, state_dict, mapper_rec, conn, params in records
1371 ]
1372
1373 check_rowcount = assert_multirow or (
1374 assert_singlerow and len(multiparams) == 1
1375 )
1376
1377 c = connection.execute(
1378 statement, multiparams, execution_options=execution_options
1379 )
1380
1381 rows += c.rowcount
1382 for i, (
1383 state,
1384 state_dict,
1385 mapper_rec,
1386 connection,
1387 params,
1388 ) in enumerate(records):
1389 _postfetch_post_update(
1390 mapper_rec,
1391 uowtransaction,
1392 table,
1393 state,
1394 state_dict,
1395 c,
1396 c.context.compiled_parameters[i],
1397 )
1398
1399 if check_rowcount:
1400 if rows != len(records):
1401 raise orm_exc.StaleDataError(
1402 "UPDATE statement on table '%s' expected to "
1403 "update %d row(s); %d were matched."
1404 % (table.description, len(records), rows)
1405 )
1406
1407 elif needs_version_id:
1408 util.warn(
1409 "Dialect %s does not support updated rowcount "
1410 "- versioning cannot be verified."
1411 % c.dialect.dialect_description
1412 )
1413
1414
1415def _emit_delete_statements(
1416 base_mapper, uowtransaction, mapper, table, delete
1417):
1418 """Emit DELETE statements corresponding to value lists collected
1419 by _collect_delete_commands()."""
1420
1421 need_version_id = (
1422 mapper.version_id_col is not None
1423 and mapper.version_id_col in mapper._cols_by_table[table]
1424 )
1425
1426 def delete_stmt():
1427 clauses = BooleanClauseList._construct_raw(operators.and_)
1428
1429 for col in mapper._pks_by_table[table]:
1430 clauses._append_inplace(
1431 col == sql.bindparam(col.key, type_=col.type)
1432 )
1433
1434 if need_version_id:
1435 clauses._append_inplace(
1436 mapper.version_id_col
1437 == sql.bindparam(
1438 mapper.version_id_col.key, type_=mapper.version_id_col.type
1439 )
1440 )
1441
1442 return table.delete().where(clauses)
1443
1444 statement = base_mapper._memo(("delete", table), delete_stmt)
1445 for connection, recs in groupby(delete, lambda rec: rec[1]): # connection
1446 del_objects = [params for params, connection in recs]
1447
1448 execution_options = {"compiled_cache": base_mapper._compiled_cache}
1449 expected = len(del_objects)
1450 rows_matched = -1
1451 only_warn = False
1452
1453 if (
1454 need_version_id
1455 and not connection.dialect.supports_sane_multi_rowcount
1456 ):
1457 if connection.dialect.supports_sane_rowcount:
1458 rows_matched = 0
1459 # execute deletes individually so that versioned
1460 # rows can be verified
1461 for params in del_objects:
1462 c = connection.execute(
1463 statement, params, execution_options=execution_options
1464 )
1465 rows_matched += c.rowcount
1466 else:
1467 util.warn(
1468 "Dialect %s does not support deleted rowcount "
1469 "- versioning cannot be verified."
1470 % connection.dialect.dialect_description
1471 )
1472 connection.execute(
1473 statement, del_objects, execution_options=execution_options
1474 )
1475 else:
1476 c = connection.execute(
1477 statement, del_objects, execution_options=execution_options
1478 )
1479
1480 if not need_version_id:
1481 only_warn = True
1482
1483 rows_matched = c.rowcount
1484
1485 if (
1486 base_mapper.confirm_deleted_rows
1487 and rows_matched > -1
1488 and expected != rows_matched
1489 and (
1490 connection.dialect.supports_sane_multi_rowcount
1491 or len(del_objects) == 1
1492 )
1493 ):
1494 # TODO: why does this "only warn" if versioning is turned off,
1495 # whereas the UPDATE raises?
1496 if only_warn:
1497 util.warn(
1498 "DELETE statement on table '%s' expected to "
1499 "delete %d row(s); %d were matched. Please set "
1500 "confirm_deleted_rows=False within the mapper "
1501 "configuration to prevent this warning."
1502 % (table.description, expected, rows_matched)
1503 )
1504 else:
1505 raise orm_exc.StaleDataError(
1506 "DELETE statement on table '%s' expected to "
1507 "delete %d row(s); %d were matched. Please set "
1508 "confirm_deleted_rows=False within the mapper "
1509 "configuration to prevent this warning."
1510 % (table.description, expected, rows_matched)
1511 )
1512
1513
1514def _finalize_insert_update_commands(base_mapper, uowtransaction, states):
1515 """finalize state on states that have been inserted or updated,
1516 including calling after_insert/after_update events.
1517
1518 """
1519 for state, state_dict, mapper, connection, has_identity in states:
1520 if mapper._readonly_props:
1521 readonly = state.unmodified_intersection(
1522 [
1523 p.key
1524 for p in mapper._readonly_props
1525 if (
1526 p.expire_on_flush
1527 and (not p.deferred or p.key in state.dict)
1528 )
1529 or (
1530 not p.expire_on_flush
1531 and not p.deferred
1532 and p.key not in state.dict
1533 )
1534 ]
1535 )
1536 if readonly:
1537 state._expire_attributes(state.dict, readonly)
1538
1539 # if eager_defaults option is enabled, load
1540 # all expired cols. Else if we have a version_id_col, make sure
1541 # it isn't expired.
1542 toload_now = []
1543
1544 # this is specifically to emit a second SELECT for eager_defaults,
1545 # so only if it's set to True, not "auto"
1546 if base_mapper.eager_defaults is True:
1547 toload_now.extend(
1548 state._unloaded_non_object.intersection(
1549 mapper._server_default_plus_onupdate_propkeys
1550 )
1551 )
1552
1553 if (
1554 mapper.version_id_col is not None
1555 and mapper.version_id_generator is False
1556 ):
1557 if mapper._version_id_prop.key in state.unloaded:
1558 toload_now.extend([mapper._version_id_prop.key])
1559
1560 if toload_now:
1561 state.key = base_mapper._identity_key_from_state(state)
1562 stmt = future.select(mapper).set_label_style(
1563 LABEL_STYLE_TABLENAME_PLUS_COL
1564 )
1565 loading._load_on_ident(
1566 uowtransaction.session,
1567 stmt,
1568 state.key,
1569 refresh_state=state,
1570 only_load_props=toload_now,
1571 )
1572
1573 # call after_XXX extensions
1574 if not has_identity:
1575 mapper.dispatch.after_insert(mapper, connection, state)
1576 else:
1577 mapper.dispatch.after_update(mapper, connection, state)
1578
1579 if (
1580 mapper.version_id_generator is False
1581 and mapper.version_id_col is not None
1582 ):
1583 if state_dict[mapper._version_id_prop.key] is None:
1584 raise orm_exc.FlushError(
1585 "Instance does not contain a non-NULL version value"
1586 )
1587
1588
1589def _postfetch_post_update(
1590 mapper, uowtransaction, table, state, dict_, result, params
1591):
1592 needs_version_id = (
1593 mapper.version_id_col is not None
1594 and mapper.version_id_col in mapper._cols_by_table[table]
1595 )
1596
1597 if not uowtransaction.is_deleted(state):
1598 # post updating after a regular INSERT or UPDATE, do a full postfetch
1599 prefetch_cols = result.context.compiled.prefetch
1600 postfetch_cols = result.context.compiled.postfetch
1601 elif needs_version_id:
1602 # post updating before a DELETE with a version_id_col, need to
1603 # postfetch just version_id_col
1604 prefetch_cols = postfetch_cols = ()
1605 else:
1606 # post updating before a DELETE without a version_id_col,
1607 # don't need to postfetch
1608 return
1609
1610 if needs_version_id:
1611 prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
1612
1613 refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
1614 if refresh_flush:
1615 load_evt_attrs = []
1616
1617 for c in prefetch_cols:
1618 if c.key in params and c in mapper._columntoproperty:
1619 dict_[mapper._columntoproperty[c].key] = params[c.key]
1620 if refresh_flush:
1621 load_evt_attrs.append(mapper._columntoproperty[c].key)
1622
1623 if refresh_flush and load_evt_attrs:
1624 mapper.class_manager.dispatch.refresh_flush(
1625 state, uowtransaction, load_evt_attrs
1626 )
1627
1628 if postfetch_cols:
1629 state._expire_attributes(
1630 state.dict,
1631 [
1632 mapper._columntoproperty[c].key
1633 for c in postfetch_cols
1634 if c in mapper._columntoproperty
1635 ],
1636 )
1637
1638
1639def _postfetch(
1640 mapper,
1641 uowtransaction,
1642 table,
1643 state,
1644 dict_,
1645 result,
1646 params,
1647 value_params,
1648 isupdate,
1649 returned_defaults,
1650):
1651 """Expire attributes in need of newly persisted database state,
1652 after an INSERT or UPDATE statement has proceeded for that
1653 state."""
1654
1655 prefetch_cols = result.context.compiled.prefetch
1656 postfetch_cols = result.context.compiled.postfetch
1657 returning_cols = result.context.compiled.effective_returning
1658
1659 if (
1660 mapper.version_id_col is not None
1661 and mapper.version_id_col in mapper._cols_by_table[table]
1662 ):
1663 prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
1664
1665 refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
1666 if refresh_flush:
1667 load_evt_attrs = []
1668
1669 if returning_cols:
1670 row = returned_defaults
1671 if row is not None:
1672 for row_value, col in zip(row, returning_cols):
1673 # pk cols returned from insert are handled
1674 # distinctly, don't step on the values here
1675 if col.primary_key and result.context.isinsert:
1676 continue
1677
1678 # note that columns can be in the "return defaults" that are
1679 # not mapped to this mapper, typically because they are
1680 # "excluded", which can be specified directly or also occurs
1681 # when using declarative w/ single table inheritance
1682 prop = mapper._columntoproperty.get(col)
1683 if prop:
1684 dict_[prop.key] = row_value
1685 if refresh_flush:
1686 load_evt_attrs.append(prop.key)
1687
1688 for c in prefetch_cols:
1689 if c.key in params and c in mapper._columntoproperty:
1690 pkey = mapper._columntoproperty[c].key
1691
1692 # set prefetched value in dict and also pop from committed_state,
1693 # since this is new database state that replaces whatever might
1694 # have previously been fetched (see #10800). this is essentially a
1695 # shorthand version of set_committed_value(), which could also be
1696 # used here directly (with more overhead)
1697 dict_[pkey] = params[c.key]
1698 state.committed_state.pop(pkey, None)
1699
1700 if refresh_flush:
1701 load_evt_attrs.append(pkey)
1702
1703 if refresh_flush and load_evt_attrs:
1704 mapper.class_manager.dispatch.refresh_flush(
1705 state, uowtransaction, load_evt_attrs
1706 )
1707
1708 if isupdate and value_params:
1709 # explicitly suit the use case specified by
1710 # [ticket:3801], PK SQL expressions for UPDATE on non-RETURNING
1711 # database which are set to themselves in order to do a version bump.
1712 postfetch_cols.extend(
1713 [
1714 col
1715 for col in value_params
1716 if col.primary_key and col not in returning_cols
1717 ]
1718 )
1719
1720 if postfetch_cols:
1721 state._expire_attributes(
1722 state.dict,
1723 [
1724 mapper._columntoproperty[c].key
1725 for c in postfetch_cols
1726 if c in mapper._columntoproperty
1727 ],
1728 )
1729
1730 # synchronize newly inserted ids from one table to the next
1731 # TODO: this still goes a little too often. would be nice to
1732 # have definitive list of "columns that changed" here
1733 for m, equated_pairs in mapper._table_to_equated[table]:
1734 sync._populate(
1735 state,
1736 m,
1737 state,
1738 m,
1739 equated_pairs,
1740 uowtransaction,
1741 mapper.passive_updates,
1742 )
1743
1744
1745def _postfetch_bulk_save(mapper, dict_, table):
1746 for m, equated_pairs in mapper._table_to_equated[table]:
1747 sync._bulk_populate_inherit_keys(dict_, m, equated_pairs)
1748
1749
1750def _connections_for_states(base_mapper, uowtransaction, states):
1751 """Return an iterator of (state, state.dict, mapper, connection).
1752
1753 The states are sorted according to _sort_states, then paired
1754 with the connection they should be using for the given
1755 unit of work transaction.
1756
1757 """
1758 # if session has a connection callable,
1759 # organize individual states with the connection
1760 # to use for update
1761 if uowtransaction.session.connection_callable:
1762 connection_callable = uowtransaction.session.connection_callable
1763 else:
1764 connection = uowtransaction.transaction.connection(base_mapper)
1765 connection_callable = None
1766
1767 for state in _sort_states(base_mapper, states):
1768 if connection_callable:
1769 connection = connection_callable(base_mapper, state.obj())
1770
1771 mapper = state.manager.mapper
1772
1773 yield state, state.dict, mapper, connection
1774
1775
1776def _sort_states(mapper, states):
1777 pending = set(states)
1778 persistent = {s for s in pending if s.key is not None}
1779 pending.difference_update(persistent)
1780
1781 try:
1782 persistent_sorted = sorted(
1783 persistent, key=mapper._persistent_sortkey_fn
1784 )
1785 except TypeError as err:
1786 raise sa_exc.InvalidRequestError(
1787 "Could not sort objects by primary key; primary key "
1788 "values must be sortable in Python (was: %s)" % err
1789 ) from err
1790 return (
1791 sorted(pending, key=operator.attrgetter("insert_order"))
1792 + persistent_sorted
1793 )