1# orm/persistence.py
2# Copyright (C) 2005-2026 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: ignore-errors
8
9
10"""private module containing functions used to emit INSERT, UPDATE
11and DELETE statements on behalf of a :class:`_orm.Mapper` and its descending
12mappers.
13
14The functions here are called only by the unit of work functions
15in unitofwork.py.
16
17"""
18
19from __future__ import annotations
20
21from itertools import chain
22from itertools import groupby
23from itertools import zip_longest
24import operator
25
26from . import attributes
27from . import exc as orm_exc
28from . import loading
29from . import sync
30from .base import state_str
31from .. import exc as sa_exc
32from .. import future
33from .. import sql
34from .. import util
35from ..engine import cursor as _cursor
36from ..sql import operators
37from ..sql.elements import BooleanClauseList
38from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
39
40
41def save_obj(base_mapper, states, uowtransaction, single=False):
42 """Issue ``INSERT`` and/or ``UPDATE`` statements for a list
43 of objects.
44
45 This is called within the context of a UOWTransaction during a
46 flush operation, given a list of states to be flushed. The
47 base mapper in an inheritance hierarchy handles the inserts/
48 updates for all descendant mappers.
49
50 """
51
52 # if batch=false, call _save_obj separately for each object
53 if not single and not base_mapper.batch:
54 for state in _sort_states(base_mapper, states):
55 save_obj(base_mapper, [state], uowtransaction, single=True)
56 return
57
58 states_to_update = []
59 states_to_insert = []
60
61 for (
62 state,
63 dict_,
64 mapper,
65 connection,
66 has_identity,
67 row_switch,
68 update_version_id,
69 ) in _organize_states_for_save(base_mapper, states, uowtransaction):
70 if has_identity or row_switch:
71 states_to_update.append(
72 (state, dict_, mapper, connection, update_version_id)
73 )
74 else:
75 states_to_insert.append((state, dict_, mapper, connection))
76
77 for table, mapper in base_mapper._sorted_tables.items():
78 if table not in mapper._pks_by_table:
79 continue
80 insert = _collect_insert_commands(table, states_to_insert)
81
82 update = _collect_update_commands(
83 uowtransaction, table, states_to_update
84 )
85
86 _emit_update_statements(
87 base_mapper,
88 uowtransaction,
89 mapper,
90 table,
91 update,
92 )
93
94 _emit_insert_statements(
95 base_mapper,
96 uowtransaction,
97 mapper,
98 table,
99 insert,
100 )
101
102 _finalize_insert_update_commands(
103 base_mapper,
104 uowtransaction,
105 chain(
106 (
107 (state, state_dict, mapper, connection, False)
108 for (state, state_dict, mapper, connection) in states_to_insert
109 ),
110 (
111 (state, state_dict, mapper, connection, True)
112 for (
113 state,
114 state_dict,
115 mapper,
116 connection,
117 update_version_id,
118 ) in states_to_update
119 ),
120 ),
121 )
122
123
124def post_update(base_mapper, states, uowtransaction, post_update_cols):
125 """Issue UPDATE statements on behalf of a relationship() which
126 specifies post_update.
127
128 """
129
130 states_to_update = list(
131 _organize_states_for_post_update(base_mapper, states, uowtransaction)
132 )
133
134 for table, mapper in base_mapper._sorted_tables.items():
135 if table not in mapper._pks_by_table:
136 continue
137
138 update = (
139 (
140 state,
141 state_dict,
142 sub_mapper,
143 connection,
144 (
145 mapper._get_committed_state_attr_by_column(
146 state, state_dict, mapper.version_id_col
147 )
148 if mapper.version_id_col is not None
149 else None
150 ),
151 )
152 for state, state_dict, sub_mapper, connection in states_to_update
153 if table in sub_mapper._pks_by_table
154 )
155
156 update = _collect_post_update_commands(
157 base_mapper, uowtransaction, table, update, post_update_cols
158 )
159
160 _emit_post_update_statements(
161 base_mapper,
162 uowtransaction,
163 mapper,
164 table,
165 update,
166 )
167
168
169def delete_obj(base_mapper, states, uowtransaction):
170 """Issue ``DELETE`` statements for a list of objects.
171
172 This is called within the context of a UOWTransaction during a
173 flush operation.
174
175 """
176
177 states_to_delete = list(
178 _organize_states_for_delete(base_mapper, states, uowtransaction)
179 )
180
181 table_to_mapper = base_mapper._sorted_tables
182
183 for table in reversed(list(table_to_mapper.keys())):
184 mapper = table_to_mapper[table]
185 if table not in mapper._pks_by_table:
186 continue
187 elif mapper.inherits and mapper.passive_deletes:
188 continue
189
190 delete = _collect_delete_commands(
191 base_mapper, uowtransaction, table, states_to_delete
192 )
193
194 _emit_delete_statements(
195 base_mapper,
196 uowtransaction,
197 mapper,
198 table,
199 delete,
200 )
201
202 for (
203 state,
204 state_dict,
205 mapper,
206 connection,
207 update_version_id,
208 ) in states_to_delete:
209 mapper.dispatch.after_delete(mapper, connection, state)
210
211
212def _organize_states_for_save(base_mapper, states, uowtransaction):
213 """Make an initial pass across a set of states for INSERT or
214 UPDATE.
215
216 This includes splitting out into distinct lists for
217 each, calling before_insert/before_update, obtaining
218 key information for each state including its dictionary,
219 mapper, the connection to use for the execution per state,
220 and the identity flag.
221
222 """
223
224 for state, dict_, mapper, connection in _connections_for_states(
225 base_mapper, uowtransaction, states
226 ):
227 has_identity = bool(state.key)
228
229 instance_key = state.key or mapper._identity_key_from_state(state)
230
231 row_switch = update_version_id = None
232
233 # call before_XXX extensions
234 if not has_identity:
235 mapper.dispatch.before_insert(mapper, connection, state)
236 else:
237 mapper.dispatch.before_update(mapper, connection, state)
238
239 if mapper._validate_polymorphic_identity:
240 mapper._validate_polymorphic_identity(mapper, state, dict_)
241
242 # detect if we have a "pending" instance (i.e. has
243 # no instance_key attached to it), and another instance
244 # with the same identity key already exists as persistent.
245 # convert to an UPDATE if so.
246 if (
247 not has_identity
248 and instance_key in uowtransaction.session.identity_map
249 ):
250 instance = uowtransaction.session.identity_map[instance_key]
251 existing = attributes.instance_state(instance)
252
253 if not uowtransaction.was_already_deleted(existing):
254 if not uowtransaction.is_deleted(existing):
255 util.warn(
256 "New instance %s with identity key %s conflicts "
257 "with persistent instance %s"
258 % (state_str(state), instance_key, state_str(existing))
259 )
260 else:
261 base_mapper._log_debug(
262 "detected row switch for identity %s. "
263 "will update %s, remove %s from "
264 "transaction",
265 instance_key,
266 state_str(state),
267 state_str(existing),
268 )
269
270 # remove the "delete" flag from the existing element
271 uowtransaction.remove_state_actions(existing)
272 row_switch = existing
273
274 if (has_identity or row_switch) and mapper.version_id_col is not None:
275 update_version_id = mapper._get_committed_state_attr_by_column(
276 row_switch if row_switch else state,
277 row_switch.dict if row_switch else dict_,
278 mapper.version_id_col,
279 )
280
281 yield (
282 state,
283 dict_,
284 mapper,
285 connection,
286 has_identity,
287 row_switch,
288 update_version_id,
289 )
290
291
292def _organize_states_for_post_update(base_mapper, states, uowtransaction):
293 """Make an initial pass across a set of states for UPDATE
294 corresponding to post_update.
295
296 This includes obtaining key information for each state
297 including its dictionary, mapper, the connection to use for
298 the execution per state.
299
300 """
301 return _connections_for_states(base_mapper, uowtransaction, states)
302
303
304def _organize_states_for_delete(base_mapper, states, uowtransaction):
305 """Make an initial pass across a set of states for DELETE.
306
307 This includes calling out before_delete and obtaining
308 key information for each state including its dictionary,
309 mapper, the connection to use for the execution per state.
310
311 """
312 for state, dict_, mapper, connection in _connections_for_states(
313 base_mapper, uowtransaction, states
314 ):
315 mapper.dispatch.before_delete(mapper, connection, state)
316
317 if mapper.version_id_col is not None:
318 update_version_id = mapper._get_committed_state_attr_by_column(
319 state, dict_, mapper.version_id_col
320 )
321 else:
322 update_version_id = None
323
324 yield (state, dict_, mapper, connection, update_version_id)
325
326
327def _collect_insert_commands(
328 table,
329 states_to_insert,
330 *,
331 bulk=False,
332 return_defaults=False,
333 render_nulls=False,
334 include_bulk_keys=(),
335):
336 """Identify sets of values to use in INSERT statements for a
337 list of states.
338
339 """
340 for state, state_dict, mapper, connection in states_to_insert:
341 if table not in mapper._pks_by_table:
342 continue
343
344 params = {}
345 value_params = {}
346
347 propkey_to_col = mapper._propkey_to_col[table]
348
349 eval_none = mapper._insert_cols_evaluating_none[table]
350
351 for propkey in set(propkey_to_col).intersection(state_dict):
352 value = state_dict[propkey]
353 col = propkey_to_col[propkey]
354 if value is None and col not in eval_none and not render_nulls:
355 continue
356 elif not bulk and (
357 hasattr(value, "__clause_element__")
358 or isinstance(value, sql.ClauseElement)
359 ):
360 value_params[col] = (
361 value.__clause_element__()
362 if hasattr(value, "__clause_element__")
363 else value
364 )
365 else:
366 params[col.key] = value
367
368 if not bulk:
369 # for all the columns that have no default and we don't have
370 # a value and where "None" is not a special value, add
371 # explicit None to the INSERT. This is a legacy behavior
372 # which might be worth removing, as it should not be necessary
373 # and also produces confusion, given that "missing" and None
374 # now have distinct meanings
375 for colkey in (
376 mapper._insert_cols_as_none[table]
377 .difference(params)
378 .difference([c.key for c in value_params])
379 ):
380 params[colkey] = None
381
382 if not bulk or return_defaults:
383 # params are in terms of Column key objects, so
384 # compare to pk_keys_by_table
385 has_all_pks = mapper._pk_keys_by_table[table].issubset(params)
386
387 if mapper.base_mapper._prefer_eager_defaults(
388 connection.dialect, table
389 ):
390 has_all_defaults = mapper._server_default_col_keys[
391 table
392 ].issubset(params)
393 else:
394 has_all_defaults = True
395 else:
396 has_all_defaults = has_all_pks = True
397
398 if (
399 mapper.version_id_generator is not False
400 and mapper.version_id_col is not None
401 and mapper.version_id_col in mapper._cols_by_table[table]
402 ):
403 params[mapper.version_id_col.key] = mapper.version_id_generator(
404 None
405 )
406
407 if bulk:
408 if mapper._set_polymorphic_identity:
409 params.setdefault(
410 mapper._polymorphic_attr_key, mapper.polymorphic_identity
411 )
412
413 if include_bulk_keys:
414 params.update((k, state_dict[k]) for k in include_bulk_keys)
415
416 yield (
417 state,
418 state_dict,
419 params,
420 mapper,
421 connection,
422 value_params,
423 has_all_pks,
424 has_all_defaults,
425 )
426
427
428def _collect_update_commands(
429 uowtransaction,
430 table,
431 states_to_update,
432 *,
433 bulk=False,
434 use_orm_update_stmt=None,
435 include_bulk_keys=(),
436):
437 """Identify sets of values to use in UPDATE statements for a
438 list of states.
439
440 This function works intricately with the history system
441 to determine exactly what values should be updated
442 as well as how the row should be matched within an UPDATE
443 statement. Includes some tricky scenarios where the primary
444 key of an object might have been changed.
445
446 """
447
448 for (
449 state,
450 state_dict,
451 mapper,
452 connection,
453 update_version_id,
454 ) in states_to_update:
455 if table not in mapper._pks_by_table:
456 continue
457
458 pks = mapper._pks_by_table[table]
459
460 if use_orm_update_stmt is not None:
461 # TODO: ordered values, etc
462 value_params = use_orm_update_stmt._values
463 else:
464 value_params = {}
465
466 propkey_to_col = mapper._propkey_to_col[table]
467
468 if bulk:
469 # keys here are mapped attribute keys, so
470 # look at mapper attribute keys for pk
471 params = {
472 propkey_to_col[propkey].key: state_dict[propkey]
473 for propkey in set(propkey_to_col)
474 .intersection(state_dict)
475 .difference(mapper._pk_attr_keys_by_table[table])
476 }
477 has_all_defaults = True
478 else:
479 params = {}
480 for propkey in set(propkey_to_col).intersection(
481 state.committed_state
482 ):
483 value = state_dict[propkey]
484 col = propkey_to_col[propkey]
485
486 if hasattr(value, "__clause_element__") or isinstance(
487 value, sql.ClauseElement
488 ):
489 value_params[col] = (
490 value.__clause_element__()
491 if hasattr(value, "__clause_element__")
492 else value
493 )
494 # guard against values that generate non-__nonzero__
495 # objects for __eq__()
496 elif (
497 state.manager[propkey].impl.is_equal(
498 value, state.committed_state[propkey]
499 )
500 is not True
501 ):
502 params[col.key] = value
503
504 if mapper.base_mapper.eager_defaults is True:
505 has_all_defaults = (
506 mapper._server_onupdate_default_col_keys[table]
507 ).issubset(params)
508 else:
509 has_all_defaults = True
510
511 if (
512 update_version_id is not None
513 and mapper.version_id_col in mapper._cols_by_table[table]
514 ):
515 if not bulk and not (params or value_params):
516 # HACK: check for history in other tables, in case the
517 # history is only in a different table than the one
518 # where the version_id_col is. This logic was lost
519 # from 0.9 -> 1.0.0 and restored in 1.0.6.
520 for prop in mapper._columntoproperty.values():
521 history = state.manager[prop.key].impl.get_history(
522 state, state_dict, attributes.PASSIVE_NO_INITIALIZE
523 )
524 if history.added:
525 break
526 else:
527 # no net change, break
528 continue
529
530 col = mapper.version_id_col
531 no_params = not params and not value_params
532 params[col._label] = update_version_id
533
534 if (
535 bulk or col.key not in params
536 ) and mapper.version_id_generator is not False:
537 val = mapper.version_id_generator(update_version_id)
538 params[col.key] = val
539 elif mapper.version_id_generator is False and no_params:
540 # no version id generator, no values set on the table,
541 # and version id wasn't manually incremented.
542 # set version id to itself so we get an UPDATE
543 # statement
544 params[col.key] = update_version_id
545
546 elif not (params or value_params):
547 continue
548
549 has_all_pks = True
550 expect_pk_cascaded = False
551 if bulk:
552 # keys here are mapped attribute keys, so
553 # look at mapper attribute keys for pk
554 pk_params = {
555 propkey_to_col[propkey]._label: state_dict.get(propkey)
556 for propkey in set(propkey_to_col).intersection(
557 mapper._pk_attr_keys_by_table[table]
558 )
559 }
560 if util.NONE_SET.intersection(pk_params.values()):
561 raise sa_exc.InvalidRequestError(
562 f"No primary key value supplied for column(s) "
563 f"""{
564 ', '.join(
565 str(c) for c in pks if pk_params[c._label] is None
566 )
567 }; """
568 "per-row ORM Bulk UPDATE by Primary Key requires that "
569 "records contain primary key values",
570 code="bupq",
571 )
572
573 else:
574 pk_params = {}
575 for col in pks:
576 propkey = mapper._columntoproperty[col].key
577
578 history = state.manager[propkey].impl.get_history(
579 state, state_dict, attributes.PASSIVE_OFF
580 )
581
582 if history.added:
583 if (
584 not history.deleted
585 or ("pk_cascaded", state, col)
586 in uowtransaction.attributes
587 ):
588 expect_pk_cascaded = True
589 pk_params[col._label] = history.added[0]
590 params.pop(col.key, None)
591 else:
592 # else, use the old value to locate the row
593 pk_params[col._label] = history.deleted[0]
594 if col in value_params:
595 has_all_pks = False
596 else:
597 pk_params[col._label] = history.unchanged[0]
598 if pk_params[col._label] is None:
599 raise orm_exc.FlushError(
600 "Can't update table %s using NULL for primary "
601 "key value on column %s" % (table, col)
602 )
603
604 if include_bulk_keys:
605 params.update((k, state_dict[k]) for k in include_bulk_keys)
606
607 if params or value_params:
608 params.update(pk_params)
609 yield (
610 state,
611 state_dict,
612 params,
613 mapper,
614 connection,
615 value_params,
616 has_all_defaults,
617 has_all_pks,
618 )
619 elif expect_pk_cascaded:
620 # no UPDATE occurs on this table, but we expect that CASCADE rules
621 # have changed the primary key of the row; propagate this event to
622 # other columns that expect to have been modified. this normally
623 # occurs after the UPDATE is emitted however we invoke it here
624 # explicitly in the absence of our invoking an UPDATE
625 for m, equated_pairs in mapper._table_to_equated[table]:
626 sync.populate(
627 state,
628 m,
629 state,
630 m,
631 equated_pairs,
632 uowtransaction,
633 mapper.passive_updates,
634 )
635
636
637def _collect_post_update_commands(
638 base_mapper, uowtransaction, table, states_to_update, post_update_cols
639):
640 """Identify sets of values to use in UPDATE statements for a
641 list of states within a post_update operation.
642
643 """
644
645 for (
646 state,
647 state_dict,
648 mapper,
649 connection,
650 update_version_id,
651 ) in states_to_update:
652 # assert table in mapper._pks_by_table
653
654 pks = mapper._pks_by_table[table]
655 params = {}
656 hasdata = False
657
658 for col in mapper._cols_by_table[table]:
659 if col in pks:
660 params[col._label] = mapper._get_state_attr_by_column(
661 state, state_dict, col, passive=attributes.PASSIVE_OFF
662 )
663
664 elif col in post_update_cols or col.onupdate is not None:
665 prop = mapper._columntoproperty[col]
666 history = state.manager[prop.key].impl.get_history(
667 state, state_dict, attributes.PASSIVE_NO_INITIALIZE
668 )
669 if history.added:
670 value = history.added[0]
671 params[col.key] = value
672 hasdata = True
673 if hasdata:
674 if (
675 update_version_id is not None
676 and mapper.version_id_col in mapper._cols_by_table[table]
677 ):
678 col = mapper.version_id_col
679 params[col._label] = update_version_id
680
681 if (
682 bool(state.key)
683 and col.key not in params
684 and mapper.version_id_generator is not False
685 ):
686 val = mapper.version_id_generator(update_version_id)
687 params[col.key] = val
688 yield state, state_dict, mapper, connection, params
689
690
691def _collect_delete_commands(
692 base_mapper, uowtransaction, table, states_to_delete
693):
694 """Identify values to use in DELETE statements for a list of
695 states to be deleted."""
696
697 for (
698 state,
699 state_dict,
700 mapper,
701 connection,
702 update_version_id,
703 ) in states_to_delete:
704 if table not in mapper._pks_by_table:
705 continue
706
707 params = {}
708 for col in mapper._pks_by_table[table]:
709 params[col.key] = value = (
710 mapper._get_committed_state_attr_by_column(
711 state, state_dict, col
712 )
713 )
714 if value is None:
715 raise orm_exc.FlushError(
716 "Can't delete from table %s "
717 "using NULL for primary "
718 "key value on column %s" % (table, col)
719 )
720
721 if (
722 update_version_id is not None
723 and mapper.version_id_col in mapper._cols_by_table[table]
724 ):
725 params[mapper.version_id_col.key] = update_version_id
726 yield params, connection
727
728
729def _emit_update_statements(
730 base_mapper,
731 uowtransaction,
732 mapper,
733 table,
734 update,
735 *,
736 bookkeeping=True,
737 use_orm_update_stmt=None,
738 enable_check_rowcount=True,
739):
740 """Emit UPDATE statements corresponding to value lists collected
741 by _collect_update_commands()."""
742
743 needs_version_id = (
744 mapper.version_id_col is not None
745 and mapper.version_id_col in mapper._cols_by_table[table]
746 )
747
748 execution_options = {"compiled_cache": base_mapper._compiled_cache}
749
750 def update_stmt(existing_stmt=None):
751 clauses = BooleanClauseList._construct_raw(operators.and_)
752
753 for col in mapper._pks_by_table[table]:
754 clauses._append_inplace(
755 col == sql.bindparam(col._label, type_=col.type)
756 )
757
758 if needs_version_id:
759 clauses._append_inplace(
760 mapper.version_id_col
761 == sql.bindparam(
762 mapper.version_id_col._label,
763 type_=mapper.version_id_col.type,
764 )
765 )
766
767 if existing_stmt is not None:
768 stmt = existing_stmt.where(clauses)
769 else:
770 stmt = table.update().where(clauses)
771 return stmt
772
773 if use_orm_update_stmt is not None:
774 cached_stmt = update_stmt(use_orm_update_stmt)
775
776 else:
777 cached_stmt = base_mapper._memo(("update", table), update_stmt)
778
779 for (
780 (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks),
781 records,
782 ) in groupby(
783 update,
784 lambda rec: (
785 rec[4], # connection
786 set(rec[2]), # set of parameter keys
787 bool(rec[5]), # whether or not we have "value" parameters
788 rec[6], # has_all_defaults
789 rec[7], # has all pks
790 ),
791 ):
792 rows = 0
793 records = list(records)
794
795 statement = cached_stmt
796
797 if use_orm_update_stmt is not None:
798 statement = statement._annotate(
799 {
800 "_emit_update_table": table,
801 "_emit_update_mapper": mapper,
802 }
803 )
804
805 return_defaults = False
806
807 if not has_all_pks:
808 statement = statement.return_defaults(*mapper._pks_by_table[table])
809 return_defaults = True
810
811 if (
812 bookkeeping
813 and not has_all_defaults
814 and mapper.base_mapper.eager_defaults is True
815 # change as of #8889 - if RETURNING is not going to be used anyway,
816 # (applies to MySQL, MariaDB which lack UPDATE RETURNING) ensure
817 # we can do an executemany UPDATE which is more efficient
818 and table.implicit_returning
819 and connection.dialect.update_returning
820 ):
821 statement = statement.return_defaults(
822 *mapper._server_onupdate_default_cols[table]
823 )
824 return_defaults = True
825
826 if mapper._version_id_has_server_side_value:
827 statement = statement.return_defaults(mapper.version_id_col)
828 return_defaults = True
829
830 assert_singlerow = connection.dialect.supports_sane_rowcount
831
832 assert_multirow = (
833 assert_singlerow
834 and connection.dialect.supports_sane_multi_rowcount
835 )
836
837 # change as of #8889 - if RETURNING is not going to be used anyway,
838 # (applies to MySQL, MariaDB which lack UPDATE RETURNING) ensure
839 # we can do an executemany UPDATE which is more efficient
840 allow_executemany = not return_defaults and not needs_version_id
841
842 if hasvalue:
843 for (
844 state,
845 state_dict,
846 params,
847 mapper,
848 connection,
849 value_params,
850 has_all_defaults,
851 has_all_pks,
852 ) in records:
853 c = connection.execute(
854 statement.values(value_params),
855 params,
856 execution_options=execution_options,
857 )
858 if bookkeeping:
859 _postfetch(
860 mapper,
861 uowtransaction,
862 table,
863 state,
864 state_dict,
865 c,
866 c.context.compiled_parameters[0],
867 value_params,
868 True,
869 c.returned_defaults,
870 )
871 rows += c.rowcount
872 check_rowcount = enable_check_rowcount and assert_singlerow
873 else:
874 if not allow_executemany:
875 check_rowcount = enable_check_rowcount and assert_singlerow
876 for (
877 state,
878 state_dict,
879 params,
880 mapper,
881 connection,
882 value_params,
883 has_all_defaults,
884 has_all_pks,
885 ) in records:
886 c = connection.execute(
887 statement, params, execution_options=execution_options
888 )
889
890 # TODO: why with bookkeeping=False?
891 if bookkeeping:
892 _postfetch(
893 mapper,
894 uowtransaction,
895 table,
896 state,
897 state_dict,
898 c,
899 c.context.compiled_parameters[0],
900 value_params,
901 True,
902 c.returned_defaults,
903 )
904 rows += c.rowcount
905 else:
906 multiparams = [rec[2] for rec in records]
907
908 check_rowcount = enable_check_rowcount and (
909 assert_multirow
910 or (assert_singlerow and len(multiparams) == 1)
911 )
912
913 c = connection.execute(
914 statement, multiparams, execution_options=execution_options
915 )
916
917 rows += c.rowcount
918
919 for (
920 state,
921 state_dict,
922 params,
923 mapper,
924 connection,
925 value_params,
926 has_all_defaults,
927 has_all_pks,
928 ) in records:
929 if bookkeeping:
930 _postfetch(
931 mapper,
932 uowtransaction,
933 table,
934 state,
935 state_dict,
936 c,
937 c.context.compiled_parameters[0],
938 value_params,
939 True,
940 (
941 c.returned_defaults
942 if not c.context.executemany
943 else None
944 ),
945 )
946
947 if check_rowcount:
948 if rows != len(records):
949 raise orm_exc.StaleDataError(
950 "UPDATE statement on table '%s' expected to "
951 "update %d row(s); %d were matched."
952 % (table.description, len(records), rows)
953 )
954
955 elif needs_version_id:
956 util.warn(
957 "Dialect %s does not support updated rowcount "
958 "- versioning cannot be verified."
959 % c.dialect.dialect_description
960 )
961
962
963def _emit_insert_statements(
964 base_mapper,
965 uowtransaction,
966 mapper,
967 table,
968 insert,
969 *,
970 bookkeeping=True,
971 use_orm_insert_stmt=None,
972 execution_options=None,
973):
974 """Emit INSERT statements corresponding to value lists collected
975 by _collect_insert_commands()."""
976
977 if use_orm_insert_stmt is not None:
978 cached_stmt = use_orm_insert_stmt
979 exec_opt = util.EMPTY_DICT
980
981 # if a user query with RETURNING was passed, we definitely need
982 # to use RETURNING.
983 returning_is_required_anyway = bool(use_orm_insert_stmt._returning)
984 deterministic_results_reqd = (
985 returning_is_required_anyway
986 and use_orm_insert_stmt._sort_by_parameter_order
987 ) or bookkeeping
988 else:
989 returning_is_required_anyway = False
990 deterministic_results_reqd = bookkeeping
991 cached_stmt = base_mapper._memo(("insert", table), table.insert)
992 exec_opt = {"compiled_cache": base_mapper._compiled_cache}
993
994 if execution_options:
995 execution_options = util.EMPTY_DICT.merge_with(
996 exec_opt, execution_options
997 )
998 else:
999 execution_options = exec_opt
1000
1001 return_result = None
1002
1003 for (
1004 (connection, _, hasvalue, has_all_pks, has_all_defaults),
1005 records,
1006 ) in groupby(
1007 insert,
1008 lambda rec: (
1009 rec[4], # connection
1010 set(rec[2]), # parameter keys
1011 bool(rec[5]), # whether we have "value" parameters
1012 rec[6],
1013 rec[7],
1014 ),
1015 ):
1016 statement = cached_stmt
1017
1018 if use_orm_insert_stmt is not None:
1019 statement = statement._annotate(
1020 {
1021 "_emit_insert_table": table,
1022 "_emit_insert_mapper": mapper,
1023 }
1024 )
1025
1026 if (
1027 (
1028 not bookkeeping
1029 or (
1030 has_all_defaults
1031 or not base_mapper._prefer_eager_defaults(
1032 connection.dialect, table
1033 )
1034 or not table.implicit_returning
1035 or not connection.dialect.insert_returning
1036 )
1037 )
1038 and not returning_is_required_anyway
1039 and has_all_pks
1040 and not hasvalue
1041 ):
1042 # the "we don't need newly generated values back" section.
1043 # here we have all the PKs, all the defaults or we don't want
1044 # to fetch them, or the dialect doesn't support RETURNING at all
1045 # so we have to post-fetch / use lastrowid anyway.
1046 records = list(records)
1047 multiparams = [rec[2] for rec in records]
1048
1049 result = connection.execute(
1050 statement, multiparams, execution_options=execution_options
1051 )
1052 if bookkeeping:
1053 for (
1054 (
1055 state,
1056 state_dict,
1057 params,
1058 mapper_rec,
1059 conn,
1060 value_params,
1061 has_all_pks,
1062 has_all_defaults,
1063 ),
1064 last_inserted_params,
1065 ) in zip(records, result.context.compiled_parameters):
1066 if state:
1067 _postfetch(
1068 mapper_rec,
1069 uowtransaction,
1070 table,
1071 state,
1072 state_dict,
1073 result,
1074 last_inserted_params,
1075 value_params,
1076 False,
1077 (
1078 result.returned_defaults
1079 if not result.context.executemany
1080 else None
1081 ),
1082 )
1083 else:
1084 _postfetch_bulk_save(mapper_rec, state_dict, table)
1085
1086 else:
1087 # here, we need defaults and/or pk values back or we otherwise
1088 # know that we are using RETURNING in any case
1089
1090 records = list(records)
1091
1092 if returning_is_required_anyway or (
1093 table.implicit_returning and not hasvalue and len(records) > 1
1094 ):
1095 if (
1096 deterministic_results_reqd
1097 and connection.dialect.insert_executemany_returning_sort_by_parameter_order # noqa: E501
1098 ) or (
1099 not deterministic_results_reqd
1100 and connection.dialect.insert_executemany_returning
1101 ):
1102 do_executemany = True
1103 elif returning_is_required_anyway:
1104 if deterministic_results_reqd:
1105 dt = " with RETURNING and sort by parameter order"
1106 else:
1107 dt = " with RETURNING"
1108 raise sa_exc.InvalidRequestError(
1109 f"Can't use explicit RETURNING for bulk INSERT "
1110 f"operation with "
1111 f"{connection.dialect.dialect_description} backend; "
1112 f"executemany{dt} is not enabled for this dialect."
1113 )
1114 else:
1115 do_executemany = False
1116 else:
1117 do_executemany = False
1118
1119 if use_orm_insert_stmt is None:
1120 if (
1121 not has_all_defaults
1122 and base_mapper._prefer_eager_defaults(
1123 connection.dialect, table
1124 )
1125 ):
1126 statement = statement.return_defaults(
1127 *mapper._server_default_cols[table],
1128 sort_by_parameter_order=bookkeeping,
1129 )
1130
1131 if mapper.version_id_col is not None:
1132 statement = statement.return_defaults(
1133 mapper.version_id_col,
1134 sort_by_parameter_order=bookkeeping,
1135 )
1136 elif do_executemany:
1137 statement = statement.return_defaults(
1138 *table.primary_key, sort_by_parameter_order=bookkeeping
1139 )
1140
1141 if do_executemany:
1142 multiparams = [rec[2] for rec in records]
1143
1144 result = connection.execute(
1145 statement, multiparams, execution_options=execution_options
1146 )
1147
1148 if use_orm_insert_stmt is not None:
1149 if return_result is None:
1150 return_result = result
1151 else:
1152 return_result = return_result.splice_vertically(result)
1153
1154 if bookkeeping:
1155 for (
1156 (
1157 state,
1158 state_dict,
1159 params,
1160 mapper_rec,
1161 conn,
1162 value_params,
1163 has_all_pks,
1164 has_all_defaults,
1165 ),
1166 last_inserted_params,
1167 inserted_primary_key,
1168 returned_defaults,
1169 ) in zip_longest(
1170 records,
1171 result.context.compiled_parameters,
1172 result.inserted_primary_key_rows,
1173 result.returned_defaults_rows or (),
1174 ):
1175 if inserted_primary_key is None:
1176 # this is a real problem and means that we didn't
1177 # get back as many PK rows. we can't continue
1178 # since this indicates PK rows were missing, which
1179 # means we likely mis-populated records starting
1180 # at that point with incorrectly matched PK
1181 # values.
1182 raise orm_exc.FlushError(
1183 "Multi-row INSERT statement for %s did not "
1184 "produce "
1185 "the correct number of INSERTed rows for "
1186 "RETURNING. Ensure there are no triggers or "
1187 "special driver issues preventing INSERT from "
1188 "functioning properly." % mapper_rec
1189 )
1190
1191 for pk, col in zip(
1192 inserted_primary_key,
1193 mapper._pks_by_table[table],
1194 ):
1195 prop = mapper_rec._columntoproperty[col]
1196 if state_dict.get(prop.key) is None:
1197 state_dict[prop.key] = pk
1198
1199 if state:
1200 _postfetch(
1201 mapper_rec,
1202 uowtransaction,
1203 table,
1204 state,
1205 state_dict,
1206 result,
1207 last_inserted_params,
1208 value_params,
1209 False,
1210 returned_defaults,
1211 )
1212 else:
1213 _postfetch_bulk_save(mapper_rec, state_dict, table)
1214 else:
1215 assert not returning_is_required_anyway
1216
1217 for (
1218 state,
1219 state_dict,
1220 params,
1221 mapper_rec,
1222 connection,
1223 value_params,
1224 has_all_pks,
1225 has_all_defaults,
1226 ) in records:
1227 if value_params:
1228 result = connection.execute(
1229 statement.values(value_params),
1230 params,
1231 execution_options=execution_options,
1232 )
1233 else:
1234 result = connection.execute(
1235 statement,
1236 params,
1237 execution_options=execution_options,
1238 )
1239
1240 primary_key = result.inserted_primary_key
1241 if primary_key is None:
1242 raise orm_exc.FlushError(
1243 "Single-row INSERT statement for %s "
1244 "did not produce a "
1245 "new primary key result "
1246 "being invoked. Ensure there are no triggers or "
1247 "special driver issues preventing INSERT from "
1248 "functioning properly." % (mapper_rec,)
1249 )
1250 for pk, col in zip(
1251 primary_key, mapper._pks_by_table[table]
1252 ):
1253 prop = mapper_rec._columntoproperty[col]
1254 if (
1255 col in value_params
1256 or state_dict.get(prop.key) is None
1257 ):
1258 state_dict[prop.key] = pk
1259 if bookkeeping:
1260 if state:
1261 _postfetch(
1262 mapper_rec,
1263 uowtransaction,
1264 table,
1265 state,
1266 state_dict,
1267 result,
1268 result.context.compiled_parameters[0],
1269 value_params,
1270 False,
1271 (
1272 result.returned_defaults
1273 if not result.context.executemany
1274 else None
1275 ),
1276 )
1277 else:
1278 _postfetch_bulk_save(mapper_rec, state_dict, table)
1279
1280 if use_orm_insert_stmt is not None:
1281 if return_result is None:
1282 return _cursor.null_dml_result()
1283 else:
1284 return return_result
1285
1286
1287def _emit_post_update_statements(
1288 base_mapper, uowtransaction, mapper, table, update
1289):
1290 """Emit UPDATE statements corresponding to value lists collected
1291 by _collect_post_update_commands()."""
1292
1293 execution_options = {"compiled_cache": base_mapper._compiled_cache}
1294
1295 needs_version_id = (
1296 mapper.version_id_col is not None
1297 and mapper.version_id_col in mapper._cols_by_table[table]
1298 )
1299
1300 def update_stmt():
1301 clauses = BooleanClauseList._construct_raw(operators.and_)
1302
1303 for col in mapper._pks_by_table[table]:
1304 clauses._append_inplace(
1305 col == sql.bindparam(col._label, type_=col.type)
1306 )
1307
1308 if needs_version_id:
1309 clauses._append_inplace(
1310 mapper.version_id_col
1311 == sql.bindparam(
1312 mapper.version_id_col._label,
1313 type_=mapper.version_id_col.type,
1314 )
1315 )
1316
1317 stmt = table.update().where(clauses)
1318
1319 return stmt
1320
1321 statement = base_mapper._memo(("post_update", table), update_stmt)
1322
1323 if mapper._version_id_has_server_side_value:
1324 statement = statement.return_defaults(mapper.version_id_col)
1325
1326 # execute each UPDATE in the order according to the original
1327 # list of states to guarantee row access order, but
1328 # also group them into common (connection, cols) sets
1329 # to support executemany().
1330 for key, records in groupby(
1331 update,
1332 lambda rec: (rec[3], set(rec[4])), # connection # parameter keys
1333 ):
1334 rows = 0
1335
1336 records = list(records)
1337 connection = key[0]
1338
1339 assert_singlerow = connection.dialect.supports_sane_rowcount
1340 assert_multirow = (
1341 assert_singlerow
1342 and connection.dialect.supports_sane_multi_rowcount
1343 )
1344 allow_executemany = not needs_version_id or assert_multirow
1345
1346 if not allow_executemany:
1347 check_rowcount = assert_singlerow
1348 for state, state_dict, mapper_rec, connection, params in records:
1349 c = connection.execute(
1350 statement, params, execution_options=execution_options
1351 )
1352
1353 _postfetch_post_update(
1354 mapper_rec,
1355 uowtransaction,
1356 table,
1357 state,
1358 state_dict,
1359 c,
1360 c.context.compiled_parameters[0],
1361 )
1362 rows += c.rowcount
1363 else:
1364 multiparams = [
1365 params
1366 for state, state_dict, mapper_rec, conn, params in records
1367 ]
1368
1369 check_rowcount = assert_multirow or (
1370 assert_singlerow and len(multiparams) == 1
1371 )
1372
1373 c = connection.execute(
1374 statement, multiparams, execution_options=execution_options
1375 )
1376
1377 rows += c.rowcount
1378 for i, (
1379 state,
1380 state_dict,
1381 mapper_rec,
1382 connection,
1383 params,
1384 ) in enumerate(records):
1385 _postfetch_post_update(
1386 mapper_rec,
1387 uowtransaction,
1388 table,
1389 state,
1390 state_dict,
1391 c,
1392 c.context.compiled_parameters[i],
1393 )
1394
1395 if check_rowcount:
1396 if rows != len(records):
1397 raise orm_exc.StaleDataError(
1398 "UPDATE statement on table '%s' expected to "
1399 "update %d row(s); %d were matched."
1400 % (table.description, len(records), rows)
1401 )
1402
1403 elif needs_version_id:
1404 util.warn(
1405 "Dialect %s does not support updated rowcount "
1406 "- versioning cannot be verified."
1407 % c.dialect.dialect_description
1408 )
1409
1410
1411def _emit_delete_statements(
1412 base_mapper, uowtransaction, mapper, table, delete
1413):
1414 """Emit DELETE statements corresponding to value lists collected
1415 by _collect_delete_commands()."""
1416
1417 need_version_id = (
1418 mapper.version_id_col is not None
1419 and mapper.version_id_col in mapper._cols_by_table[table]
1420 )
1421
1422 def delete_stmt():
1423 clauses = BooleanClauseList._construct_raw(operators.and_)
1424
1425 for col in mapper._pks_by_table[table]:
1426 clauses._append_inplace(
1427 col == sql.bindparam(col.key, type_=col.type)
1428 )
1429
1430 if need_version_id:
1431 clauses._append_inplace(
1432 mapper.version_id_col
1433 == sql.bindparam(
1434 mapper.version_id_col.key, type_=mapper.version_id_col.type
1435 )
1436 )
1437
1438 return table.delete().where(clauses)
1439
1440 statement = base_mapper._memo(("delete", table), delete_stmt)
1441 for connection, recs in groupby(delete, lambda rec: rec[1]): # connection
1442 del_objects = [params for params, connection in recs]
1443
1444 execution_options = {"compiled_cache": base_mapper._compiled_cache}
1445 expected = len(del_objects)
1446 rows_matched = -1
1447 only_warn = False
1448
1449 if (
1450 need_version_id
1451 and not connection.dialect.supports_sane_multi_rowcount
1452 ):
1453 if connection.dialect.supports_sane_rowcount:
1454 rows_matched = 0
1455 # execute deletes individually so that versioned
1456 # rows can be verified
1457 for params in del_objects:
1458 c = connection.execute(
1459 statement, params, execution_options=execution_options
1460 )
1461 rows_matched += c.rowcount
1462 else:
1463 util.warn(
1464 "Dialect %s does not support deleted rowcount "
1465 "- versioning cannot be verified."
1466 % connection.dialect.dialect_description
1467 )
1468 connection.execute(
1469 statement, del_objects, execution_options=execution_options
1470 )
1471 else:
1472 c = connection.execute(
1473 statement, del_objects, execution_options=execution_options
1474 )
1475
1476 if not need_version_id:
1477 only_warn = True
1478
1479 rows_matched = c.rowcount
1480
1481 if (
1482 base_mapper.confirm_deleted_rows
1483 and rows_matched > -1
1484 and expected != rows_matched
1485 and (
1486 connection.dialect.supports_sane_multi_rowcount
1487 or len(del_objects) == 1
1488 )
1489 ):
1490 # TODO: why does this "only warn" if versioning is turned off,
1491 # whereas the UPDATE raises?
1492 if only_warn:
1493 util.warn(
1494 "DELETE statement on table '%s' expected to "
1495 "delete %d row(s); %d were matched. Please set "
1496 "confirm_deleted_rows=False within the mapper "
1497 "configuration to prevent this warning."
1498 % (table.description, expected, rows_matched)
1499 )
1500 else:
1501 raise orm_exc.StaleDataError(
1502 "DELETE statement on table '%s' expected to "
1503 "delete %d row(s); %d were matched. Please set "
1504 "confirm_deleted_rows=False within the mapper "
1505 "configuration to prevent this warning."
1506 % (table.description, expected, rows_matched)
1507 )
1508
1509
1510def _finalize_insert_update_commands(base_mapper, uowtransaction, states):
1511 """finalize state on states that have been inserted or updated,
1512 including calling after_insert/after_update events.
1513
1514 """
1515 for state, state_dict, mapper, connection, has_identity in states:
1516 if mapper._readonly_props:
1517 readonly = state.unmodified_intersection(
1518 [
1519 p.key
1520 for p in mapper._readonly_props
1521 if (
1522 p.expire_on_flush
1523 and (not p.deferred or p.key in state.dict)
1524 )
1525 or (
1526 not p.expire_on_flush
1527 and not p.deferred
1528 and p.key not in state.dict
1529 )
1530 ]
1531 )
1532 if readonly:
1533 state._expire_attributes(state.dict, readonly)
1534
1535 # if eager_defaults option is enabled, load
1536 # all expired cols. Else if we have a version_id_col, make sure
1537 # it isn't expired.
1538 toload_now = []
1539
1540 # this is specifically to emit a second SELECT for eager_defaults,
1541 # so only if it's set to True, not "auto"
1542 if base_mapper.eager_defaults is True:
1543 toload_now.extend(
1544 state._unloaded_non_object.intersection(
1545 mapper._server_default_plus_onupdate_propkeys
1546 )
1547 )
1548
1549 if (
1550 mapper.version_id_col is not None
1551 and mapper.version_id_generator is False
1552 ):
1553 if mapper._version_id_prop.key in state.unloaded:
1554 toload_now.extend([mapper._version_id_prop.key])
1555
1556 if toload_now:
1557 state.key = base_mapper._identity_key_from_state(state)
1558 stmt = future.select(mapper).set_label_style(
1559 LABEL_STYLE_TABLENAME_PLUS_COL
1560 )
1561 loading.load_on_ident(
1562 uowtransaction.session,
1563 stmt,
1564 state.key,
1565 refresh_state=state,
1566 only_load_props=toload_now,
1567 )
1568
1569 # call after_XXX extensions
1570 if not has_identity:
1571 mapper.dispatch.after_insert(mapper, connection, state)
1572 else:
1573 mapper.dispatch.after_update(mapper, connection, state)
1574
1575 if (
1576 mapper.version_id_generator is False
1577 and mapper.version_id_col is not None
1578 ):
1579 if state_dict[mapper._version_id_prop.key] is None:
1580 raise orm_exc.FlushError(
1581 "Instance does not contain a non-NULL version value"
1582 )
1583
1584
1585def _postfetch_post_update(
1586 mapper, uowtransaction, table, state, dict_, result, params
1587):
1588 needs_version_id = (
1589 mapper.version_id_col is not None
1590 and mapper.version_id_col in mapper._cols_by_table[table]
1591 )
1592
1593 if not uowtransaction.is_deleted(state):
1594 # post updating after a regular INSERT or UPDATE, do a full postfetch
1595 prefetch_cols = result.context.compiled.prefetch
1596 postfetch_cols = result.context.compiled.postfetch
1597 elif needs_version_id:
1598 # post updating before a DELETE with a version_id_col, need to
1599 # postfetch just version_id_col
1600 prefetch_cols = postfetch_cols = ()
1601 else:
1602 # post updating before a DELETE without a version_id_col,
1603 # don't need to postfetch
1604 return
1605
1606 if needs_version_id:
1607 prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
1608
1609 refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
1610 if refresh_flush:
1611 load_evt_attrs = []
1612
1613 for c in prefetch_cols:
1614 if c.key in params and c in mapper._columntoproperty:
1615 dict_[mapper._columntoproperty[c].key] = params[c.key]
1616 if refresh_flush:
1617 load_evt_attrs.append(mapper._columntoproperty[c].key)
1618
1619 if refresh_flush and load_evt_attrs:
1620 mapper.class_manager.dispatch.refresh_flush(
1621 state, uowtransaction, load_evt_attrs
1622 )
1623
1624 if postfetch_cols:
1625 state._expire_attributes(
1626 state.dict,
1627 [
1628 mapper._columntoproperty[c].key
1629 for c in postfetch_cols
1630 if c in mapper._columntoproperty
1631 ],
1632 )
1633
1634
1635def _postfetch(
1636 mapper,
1637 uowtransaction,
1638 table,
1639 state,
1640 dict_,
1641 result,
1642 params,
1643 value_params,
1644 isupdate,
1645 returned_defaults,
1646):
1647 """Expire attributes in need of newly persisted database state,
1648 after an INSERT or UPDATE statement has proceeded for that
1649 state."""
1650
1651 prefetch_cols = result.context.compiled.prefetch
1652 postfetch_cols = result.context.compiled.postfetch
1653 returning_cols = result.context.compiled.effective_returning
1654
1655 if (
1656 mapper.version_id_col is not None
1657 and mapper.version_id_col in mapper._cols_by_table[table]
1658 ):
1659 prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
1660
1661 refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
1662 if refresh_flush:
1663 load_evt_attrs = []
1664
1665 if returning_cols:
1666 row = returned_defaults
1667 if row is not None:
1668 for row_value, col in zip(row, returning_cols):
1669 # pk cols returned from insert are handled
1670 # distinctly, don't step on the values here
1671 if col.primary_key and result.context.isinsert:
1672 continue
1673
1674 # note that columns can be in the "return defaults" that are
1675 # not mapped to this mapper, typically because they are
1676 # "excluded", which can be specified directly or also occurs
1677 # when using declarative w/ single table inheritance
1678 prop = mapper._columntoproperty.get(col)
1679 if prop:
1680 dict_[prop.key] = row_value
1681 if refresh_flush:
1682 load_evt_attrs.append(prop.key)
1683
1684 for c in prefetch_cols:
1685 if c.key in params and c in mapper._columntoproperty:
1686 pkey = mapper._columntoproperty[c].key
1687
1688 # set prefetched value in dict and also pop from committed_state,
1689 # since this is new database state that replaces whatever might
1690 # have previously been fetched (see #10800). this is essentially a
1691 # shorthand version of set_committed_value(), which could also be
1692 # used here directly (with more overhead)
1693 dict_[pkey] = params[c.key]
1694 state.committed_state.pop(pkey, None)
1695
1696 if refresh_flush:
1697 load_evt_attrs.append(pkey)
1698
1699 if refresh_flush and load_evt_attrs:
1700 mapper.class_manager.dispatch.refresh_flush(
1701 state, uowtransaction, load_evt_attrs
1702 )
1703
1704 if isupdate and value_params:
1705 # explicitly suit the use case specified by
1706 # [ticket:3801], PK SQL expressions for UPDATE on non-RETURNING
1707 # database which are set to themselves in order to do a version bump.
1708 postfetch_cols.extend(
1709 [
1710 col
1711 for col in value_params
1712 if col.primary_key and col not in returning_cols
1713 ]
1714 )
1715
1716 if postfetch_cols:
1717 state._expire_attributes(
1718 state.dict,
1719 [
1720 mapper._columntoproperty[c].key
1721 for c in postfetch_cols
1722 if c in mapper._columntoproperty
1723 ],
1724 )
1725
1726 # synchronize newly inserted ids from one table to the next
1727 # TODO: this still goes a little too often. would be nice to
1728 # have definitive list of "columns that changed" here
1729 for m, equated_pairs in mapper._table_to_equated[table]:
1730 sync.populate(
1731 state,
1732 m,
1733 state,
1734 m,
1735 equated_pairs,
1736 uowtransaction,
1737 mapper.passive_updates,
1738 )
1739
1740
1741def _postfetch_bulk_save(mapper, dict_, table):
1742 for m, equated_pairs in mapper._table_to_equated[table]:
1743 sync.bulk_populate_inherit_keys(dict_, m, equated_pairs)
1744
1745
1746def _connections_for_states(base_mapper, uowtransaction, states):
1747 """Return an iterator of (state, state.dict, mapper, connection).
1748
1749 The states are sorted according to _sort_states, then paired
1750 with the connection they should be using for the given
1751 unit of work transaction.
1752
1753 """
1754 # if session has a connection callable,
1755 # organize individual states with the connection
1756 # to use for update
1757 if uowtransaction.session.connection_callable:
1758 connection_callable = uowtransaction.session.connection_callable
1759 else:
1760 connection = uowtransaction.transaction.connection(base_mapper)
1761 connection_callable = None
1762
1763 for state in _sort_states(base_mapper, states):
1764 if connection_callable:
1765 connection = connection_callable(base_mapper, state.obj())
1766
1767 mapper = state.manager.mapper
1768
1769 yield state, state.dict, mapper, connection
1770
1771
1772def _sort_states(mapper, states):
1773 pending = set(states)
1774 persistent = {s for s in pending if s.key is not None}
1775 pending.difference_update(persistent)
1776
1777 try:
1778 persistent_sorted = sorted(
1779 persistent, key=mapper._persistent_sortkey_fn
1780 )
1781 except TypeError as err:
1782 raise sa_exc.InvalidRequestError(
1783 "Could not sort objects by primary key; primary key "
1784 "values must be sortable in Python (was: %s)" % err
1785 ) from err
1786 return (
1787 sorted(pending, key=operator.attrgetter("insert_order"))
1788 + persistent_sorted
1789 )