1# orm/persistence.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: ignore-errors
8
9
10"""private module containing functions used to emit INSERT, UPDATE
11and DELETE statements on behalf of a :class:`_orm.Mapper` and its descending
12mappers.
13
14The functions here are called only by the unit of work functions
15in unitofwork.py.
16
17"""
18from __future__ import annotations
19
20from itertools import chain
21from itertools import groupby
22from itertools import zip_longest
23import operator
24
25from . import attributes
26from . import exc as orm_exc
27from . import loading
28from . import sync
29from .base import state_str
30from .. import exc as sa_exc
31from .. import sql
32from .. import util
33from ..engine import cursor as _cursor
34from ..sql import operators
35from ..sql.elements import BooleanClauseList
36
37
38def _save_obj(base_mapper, states, uowtransaction, single=False):
39 """Issue ``INSERT`` and/or ``UPDATE`` statements for a list
40 of objects.
41
42 This is called within the context of a UOWTransaction during a
43 flush operation, given a list of states to be flushed. The
44 base mapper in an inheritance hierarchy handles the inserts/
45 updates for all descendant mappers.
46
47 """
48
49 # if batch=false, call _save_obj separately for each object
50 if not single and not base_mapper.batch:
51 for state in _sort_states(base_mapper, states):
52 _save_obj(base_mapper, [state], uowtransaction, single=True)
53 return
54
55 states_to_update = []
56 states_to_insert = []
57
58 for (
59 state,
60 dict_,
61 mapper,
62 connection,
63 has_identity,
64 row_switch,
65 update_version_id,
66 ) in _organize_states_for_save(base_mapper, states, uowtransaction):
67 if has_identity or row_switch:
68 states_to_update.append(
69 (state, dict_, mapper, connection, update_version_id)
70 )
71 else:
72 states_to_insert.append((state, dict_, mapper, connection))
73
74 for table, mapper in base_mapper._sorted_tables.items():
75 if table not in mapper._pks_by_table:
76 continue
77 insert = _collect_insert_commands(table, states_to_insert)
78
79 update = _collect_update_commands(
80 uowtransaction, table, states_to_update
81 )
82
83 _emit_update_statements(
84 base_mapper,
85 uowtransaction,
86 mapper,
87 table,
88 update,
89 )
90
91 _emit_insert_statements(
92 base_mapper,
93 uowtransaction,
94 mapper,
95 table,
96 insert,
97 )
98
99 _finalize_insert_update_commands(
100 base_mapper,
101 uowtransaction,
102 chain(
103 (
104 (state, state_dict, mapper, connection, False)
105 for (state, state_dict, mapper, connection) in states_to_insert
106 ),
107 (
108 (state, state_dict, mapper, connection, True)
109 for (
110 state,
111 state_dict,
112 mapper,
113 connection,
114 update_version_id,
115 ) in states_to_update
116 ),
117 ),
118 )
119
120
121def _post_update(base_mapper, states, uowtransaction, post_update_cols):
122 """Issue UPDATE statements on behalf of a relationship() which
123 specifies post_update.
124
125 """
126
127 states_to_update = list(
128 _organize_states_for_post_update(base_mapper, states, uowtransaction)
129 )
130
131 for table, mapper in base_mapper._sorted_tables.items():
132 if table not in mapper._pks_by_table:
133 continue
134
135 update = (
136 (
137 state,
138 state_dict,
139 sub_mapper,
140 connection,
141 (
142 mapper._get_committed_state_attr_by_column(
143 state, state_dict, mapper.version_id_col
144 )
145 if mapper.version_id_col is not None
146 else None
147 ),
148 )
149 for state, state_dict, sub_mapper, connection in states_to_update
150 if table in sub_mapper._pks_by_table
151 )
152
153 update = _collect_post_update_commands(
154 base_mapper, uowtransaction, table, update, post_update_cols
155 )
156
157 _emit_post_update_statements(
158 base_mapper,
159 uowtransaction,
160 mapper,
161 table,
162 update,
163 )
164
165
166def _delete_obj(base_mapper, states, uowtransaction):
167 """Issue ``DELETE`` statements for a list of objects.
168
169 This is called within the context of a UOWTransaction during a
170 flush operation.
171
172 """
173
174 states_to_delete = list(
175 _organize_states_for_delete(base_mapper, states, uowtransaction)
176 )
177
178 table_to_mapper = base_mapper._sorted_tables
179
180 for table in reversed(list(table_to_mapper.keys())):
181 mapper = table_to_mapper[table]
182 if table not in mapper._pks_by_table:
183 continue
184 elif mapper.inherits and mapper.passive_deletes:
185 continue
186
187 delete = _collect_delete_commands(
188 base_mapper, uowtransaction, table, states_to_delete
189 )
190
191 _emit_delete_statements(
192 base_mapper,
193 uowtransaction,
194 mapper,
195 table,
196 delete,
197 )
198
199 for (
200 state,
201 state_dict,
202 mapper,
203 connection,
204 update_version_id,
205 ) in states_to_delete:
206 mapper.dispatch.after_delete(mapper, connection, state)
207
208
209def _organize_states_for_save(base_mapper, states, uowtransaction):
210 """Make an initial pass across a set of states for INSERT or
211 UPDATE.
212
213 This includes splitting out into distinct lists for
214 each, calling before_insert/before_update, obtaining
215 key information for each state including its dictionary,
216 mapper, the connection to use for the execution per state,
217 and the identity flag.
218
219 """
220
221 for state, dict_, mapper, connection in _connections_for_states(
222 base_mapper, uowtransaction, states
223 ):
224 has_identity = bool(state.key)
225
226 instance_key = state.key or mapper._identity_key_from_state(state)
227
228 row_switch = update_version_id = None
229
230 # call before_XXX extensions
231 if not has_identity:
232 mapper.dispatch.before_insert(mapper, connection, state)
233 else:
234 mapper.dispatch.before_update(mapper, connection, state)
235
236 if mapper._validate_polymorphic_identity:
237 mapper._validate_polymorphic_identity(mapper, state, dict_)
238
239 # detect if we have a "pending" instance (i.e. has
240 # no instance_key attached to it), and another instance
241 # with the same identity key already exists as persistent.
242 # convert to an UPDATE if so.
243 if (
244 not has_identity
245 and instance_key in uowtransaction.session.identity_map
246 ):
247 instance = uowtransaction.session.identity_map[instance_key]
248 existing = attributes.instance_state(instance)
249
250 if not uowtransaction.was_already_deleted(existing):
251 if not uowtransaction.is_deleted(existing):
252 util.warn(
253 "New instance %s with identity key %s conflicts "
254 "with persistent instance %s"
255 % (state_str(state), instance_key, state_str(existing))
256 )
257 else:
258 base_mapper._log_debug(
259 "detected row switch for identity %s. "
260 "will update %s, remove %s from "
261 "transaction",
262 instance_key,
263 state_str(state),
264 state_str(existing),
265 )
266
267 # remove the "delete" flag from the existing element
268 uowtransaction.remove_state_actions(existing)
269 row_switch = existing
270
271 if (has_identity or row_switch) and mapper.version_id_col is not None:
272 update_version_id = mapper._get_committed_state_attr_by_column(
273 row_switch if row_switch else state,
274 row_switch.dict if row_switch else dict_,
275 mapper.version_id_col,
276 )
277
278 yield (
279 state,
280 dict_,
281 mapper,
282 connection,
283 has_identity,
284 row_switch,
285 update_version_id,
286 )
287
288
289def _organize_states_for_post_update(base_mapper, states, uowtransaction):
290 """Make an initial pass across a set of states for UPDATE
291 corresponding to post_update.
292
293 This includes obtaining key information for each state
294 including its dictionary, mapper, the connection to use for
295 the execution per state.
296
297 """
298 return _connections_for_states(base_mapper, uowtransaction, states)
299
300
301def _organize_states_for_delete(base_mapper, states, uowtransaction):
302 """Make an initial pass across a set of states for DELETE.
303
304 This includes calling out before_delete and obtaining
305 key information for each state including its dictionary,
306 mapper, the connection to use for the execution per state.
307
308 """
309 for state, dict_, mapper, connection in _connections_for_states(
310 base_mapper, uowtransaction, states
311 ):
312 mapper.dispatch.before_delete(mapper, connection, state)
313
314 if mapper.version_id_col is not None:
315 update_version_id = mapper._get_committed_state_attr_by_column(
316 state, dict_, mapper.version_id_col
317 )
318 else:
319 update_version_id = None
320
321 yield (state, dict_, mapper, connection, update_version_id)
322
323
324def _collect_insert_commands(
325 table,
326 states_to_insert,
327 *,
328 bulk=False,
329 return_defaults=False,
330 render_nulls=False,
331 include_bulk_keys=(),
332):
333 """Identify sets of values to use in INSERT statements for a
334 list of states.
335
336 """
337 for state, state_dict, mapper, connection in states_to_insert:
338 if table not in mapper._pks_by_table:
339 continue
340
341 params = {}
342 value_params = {}
343
344 propkey_to_col = mapper._propkey_to_col[table]
345
346 eval_none = mapper._insert_cols_evaluating_none[table]
347
348 for propkey in set(propkey_to_col).intersection(state_dict):
349 value = state_dict[propkey]
350 col = propkey_to_col[propkey]
351 if value is None and col not in eval_none and not render_nulls:
352 continue
353 elif not bulk and (
354 hasattr(value, "__clause_element__")
355 or isinstance(value, sql.ClauseElement)
356 ):
357 value_params[col] = (
358 value.__clause_element__()
359 if hasattr(value, "__clause_element__")
360 else value
361 )
362 else:
363 params[col.key] = value
364
365 if not bulk:
366 # for all the columns that have no default and we don't have
367 # a value and where "None" is not a special value, add
368 # explicit None to the INSERT. This is a legacy behavior
369 # which might be worth removing, as it should not be necessary
370 # and also produces confusion, given that "missing" and None
371 # now have distinct meanings
372 for colkey in (
373 mapper._insert_cols_as_none[table]
374 .difference(params)
375 .difference([c.key for c in value_params])
376 ):
377 params[colkey] = None
378
379 if not bulk or return_defaults:
380 # params are in terms of Column key objects, so
381 # compare to pk_keys_by_table
382 has_all_pks = mapper._pk_keys_by_table[table].issubset(params)
383
384 if mapper.base_mapper._prefer_eager_defaults(
385 connection.dialect, table
386 ):
387 has_all_defaults = mapper._server_default_col_keys[
388 table
389 ].issubset(params)
390 else:
391 has_all_defaults = True
392 else:
393 has_all_defaults = has_all_pks = True
394
395 if (
396 mapper.version_id_generator is not False
397 and mapper.version_id_col is not None
398 and mapper.version_id_col in mapper._cols_by_table[table]
399 ):
400 params[mapper.version_id_col.key] = mapper.version_id_generator(
401 None
402 )
403
404 if bulk:
405 if mapper._set_polymorphic_identity:
406 params.setdefault(
407 mapper._polymorphic_attr_key, mapper.polymorphic_identity
408 )
409
410 if include_bulk_keys:
411 params.update((k, state_dict[k]) for k in include_bulk_keys)
412
413 yield (
414 state,
415 state_dict,
416 params,
417 mapper,
418 connection,
419 value_params,
420 has_all_pks,
421 has_all_defaults,
422 )
423
424
425def _collect_update_commands(
426 uowtransaction,
427 table,
428 states_to_update,
429 *,
430 bulk=False,
431 use_orm_update_stmt=None,
432 include_bulk_keys=(),
433):
434 """Identify sets of values to use in UPDATE statements for a
435 list of states.
436
437 This function works intricately with the history system
438 to determine exactly what values should be updated
439 as well as how the row should be matched within an UPDATE
440 statement. Includes some tricky scenarios where the primary
441 key of an object might have been changed.
442
443 """
444
445 for (
446 state,
447 state_dict,
448 mapper,
449 connection,
450 update_version_id,
451 ) in states_to_update:
452 if table not in mapper._pks_by_table:
453 continue
454
455 pks = mapper._pks_by_table[table]
456
457 if (
458 use_orm_update_stmt is not None
459 and not use_orm_update_stmt._maintain_values_ordering
460 ):
461 # TODO: ordered values, etc
462 # ORM bulk_persistence will raise for the maintain_values_ordering
463 # case right now
464 value_params = use_orm_update_stmt._values
465 else:
466 value_params = {}
467
468 propkey_to_col = mapper._propkey_to_col[table]
469
470 if bulk:
471 # keys here are mapped attribute keys, so
472 # look at mapper attribute keys for pk
473 params = {
474 propkey_to_col[propkey].key: state_dict[propkey]
475 for propkey in set(propkey_to_col)
476 .intersection(state_dict)
477 .difference(mapper._pk_attr_keys_by_table[table])
478 }
479 has_all_defaults = True
480 else:
481 params = {}
482 for propkey in set(propkey_to_col).intersection(
483 state.committed_state
484 ):
485 value = state_dict[propkey]
486 col = propkey_to_col[propkey]
487
488 if hasattr(value, "__clause_element__") or isinstance(
489 value, sql.ClauseElement
490 ):
491 value_params[col] = (
492 value.__clause_element__()
493 if hasattr(value, "__clause_element__")
494 else value
495 )
496 # guard against values that generate non-__nonzero__
497 # objects for __eq__()
498 elif (
499 state.manager[propkey].impl.is_equal(
500 value, state.committed_state[propkey]
501 )
502 is not True
503 ):
504 params[col.key] = value
505
506 if mapper.base_mapper.eager_defaults is True:
507 has_all_defaults = (
508 mapper._server_onupdate_default_col_keys[table]
509 ).issubset(params)
510 else:
511 has_all_defaults = True
512
513 if (
514 update_version_id is not None
515 and mapper.version_id_col in mapper._cols_by_table[table]
516 ):
517 if not bulk and not (params or value_params):
518 # HACK: check for history in other tables, in case the
519 # history is only in a different table than the one
520 # where the version_id_col is. This logic was lost
521 # from 0.9 -> 1.0.0 and restored in 1.0.6.
522 for prop in mapper._columntoproperty.values():
523 history = state.manager[prop.key].impl.get_history(
524 state, state_dict, attributes.PASSIVE_NO_INITIALIZE
525 )
526 if history.added:
527 break
528 else:
529 # no net change, break
530 continue
531
532 col = mapper.version_id_col
533 no_params = not params and not value_params
534 params[col._label] = update_version_id
535
536 if (
537 bulk or col.key not in params
538 ) and mapper.version_id_generator is not False:
539 val = mapper.version_id_generator(update_version_id)
540 params[col.key] = val
541 elif mapper.version_id_generator is False and no_params:
542 # no version id generator, no values set on the table,
543 # and version id wasn't manually incremented.
544 # set version id to itself so we get an UPDATE
545 # statement
546 params[col.key] = update_version_id
547
548 elif not (params or value_params):
549 continue
550
551 has_all_pks = True
552 expect_pk_cascaded = False
553 if bulk:
554 # keys here are mapped attribute keys, so
555 # look at mapper attribute keys for pk
556 pk_params = {
557 propkey_to_col[propkey]._label: state_dict.get(propkey)
558 for propkey in set(propkey_to_col).intersection(
559 mapper._pk_attr_keys_by_table[table]
560 )
561 }
562 if util.NONE_SET.intersection(pk_params.values()):
563 raise sa_exc.InvalidRequestError(
564 f"No primary key value supplied for column(s) "
565 f"""{
566 ', '.join(
567 str(c) for c in pks if pk_params[c._label] is None
568 )
569 }; """
570 "per-row ORM Bulk UPDATE by Primary Key requires that "
571 "records contain primary key values",
572 code="bupq",
573 )
574
575 else:
576 pk_params = {}
577 for col in pks:
578 propkey = mapper._columntoproperty[col].key
579
580 history = state.manager[propkey].impl.get_history(
581 state, state_dict, attributes.PASSIVE_OFF
582 )
583
584 if history.added:
585 if (
586 not history.deleted
587 or ("pk_cascaded", state, col)
588 in uowtransaction.attributes
589 ):
590 expect_pk_cascaded = True
591 pk_params[col._label] = history.added[0]
592 params.pop(col.key, None)
593 else:
594 # else, use the old value to locate the row
595 pk_params[col._label] = history.deleted[0]
596 if col in value_params:
597 has_all_pks = False
598 else:
599 pk_params[col._label] = history.unchanged[0]
600 if pk_params[col._label] is None:
601 raise orm_exc.FlushError(
602 "Can't update table %s using NULL for primary "
603 "key value on column %s" % (table, col)
604 )
605
606 if include_bulk_keys:
607 params.update((k, state_dict[k]) for k in include_bulk_keys)
608
609 if params or value_params:
610 params.update(pk_params)
611 yield (
612 state,
613 state_dict,
614 params,
615 mapper,
616 connection,
617 value_params,
618 has_all_defaults,
619 has_all_pks,
620 )
621 elif expect_pk_cascaded:
622 # no UPDATE occurs on this table, but we expect that CASCADE rules
623 # have changed the primary key of the row; propagate this event to
624 # other columns that expect to have been modified. this normally
625 # occurs after the UPDATE is emitted however we invoke it here
626 # explicitly in the absence of our invoking an UPDATE
627 for m, equated_pairs in mapper._table_to_equated[table]:
628 sync._populate(
629 state,
630 m,
631 state,
632 m,
633 equated_pairs,
634 uowtransaction,
635 mapper.passive_updates,
636 )
637
638
639def _collect_post_update_commands(
640 base_mapper, uowtransaction, table, states_to_update, post_update_cols
641):
642 """Identify sets of values to use in UPDATE statements for a
643 list of states within a post_update operation.
644
645 """
646
647 for (
648 state,
649 state_dict,
650 mapper,
651 connection,
652 update_version_id,
653 ) in states_to_update:
654 # assert table in mapper._pks_by_table
655
656 pks = mapper._pks_by_table[table]
657 params = {}
658 hasdata = False
659
660 for col in mapper._cols_by_table[table]:
661 if col in pks:
662 params[col._label] = mapper._get_state_attr_by_column(
663 state, state_dict, col, passive=attributes.PASSIVE_OFF
664 )
665
666 elif col in post_update_cols or col.onupdate is not None:
667 prop = mapper._columntoproperty[col]
668 history = state.manager[prop.key].impl.get_history(
669 state, state_dict, attributes.PASSIVE_NO_INITIALIZE
670 )
671 if history.added:
672 value = history.added[0]
673 params[col.key] = value
674 hasdata = True
675 if hasdata:
676 if (
677 update_version_id is not None
678 and mapper.version_id_col in mapper._cols_by_table[table]
679 ):
680 col = mapper.version_id_col
681 params[col._label] = update_version_id
682
683 if (
684 bool(state.key)
685 and col.key not in params
686 and mapper.version_id_generator is not False
687 ):
688 val = mapper.version_id_generator(update_version_id)
689 params[col.key] = val
690 yield state, state_dict, mapper, connection, params
691
692
693def _collect_delete_commands(
694 base_mapper, uowtransaction, table, states_to_delete
695):
696 """Identify values to use in DELETE statements for a list of
697 states to be deleted."""
698
699 for (
700 state,
701 state_dict,
702 mapper,
703 connection,
704 update_version_id,
705 ) in states_to_delete:
706 if table not in mapper._pks_by_table:
707 continue
708
709 params = {}
710 for col in mapper._pks_by_table[table]:
711 params[col.key] = value = (
712 mapper._get_committed_state_attr_by_column(
713 state, state_dict, col
714 )
715 )
716 if value is None:
717 raise orm_exc.FlushError(
718 "Can't delete from table %s "
719 "using NULL for primary "
720 "key value on column %s" % (table, col)
721 )
722
723 if (
724 update_version_id is not None
725 and mapper.version_id_col in mapper._cols_by_table[table]
726 ):
727 params[mapper.version_id_col.key] = update_version_id
728 yield params, connection
729
730
731def _emit_update_statements(
732 base_mapper,
733 uowtransaction,
734 mapper,
735 table,
736 update,
737 *,
738 bookkeeping=True,
739 use_orm_update_stmt=None,
740 enable_check_rowcount=True,
741):
742 """Emit UPDATE statements corresponding to value lists collected
743 by _collect_update_commands()."""
744
745 needs_version_id = (
746 mapper.version_id_col is not None
747 and mapper.version_id_col in mapper._cols_by_table[table]
748 )
749
750 execution_options = {"compiled_cache": base_mapper._compiled_cache}
751
752 def update_stmt(existing_stmt=None):
753 clauses = BooleanClauseList._construct_raw(operators.and_)
754
755 for col in mapper._pks_by_table[table]:
756 clauses._append_inplace(
757 col == sql.bindparam(col._label, type_=col.type)
758 )
759
760 if needs_version_id:
761 clauses._append_inplace(
762 mapper.version_id_col
763 == sql.bindparam(
764 mapper.version_id_col._label,
765 type_=mapper.version_id_col.type,
766 )
767 )
768
769 if existing_stmt is not None:
770 stmt = existing_stmt.where(clauses)
771 else:
772 stmt = table.update().where(clauses)
773 return stmt
774
775 if use_orm_update_stmt is not None:
776 cached_stmt = update_stmt(use_orm_update_stmt)
777
778 else:
779 cached_stmt = base_mapper._memo(("update", table), update_stmt)
780
781 for (
782 (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks),
783 records,
784 ) in groupby(
785 update,
786 lambda rec: (
787 rec[4], # connection
788 set(rec[2]), # set of parameter keys
789 bool(rec[5]), # whether or not we have "value" parameters
790 rec[6], # has_all_defaults
791 rec[7], # has all pks
792 ),
793 ):
794 rows = 0
795 records = list(records)
796
797 statement = cached_stmt
798
799 if use_orm_update_stmt is not None:
800 statement = statement._annotate(
801 {
802 "_emit_update_table": table,
803 "_emit_update_mapper": mapper,
804 }
805 )
806
807 return_defaults = False
808
809 if not has_all_pks:
810 statement = statement.return_defaults(*mapper._pks_by_table[table])
811 return_defaults = True
812
813 if (
814 bookkeeping
815 and not has_all_defaults
816 and mapper.base_mapper.eager_defaults is True
817 # change as of #8889 - if RETURNING is not going to be used anyway,
818 # (applies to MySQL, MariaDB which lack UPDATE RETURNING) ensure
819 # we can do an executemany UPDATE which is more efficient
820 and table.implicit_returning
821 and connection.dialect.update_returning
822 ):
823 statement = statement.return_defaults(
824 *mapper._server_onupdate_default_cols[table]
825 )
826 return_defaults = True
827
828 if mapper._version_id_has_server_side_value:
829 statement = statement.return_defaults(mapper.version_id_col)
830 return_defaults = True
831
832 assert_singlerow = connection.dialect.supports_sane_rowcount
833
834 assert_multirow = (
835 assert_singlerow
836 and connection.dialect.supports_sane_multi_rowcount
837 )
838
839 # change as of #8889 - if RETURNING is not going to be used anyway,
840 # (applies to MySQL, MariaDB which lack UPDATE RETURNING) ensure
841 # we can do an executemany UPDATE which is more efficient
842 allow_executemany = not return_defaults and not needs_version_id
843
844 if hasvalue:
845 for (
846 state,
847 state_dict,
848 params,
849 mapper,
850 connection,
851 value_params,
852 has_all_defaults,
853 has_all_pks,
854 ) in records:
855 c = connection.execute(
856 statement.values(value_params),
857 params,
858 execution_options=execution_options,
859 )
860 if bookkeeping:
861 _postfetch(
862 mapper,
863 uowtransaction,
864 table,
865 state,
866 state_dict,
867 c,
868 c.context.compiled_parameters[0],
869 value_params,
870 True,
871 c.returned_defaults,
872 )
873 rows += c.rowcount
874 check_rowcount = enable_check_rowcount and assert_singlerow
875 else:
876 if not allow_executemany:
877 check_rowcount = enable_check_rowcount and assert_singlerow
878 for (
879 state,
880 state_dict,
881 params,
882 mapper,
883 connection,
884 value_params,
885 has_all_defaults,
886 has_all_pks,
887 ) in records:
888 c = connection.execute(
889 statement, params, execution_options=execution_options
890 )
891
892 # TODO: why with bookkeeping=False?
893 if bookkeeping:
894 _postfetch(
895 mapper,
896 uowtransaction,
897 table,
898 state,
899 state_dict,
900 c,
901 c.context.compiled_parameters[0],
902 value_params,
903 True,
904 c.returned_defaults,
905 )
906 rows += c.rowcount
907 else:
908 multiparams = [rec[2] for rec in records]
909
910 check_rowcount = enable_check_rowcount and (
911 assert_multirow
912 or (assert_singlerow and len(multiparams) == 1)
913 )
914
915 c = connection.execute(
916 statement, multiparams, execution_options=execution_options
917 )
918
919 rows += c.rowcount
920
921 for (
922 state,
923 state_dict,
924 params,
925 mapper,
926 connection,
927 value_params,
928 has_all_defaults,
929 has_all_pks,
930 ) in records:
931 if bookkeeping:
932 _postfetch(
933 mapper,
934 uowtransaction,
935 table,
936 state,
937 state_dict,
938 c,
939 c.context.compiled_parameters[0],
940 value_params,
941 True,
942 (
943 c.returned_defaults
944 if not c.context.executemany
945 else None
946 ),
947 )
948
949 if check_rowcount:
950 if rows != len(records):
951 raise orm_exc.StaleDataError(
952 "UPDATE statement on table '%s' expected to "
953 "update %d row(s); %d were matched."
954 % (table.description, len(records), rows)
955 )
956
957 elif needs_version_id:
958 util.warn(
959 "Dialect %s does not support updated rowcount "
960 "- versioning cannot be verified."
961 % c.dialect.dialect_description
962 )
963
964
965def _emit_insert_statements(
966 base_mapper,
967 uowtransaction,
968 mapper,
969 table,
970 insert,
971 *,
972 bookkeeping=True,
973 use_orm_insert_stmt=None,
974 execution_options=None,
975):
976 """Emit INSERT statements corresponding to value lists collected
977 by _collect_insert_commands()."""
978
979 if use_orm_insert_stmt is not None:
980 cached_stmt = use_orm_insert_stmt
981 exec_opt = util.EMPTY_DICT
982
983 # if a user query with RETURNING was passed, we definitely need
984 # to use RETURNING.
985 returning_is_required_anyway = bool(use_orm_insert_stmt._returning)
986 deterministic_results_reqd = (
987 returning_is_required_anyway
988 and use_orm_insert_stmt._sort_by_parameter_order
989 ) or bookkeeping
990 else:
991 returning_is_required_anyway = False
992 deterministic_results_reqd = bookkeeping
993 cached_stmt = base_mapper._memo(("insert", table), table.insert)
994 exec_opt = {"compiled_cache": base_mapper._compiled_cache}
995
996 if execution_options:
997 execution_options = util.EMPTY_DICT.merge_with(
998 exec_opt, execution_options
999 )
1000 else:
1001 execution_options = exec_opt
1002
1003 return_result = None
1004
1005 for (
1006 (connection, _, hasvalue, has_all_pks, has_all_defaults),
1007 records,
1008 ) in groupby(
1009 insert,
1010 lambda rec: (
1011 rec[4], # connection
1012 set(rec[2]), # parameter keys
1013 bool(rec[5]), # whether we have "value" parameters
1014 rec[6],
1015 rec[7],
1016 ),
1017 ):
1018 statement = cached_stmt
1019
1020 if use_orm_insert_stmt is not None:
1021 statement = statement._annotate(
1022 {
1023 "_emit_insert_table": table,
1024 "_emit_insert_mapper": mapper,
1025 }
1026 )
1027
1028 if (
1029 (
1030 not bookkeeping
1031 or (
1032 has_all_defaults
1033 or not base_mapper._prefer_eager_defaults(
1034 connection.dialect, table
1035 )
1036 or not table.implicit_returning
1037 or not connection.dialect.insert_returning
1038 )
1039 )
1040 and not returning_is_required_anyway
1041 and has_all_pks
1042 and not hasvalue
1043 ):
1044 # the "we don't need newly generated values back" section.
1045 # here we have all the PKs, all the defaults or we don't want
1046 # to fetch them, or the dialect doesn't support RETURNING at all
1047 # so we have to post-fetch / use lastrowid anyway.
1048 records = list(records)
1049 multiparams = [rec[2] for rec in records]
1050
1051 result = connection.execute(
1052 statement, multiparams, execution_options=execution_options
1053 )
1054 if bookkeeping:
1055 for (
1056 (
1057 state,
1058 state_dict,
1059 params,
1060 mapper_rec,
1061 conn,
1062 value_params,
1063 has_all_pks,
1064 has_all_defaults,
1065 ),
1066 last_inserted_params,
1067 ) in zip(records, result.context.compiled_parameters):
1068 if state:
1069 _postfetch(
1070 mapper_rec,
1071 uowtransaction,
1072 table,
1073 state,
1074 state_dict,
1075 result,
1076 last_inserted_params,
1077 value_params,
1078 False,
1079 (
1080 result.returned_defaults
1081 if not result.context.executemany
1082 else None
1083 ),
1084 )
1085 else:
1086 _postfetch_bulk_save(mapper_rec, state_dict, table)
1087
1088 else:
1089 # here, we need defaults and/or pk values back or we otherwise
1090 # know that we are using RETURNING in any case
1091
1092 records = list(records)
1093
1094 if returning_is_required_anyway or (
1095 table.implicit_returning and not hasvalue and len(records) > 1
1096 ):
1097 if (
1098 deterministic_results_reqd
1099 and connection.dialect.insert_executemany_returning_sort_by_parameter_order # noqa: E501
1100 ) or (
1101 not deterministic_results_reqd
1102 and connection.dialect.insert_executemany_returning
1103 ):
1104 do_executemany = True
1105 elif returning_is_required_anyway:
1106 if deterministic_results_reqd:
1107 dt = " with RETURNING and sort by parameter order"
1108 else:
1109 dt = " with RETURNING"
1110 raise sa_exc.InvalidRequestError(
1111 f"Can't use explicit RETURNING for bulk INSERT "
1112 f"operation with "
1113 f"{connection.dialect.dialect_description} backend; "
1114 f"executemany{dt} is not enabled for this dialect."
1115 )
1116 else:
1117 do_executemany = False
1118 else:
1119 do_executemany = False
1120
1121 if use_orm_insert_stmt is None:
1122 if (
1123 not has_all_defaults
1124 and base_mapper._prefer_eager_defaults(
1125 connection.dialect, table
1126 )
1127 ):
1128 statement = statement.return_defaults(
1129 *mapper._server_default_cols[table],
1130 sort_by_parameter_order=bookkeeping,
1131 )
1132
1133 if mapper.version_id_col is not None:
1134 statement = statement.return_defaults(
1135 mapper.version_id_col,
1136 sort_by_parameter_order=bookkeeping,
1137 )
1138 elif do_executemany:
1139 statement = statement.return_defaults(
1140 *table.primary_key, sort_by_parameter_order=bookkeeping
1141 )
1142
1143 if do_executemany:
1144 multiparams = [rec[2] for rec in records]
1145
1146 result = connection.execute(
1147 statement, multiparams, execution_options=execution_options
1148 )
1149
1150 if use_orm_insert_stmt is not None:
1151 if return_result is None:
1152 return_result = result
1153 else:
1154 return_result = return_result.splice_vertically(result)
1155
1156 if bookkeeping:
1157 for (
1158 (
1159 state,
1160 state_dict,
1161 params,
1162 mapper_rec,
1163 conn,
1164 value_params,
1165 has_all_pks,
1166 has_all_defaults,
1167 ),
1168 last_inserted_params,
1169 inserted_primary_key,
1170 returned_defaults,
1171 ) in zip_longest(
1172 records,
1173 result.context.compiled_parameters,
1174 result.inserted_primary_key_rows,
1175 result.returned_defaults_rows or (),
1176 ):
1177 if inserted_primary_key is None:
1178 # this is a real problem and means that we didn't
1179 # get back as many PK rows. we can't continue
1180 # since this indicates PK rows were missing, which
1181 # means we likely mis-populated records starting
1182 # at that point with incorrectly matched PK
1183 # values.
1184 raise orm_exc.FlushError(
1185 "Multi-row INSERT statement for %s did not "
1186 "produce "
1187 "the correct number of INSERTed rows for "
1188 "RETURNING. Ensure there are no triggers or "
1189 "special driver issues preventing INSERT from "
1190 "functioning properly." % mapper_rec
1191 )
1192
1193 for pk, col in zip(
1194 inserted_primary_key,
1195 mapper._pks_by_table[table],
1196 ):
1197 prop = mapper_rec._columntoproperty[col]
1198 if state_dict.get(prop.key) is None:
1199 state_dict[prop.key] = pk
1200
1201 if state:
1202 _postfetch(
1203 mapper_rec,
1204 uowtransaction,
1205 table,
1206 state,
1207 state_dict,
1208 result,
1209 last_inserted_params,
1210 value_params,
1211 False,
1212 returned_defaults,
1213 )
1214 else:
1215 _postfetch_bulk_save(mapper_rec, state_dict, table)
1216 else:
1217 assert not returning_is_required_anyway
1218
1219 for (
1220 state,
1221 state_dict,
1222 params,
1223 mapper_rec,
1224 connection,
1225 value_params,
1226 has_all_pks,
1227 has_all_defaults,
1228 ) in records:
1229 if value_params:
1230 result = connection.execute(
1231 statement.values(value_params),
1232 params,
1233 execution_options=execution_options,
1234 )
1235 else:
1236 result = connection.execute(
1237 statement,
1238 params,
1239 execution_options=execution_options,
1240 )
1241
1242 primary_key = result.inserted_primary_key
1243 if primary_key is None:
1244 raise orm_exc.FlushError(
1245 "Single-row INSERT statement for %s "
1246 "did not produce a "
1247 "new primary key result "
1248 "being invoked. Ensure there are no triggers or "
1249 "special driver issues preventing INSERT from "
1250 "functioning properly." % (mapper_rec,)
1251 )
1252 for pk, col in zip(
1253 primary_key, mapper._pks_by_table[table]
1254 ):
1255 prop = mapper_rec._columntoproperty[col]
1256 if (
1257 col in value_params
1258 or state_dict.get(prop.key) is None
1259 ):
1260 state_dict[prop.key] = pk
1261 if bookkeeping:
1262 if state:
1263 _postfetch(
1264 mapper_rec,
1265 uowtransaction,
1266 table,
1267 state,
1268 state_dict,
1269 result,
1270 result.context.compiled_parameters[0],
1271 value_params,
1272 False,
1273 (
1274 result.returned_defaults
1275 if not result.context.executemany
1276 else None
1277 ),
1278 )
1279 else:
1280 _postfetch_bulk_save(mapper_rec, state_dict, table)
1281
1282 if use_orm_insert_stmt is not None:
1283 if return_result is None:
1284 return _cursor.null_dml_result()
1285 else:
1286 return return_result
1287
1288
1289def _emit_post_update_statements(
1290 base_mapper, uowtransaction, mapper, table, update
1291):
1292 """Emit UPDATE statements corresponding to value lists collected
1293 by _collect_post_update_commands()."""
1294
1295 execution_options = {"compiled_cache": base_mapper._compiled_cache}
1296
1297 needs_version_id = (
1298 mapper.version_id_col is not None
1299 and mapper.version_id_col in mapper._cols_by_table[table]
1300 )
1301
1302 def update_stmt():
1303 clauses = BooleanClauseList._construct_raw(operators.and_)
1304
1305 for col in mapper._pks_by_table[table]:
1306 clauses._append_inplace(
1307 col == sql.bindparam(col._label, type_=col.type)
1308 )
1309
1310 if needs_version_id:
1311 clauses._append_inplace(
1312 mapper.version_id_col
1313 == sql.bindparam(
1314 mapper.version_id_col._label,
1315 type_=mapper.version_id_col.type,
1316 )
1317 )
1318
1319 stmt = table.update().where(clauses)
1320
1321 return stmt
1322
1323 statement = base_mapper._memo(("post_update", table), update_stmt)
1324
1325 if mapper._version_id_has_server_side_value:
1326 statement = statement.return_defaults(mapper.version_id_col)
1327
1328 # execute each UPDATE in the order according to the original
1329 # list of states to guarantee row access order, but
1330 # also group them into common (connection, cols) sets
1331 # to support executemany().
1332 for key, records in groupby(
1333 update,
1334 lambda rec: (rec[3], set(rec[4])), # connection # parameter keys
1335 ):
1336 rows = 0
1337
1338 records = list(records)
1339 connection = key[0]
1340
1341 assert_singlerow = connection.dialect.supports_sane_rowcount
1342 assert_multirow = (
1343 assert_singlerow
1344 and connection.dialect.supports_sane_multi_rowcount
1345 )
1346 allow_executemany = not needs_version_id or assert_multirow
1347
1348 if not allow_executemany:
1349 check_rowcount = assert_singlerow
1350 for state, state_dict, mapper_rec, connection, params in records:
1351 c = connection.execute(
1352 statement, params, execution_options=execution_options
1353 )
1354
1355 _postfetch_post_update(
1356 mapper_rec,
1357 uowtransaction,
1358 table,
1359 state,
1360 state_dict,
1361 c,
1362 c.context.compiled_parameters[0],
1363 )
1364 rows += c.rowcount
1365 else:
1366 multiparams = [
1367 params
1368 for state, state_dict, mapper_rec, conn, params in records
1369 ]
1370
1371 check_rowcount = assert_multirow or (
1372 assert_singlerow and len(multiparams) == 1
1373 )
1374
1375 c = connection.execute(
1376 statement, multiparams, execution_options=execution_options
1377 )
1378
1379 rows += c.rowcount
1380 for i, (
1381 state,
1382 state_dict,
1383 mapper_rec,
1384 connection,
1385 params,
1386 ) in enumerate(records):
1387 _postfetch_post_update(
1388 mapper_rec,
1389 uowtransaction,
1390 table,
1391 state,
1392 state_dict,
1393 c,
1394 c.context.compiled_parameters[i],
1395 )
1396
1397 if check_rowcount:
1398 if rows != len(records):
1399 raise orm_exc.StaleDataError(
1400 "UPDATE statement on table '%s' expected to "
1401 "update %d row(s); %d were matched."
1402 % (table.description, len(records), rows)
1403 )
1404
1405 elif needs_version_id:
1406 util.warn(
1407 "Dialect %s does not support updated rowcount "
1408 "- versioning cannot be verified."
1409 % c.dialect.dialect_description
1410 )
1411
1412
1413def _emit_delete_statements(
1414 base_mapper, uowtransaction, mapper, table, delete
1415):
1416 """Emit DELETE statements corresponding to value lists collected
1417 by _collect_delete_commands()."""
1418
1419 need_version_id = (
1420 mapper.version_id_col is not None
1421 and mapper.version_id_col in mapper._cols_by_table[table]
1422 )
1423
1424 def delete_stmt():
1425 clauses = BooleanClauseList._construct_raw(operators.and_)
1426
1427 for col in mapper._pks_by_table[table]:
1428 clauses._append_inplace(
1429 col == sql.bindparam(col.key, type_=col.type)
1430 )
1431
1432 if need_version_id:
1433 clauses._append_inplace(
1434 mapper.version_id_col
1435 == sql.bindparam(
1436 mapper.version_id_col.key, type_=mapper.version_id_col.type
1437 )
1438 )
1439
1440 return table.delete().where(clauses)
1441
1442 statement = base_mapper._memo(("delete", table), delete_stmt)
1443 for connection, recs in groupby(delete, lambda rec: rec[1]): # connection
1444 del_objects = [params for params, connection in recs]
1445
1446 execution_options = {"compiled_cache": base_mapper._compiled_cache}
1447 expected = len(del_objects)
1448 rows_matched = -1
1449 only_warn = False
1450
1451 if (
1452 need_version_id
1453 and not connection.dialect.supports_sane_multi_rowcount
1454 ):
1455 if connection.dialect.supports_sane_rowcount:
1456 rows_matched = 0
1457 # execute deletes individually so that versioned
1458 # rows can be verified
1459 for params in del_objects:
1460 c = connection.execute(
1461 statement, params, execution_options=execution_options
1462 )
1463 rows_matched += c.rowcount
1464 else:
1465 util.warn(
1466 "Dialect %s does not support deleted rowcount "
1467 "- versioning cannot be verified."
1468 % connection.dialect.dialect_description
1469 )
1470 connection.execute(
1471 statement, del_objects, execution_options=execution_options
1472 )
1473 else:
1474 c = connection.execute(
1475 statement, del_objects, execution_options=execution_options
1476 )
1477
1478 if not need_version_id:
1479 only_warn = True
1480
1481 rows_matched = c.rowcount
1482
1483 if (
1484 base_mapper.confirm_deleted_rows
1485 and rows_matched > -1
1486 and expected != rows_matched
1487 and (
1488 connection.dialect.supports_sane_multi_rowcount
1489 or len(del_objects) == 1
1490 )
1491 ):
1492 # TODO: why does this "only warn" if versioning is turned off,
1493 # whereas the UPDATE raises?
1494 if only_warn:
1495 util.warn(
1496 "DELETE statement on table '%s' expected to "
1497 "delete %d row(s); %d were matched. Please set "
1498 "confirm_deleted_rows=False within the mapper "
1499 "configuration to prevent this warning."
1500 % (table.description, expected, rows_matched)
1501 )
1502 else:
1503 raise orm_exc.StaleDataError(
1504 "DELETE statement on table '%s' expected to "
1505 "delete %d row(s); %d were matched. Please set "
1506 "confirm_deleted_rows=False within the mapper "
1507 "configuration to prevent this warning."
1508 % (table.description, expected, rows_matched)
1509 )
1510
1511
1512def _finalize_insert_update_commands(base_mapper, uowtransaction, states):
1513 """finalize state on states that have been inserted or updated,
1514 including calling after_insert/after_update events.
1515
1516 """
1517 for state, state_dict, mapper, connection, has_identity in states:
1518 if mapper._readonly_props:
1519 readonly = state.unmodified_intersection(
1520 [
1521 p.key
1522 for p in mapper._readonly_props
1523 if (
1524 p.expire_on_flush
1525 and (not p.deferred or p.key in state.dict)
1526 )
1527 or (
1528 not p.expire_on_flush
1529 and not p.deferred
1530 and p.key not in state.dict
1531 )
1532 ]
1533 )
1534 if readonly:
1535 state._expire_attributes(state.dict, readonly)
1536
1537 # if eager_defaults option is enabled, load
1538 # all expired cols. Else if we have a version_id_col, make sure
1539 # it isn't expired.
1540 toload_now = []
1541
1542 # this is specifically to emit a second SELECT for eager_defaults,
1543 # so only if it's set to True, not "auto"
1544 if base_mapper.eager_defaults is True:
1545 toload_now.extend(
1546 state._unloaded_non_object.intersection(
1547 mapper._server_default_plus_onupdate_propkeys
1548 )
1549 )
1550
1551 if (
1552 mapper.version_id_col is not None
1553 and mapper.version_id_generator is False
1554 ):
1555 if mapper._version_id_prop.key in state.unloaded:
1556 toload_now.extend([mapper._version_id_prop.key])
1557
1558 if toload_now:
1559 state.key = base_mapper._identity_key_from_state(state)
1560 stmt = sql.select(mapper)
1561 loading._load_on_ident(
1562 uowtransaction.session,
1563 stmt,
1564 state.key,
1565 refresh_state=state,
1566 only_load_props=toload_now,
1567 )
1568
1569 # call after_XXX extensions
1570 if not has_identity:
1571 mapper.dispatch.after_insert(mapper, connection, state)
1572 else:
1573 mapper.dispatch.after_update(mapper, connection, state)
1574
1575 if (
1576 mapper.version_id_generator is False
1577 and mapper.version_id_col is not None
1578 ):
1579 if state_dict[mapper._version_id_prop.key] is None:
1580 raise orm_exc.FlushError(
1581 "Instance does not contain a non-NULL version value"
1582 )
1583
1584
1585def _postfetch_post_update(
1586 mapper, uowtransaction, table, state, dict_, result, params
1587):
1588 needs_version_id = (
1589 mapper.version_id_col is not None
1590 and mapper.version_id_col in mapper._cols_by_table[table]
1591 )
1592
1593 if not uowtransaction.is_deleted(state):
1594 # post updating after a regular INSERT or UPDATE, do a full postfetch
1595 prefetch_cols = result.context.compiled.prefetch
1596 postfetch_cols = result.context.compiled.postfetch
1597 elif needs_version_id:
1598 # post updating before a DELETE with a version_id_col, need to
1599 # postfetch just version_id_col
1600 prefetch_cols = postfetch_cols = ()
1601 else:
1602 # post updating before a DELETE without a version_id_col,
1603 # don't need to postfetch
1604 return
1605
1606 if needs_version_id:
1607 prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
1608
1609 refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
1610 if refresh_flush:
1611 load_evt_attrs = []
1612
1613 for c in prefetch_cols:
1614 if c.key in params and c in mapper._columntoproperty:
1615 dict_[mapper._columntoproperty[c].key] = params[c.key]
1616 if refresh_flush:
1617 load_evt_attrs.append(mapper._columntoproperty[c].key)
1618
1619 if refresh_flush and load_evt_attrs:
1620 mapper.class_manager.dispatch.refresh_flush(
1621 state, uowtransaction, load_evt_attrs
1622 )
1623
1624 if postfetch_cols:
1625 state._expire_attributes(
1626 state.dict,
1627 [
1628 mapper._columntoproperty[c].key
1629 for c in postfetch_cols
1630 if c in mapper._columntoproperty
1631 ],
1632 )
1633
1634
1635def _postfetch(
1636 mapper,
1637 uowtransaction,
1638 table,
1639 state,
1640 dict_,
1641 result,
1642 params,
1643 value_params,
1644 isupdate,
1645 returned_defaults,
1646):
1647 """Expire attributes in need of newly persisted database state,
1648 after an INSERT or UPDATE statement has proceeded for that
1649 state."""
1650
1651 prefetch_cols = result.context.compiled.prefetch
1652 postfetch_cols = result.context.compiled.postfetch
1653 returning_cols = result.context.compiled.effective_returning
1654
1655 if (
1656 mapper.version_id_col is not None
1657 and mapper.version_id_col in mapper._cols_by_table[table]
1658 ):
1659 prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
1660
1661 refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
1662 if refresh_flush:
1663 load_evt_attrs = []
1664
1665 if returning_cols:
1666 row = returned_defaults
1667 if row is not None:
1668 for row_value, col in zip(row, returning_cols):
1669 # pk cols returned from insert are handled
1670 # distinctly, don't step on the values here
1671 if col.primary_key and result.context.isinsert:
1672 continue
1673
1674 # note that columns can be in the "return defaults" that are
1675 # not mapped to this mapper, typically because they are
1676 # "excluded", which can be specified directly or also occurs
1677 # when using declarative w/ single table inheritance
1678 prop = mapper._columntoproperty.get(col)
1679 if prop:
1680 dict_[prop.key] = row_value
1681 if refresh_flush:
1682 load_evt_attrs.append(prop.key)
1683
1684 for c in prefetch_cols:
1685 if c.key in params and c in mapper._columntoproperty:
1686 pkey = mapper._columntoproperty[c].key
1687
1688 # set prefetched value in dict and also pop from committed_state,
1689 # since this is new database state that replaces whatever might
1690 # have previously been fetched (see #10800). this is essentially a
1691 # shorthand version of set_committed_value(), which could also be
1692 # used here directly (with more overhead)
1693 dict_[pkey] = params[c.key]
1694 state.committed_state.pop(pkey, None)
1695
1696 if refresh_flush:
1697 load_evt_attrs.append(pkey)
1698
1699 if refresh_flush and load_evt_attrs:
1700 mapper.class_manager.dispatch.refresh_flush(
1701 state, uowtransaction, load_evt_attrs
1702 )
1703
1704 if isupdate and value_params:
1705 # explicitly suit the use case specified by
1706 # [ticket:3801], PK SQL expressions for UPDATE on non-RETURNING
1707 # database which are set to themselves in order to do a version bump.
1708 postfetch_cols.extend(
1709 [
1710 col
1711 for col in value_params
1712 if col.primary_key and col not in returning_cols
1713 ]
1714 )
1715
1716 if postfetch_cols:
1717 state._expire_attributes(
1718 state.dict,
1719 [
1720 mapper._columntoproperty[c].key
1721 for c in postfetch_cols
1722 if c in mapper._columntoproperty
1723 ],
1724 )
1725
1726 # synchronize newly inserted ids from one table to the next
1727 # TODO: this still goes a little too often. would be nice to
1728 # have definitive list of "columns that changed" here
1729 for m, equated_pairs in mapper._table_to_equated[table]:
1730 sync._populate(
1731 state,
1732 m,
1733 state,
1734 m,
1735 equated_pairs,
1736 uowtransaction,
1737 mapper.passive_updates,
1738 )
1739
1740
1741def _postfetch_bulk_save(mapper, dict_, table):
1742 for m, equated_pairs in mapper._table_to_equated[table]:
1743 sync._bulk_populate_inherit_keys(dict_, m, equated_pairs)
1744
1745
1746def _connections_for_states(base_mapper, uowtransaction, states):
1747 """Return an iterator of (state, state.dict, mapper, connection).
1748
1749 The states are sorted according to _sort_states, then paired
1750 with the connection they should be using for the given
1751 unit of work transaction.
1752
1753 """
1754 # if session has a connection callable,
1755 # organize individual states with the connection
1756 # to use for update
1757 if uowtransaction.session.connection_callable:
1758 connection_callable = uowtransaction.session.connection_callable
1759 else:
1760 connection = uowtransaction.transaction.connection(base_mapper)
1761 connection_callable = None
1762
1763 for state in _sort_states(base_mapper, states):
1764 if connection_callable:
1765 connection = connection_callable(base_mapper, state.obj())
1766
1767 mapper = state.manager.mapper
1768
1769 yield state, state.dict, mapper, connection
1770
1771
1772def _sort_states(mapper, states):
1773 pending = set(states)
1774 persistent = {s for s in pending if s.key is not None}
1775 pending.difference_update(persistent)
1776
1777 try:
1778 persistent_sorted = sorted(
1779 persistent, key=mapper._persistent_sortkey_fn
1780 )
1781 except TypeError as err:
1782 raise sa_exc.InvalidRequestError(
1783 "Could not sort objects by primary key; primary key "
1784 "values must be sortable in Python (was: %s)" % err
1785 ) from err
1786 return (
1787 sorted(pending, key=operator.attrgetter("insert_order"))
1788 + persistent_sorted
1789 )