1# orm/persistence.py
2# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: ignore-errors
8
9
10"""private module containing functions used to emit INSERT, UPDATE
11and DELETE statements on behalf of a :class:`_orm.Mapper` and its descending
12mappers.
13
14The functions here are called only by the unit of work functions
15in unitofwork.py.
16
17"""
18from __future__ import annotations
19
20from itertools import chain
21from itertools import groupby
22from itertools import zip_longest
23import operator
24
25from . import attributes
26from . import exc as orm_exc
27from . import loading
28from . import sync
29from .base import state_str
30from .. import exc as sa_exc
31from .. import future
32from .. import sql
33from .. import util
34from ..engine import cursor as _cursor
35from ..sql import operators
36from ..sql.elements import BooleanClauseList
37from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
38
39
40def save_obj(base_mapper, states, uowtransaction, single=False):
41 """Issue ``INSERT`` and/or ``UPDATE`` statements for a list
42 of objects.
43
44 This is called within the context of a UOWTransaction during a
45 flush operation, given a list of states to be flushed. The
46 base mapper in an inheritance hierarchy handles the inserts/
47 updates for all descendant mappers.
48
49 """
50
51 # if batch=false, call _save_obj separately for each object
52 if not single and not base_mapper.batch:
53 for state in _sort_states(base_mapper, states):
54 save_obj(base_mapper, [state], uowtransaction, single=True)
55 return
56
57 states_to_update = []
58 states_to_insert = []
59
60 for (
61 state,
62 dict_,
63 mapper,
64 connection,
65 has_identity,
66 row_switch,
67 update_version_id,
68 ) in _organize_states_for_save(base_mapper, states, uowtransaction):
69 if has_identity or row_switch:
70 states_to_update.append(
71 (state, dict_, mapper, connection, update_version_id)
72 )
73 else:
74 states_to_insert.append((state, dict_, mapper, connection))
75
76 for table, mapper in base_mapper._sorted_tables.items():
77 if table not in mapper._pks_by_table:
78 continue
79 insert = _collect_insert_commands(table, states_to_insert)
80
81 update = _collect_update_commands(
82 uowtransaction, table, states_to_update
83 )
84
85 _emit_update_statements(
86 base_mapper,
87 uowtransaction,
88 mapper,
89 table,
90 update,
91 )
92
93 _emit_insert_statements(
94 base_mapper,
95 uowtransaction,
96 mapper,
97 table,
98 insert,
99 )
100
101 _finalize_insert_update_commands(
102 base_mapper,
103 uowtransaction,
104 chain(
105 (
106 (state, state_dict, mapper, connection, False)
107 for (state, state_dict, mapper, connection) in states_to_insert
108 ),
109 (
110 (state, state_dict, mapper, connection, True)
111 for (
112 state,
113 state_dict,
114 mapper,
115 connection,
116 update_version_id,
117 ) in states_to_update
118 ),
119 ),
120 )
121
122
123def post_update(base_mapper, states, uowtransaction, post_update_cols):
124 """Issue UPDATE statements on behalf of a relationship() which
125 specifies post_update.
126
127 """
128
129 states_to_update = list(
130 _organize_states_for_post_update(base_mapper, states, uowtransaction)
131 )
132
133 for table, mapper in base_mapper._sorted_tables.items():
134 if table not in mapper._pks_by_table:
135 continue
136
137 update = (
138 (
139 state,
140 state_dict,
141 sub_mapper,
142 connection,
143 (
144 mapper._get_committed_state_attr_by_column(
145 state, state_dict, mapper.version_id_col
146 )
147 if mapper.version_id_col is not None
148 else None
149 ),
150 )
151 for state, state_dict, sub_mapper, connection in states_to_update
152 if table in sub_mapper._pks_by_table
153 )
154
155 update = _collect_post_update_commands(
156 base_mapper, uowtransaction, table, update, post_update_cols
157 )
158
159 _emit_post_update_statements(
160 base_mapper,
161 uowtransaction,
162 mapper,
163 table,
164 update,
165 )
166
167
168def delete_obj(base_mapper, states, uowtransaction):
169 """Issue ``DELETE`` statements for a list of objects.
170
171 This is called within the context of a UOWTransaction during a
172 flush operation.
173
174 """
175
176 states_to_delete = list(
177 _organize_states_for_delete(base_mapper, states, uowtransaction)
178 )
179
180 table_to_mapper = base_mapper._sorted_tables
181
182 for table in reversed(list(table_to_mapper.keys())):
183 mapper = table_to_mapper[table]
184 if table not in mapper._pks_by_table:
185 continue
186 elif mapper.inherits and mapper.passive_deletes:
187 continue
188
189 delete = _collect_delete_commands(
190 base_mapper, uowtransaction, table, states_to_delete
191 )
192
193 _emit_delete_statements(
194 base_mapper,
195 uowtransaction,
196 mapper,
197 table,
198 delete,
199 )
200
201 for (
202 state,
203 state_dict,
204 mapper,
205 connection,
206 update_version_id,
207 ) in states_to_delete:
208 mapper.dispatch.after_delete(mapper, connection, state)
209
210
211def _organize_states_for_save(base_mapper, states, uowtransaction):
212 """Make an initial pass across a set of states for INSERT or
213 UPDATE.
214
215 This includes splitting out into distinct lists for
216 each, calling before_insert/before_update, obtaining
217 key information for each state including its dictionary,
218 mapper, the connection to use for the execution per state,
219 and the identity flag.
220
221 """
222
223 for state, dict_, mapper, connection in _connections_for_states(
224 base_mapper, uowtransaction, states
225 ):
226 has_identity = bool(state.key)
227
228 instance_key = state.key or mapper._identity_key_from_state(state)
229
230 row_switch = update_version_id = None
231
232 # call before_XXX extensions
233 if not has_identity:
234 mapper.dispatch.before_insert(mapper, connection, state)
235 else:
236 mapper.dispatch.before_update(mapper, connection, state)
237
238 if mapper._validate_polymorphic_identity:
239 mapper._validate_polymorphic_identity(mapper, state, dict_)
240
241 # detect if we have a "pending" instance (i.e. has
242 # no instance_key attached to it), and another instance
243 # with the same identity key already exists as persistent.
244 # convert to an UPDATE if so.
245 if (
246 not has_identity
247 and instance_key in uowtransaction.session.identity_map
248 ):
249 instance = uowtransaction.session.identity_map[instance_key]
250 existing = attributes.instance_state(instance)
251
252 if not uowtransaction.was_already_deleted(existing):
253 if not uowtransaction.is_deleted(existing):
254 util.warn(
255 "New instance %s with identity key %s conflicts "
256 "with persistent instance %s"
257 % (state_str(state), instance_key, state_str(existing))
258 )
259 else:
260 base_mapper._log_debug(
261 "detected row switch for identity %s. "
262 "will update %s, remove %s from "
263 "transaction",
264 instance_key,
265 state_str(state),
266 state_str(existing),
267 )
268
269 # remove the "delete" flag from the existing element
270 uowtransaction.remove_state_actions(existing)
271 row_switch = existing
272
273 if (has_identity or row_switch) and mapper.version_id_col is not None:
274 update_version_id = mapper._get_committed_state_attr_by_column(
275 row_switch if row_switch else state,
276 row_switch.dict if row_switch else dict_,
277 mapper.version_id_col,
278 )
279
280 yield (
281 state,
282 dict_,
283 mapper,
284 connection,
285 has_identity,
286 row_switch,
287 update_version_id,
288 )
289
290
291def _organize_states_for_post_update(base_mapper, states, uowtransaction):
292 """Make an initial pass across a set of states for UPDATE
293 corresponding to post_update.
294
295 This includes obtaining key information for each state
296 including its dictionary, mapper, the connection to use for
297 the execution per state.
298
299 """
300 return _connections_for_states(base_mapper, uowtransaction, states)
301
302
303def _organize_states_for_delete(base_mapper, states, uowtransaction):
304 """Make an initial pass across a set of states for DELETE.
305
306 This includes calling out before_delete and obtaining
307 key information for each state including its dictionary,
308 mapper, the connection to use for the execution per state.
309
310 """
311 for state, dict_, mapper, connection in _connections_for_states(
312 base_mapper, uowtransaction, states
313 ):
314 mapper.dispatch.before_delete(mapper, connection, state)
315
316 if mapper.version_id_col is not None:
317 update_version_id = mapper._get_committed_state_attr_by_column(
318 state, dict_, mapper.version_id_col
319 )
320 else:
321 update_version_id = None
322
323 yield (state, dict_, mapper, connection, update_version_id)
324
325
326def _collect_insert_commands(
327 table,
328 states_to_insert,
329 *,
330 bulk=False,
331 return_defaults=False,
332 render_nulls=False,
333 include_bulk_keys=(),
334):
335 """Identify sets of values to use in INSERT statements for a
336 list of states.
337
338 """
339 for state, state_dict, mapper, connection in states_to_insert:
340 if table not in mapper._pks_by_table:
341 continue
342
343 params = {}
344 value_params = {}
345
346 propkey_to_col = mapper._propkey_to_col[table]
347
348 eval_none = mapper._insert_cols_evaluating_none[table]
349
350 for propkey in set(propkey_to_col).intersection(state_dict):
351 value = state_dict[propkey]
352 col = propkey_to_col[propkey]
353 if value is None and col not in eval_none and not render_nulls:
354 continue
355 elif not bulk and (
356 hasattr(value, "__clause_element__")
357 or isinstance(value, sql.ClauseElement)
358 ):
359 value_params[col] = (
360 value.__clause_element__()
361 if hasattr(value, "__clause_element__")
362 else value
363 )
364 else:
365 params[col.key] = value
366
367 if not bulk:
368 # for all the columns that have no default and we don't have
369 # a value and where "None" is not a special value, add
370 # explicit None to the INSERT. This is a legacy behavior
371 # which might be worth removing, as it should not be necessary
372 # and also produces confusion, given that "missing" and None
373 # now have distinct meanings
374 for colkey in (
375 mapper._insert_cols_as_none[table]
376 .difference(params)
377 .difference([c.key for c in value_params])
378 ):
379 params[colkey] = None
380
381 if not bulk or return_defaults:
382 # params are in terms of Column key objects, so
383 # compare to pk_keys_by_table
384 has_all_pks = mapper._pk_keys_by_table[table].issubset(params)
385
386 if mapper.base_mapper._prefer_eager_defaults(
387 connection.dialect, table
388 ):
389 has_all_defaults = mapper._server_default_col_keys[
390 table
391 ].issubset(params)
392 else:
393 has_all_defaults = True
394 else:
395 has_all_defaults = has_all_pks = True
396
397 if (
398 mapper.version_id_generator is not False
399 and mapper.version_id_col is not None
400 and mapper.version_id_col in mapper._cols_by_table[table]
401 ):
402 params[mapper.version_id_col.key] = mapper.version_id_generator(
403 None
404 )
405
406 if bulk:
407 if mapper._set_polymorphic_identity:
408 params.setdefault(
409 mapper._polymorphic_attr_key, mapper.polymorphic_identity
410 )
411
412 if include_bulk_keys:
413 params.update((k, state_dict[k]) for k in include_bulk_keys)
414
415 yield (
416 state,
417 state_dict,
418 params,
419 mapper,
420 connection,
421 value_params,
422 has_all_pks,
423 has_all_defaults,
424 )
425
426
427def _collect_update_commands(
428 uowtransaction,
429 table,
430 states_to_update,
431 *,
432 bulk=False,
433 use_orm_update_stmt=None,
434 include_bulk_keys=(),
435):
436 """Identify sets of values to use in UPDATE statements for a
437 list of states.
438
439 This function works intricately with the history system
440 to determine exactly what values should be updated
441 as well as how the row should be matched within an UPDATE
442 statement. Includes some tricky scenarios where the primary
443 key of an object might have been changed.
444
445 """
446
447 for (
448 state,
449 state_dict,
450 mapper,
451 connection,
452 update_version_id,
453 ) in states_to_update:
454 if table not in mapper._pks_by_table:
455 continue
456
457 pks = mapper._pks_by_table[table]
458
459 if use_orm_update_stmt is not None:
460 # TODO: ordered values, etc
461 value_params = use_orm_update_stmt._values
462 else:
463 value_params = {}
464
465 propkey_to_col = mapper._propkey_to_col[table]
466
467 if bulk:
468 # keys here are mapped attribute keys, so
469 # look at mapper attribute keys for pk
470 params = {
471 propkey_to_col[propkey].key: state_dict[propkey]
472 for propkey in set(propkey_to_col)
473 .intersection(state_dict)
474 .difference(mapper._pk_attr_keys_by_table[table])
475 }
476 has_all_defaults = True
477 else:
478 params = {}
479 for propkey in set(propkey_to_col).intersection(
480 state.committed_state
481 ):
482 value = state_dict[propkey]
483 col = propkey_to_col[propkey]
484
485 if hasattr(value, "__clause_element__") or isinstance(
486 value, sql.ClauseElement
487 ):
488 value_params[col] = (
489 value.__clause_element__()
490 if hasattr(value, "__clause_element__")
491 else value
492 )
493 # guard against values that generate non-__nonzero__
494 # objects for __eq__()
495 elif (
496 state.manager[propkey].impl.is_equal(
497 value, state.committed_state[propkey]
498 )
499 is not True
500 ):
501 params[col.key] = value
502
503 if mapper.base_mapper.eager_defaults is True:
504 has_all_defaults = (
505 mapper._server_onupdate_default_col_keys[table]
506 ).issubset(params)
507 else:
508 has_all_defaults = True
509
510 if (
511 update_version_id is not None
512 and mapper.version_id_col in mapper._cols_by_table[table]
513 ):
514 if not bulk and not (params or value_params):
515 # HACK: check for history in other tables, in case the
516 # history is only in a different table than the one
517 # where the version_id_col is. This logic was lost
518 # from 0.9 -> 1.0.0 and restored in 1.0.6.
519 for prop in mapper._columntoproperty.values():
520 history = state.manager[prop.key].impl.get_history(
521 state, state_dict, attributes.PASSIVE_NO_INITIALIZE
522 )
523 if history.added:
524 break
525 else:
526 # no net change, break
527 continue
528
529 col = mapper.version_id_col
530 no_params = not params and not value_params
531 params[col._label] = update_version_id
532
533 if (
534 bulk or col.key not in params
535 ) and mapper.version_id_generator is not False:
536 val = mapper.version_id_generator(update_version_id)
537 params[col.key] = val
538 elif mapper.version_id_generator is False and no_params:
539 # no version id generator, no values set on the table,
540 # and version id wasn't manually incremented.
541 # set version id to itself so we get an UPDATE
542 # statement
543 params[col.key] = update_version_id
544
545 elif not (params or value_params):
546 continue
547
548 has_all_pks = True
549 expect_pk_cascaded = False
550 if bulk:
551 # keys here are mapped attribute keys, so
552 # look at mapper attribute keys for pk
553 pk_params = {
554 propkey_to_col[propkey]._label: state_dict.get(propkey)
555 for propkey in set(propkey_to_col).intersection(
556 mapper._pk_attr_keys_by_table[table]
557 )
558 }
559 if util.NONE_SET.intersection(pk_params.values()):
560 raise sa_exc.InvalidRequestError(
561 f"No primary key value supplied for column(s) "
562 f"""{
563 ', '.join(
564 str(c) for c in pks if pk_params[c._label] is None
565 )
566 }; """
567 "per-row ORM Bulk UPDATE by Primary Key requires that "
568 "records contain primary key values",
569 code="bupq",
570 )
571
572 else:
573 pk_params = {}
574 for col in pks:
575 propkey = mapper._columntoproperty[col].key
576
577 history = state.manager[propkey].impl.get_history(
578 state, state_dict, attributes.PASSIVE_OFF
579 )
580
581 if history.added:
582 if (
583 not history.deleted
584 or ("pk_cascaded", state, col)
585 in uowtransaction.attributes
586 ):
587 expect_pk_cascaded = True
588 pk_params[col._label] = history.added[0]
589 params.pop(col.key, None)
590 else:
591 # else, use the old value to locate the row
592 pk_params[col._label] = history.deleted[0]
593 if col in value_params:
594 has_all_pks = False
595 else:
596 pk_params[col._label] = history.unchanged[0]
597 if pk_params[col._label] is None:
598 raise orm_exc.FlushError(
599 "Can't update table %s using NULL for primary "
600 "key value on column %s" % (table, col)
601 )
602
603 if include_bulk_keys:
604 params.update((k, state_dict[k]) for k in include_bulk_keys)
605
606 if params or value_params:
607 params.update(pk_params)
608 yield (
609 state,
610 state_dict,
611 params,
612 mapper,
613 connection,
614 value_params,
615 has_all_defaults,
616 has_all_pks,
617 )
618 elif expect_pk_cascaded:
619 # no UPDATE occurs on this table, but we expect that CASCADE rules
620 # have changed the primary key of the row; propagate this event to
621 # other columns that expect to have been modified. this normally
622 # occurs after the UPDATE is emitted however we invoke it here
623 # explicitly in the absence of our invoking an UPDATE
624 for m, equated_pairs in mapper._table_to_equated[table]:
625 sync.populate(
626 state,
627 m,
628 state,
629 m,
630 equated_pairs,
631 uowtransaction,
632 mapper.passive_updates,
633 )
634
635
636def _collect_post_update_commands(
637 base_mapper, uowtransaction, table, states_to_update, post_update_cols
638):
639 """Identify sets of values to use in UPDATE statements for a
640 list of states within a post_update operation.
641
642 """
643
644 for (
645 state,
646 state_dict,
647 mapper,
648 connection,
649 update_version_id,
650 ) in states_to_update:
651 # assert table in mapper._pks_by_table
652
653 pks = mapper._pks_by_table[table]
654 params = {}
655 hasdata = False
656
657 for col in mapper._cols_by_table[table]:
658 if col in pks:
659 params[col._label] = mapper._get_state_attr_by_column(
660 state, state_dict, col, passive=attributes.PASSIVE_OFF
661 )
662
663 elif col in post_update_cols or col.onupdate is not None:
664 prop = mapper._columntoproperty[col]
665 history = state.manager[prop.key].impl.get_history(
666 state, state_dict, attributes.PASSIVE_NO_INITIALIZE
667 )
668 if history.added:
669 value = history.added[0]
670 params[col.key] = value
671 hasdata = True
672 if hasdata:
673 if (
674 update_version_id is not None
675 and mapper.version_id_col in mapper._cols_by_table[table]
676 ):
677 col = mapper.version_id_col
678 params[col._label] = update_version_id
679
680 if (
681 bool(state.key)
682 and col.key not in params
683 and mapper.version_id_generator is not False
684 ):
685 val = mapper.version_id_generator(update_version_id)
686 params[col.key] = val
687 yield state, state_dict, mapper, connection, params
688
689
690def _collect_delete_commands(
691 base_mapper, uowtransaction, table, states_to_delete
692):
693 """Identify values to use in DELETE statements for a list of
694 states to be deleted."""
695
696 for (
697 state,
698 state_dict,
699 mapper,
700 connection,
701 update_version_id,
702 ) in states_to_delete:
703 if table not in mapper._pks_by_table:
704 continue
705
706 params = {}
707 for col in mapper._pks_by_table[table]:
708 params[col.key] = value = (
709 mapper._get_committed_state_attr_by_column(
710 state, state_dict, col
711 )
712 )
713 if value is None:
714 raise orm_exc.FlushError(
715 "Can't delete from table %s "
716 "using NULL for primary "
717 "key value on column %s" % (table, col)
718 )
719
720 if (
721 update_version_id is not None
722 and mapper.version_id_col in mapper._cols_by_table[table]
723 ):
724 params[mapper.version_id_col.key] = update_version_id
725 yield params, connection
726
727
728def _emit_update_statements(
729 base_mapper,
730 uowtransaction,
731 mapper,
732 table,
733 update,
734 *,
735 bookkeeping=True,
736 use_orm_update_stmt=None,
737 enable_check_rowcount=True,
738):
739 """Emit UPDATE statements corresponding to value lists collected
740 by _collect_update_commands()."""
741
742 needs_version_id = (
743 mapper.version_id_col is not None
744 and mapper.version_id_col in mapper._cols_by_table[table]
745 )
746
747 execution_options = {"compiled_cache": base_mapper._compiled_cache}
748
749 def update_stmt(existing_stmt=None):
750 clauses = BooleanClauseList._construct_raw(operators.and_)
751
752 for col in mapper._pks_by_table[table]:
753 clauses._append_inplace(
754 col == sql.bindparam(col._label, type_=col.type)
755 )
756
757 if needs_version_id:
758 clauses._append_inplace(
759 mapper.version_id_col
760 == sql.bindparam(
761 mapper.version_id_col._label,
762 type_=mapper.version_id_col.type,
763 )
764 )
765
766 if existing_stmt is not None:
767 stmt = existing_stmt.where(clauses)
768 else:
769 stmt = table.update().where(clauses)
770 return stmt
771
772 if use_orm_update_stmt is not None:
773 cached_stmt = update_stmt(use_orm_update_stmt)
774
775 else:
776 cached_stmt = base_mapper._memo(("update", table), update_stmt)
777
778 for (
779 (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks),
780 records,
781 ) in groupby(
782 update,
783 lambda rec: (
784 rec[4], # connection
785 set(rec[2]), # set of parameter keys
786 bool(rec[5]), # whether or not we have "value" parameters
787 rec[6], # has_all_defaults
788 rec[7], # has all pks
789 ),
790 ):
791 rows = 0
792 records = list(records)
793
794 statement = cached_stmt
795
796 if use_orm_update_stmt is not None:
797 statement = statement._annotate(
798 {
799 "_emit_update_table": table,
800 "_emit_update_mapper": mapper,
801 }
802 )
803
804 return_defaults = False
805
806 if not has_all_pks:
807 statement = statement.return_defaults(*mapper._pks_by_table[table])
808 return_defaults = True
809
810 if (
811 bookkeeping
812 and not has_all_defaults
813 and mapper.base_mapper.eager_defaults is True
814 # change as of #8889 - if RETURNING is not going to be used anyway,
815 # (applies to MySQL, MariaDB which lack UPDATE RETURNING) ensure
816 # we can do an executemany UPDATE which is more efficient
817 and table.implicit_returning
818 and connection.dialect.update_returning
819 ):
820 statement = statement.return_defaults(
821 *mapper._server_onupdate_default_cols[table]
822 )
823 return_defaults = True
824
825 if mapper._version_id_has_server_side_value:
826 statement = statement.return_defaults(mapper.version_id_col)
827 return_defaults = True
828
829 assert_singlerow = connection.dialect.supports_sane_rowcount
830
831 assert_multirow = (
832 assert_singlerow
833 and connection.dialect.supports_sane_multi_rowcount
834 )
835
836 # change as of #8889 - if RETURNING is not going to be used anyway,
837 # (applies to MySQL, MariaDB which lack UPDATE RETURNING) ensure
838 # we can do an executemany UPDATE which is more efficient
839 allow_executemany = not return_defaults and not needs_version_id
840
841 if hasvalue:
842 for (
843 state,
844 state_dict,
845 params,
846 mapper,
847 connection,
848 value_params,
849 has_all_defaults,
850 has_all_pks,
851 ) in records:
852 c = connection.execute(
853 statement.values(value_params),
854 params,
855 execution_options=execution_options,
856 )
857 if bookkeeping:
858 _postfetch(
859 mapper,
860 uowtransaction,
861 table,
862 state,
863 state_dict,
864 c,
865 c.context.compiled_parameters[0],
866 value_params,
867 True,
868 c.returned_defaults,
869 )
870 rows += c.rowcount
871 check_rowcount = enable_check_rowcount and assert_singlerow
872 else:
873 if not allow_executemany:
874 check_rowcount = enable_check_rowcount and assert_singlerow
875 for (
876 state,
877 state_dict,
878 params,
879 mapper,
880 connection,
881 value_params,
882 has_all_defaults,
883 has_all_pks,
884 ) in records:
885 c = connection.execute(
886 statement, params, execution_options=execution_options
887 )
888
889 # TODO: why with bookkeeping=False?
890 if bookkeeping:
891 _postfetch(
892 mapper,
893 uowtransaction,
894 table,
895 state,
896 state_dict,
897 c,
898 c.context.compiled_parameters[0],
899 value_params,
900 True,
901 c.returned_defaults,
902 )
903 rows += c.rowcount
904 else:
905 multiparams = [rec[2] for rec in records]
906
907 check_rowcount = enable_check_rowcount and (
908 assert_multirow
909 or (assert_singlerow and len(multiparams) == 1)
910 )
911
912 c = connection.execute(
913 statement, multiparams, execution_options=execution_options
914 )
915
916 rows += c.rowcount
917
918 for (
919 state,
920 state_dict,
921 params,
922 mapper,
923 connection,
924 value_params,
925 has_all_defaults,
926 has_all_pks,
927 ) in records:
928 if bookkeeping:
929 _postfetch(
930 mapper,
931 uowtransaction,
932 table,
933 state,
934 state_dict,
935 c,
936 c.context.compiled_parameters[0],
937 value_params,
938 True,
939 (
940 c.returned_defaults
941 if not c.context.executemany
942 else None
943 ),
944 )
945
946 if check_rowcount:
947 if rows != len(records):
948 raise orm_exc.StaleDataError(
949 "UPDATE statement on table '%s' expected to "
950 "update %d row(s); %d were matched."
951 % (table.description, len(records), rows)
952 )
953
954 elif needs_version_id:
955 util.warn(
956 "Dialect %s does not support updated rowcount "
957 "- versioning cannot be verified."
958 % c.dialect.dialect_description
959 )
960
961
962def _emit_insert_statements(
963 base_mapper,
964 uowtransaction,
965 mapper,
966 table,
967 insert,
968 *,
969 bookkeeping=True,
970 use_orm_insert_stmt=None,
971 execution_options=None,
972):
973 """Emit INSERT statements corresponding to value lists collected
974 by _collect_insert_commands()."""
975
976 if use_orm_insert_stmt is not None:
977 cached_stmt = use_orm_insert_stmt
978 exec_opt = util.EMPTY_DICT
979
980 # if a user query with RETURNING was passed, we definitely need
981 # to use RETURNING.
982 returning_is_required_anyway = bool(use_orm_insert_stmt._returning)
983 deterministic_results_reqd = (
984 returning_is_required_anyway
985 and use_orm_insert_stmt._sort_by_parameter_order
986 ) or bookkeeping
987 else:
988 returning_is_required_anyway = False
989 deterministic_results_reqd = bookkeeping
990 cached_stmt = base_mapper._memo(("insert", table), table.insert)
991 exec_opt = {"compiled_cache": base_mapper._compiled_cache}
992
993 if execution_options:
994 execution_options = util.EMPTY_DICT.merge_with(
995 exec_opt, execution_options
996 )
997 else:
998 execution_options = exec_opt
999
1000 return_result = None
1001
1002 for (
1003 (connection, _, hasvalue, has_all_pks, has_all_defaults),
1004 records,
1005 ) in groupby(
1006 insert,
1007 lambda rec: (
1008 rec[4], # connection
1009 set(rec[2]), # parameter keys
1010 bool(rec[5]), # whether we have "value" parameters
1011 rec[6],
1012 rec[7],
1013 ),
1014 ):
1015 statement = cached_stmt
1016
1017 if use_orm_insert_stmt is not None:
1018 statement = statement._annotate(
1019 {
1020 "_emit_insert_table": table,
1021 "_emit_insert_mapper": mapper,
1022 }
1023 )
1024
1025 if (
1026 (
1027 not bookkeeping
1028 or (
1029 has_all_defaults
1030 or not base_mapper._prefer_eager_defaults(
1031 connection.dialect, table
1032 )
1033 or not table.implicit_returning
1034 or not connection.dialect.insert_returning
1035 )
1036 )
1037 and not returning_is_required_anyway
1038 and has_all_pks
1039 and not hasvalue
1040 ):
1041 # the "we don't need newly generated values back" section.
1042 # here we have all the PKs, all the defaults or we don't want
1043 # to fetch them, or the dialect doesn't support RETURNING at all
1044 # so we have to post-fetch / use lastrowid anyway.
1045 records = list(records)
1046 multiparams = [rec[2] for rec in records]
1047
1048 result = connection.execute(
1049 statement, multiparams, execution_options=execution_options
1050 )
1051 if bookkeeping:
1052 for (
1053 (
1054 state,
1055 state_dict,
1056 params,
1057 mapper_rec,
1058 conn,
1059 value_params,
1060 has_all_pks,
1061 has_all_defaults,
1062 ),
1063 last_inserted_params,
1064 ) in zip(records, result.context.compiled_parameters):
1065 if state:
1066 _postfetch(
1067 mapper_rec,
1068 uowtransaction,
1069 table,
1070 state,
1071 state_dict,
1072 result,
1073 last_inserted_params,
1074 value_params,
1075 False,
1076 (
1077 result.returned_defaults
1078 if not result.context.executemany
1079 else None
1080 ),
1081 )
1082 else:
1083 _postfetch_bulk_save(mapper_rec, state_dict, table)
1084
1085 else:
1086 # here, we need defaults and/or pk values back or we otherwise
1087 # know that we are using RETURNING in any case
1088
1089 records = list(records)
1090
1091 if returning_is_required_anyway or (
1092 table.implicit_returning and not hasvalue and len(records) > 1
1093 ):
1094 if (
1095 deterministic_results_reqd
1096 and connection.dialect.insert_executemany_returning_sort_by_parameter_order # noqa: E501
1097 ) or (
1098 not deterministic_results_reqd
1099 and connection.dialect.insert_executemany_returning
1100 ):
1101 do_executemany = True
1102 elif returning_is_required_anyway:
1103 if deterministic_results_reqd:
1104 dt = " with RETURNING and sort by parameter order"
1105 else:
1106 dt = " with RETURNING"
1107 raise sa_exc.InvalidRequestError(
1108 f"Can't use explicit RETURNING for bulk INSERT "
1109 f"operation with "
1110 f"{connection.dialect.dialect_description} backend; "
1111 f"executemany{dt} is not enabled for this dialect."
1112 )
1113 else:
1114 do_executemany = False
1115 else:
1116 do_executemany = False
1117
1118 if use_orm_insert_stmt is None:
1119 if (
1120 not has_all_defaults
1121 and base_mapper._prefer_eager_defaults(
1122 connection.dialect, table
1123 )
1124 ):
1125 statement = statement.return_defaults(
1126 *mapper._server_default_cols[table],
1127 sort_by_parameter_order=bookkeeping,
1128 )
1129
1130 if mapper.version_id_col is not None:
1131 statement = statement.return_defaults(
1132 mapper.version_id_col,
1133 sort_by_parameter_order=bookkeeping,
1134 )
1135 elif do_executemany:
1136 statement = statement.return_defaults(
1137 *table.primary_key, sort_by_parameter_order=bookkeeping
1138 )
1139
1140 if do_executemany:
1141 multiparams = [rec[2] for rec in records]
1142
1143 result = connection.execute(
1144 statement, multiparams, execution_options=execution_options
1145 )
1146
1147 if use_orm_insert_stmt is not None:
1148 if return_result is None:
1149 return_result = result
1150 else:
1151 return_result = return_result.splice_vertically(result)
1152
1153 if bookkeeping:
1154 for (
1155 (
1156 state,
1157 state_dict,
1158 params,
1159 mapper_rec,
1160 conn,
1161 value_params,
1162 has_all_pks,
1163 has_all_defaults,
1164 ),
1165 last_inserted_params,
1166 inserted_primary_key,
1167 returned_defaults,
1168 ) in zip_longest(
1169 records,
1170 result.context.compiled_parameters,
1171 result.inserted_primary_key_rows,
1172 result.returned_defaults_rows or (),
1173 ):
1174 if inserted_primary_key is None:
1175 # this is a real problem and means that we didn't
1176 # get back as many PK rows. we can't continue
1177 # since this indicates PK rows were missing, which
1178 # means we likely mis-populated records starting
1179 # at that point with incorrectly matched PK
1180 # values.
1181 raise orm_exc.FlushError(
1182 "Multi-row INSERT statement for %s did not "
1183 "produce "
1184 "the correct number of INSERTed rows for "
1185 "RETURNING. Ensure there are no triggers or "
1186 "special driver issues preventing INSERT from "
1187 "functioning properly." % mapper_rec
1188 )
1189
1190 for pk, col in zip(
1191 inserted_primary_key,
1192 mapper._pks_by_table[table],
1193 ):
1194 prop = mapper_rec._columntoproperty[col]
1195 if state_dict.get(prop.key) is None:
1196 state_dict[prop.key] = pk
1197
1198 if state:
1199 _postfetch(
1200 mapper_rec,
1201 uowtransaction,
1202 table,
1203 state,
1204 state_dict,
1205 result,
1206 last_inserted_params,
1207 value_params,
1208 False,
1209 returned_defaults,
1210 )
1211 else:
1212 _postfetch_bulk_save(mapper_rec, state_dict, table)
1213 else:
1214 assert not returning_is_required_anyway
1215
1216 for (
1217 state,
1218 state_dict,
1219 params,
1220 mapper_rec,
1221 connection,
1222 value_params,
1223 has_all_pks,
1224 has_all_defaults,
1225 ) in records:
1226 if value_params:
1227 result = connection.execute(
1228 statement.values(value_params),
1229 params,
1230 execution_options=execution_options,
1231 )
1232 else:
1233 result = connection.execute(
1234 statement,
1235 params,
1236 execution_options=execution_options,
1237 )
1238
1239 primary_key = result.inserted_primary_key
1240 if primary_key is None:
1241 raise orm_exc.FlushError(
1242 "Single-row INSERT statement for %s "
1243 "did not produce a "
1244 "new primary key result "
1245 "being invoked. Ensure there are no triggers or "
1246 "special driver issues preventing INSERT from "
1247 "functioning properly." % (mapper_rec,)
1248 )
1249 for pk, col in zip(
1250 primary_key, mapper._pks_by_table[table]
1251 ):
1252 prop = mapper_rec._columntoproperty[col]
1253 if (
1254 col in value_params
1255 or state_dict.get(prop.key) is None
1256 ):
1257 state_dict[prop.key] = pk
1258 if bookkeeping:
1259 if state:
1260 _postfetch(
1261 mapper_rec,
1262 uowtransaction,
1263 table,
1264 state,
1265 state_dict,
1266 result,
1267 result.context.compiled_parameters[0],
1268 value_params,
1269 False,
1270 (
1271 result.returned_defaults
1272 if not result.context.executemany
1273 else None
1274 ),
1275 )
1276 else:
1277 _postfetch_bulk_save(mapper_rec, state_dict, table)
1278
1279 if use_orm_insert_stmt is not None:
1280 if return_result is None:
1281 return _cursor.null_dml_result()
1282 else:
1283 return return_result
1284
1285
1286def _emit_post_update_statements(
1287 base_mapper, uowtransaction, mapper, table, update
1288):
1289 """Emit UPDATE statements corresponding to value lists collected
1290 by _collect_post_update_commands()."""
1291
1292 execution_options = {"compiled_cache": base_mapper._compiled_cache}
1293
1294 needs_version_id = (
1295 mapper.version_id_col is not None
1296 and mapper.version_id_col in mapper._cols_by_table[table]
1297 )
1298
1299 def update_stmt():
1300 clauses = BooleanClauseList._construct_raw(operators.and_)
1301
1302 for col in mapper._pks_by_table[table]:
1303 clauses._append_inplace(
1304 col == sql.bindparam(col._label, type_=col.type)
1305 )
1306
1307 if needs_version_id:
1308 clauses._append_inplace(
1309 mapper.version_id_col
1310 == sql.bindparam(
1311 mapper.version_id_col._label,
1312 type_=mapper.version_id_col.type,
1313 )
1314 )
1315
1316 stmt = table.update().where(clauses)
1317
1318 return stmt
1319
1320 statement = base_mapper._memo(("post_update", table), update_stmt)
1321
1322 if mapper._version_id_has_server_side_value:
1323 statement = statement.return_defaults(mapper.version_id_col)
1324
1325 # execute each UPDATE in the order according to the original
1326 # list of states to guarantee row access order, but
1327 # also group them into common (connection, cols) sets
1328 # to support executemany().
1329 for key, records in groupby(
1330 update,
1331 lambda rec: (rec[3], set(rec[4])), # connection # parameter keys
1332 ):
1333 rows = 0
1334
1335 records = list(records)
1336 connection = key[0]
1337
1338 assert_singlerow = connection.dialect.supports_sane_rowcount
1339 assert_multirow = (
1340 assert_singlerow
1341 and connection.dialect.supports_sane_multi_rowcount
1342 )
1343 allow_executemany = not needs_version_id or assert_multirow
1344
1345 if not allow_executemany:
1346 check_rowcount = assert_singlerow
1347 for state, state_dict, mapper_rec, connection, params in records:
1348 c = connection.execute(
1349 statement, params, execution_options=execution_options
1350 )
1351
1352 _postfetch_post_update(
1353 mapper_rec,
1354 uowtransaction,
1355 table,
1356 state,
1357 state_dict,
1358 c,
1359 c.context.compiled_parameters[0],
1360 )
1361 rows += c.rowcount
1362 else:
1363 multiparams = [
1364 params
1365 for state, state_dict, mapper_rec, conn, params in records
1366 ]
1367
1368 check_rowcount = assert_multirow or (
1369 assert_singlerow and len(multiparams) == 1
1370 )
1371
1372 c = connection.execute(
1373 statement, multiparams, execution_options=execution_options
1374 )
1375
1376 rows += c.rowcount
1377 for state, state_dict, mapper_rec, connection, params in records:
1378 _postfetch_post_update(
1379 mapper_rec,
1380 uowtransaction,
1381 table,
1382 state,
1383 state_dict,
1384 c,
1385 c.context.compiled_parameters[0],
1386 )
1387
1388 if check_rowcount:
1389 if rows != len(records):
1390 raise orm_exc.StaleDataError(
1391 "UPDATE statement on table '%s' expected to "
1392 "update %d row(s); %d were matched."
1393 % (table.description, len(records), rows)
1394 )
1395
1396 elif needs_version_id:
1397 util.warn(
1398 "Dialect %s does not support updated rowcount "
1399 "- versioning cannot be verified."
1400 % c.dialect.dialect_description
1401 )
1402
1403
1404def _emit_delete_statements(
1405 base_mapper, uowtransaction, mapper, table, delete
1406):
1407 """Emit DELETE statements corresponding to value lists collected
1408 by _collect_delete_commands()."""
1409
1410 need_version_id = (
1411 mapper.version_id_col is not None
1412 and mapper.version_id_col in mapper._cols_by_table[table]
1413 )
1414
1415 def delete_stmt():
1416 clauses = BooleanClauseList._construct_raw(operators.and_)
1417
1418 for col in mapper._pks_by_table[table]:
1419 clauses._append_inplace(
1420 col == sql.bindparam(col.key, type_=col.type)
1421 )
1422
1423 if need_version_id:
1424 clauses._append_inplace(
1425 mapper.version_id_col
1426 == sql.bindparam(
1427 mapper.version_id_col.key, type_=mapper.version_id_col.type
1428 )
1429 )
1430
1431 return table.delete().where(clauses)
1432
1433 statement = base_mapper._memo(("delete", table), delete_stmt)
1434 for connection, recs in groupby(delete, lambda rec: rec[1]): # connection
1435 del_objects = [params for params, connection in recs]
1436
1437 execution_options = {"compiled_cache": base_mapper._compiled_cache}
1438 expected = len(del_objects)
1439 rows_matched = -1
1440 only_warn = False
1441
1442 if (
1443 need_version_id
1444 and not connection.dialect.supports_sane_multi_rowcount
1445 ):
1446 if connection.dialect.supports_sane_rowcount:
1447 rows_matched = 0
1448 # execute deletes individually so that versioned
1449 # rows can be verified
1450 for params in del_objects:
1451 c = connection.execute(
1452 statement, params, execution_options=execution_options
1453 )
1454 rows_matched += c.rowcount
1455 else:
1456 util.warn(
1457 "Dialect %s does not support deleted rowcount "
1458 "- versioning cannot be verified."
1459 % connection.dialect.dialect_description
1460 )
1461 connection.execute(
1462 statement, del_objects, execution_options=execution_options
1463 )
1464 else:
1465 c = connection.execute(
1466 statement, del_objects, execution_options=execution_options
1467 )
1468
1469 if not need_version_id:
1470 only_warn = True
1471
1472 rows_matched = c.rowcount
1473
1474 if (
1475 base_mapper.confirm_deleted_rows
1476 and rows_matched > -1
1477 and expected != rows_matched
1478 and (
1479 connection.dialect.supports_sane_multi_rowcount
1480 or len(del_objects) == 1
1481 )
1482 ):
1483 # TODO: why does this "only warn" if versioning is turned off,
1484 # whereas the UPDATE raises?
1485 if only_warn:
1486 util.warn(
1487 "DELETE statement on table '%s' expected to "
1488 "delete %d row(s); %d were matched. Please set "
1489 "confirm_deleted_rows=False within the mapper "
1490 "configuration to prevent this warning."
1491 % (table.description, expected, rows_matched)
1492 )
1493 else:
1494 raise orm_exc.StaleDataError(
1495 "DELETE statement on table '%s' expected to "
1496 "delete %d row(s); %d were matched. Please set "
1497 "confirm_deleted_rows=False within the mapper "
1498 "configuration to prevent this warning."
1499 % (table.description, expected, rows_matched)
1500 )
1501
1502
1503def _finalize_insert_update_commands(base_mapper, uowtransaction, states):
1504 """finalize state on states that have been inserted or updated,
1505 including calling after_insert/after_update events.
1506
1507 """
1508 for state, state_dict, mapper, connection, has_identity in states:
1509 if mapper._readonly_props:
1510 readonly = state.unmodified_intersection(
1511 [
1512 p.key
1513 for p in mapper._readonly_props
1514 if (
1515 p.expire_on_flush
1516 and (not p.deferred or p.key in state.dict)
1517 )
1518 or (
1519 not p.expire_on_flush
1520 and not p.deferred
1521 and p.key not in state.dict
1522 )
1523 ]
1524 )
1525 if readonly:
1526 state._expire_attributes(state.dict, readonly)
1527
1528 # if eager_defaults option is enabled, load
1529 # all expired cols. Else if we have a version_id_col, make sure
1530 # it isn't expired.
1531 toload_now = []
1532
1533 # this is specifically to emit a second SELECT for eager_defaults,
1534 # so only if it's set to True, not "auto"
1535 if base_mapper.eager_defaults is True:
1536 toload_now.extend(
1537 state._unloaded_non_object.intersection(
1538 mapper._server_default_plus_onupdate_propkeys
1539 )
1540 )
1541
1542 if (
1543 mapper.version_id_col is not None
1544 and mapper.version_id_generator is False
1545 ):
1546 if mapper._version_id_prop.key in state.unloaded:
1547 toload_now.extend([mapper._version_id_prop.key])
1548
1549 if toload_now:
1550 state.key = base_mapper._identity_key_from_state(state)
1551 stmt = future.select(mapper).set_label_style(
1552 LABEL_STYLE_TABLENAME_PLUS_COL
1553 )
1554 loading.load_on_ident(
1555 uowtransaction.session,
1556 stmt,
1557 state.key,
1558 refresh_state=state,
1559 only_load_props=toload_now,
1560 )
1561
1562 # call after_XXX extensions
1563 if not has_identity:
1564 mapper.dispatch.after_insert(mapper, connection, state)
1565 else:
1566 mapper.dispatch.after_update(mapper, connection, state)
1567
1568 if (
1569 mapper.version_id_generator is False
1570 and mapper.version_id_col is not None
1571 ):
1572 if state_dict[mapper._version_id_prop.key] is None:
1573 raise orm_exc.FlushError(
1574 "Instance does not contain a non-NULL version value"
1575 )
1576
1577
1578def _postfetch_post_update(
1579 mapper, uowtransaction, table, state, dict_, result, params
1580):
1581 needs_version_id = (
1582 mapper.version_id_col is not None
1583 and mapper.version_id_col in mapper._cols_by_table[table]
1584 )
1585
1586 if not uowtransaction.is_deleted(state):
1587 # post updating after a regular INSERT or UPDATE, do a full postfetch
1588 prefetch_cols = result.context.compiled.prefetch
1589 postfetch_cols = result.context.compiled.postfetch
1590 elif needs_version_id:
1591 # post updating before a DELETE with a version_id_col, need to
1592 # postfetch just version_id_col
1593 prefetch_cols = postfetch_cols = ()
1594 else:
1595 # post updating before a DELETE without a version_id_col,
1596 # don't need to postfetch
1597 return
1598
1599 if needs_version_id:
1600 prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
1601
1602 refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
1603 if refresh_flush:
1604 load_evt_attrs = []
1605
1606 for c in prefetch_cols:
1607 if c.key in params and c in mapper._columntoproperty:
1608 dict_[mapper._columntoproperty[c].key] = params[c.key]
1609 if refresh_flush:
1610 load_evt_attrs.append(mapper._columntoproperty[c].key)
1611
1612 if refresh_flush and load_evt_attrs:
1613 mapper.class_manager.dispatch.refresh_flush(
1614 state, uowtransaction, load_evt_attrs
1615 )
1616
1617 if postfetch_cols:
1618 state._expire_attributes(
1619 state.dict,
1620 [
1621 mapper._columntoproperty[c].key
1622 for c in postfetch_cols
1623 if c in mapper._columntoproperty
1624 ],
1625 )
1626
1627
1628def _postfetch(
1629 mapper,
1630 uowtransaction,
1631 table,
1632 state,
1633 dict_,
1634 result,
1635 params,
1636 value_params,
1637 isupdate,
1638 returned_defaults,
1639):
1640 """Expire attributes in need of newly persisted database state,
1641 after an INSERT or UPDATE statement has proceeded for that
1642 state."""
1643
1644 prefetch_cols = result.context.compiled.prefetch
1645 postfetch_cols = result.context.compiled.postfetch
1646 returning_cols = result.context.compiled.effective_returning
1647
1648 if (
1649 mapper.version_id_col is not None
1650 and mapper.version_id_col in mapper._cols_by_table[table]
1651 ):
1652 prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
1653
1654 refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
1655 if refresh_flush:
1656 load_evt_attrs = []
1657
1658 if returning_cols:
1659 row = returned_defaults
1660 if row is not None:
1661 for row_value, col in zip(row, returning_cols):
1662 # pk cols returned from insert are handled
1663 # distinctly, don't step on the values here
1664 if col.primary_key and result.context.isinsert:
1665 continue
1666
1667 # note that columns can be in the "return defaults" that are
1668 # not mapped to this mapper, typically because they are
1669 # "excluded", which can be specified directly or also occurs
1670 # when using declarative w/ single table inheritance
1671 prop = mapper._columntoproperty.get(col)
1672 if prop:
1673 dict_[prop.key] = row_value
1674 if refresh_flush:
1675 load_evt_attrs.append(prop.key)
1676
1677 for c in prefetch_cols:
1678 if c.key in params and c in mapper._columntoproperty:
1679 pkey = mapper._columntoproperty[c].key
1680
1681 # set prefetched value in dict and also pop from committed_state,
1682 # since this is new database state that replaces whatever might
1683 # have previously been fetched (see #10800). this is essentially a
1684 # shorthand version of set_committed_value(), which could also be
1685 # used here directly (with more overhead)
1686 dict_[pkey] = params[c.key]
1687 state.committed_state.pop(pkey, None)
1688
1689 if refresh_flush:
1690 load_evt_attrs.append(pkey)
1691
1692 if refresh_flush and load_evt_attrs:
1693 mapper.class_manager.dispatch.refresh_flush(
1694 state, uowtransaction, load_evt_attrs
1695 )
1696
1697 if isupdate and value_params:
1698 # explicitly suit the use case specified by
1699 # [ticket:3801], PK SQL expressions for UPDATE on non-RETURNING
1700 # database which are set to themselves in order to do a version bump.
1701 postfetch_cols.extend(
1702 [
1703 col
1704 for col in value_params
1705 if col.primary_key and col not in returning_cols
1706 ]
1707 )
1708
1709 if postfetch_cols:
1710 state._expire_attributes(
1711 state.dict,
1712 [
1713 mapper._columntoproperty[c].key
1714 for c in postfetch_cols
1715 if c in mapper._columntoproperty
1716 ],
1717 )
1718
1719 # synchronize newly inserted ids from one table to the next
1720 # TODO: this still goes a little too often. would be nice to
1721 # have definitive list of "columns that changed" here
1722 for m, equated_pairs in mapper._table_to_equated[table]:
1723 sync.populate(
1724 state,
1725 m,
1726 state,
1727 m,
1728 equated_pairs,
1729 uowtransaction,
1730 mapper.passive_updates,
1731 )
1732
1733
1734def _postfetch_bulk_save(mapper, dict_, table):
1735 for m, equated_pairs in mapper._table_to_equated[table]:
1736 sync.bulk_populate_inherit_keys(dict_, m, equated_pairs)
1737
1738
1739def _connections_for_states(base_mapper, uowtransaction, states):
1740 """Return an iterator of (state, state.dict, mapper, connection).
1741
1742 The states are sorted according to _sort_states, then paired
1743 with the connection they should be using for the given
1744 unit of work transaction.
1745
1746 """
1747 # if session has a connection callable,
1748 # organize individual states with the connection
1749 # to use for update
1750 if uowtransaction.session.connection_callable:
1751 connection_callable = uowtransaction.session.connection_callable
1752 else:
1753 connection = uowtransaction.transaction.connection(base_mapper)
1754 connection_callable = None
1755
1756 for state in _sort_states(base_mapper, states):
1757 if connection_callable:
1758 connection = connection_callable(base_mapper, state.obj())
1759
1760 mapper = state.manager.mapper
1761
1762 yield state, state.dict, mapper, connection
1763
1764
1765def _sort_states(mapper, states):
1766 pending = set(states)
1767 persistent = {s for s in pending if s.key is not None}
1768 pending.difference_update(persistent)
1769
1770 try:
1771 persistent_sorted = sorted(
1772 persistent, key=mapper._persistent_sortkey_fn
1773 )
1774 except TypeError as err:
1775 raise sa_exc.InvalidRequestError(
1776 "Could not sort objects by primary key; primary key "
1777 "values must be sortable in Python (was: %s)" % err
1778 ) from err
1779 return (
1780 sorted(pending, key=operator.attrgetter("insert_order"))
1781 + persistent_sorted
1782 )