1# orm/persistence.py 
    2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors 
    3# <see AUTHORS file> 
    4# 
    5# This module is part of SQLAlchemy and is released under 
    6# the MIT License: https://www.opensource.org/licenses/mit-license.php 
    7# mypy: ignore-errors 
    8 
    9 
    10"""private module containing functions used to emit INSERT, UPDATE 
    11and DELETE statements on behalf of a :class:`_orm.Mapper` and its descending 
    12mappers. 
    13 
    14The functions here are called only by the unit of work functions 
    15in unitofwork.py. 
    16 
    17""" 
    18from __future__ import annotations 
    19 
    20from itertools import chain 
    21from itertools import groupby 
    22from itertools import zip_longest 
    23import operator 
    24 
    25from . import attributes 
    26from . import exc as orm_exc 
    27from . import loading 
    28from . import sync 
    29from .base import state_str 
    30from .. import exc as sa_exc 
    31from .. import future 
    32from .. import sql 
    33from .. import util 
    34from ..engine import cursor as _cursor 
    35from ..sql import operators 
    36from ..sql.elements import BooleanClauseList 
    37from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL 
    38 
    39 
    40def _save_obj(base_mapper, states, uowtransaction, single=False): 
    41    """Issue ``INSERT`` and/or ``UPDATE`` statements for a list 
    42    of objects. 
    43 
    44    This is called within the context of a UOWTransaction during a 
    45    flush operation, given a list of states to be flushed.  The 
    46    base mapper in an inheritance hierarchy handles the inserts/ 
    47    updates for all descendant mappers. 
    48 
    49    """ 
    50 
    51    # if batch=false, call _save_obj separately for each object 
    52    if not single and not base_mapper.batch: 
    53        for state in _sort_states(base_mapper, states): 
    54            _save_obj(base_mapper, [state], uowtransaction, single=True) 
    55        return 
    56 
    57    states_to_update = [] 
    58    states_to_insert = [] 
    59 
    60    for ( 
    61        state, 
    62        dict_, 
    63        mapper, 
    64        connection, 
    65        has_identity, 
    66        row_switch, 
    67        update_version_id, 
    68    ) in _organize_states_for_save(base_mapper, states, uowtransaction): 
    69        if has_identity or row_switch: 
    70            states_to_update.append( 
    71                (state, dict_, mapper, connection, update_version_id) 
    72            ) 
    73        else: 
    74            states_to_insert.append((state, dict_, mapper, connection)) 
    75 
    76    for table, mapper in base_mapper._sorted_tables.items(): 
    77        if table not in mapper._pks_by_table: 
    78            continue 
    79        insert = _collect_insert_commands(table, states_to_insert) 
    80 
    81        update = _collect_update_commands( 
    82            uowtransaction, table, states_to_update 
    83        ) 
    84 
    85        _emit_update_statements( 
    86            base_mapper, 
    87            uowtransaction, 
    88            mapper, 
    89            table, 
    90            update, 
    91        ) 
    92 
    93        _emit_insert_statements( 
    94            base_mapper, 
    95            uowtransaction, 
    96            mapper, 
    97            table, 
    98            insert, 
    99        ) 
    100 
    101    _finalize_insert_update_commands( 
    102        base_mapper, 
    103        uowtransaction, 
    104        chain( 
    105            ( 
    106                (state, state_dict, mapper, connection, False) 
    107                for (state, state_dict, mapper, connection) in states_to_insert 
    108            ), 
    109            ( 
    110                (state, state_dict, mapper, connection, True) 
    111                for ( 
    112                    state, 
    113                    state_dict, 
    114                    mapper, 
    115                    connection, 
    116                    update_version_id, 
    117                ) in states_to_update 
    118            ), 
    119        ), 
    120    ) 
    121 
    122 
    123def _post_update(base_mapper, states, uowtransaction, post_update_cols): 
    124    """Issue UPDATE statements on behalf of a relationship() which 
    125    specifies post_update. 
    126 
    127    """ 
    128 
    129    states_to_update = list( 
    130        _organize_states_for_post_update(base_mapper, states, uowtransaction) 
    131    ) 
    132 
    133    for table, mapper in base_mapper._sorted_tables.items(): 
    134        if table not in mapper._pks_by_table: 
    135            continue 
    136 
    137        update = ( 
    138            ( 
    139                state, 
    140                state_dict, 
    141                sub_mapper, 
    142                connection, 
    143                ( 
    144                    mapper._get_committed_state_attr_by_column( 
    145                        state, state_dict, mapper.version_id_col 
    146                    ) 
    147                    if mapper.version_id_col is not None 
    148                    else None 
    149                ), 
    150            ) 
    151            for state, state_dict, sub_mapper, connection in states_to_update 
    152            if table in sub_mapper._pks_by_table 
    153        ) 
    154 
    155        update = _collect_post_update_commands( 
    156            base_mapper, uowtransaction, table, update, post_update_cols 
    157        ) 
    158 
    159        _emit_post_update_statements( 
    160            base_mapper, 
    161            uowtransaction, 
    162            mapper, 
    163            table, 
    164            update, 
    165        ) 
    166 
    167 
    168def _delete_obj(base_mapper, states, uowtransaction): 
    169    """Issue ``DELETE`` statements for a list of objects. 
    170 
    171    This is called within the context of a UOWTransaction during a 
    172    flush operation. 
    173 
    174    """ 
    175 
    176    states_to_delete = list( 
    177        _organize_states_for_delete(base_mapper, states, uowtransaction) 
    178    ) 
    179 
    180    table_to_mapper = base_mapper._sorted_tables 
    181 
    182    for table in reversed(list(table_to_mapper.keys())): 
    183        mapper = table_to_mapper[table] 
    184        if table not in mapper._pks_by_table: 
    185            continue 
    186        elif mapper.inherits and mapper.passive_deletes: 
    187            continue 
    188 
    189        delete = _collect_delete_commands( 
    190            base_mapper, uowtransaction, table, states_to_delete 
    191        ) 
    192 
    193        _emit_delete_statements( 
    194            base_mapper, 
    195            uowtransaction, 
    196            mapper, 
    197            table, 
    198            delete, 
    199        ) 
    200 
    201    for ( 
    202        state, 
    203        state_dict, 
    204        mapper, 
    205        connection, 
    206        update_version_id, 
    207    ) in states_to_delete: 
    208        mapper.dispatch.after_delete(mapper, connection, state) 
    209 
    210 
    211def _organize_states_for_save(base_mapper, states, uowtransaction): 
    212    """Make an initial pass across a set of states for INSERT or 
    213    UPDATE. 
    214 
    215    This includes splitting out into distinct lists for 
    216    each, calling before_insert/before_update, obtaining 
    217    key information for each state including its dictionary, 
    218    mapper, the connection to use for the execution per state, 
    219    and the identity flag. 
    220 
    221    """ 
    222 
    223    for state, dict_, mapper, connection in _connections_for_states( 
    224        base_mapper, uowtransaction, states 
    225    ): 
    226        has_identity = bool(state.key) 
    227 
    228        instance_key = state.key or mapper._identity_key_from_state(state) 
    229 
    230        row_switch = update_version_id = None 
    231 
    232        # call before_XXX extensions 
    233        if not has_identity: 
    234            mapper.dispatch.before_insert(mapper, connection, state) 
    235        else: 
    236            mapper.dispatch.before_update(mapper, connection, state) 
    237 
    238        if mapper._validate_polymorphic_identity: 
    239            mapper._validate_polymorphic_identity(mapper, state, dict_) 
    240 
    241        # detect if we have a "pending" instance (i.e. has 
    242        # no instance_key attached to it), and another instance 
    243        # with the same identity key already exists as persistent. 
    244        # convert to an UPDATE if so. 
    245        if ( 
    246            not has_identity 
    247            and instance_key in uowtransaction.session.identity_map 
    248        ): 
    249            instance = uowtransaction.session.identity_map[instance_key] 
    250            existing = attributes.instance_state(instance) 
    251 
    252            if not uowtransaction.was_already_deleted(existing): 
    253                if not uowtransaction.is_deleted(existing): 
    254                    util.warn( 
    255                        "New instance %s with identity key %s conflicts " 
    256                        "with persistent instance %s" 
    257                        % (state_str(state), instance_key, state_str(existing)) 
    258                    ) 
    259                else: 
    260                    base_mapper._log_debug( 
    261                        "detected row switch for identity %s.  " 
    262                        "will update %s, remove %s from " 
    263                        "transaction", 
    264                        instance_key, 
    265                        state_str(state), 
    266                        state_str(existing), 
    267                    ) 
    268 
    269                    # remove the "delete" flag from the existing element 
    270                    uowtransaction.remove_state_actions(existing) 
    271                    row_switch = existing 
    272 
    273        if (has_identity or row_switch) and mapper.version_id_col is not None: 
    274            update_version_id = mapper._get_committed_state_attr_by_column( 
    275                row_switch if row_switch else state, 
    276                row_switch.dict if row_switch else dict_, 
    277                mapper.version_id_col, 
    278            ) 
    279 
    280        yield ( 
    281            state, 
    282            dict_, 
    283            mapper, 
    284            connection, 
    285            has_identity, 
    286            row_switch, 
    287            update_version_id, 
    288        ) 
    289 
    290 
    291def _organize_states_for_post_update(base_mapper, states, uowtransaction): 
    292    """Make an initial pass across a set of states for UPDATE 
    293    corresponding to post_update. 
    294 
    295    This includes obtaining key information for each state 
    296    including its dictionary, mapper, the connection to use for 
    297    the execution per state. 
    298 
    299    """ 
    300    return _connections_for_states(base_mapper, uowtransaction, states) 
    301 
    302 
    303def _organize_states_for_delete(base_mapper, states, uowtransaction): 
    304    """Make an initial pass across a set of states for DELETE. 
    305 
    306    This includes calling out before_delete and obtaining 
    307    key information for each state including its dictionary, 
    308    mapper, the connection to use for the execution per state. 
    309 
    310    """ 
    311    for state, dict_, mapper, connection in _connections_for_states( 
    312        base_mapper, uowtransaction, states 
    313    ): 
    314        mapper.dispatch.before_delete(mapper, connection, state) 
    315 
    316        if mapper.version_id_col is not None: 
    317            update_version_id = mapper._get_committed_state_attr_by_column( 
    318                state, dict_, mapper.version_id_col 
    319            ) 
    320        else: 
    321            update_version_id = None 
    322 
    323        yield (state, dict_, mapper, connection, update_version_id) 
    324 
    325 
    326def _collect_insert_commands( 
    327    table, 
    328    states_to_insert, 
    329    *, 
    330    bulk=False, 
    331    return_defaults=False, 
    332    render_nulls=False, 
    333    include_bulk_keys=(), 
    334): 
    335    """Identify sets of values to use in INSERT statements for a 
    336    list of states. 
    337 
    338    """ 
    339    for state, state_dict, mapper, connection in states_to_insert: 
    340        if table not in mapper._pks_by_table: 
    341            continue 
    342 
    343        params = {} 
    344        value_params = {} 
    345 
    346        propkey_to_col = mapper._propkey_to_col[table] 
    347 
    348        eval_none = mapper._insert_cols_evaluating_none[table] 
    349 
    350        for propkey in set(propkey_to_col).intersection(state_dict): 
    351            value = state_dict[propkey] 
    352            col = propkey_to_col[propkey] 
    353            if value is None and col not in eval_none and not render_nulls: 
    354                continue 
    355            elif not bulk and ( 
    356                hasattr(value, "__clause_element__") 
    357                or isinstance(value, sql.ClauseElement) 
    358            ): 
    359                value_params[col] = ( 
    360                    value.__clause_element__() 
    361                    if hasattr(value, "__clause_element__") 
    362                    else value 
    363                ) 
    364            else: 
    365                params[col.key] = value 
    366 
    367        if not bulk: 
    368            # for all the columns that have no default and we don't have 
    369            # a value and where "None" is not a special value, add 
    370            # explicit None to the INSERT.   This is a legacy behavior 
    371            # which might be worth removing, as it should not be necessary 
    372            # and also produces confusion, given that "missing" and None 
    373            # now have distinct meanings 
    374            for colkey in ( 
    375                mapper._insert_cols_as_none[table] 
    376                .difference(params) 
    377                .difference([c.key for c in value_params]) 
    378            ): 
    379                params[colkey] = None 
    380 
    381        if not bulk or return_defaults: 
    382            # params are in terms of Column key objects, so 
    383            # compare to pk_keys_by_table 
    384            has_all_pks = mapper._pk_keys_by_table[table].issubset(params) 
    385 
    386            if mapper.base_mapper._prefer_eager_defaults( 
    387                connection.dialect, table 
    388            ): 
    389                has_all_defaults = mapper._server_default_col_keys[ 
    390                    table 
    391                ].issubset(params) 
    392            else: 
    393                has_all_defaults = True 
    394        else: 
    395            has_all_defaults = has_all_pks = True 
    396 
    397        if ( 
    398            mapper.version_id_generator is not False 
    399            and mapper.version_id_col is not None 
    400            and mapper.version_id_col in mapper._cols_by_table[table] 
    401        ): 
    402            params[mapper.version_id_col.key] = mapper.version_id_generator( 
    403                None 
    404            ) 
    405 
    406        if bulk: 
    407            if mapper._set_polymorphic_identity: 
    408                params.setdefault( 
    409                    mapper._polymorphic_attr_key, mapper.polymorphic_identity 
    410                ) 
    411 
    412            if include_bulk_keys: 
    413                params.update((k, state_dict[k]) for k in include_bulk_keys) 
    414 
    415        yield ( 
    416            state, 
    417            state_dict, 
    418            params, 
    419            mapper, 
    420            connection, 
    421            value_params, 
    422            has_all_pks, 
    423            has_all_defaults, 
    424        ) 
    425 
    426 
    427def _collect_update_commands( 
    428    uowtransaction, 
    429    table, 
    430    states_to_update, 
    431    *, 
    432    bulk=False, 
    433    use_orm_update_stmt=None, 
    434    include_bulk_keys=(), 
    435): 
    436    """Identify sets of values to use in UPDATE statements for a 
    437    list of states. 
    438 
    439    This function works intricately with the history system 
    440    to determine exactly what values should be updated 
    441    as well as how the row should be matched within an UPDATE 
    442    statement.  Includes some tricky scenarios where the primary 
    443    key of an object might have been changed. 
    444 
    445    """ 
    446 
    447    for ( 
    448        state, 
    449        state_dict, 
    450        mapper, 
    451        connection, 
    452        update_version_id, 
    453    ) in states_to_update: 
    454        if table not in mapper._pks_by_table: 
    455            continue 
    456 
    457        pks = mapper._pks_by_table[table] 
    458 
    459        if ( 
    460            use_orm_update_stmt is not None 
    461            and not use_orm_update_stmt._maintain_values_ordering 
    462        ): 
    463            # TODO: ordered values, etc 
    464            # ORM bulk_persistence will raise for the maintain_values_ordering 
    465            # case right now 
    466            value_params = use_orm_update_stmt._values 
    467        else: 
    468            value_params = {} 
    469 
    470        propkey_to_col = mapper._propkey_to_col[table] 
    471 
    472        if bulk: 
    473            # keys here are mapped attribute keys, so 
    474            # look at mapper attribute keys for pk 
    475            params = { 
    476                propkey_to_col[propkey].key: state_dict[propkey] 
    477                for propkey in set(propkey_to_col) 
    478                .intersection(state_dict) 
    479                .difference(mapper._pk_attr_keys_by_table[table]) 
    480            } 
    481            has_all_defaults = True 
    482        else: 
    483            params = {} 
    484            for propkey in set(propkey_to_col).intersection( 
    485                state.committed_state 
    486            ): 
    487                value = state_dict[propkey] 
    488                col = propkey_to_col[propkey] 
    489 
    490                if hasattr(value, "__clause_element__") or isinstance( 
    491                    value, sql.ClauseElement 
    492                ): 
    493                    value_params[col] = ( 
    494                        value.__clause_element__() 
    495                        if hasattr(value, "__clause_element__") 
    496                        else value 
    497                    ) 
    498                # guard against values that generate non-__nonzero__ 
    499                # objects for __eq__() 
    500                elif ( 
    501                    state.manager[propkey].impl.is_equal( 
    502                        value, state.committed_state[propkey] 
    503                    ) 
    504                    is not True 
    505                ): 
    506                    params[col.key] = value 
    507 
    508            if mapper.base_mapper.eager_defaults is True: 
    509                has_all_defaults = ( 
    510                    mapper._server_onupdate_default_col_keys[table] 
    511                ).issubset(params) 
    512            else: 
    513                has_all_defaults = True 
    514 
    515        if ( 
    516            update_version_id is not None 
    517            and mapper.version_id_col in mapper._cols_by_table[table] 
    518        ): 
    519            if not bulk and not (params or value_params): 
    520                # HACK: check for history in other tables, in case the 
    521                # history is only in a different table than the one 
    522                # where the version_id_col is.  This logic was lost 
    523                # from 0.9 -> 1.0.0 and restored in 1.0.6. 
    524                for prop in mapper._columntoproperty.values(): 
    525                    history = state.manager[prop.key].impl.get_history( 
    526                        state, state_dict, attributes.PASSIVE_NO_INITIALIZE 
    527                    ) 
    528                    if history.added: 
    529                        break 
    530                else: 
    531                    # no net change, break 
    532                    continue 
    533 
    534            col = mapper.version_id_col 
    535            no_params = not params and not value_params 
    536            params[col._label] = update_version_id 
    537 
    538            if ( 
    539                bulk or col.key not in params 
    540            ) and mapper.version_id_generator is not False: 
    541                val = mapper.version_id_generator(update_version_id) 
    542                params[col.key] = val 
    543            elif mapper.version_id_generator is False and no_params: 
    544                # no version id generator, no values set on the table, 
    545                # and version id wasn't manually incremented. 
    546                # set version id to itself so we get an UPDATE 
    547                # statement 
    548                params[col.key] = update_version_id 
    549 
    550        elif not (params or value_params): 
    551            continue 
    552 
    553        has_all_pks = True 
    554        expect_pk_cascaded = False 
    555        if bulk: 
    556            # keys here are mapped attribute keys, so 
    557            # look at mapper attribute keys for pk 
    558            pk_params = { 
    559                propkey_to_col[propkey]._label: state_dict.get(propkey) 
    560                for propkey in set(propkey_to_col).intersection( 
    561                    mapper._pk_attr_keys_by_table[table] 
    562                ) 
    563            } 
    564            if util.NONE_SET.intersection(pk_params.values()): 
    565                raise sa_exc.InvalidRequestError( 
    566                    f"No primary key value supplied for column(s) " 
    567                    f"""{ 
    568                        ', '.join( 
    569                            str(c) for c in pks if pk_params[c._label] is None 
    570                        ) 
    571                    }; """ 
    572                    "per-row ORM Bulk UPDATE by Primary Key requires that " 
    573                    "records contain primary key values", 
    574                    code="bupq", 
    575                ) 
    576 
    577        else: 
    578            pk_params = {} 
    579            for col in pks: 
    580                propkey = mapper._columntoproperty[col].key 
    581 
    582                history = state.manager[propkey].impl.get_history( 
    583                    state, state_dict, attributes.PASSIVE_OFF 
    584                ) 
    585 
    586                if history.added: 
    587                    if ( 
    588                        not history.deleted 
    589                        or ("pk_cascaded", state, col) 
    590                        in uowtransaction.attributes 
    591                    ): 
    592                        expect_pk_cascaded = True 
    593                        pk_params[col._label] = history.added[0] 
    594                        params.pop(col.key, None) 
    595                    else: 
    596                        # else, use the old value to locate the row 
    597                        pk_params[col._label] = history.deleted[0] 
    598                        if col in value_params: 
    599                            has_all_pks = False 
    600                else: 
    601                    pk_params[col._label] = history.unchanged[0] 
    602                if pk_params[col._label] is None: 
    603                    raise orm_exc.FlushError( 
    604                        "Can't update table %s using NULL for primary " 
    605                        "key value on column %s" % (table, col) 
    606                    ) 
    607 
    608        if include_bulk_keys: 
    609            params.update((k, state_dict[k]) for k in include_bulk_keys) 
    610 
    611        if params or value_params: 
    612            params.update(pk_params) 
    613            yield ( 
    614                state, 
    615                state_dict, 
    616                params, 
    617                mapper, 
    618                connection, 
    619                value_params, 
    620                has_all_defaults, 
    621                has_all_pks, 
    622            ) 
    623        elif expect_pk_cascaded: 
    624            # no UPDATE occurs on this table, but we expect that CASCADE rules 
    625            # have changed the primary key of the row; propagate this event to 
    626            # other columns that expect to have been modified. this normally 
    627            # occurs after the UPDATE is emitted however we invoke it here 
    628            # explicitly in the absence of our invoking an UPDATE 
    629            for m, equated_pairs in mapper._table_to_equated[table]: 
    630                sync._populate( 
    631                    state, 
    632                    m, 
    633                    state, 
    634                    m, 
    635                    equated_pairs, 
    636                    uowtransaction, 
    637                    mapper.passive_updates, 
    638                ) 
    639 
    640 
    641def _collect_post_update_commands( 
    642    base_mapper, uowtransaction, table, states_to_update, post_update_cols 
    643): 
    644    """Identify sets of values to use in UPDATE statements for a 
    645    list of states within a post_update operation. 
    646 
    647    """ 
    648 
    649    for ( 
    650        state, 
    651        state_dict, 
    652        mapper, 
    653        connection, 
    654        update_version_id, 
    655    ) in states_to_update: 
    656        # assert table in mapper._pks_by_table 
    657 
    658        pks = mapper._pks_by_table[table] 
    659        params = {} 
    660        hasdata = False 
    661 
    662        for col in mapper._cols_by_table[table]: 
    663            if col in pks: 
    664                params[col._label] = mapper._get_state_attr_by_column( 
    665                    state, state_dict, col, passive=attributes.PASSIVE_OFF 
    666                ) 
    667 
    668            elif col in post_update_cols or col.onupdate is not None: 
    669                prop = mapper._columntoproperty[col] 
    670                history = state.manager[prop.key].impl.get_history( 
    671                    state, state_dict, attributes.PASSIVE_NO_INITIALIZE 
    672                ) 
    673                if history.added: 
    674                    value = history.added[0] 
    675                    params[col.key] = value 
    676                    hasdata = True 
    677        if hasdata: 
    678            if ( 
    679                update_version_id is not None 
    680                and mapper.version_id_col in mapper._cols_by_table[table] 
    681            ): 
    682                col = mapper.version_id_col 
    683                params[col._label] = update_version_id 
    684 
    685                if ( 
    686                    bool(state.key) 
    687                    and col.key not in params 
    688                    and mapper.version_id_generator is not False 
    689                ): 
    690                    val = mapper.version_id_generator(update_version_id) 
    691                    params[col.key] = val 
    692            yield state, state_dict, mapper, connection, params 
    693 
    694 
    695def _collect_delete_commands( 
    696    base_mapper, uowtransaction, table, states_to_delete 
    697): 
    698    """Identify values to use in DELETE statements for a list of 
    699    states to be deleted.""" 
    700 
    701    for ( 
    702        state, 
    703        state_dict, 
    704        mapper, 
    705        connection, 
    706        update_version_id, 
    707    ) in states_to_delete: 
    708        if table not in mapper._pks_by_table: 
    709            continue 
    710 
    711        params = {} 
    712        for col in mapper._pks_by_table[table]: 
    713            params[col.key] = value = ( 
    714                mapper._get_committed_state_attr_by_column( 
    715                    state, state_dict, col 
    716                ) 
    717            ) 
    718            if value is None: 
    719                raise orm_exc.FlushError( 
    720                    "Can't delete from table %s " 
    721                    "using NULL for primary " 
    722                    "key value on column %s" % (table, col) 
    723                ) 
    724 
    725        if ( 
    726            update_version_id is not None 
    727            and mapper.version_id_col in mapper._cols_by_table[table] 
    728        ): 
    729            params[mapper.version_id_col.key] = update_version_id 
    730        yield params, connection 
    731 
    732 
    733def _emit_update_statements( 
    734    base_mapper, 
    735    uowtransaction, 
    736    mapper, 
    737    table, 
    738    update, 
    739    *, 
    740    bookkeeping=True, 
    741    use_orm_update_stmt=None, 
    742    enable_check_rowcount=True, 
    743): 
    744    """Emit UPDATE statements corresponding to value lists collected 
    745    by _collect_update_commands().""" 
    746 
    747    needs_version_id = ( 
    748        mapper.version_id_col is not None 
    749        and mapper.version_id_col in mapper._cols_by_table[table] 
    750    ) 
    751 
    752    execution_options = {"compiled_cache": base_mapper._compiled_cache} 
    753 
    754    def update_stmt(existing_stmt=None): 
    755        clauses = BooleanClauseList._construct_raw(operators.and_) 
    756 
    757        for col in mapper._pks_by_table[table]: 
    758            clauses._append_inplace( 
    759                col == sql.bindparam(col._label, type_=col.type) 
    760            ) 
    761 
    762        if needs_version_id: 
    763            clauses._append_inplace( 
    764                mapper.version_id_col 
    765                == sql.bindparam( 
    766                    mapper.version_id_col._label, 
    767                    type_=mapper.version_id_col.type, 
    768                ) 
    769            ) 
    770 
    771        if existing_stmt is not None: 
    772            stmt = existing_stmt.where(clauses) 
    773        else: 
    774            stmt = table.update().where(clauses) 
    775        return stmt 
    776 
    777    if use_orm_update_stmt is not None: 
    778        cached_stmt = update_stmt(use_orm_update_stmt) 
    779 
    780    else: 
    781        cached_stmt = base_mapper._memo(("update", table), update_stmt) 
    782 
    783    for ( 
    784        (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks), 
    785        records, 
    786    ) in groupby( 
    787        update, 
    788        lambda rec: ( 
    789            rec[4],  # connection 
    790            set(rec[2]),  # set of parameter keys 
    791            bool(rec[5]),  # whether or not we have "value" parameters 
    792            rec[6],  # has_all_defaults 
    793            rec[7],  # has all pks 
    794        ), 
    795    ): 
    796        rows = 0 
    797        records = list(records) 
    798 
    799        statement = cached_stmt 
    800 
    801        if use_orm_update_stmt is not None: 
    802            statement = statement._annotate( 
    803                { 
    804                    "_emit_update_table": table, 
    805                    "_emit_update_mapper": mapper, 
    806                } 
    807            ) 
    808 
    809        return_defaults = False 
    810 
    811        if not has_all_pks: 
    812            statement = statement.return_defaults(*mapper._pks_by_table[table]) 
    813            return_defaults = True 
    814 
    815        if ( 
    816            bookkeeping 
    817            and not has_all_defaults 
    818            and mapper.base_mapper.eager_defaults is True 
    819            # change as of #8889 - if RETURNING is not going to be used anyway, 
    820            # (applies to MySQL, MariaDB which lack UPDATE RETURNING) ensure 
    821            # we can do an executemany UPDATE which is more efficient 
    822            and table.implicit_returning 
    823            and connection.dialect.update_returning 
    824        ): 
    825            statement = statement.return_defaults( 
    826                *mapper._server_onupdate_default_cols[table] 
    827            ) 
    828            return_defaults = True 
    829 
    830        if mapper._version_id_has_server_side_value: 
    831            statement = statement.return_defaults(mapper.version_id_col) 
    832            return_defaults = True 
    833 
    834        assert_singlerow = connection.dialect.supports_sane_rowcount 
    835 
    836        assert_multirow = ( 
    837            assert_singlerow 
    838            and connection.dialect.supports_sane_multi_rowcount 
    839        ) 
    840 
    841        # change as of #8889 - if RETURNING is not going to be used anyway, 
    842        # (applies to MySQL, MariaDB which lack UPDATE RETURNING) ensure 
    843        # we can do an executemany UPDATE which is more efficient 
    844        allow_executemany = not return_defaults and not needs_version_id 
    845 
    846        if hasvalue: 
    847            for ( 
    848                state, 
    849                state_dict, 
    850                params, 
    851                mapper, 
    852                connection, 
    853                value_params, 
    854                has_all_defaults, 
    855                has_all_pks, 
    856            ) in records: 
    857                c = connection.execute( 
    858                    statement.values(value_params), 
    859                    params, 
    860                    execution_options=execution_options, 
    861                ) 
    862                if bookkeeping: 
    863                    _postfetch( 
    864                        mapper, 
    865                        uowtransaction, 
    866                        table, 
    867                        state, 
    868                        state_dict, 
    869                        c, 
    870                        c.context.compiled_parameters[0], 
    871                        value_params, 
    872                        True, 
    873                        c.returned_defaults, 
    874                    ) 
    875                rows += c.rowcount 
    876                check_rowcount = enable_check_rowcount and assert_singlerow 
    877        else: 
    878            if not allow_executemany: 
    879                check_rowcount = enable_check_rowcount and assert_singlerow 
    880                for ( 
    881                    state, 
    882                    state_dict, 
    883                    params, 
    884                    mapper, 
    885                    connection, 
    886                    value_params, 
    887                    has_all_defaults, 
    888                    has_all_pks, 
    889                ) in records: 
    890                    c = connection.execute( 
    891                        statement, params, execution_options=execution_options 
    892                    ) 
    893 
    894                    # TODO: why with bookkeeping=False? 
    895                    if bookkeeping: 
    896                        _postfetch( 
    897                            mapper, 
    898                            uowtransaction, 
    899                            table, 
    900                            state, 
    901                            state_dict, 
    902                            c, 
    903                            c.context.compiled_parameters[0], 
    904                            value_params, 
    905                            True, 
    906                            c.returned_defaults, 
    907                        ) 
    908                    rows += c.rowcount 
    909            else: 
    910                multiparams = [rec[2] for rec in records] 
    911 
    912                check_rowcount = enable_check_rowcount and ( 
    913                    assert_multirow 
    914                    or (assert_singlerow and len(multiparams) == 1) 
    915                ) 
    916 
    917                c = connection.execute( 
    918                    statement, multiparams, execution_options=execution_options 
    919                ) 
    920 
    921                rows += c.rowcount 
    922 
    923                for ( 
    924                    state, 
    925                    state_dict, 
    926                    params, 
    927                    mapper, 
    928                    connection, 
    929                    value_params, 
    930                    has_all_defaults, 
    931                    has_all_pks, 
    932                ) in records: 
    933                    if bookkeeping: 
    934                        _postfetch( 
    935                            mapper, 
    936                            uowtransaction, 
    937                            table, 
    938                            state, 
    939                            state_dict, 
    940                            c, 
    941                            c.context.compiled_parameters[0], 
    942                            value_params, 
    943                            True, 
    944                            ( 
    945                                c.returned_defaults 
    946                                if not c.context.executemany 
    947                                else None 
    948                            ), 
    949                        ) 
    950 
    951        if check_rowcount: 
    952            if rows != len(records): 
    953                raise orm_exc.StaleDataError( 
    954                    "UPDATE statement on table '%s' expected to " 
    955                    "update %d row(s); %d were matched." 
    956                    % (table.description, len(records), rows) 
    957                ) 
    958 
    959        elif needs_version_id: 
    960            util.warn( 
    961                "Dialect %s does not support updated rowcount " 
    962                "- versioning cannot be verified." 
    963                % c.dialect.dialect_description 
    964            ) 
    965 
    966 
    967def _emit_insert_statements( 
    968    base_mapper, 
    969    uowtransaction, 
    970    mapper, 
    971    table, 
    972    insert, 
    973    *, 
    974    bookkeeping=True, 
    975    use_orm_insert_stmt=None, 
    976    execution_options=None, 
    977): 
    978    """Emit INSERT statements corresponding to value lists collected 
    979    by _collect_insert_commands().""" 
    980 
    981    if use_orm_insert_stmt is not None: 
    982        cached_stmt = use_orm_insert_stmt 
    983        exec_opt = util.EMPTY_DICT 
    984 
    985        # if a user query with RETURNING was passed, we definitely need 
    986        # to use RETURNING. 
    987        returning_is_required_anyway = bool(use_orm_insert_stmt._returning) 
    988        deterministic_results_reqd = ( 
    989            returning_is_required_anyway 
    990            and use_orm_insert_stmt._sort_by_parameter_order 
    991        ) or bookkeeping 
    992    else: 
    993        returning_is_required_anyway = False 
    994        deterministic_results_reqd = bookkeeping 
    995        cached_stmt = base_mapper._memo(("insert", table), table.insert) 
    996        exec_opt = {"compiled_cache": base_mapper._compiled_cache} 
    997 
    998    if execution_options: 
    999        execution_options = util.EMPTY_DICT.merge_with( 
    1000            exec_opt, execution_options 
    1001        ) 
    1002    else: 
    1003        execution_options = exec_opt 
    1004 
    1005    return_result = None 
    1006 
    1007    for ( 
    1008        (connection, _, hasvalue, has_all_pks, has_all_defaults), 
    1009        records, 
    1010    ) in groupby( 
    1011        insert, 
    1012        lambda rec: ( 
    1013            rec[4],  # connection 
    1014            set(rec[2]),  # parameter keys 
    1015            bool(rec[5]),  # whether we have "value" parameters 
    1016            rec[6], 
    1017            rec[7], 
    1018        ), 
    1019    ): 
    1020        statement = cached_stmt 
    1021 
    1022        if use_orm_insert_stmt is not None: 
    1023            statement = statement._annotate( 
    1024                { 
    1025                    "_emit_insert_table": table, 
    1026                    "_emit_insert_mapper": mapper, 
    1027                } 
    1028            ) 
    1029 
    1030        if ( 
    1031            ( 
    1032                not bookkeeping 
    1033                or ( 
    1034                    has_all_defaults 
    1035                    or not base_mapper._prefer_eager_defaults( 
    1036                        connection.dialect, table 
    1037                    ) 
    1038                    or not table.implicit_returning 
    1039                    or not connection.dialect.insert_returning 
    1040                ) 
    1041            ) 
    1042            and not returning_is_required_anyway 
    1043            and has_all_pks 
    1044            and not hasvalue 
    1045        ): 
    1046            # the "we don't need newly generated values back" section. 
    1047            # here we have all the PKs, all the defaults or we don't want 
    1048            # to fetch them, or the dialect doesn't support RETURNING at all 
    1049            # so we have to post-fetch / use lastrowid anyway. 
    1050            records = list(records) 
    1051            multiparams = [rec[2] for rec in records] 
    1052 
    1053            result = connection.execute( 
    1054                statement, multiparams, execution_options=execution_options 
    1055            ) 
    1056            if bookkeeping: 
    1057                for ( 
    1058                    ( 
    1059                        state, 
    1060                        state_dict, 
    1061                        params, 
    1062                        mapper_rec, 
    1063                        conn, 
    1064                        value_params, 
    1065                        has_all_pks, 
    1066                        has_all_defaults, 
    1067                    ), 
    1068                    last_inserted_params, 
    1069                ) in zip(records, result.context.compiled_parameters): 
    1070                    if state: 
    1071                        _postfetch( 
    1072                            mapper_rec, 
    1073                            uowtransaction, 
    1074                            table, 
    1075                            state, 
    1076                            state_dict, 
    1077                            result, 
    1078                            last_inserted_params, 
    1079                            value_params, 
    1080                            False, 
    1081                            ( 
    1082                                result.returned_defaults 
    1083                                if not result.context.executemany 
    1084                                else None 
    1085                            ), 
    1086                        ) 
    1087                    else: 
    1088                        _postfetch_bulk_save(mapper_rec, state_dict, table) 
    1089 
    1090        else: 
    1091            # here, we need defaults and/or pk values back or we otherwise 
    1092            # know that we are using RETURNING in any case 
    1093 
    1094            records = list(records) 
    1095 
    1096            if returning_is_required_anyway or ( 
    1097                table.implicit_returning and not hasvalue and len(records) > 1 
    1098            ): 
    1099                if ( 
    1100                    deterministic_results_reqd 
    1101                    and connection.dialect.insert_executemany_returning_sort_by_parameter_order  # noqa: E501 
    1102                ) or ( 
    1103                    not deterministic_results_reqd 
    1104                    and connection.dialect.insert_executemany_returning 
    1105                ): 
    1106                    do_executemany = True 
    1107                elif returning_is_required_anyway: 
    1108                    if deterministic_results_reqd: 
    1109                        dt = " with RETURNING and sort by parameter order" 
    1110                    else: 
    1111                        dt = " with RETURNING" 
    1112                    raise sa_exc.InvalidRequestError( 
    1113                        f"Can't use explicit RETURNING for bulk INSERT " 
    1114                        f"operation with " 
    1115                        f"{connection.dialect.dialect_description} backend; " 
    1116                        f"executemany{dt} is not enabled for this dialect." 
    1117                    ) 
    1118                else: 
    1119                    do_executemany = False 
    1120            else: 
    1121                do_executemany = False 
    1122 
    1123            if use_orm_insert_stmt is None: 
    1124                if ( 
    1125                    not has_all_defaults 
    1126                    and base_mapper._prefer_eager_defaults( 
    1127                        connection.dialect, table 
    1128                    ) 
    1129                ): 
    1130                    statement = statement.return_defaults( 
    1131                        *mapper._server_default_cols[table], 
    1132                        sort_by_parameter_order=bookkeeping, 
    1133                    ) 
    1134 
    1135            if mapper.version_id_col is not None: 
    1136                statement = statement.return_defaults( 
    1137                    mapper.version_id_col, 
    1138                    sort_by_parameter_order=bookkeeping, 
    1139                ) 
    1140            elif do_executemany: 
    1141                statement = statement.return_defaults( 
    1142                    *table.primary_key, sort_by_parameter_order=bookkeeping 
    1143                ) 
    1144 
    1145            if do_executemany: 
    1146                multiparams = [rec[2] for rec in records] 
    1147 
    1148                result = connection.execute( 
    1149                    statement, multiparams, execution_options=execution_options 
    1150                ) 
    1151 
    1152                if use_orm_insert_stmt is not None: 
    1153                    if return_result is None: 
    1154                        return_result = result 
    1155                    else: 
    1156                        return_result = return_result.splice_vertically(result) 
    1157 
    1158                if bookkeeping: 
    1159                    for ( 
    1160                        ( 
    1161                            state, 
    1162                            state_dict, 
    1163                            params, 
    1164                            mapper_rec, 
    1165                            conn, 
    1166                            value_params, 
    1167                            has_all_pks, 
    1168                            has_all_defaults, 
    1169                        ), 
    1170                        last_inserted_params, 
    1171                        inserted_primary_key, 
    1172                        returned_defaults, 
    1173                    ) in zip_longest( 
    1174                        records, 
    1175                        result.context.compiled_parameters, 
    1176                        result.inserted_primary_key_rows, 
    1177                        result.returned_defaults_rows or (), 
    1178                    ): 
    1179                        if inserted_primary_key is None: 
    1180                            # this is a real problem and means that we didn't 
    1181                            # get back as many PK rows.  we can't continue 
    1182                            # since this indicates PK rows were missing, which 
    1183                            # means we likely mis-populated records starting 
    1184                            # at that point with incorrectly matched PK 
    1185                            # values. 
    1186                            raise orm_exc.FlushError( 
    1187                                "Multi-row INSERT statement for %s did not " 
    1188                                "produce " 
    1189                                "the correct number of INSERTed rows for " 
    1190                                "RETURNING.  Ensure there are no triggers or " 
    1191                                "special driver issues preventing INSERT from " 
    1192                                "functioning properly." % mapper_rec 
    1193                            ) 
    1194 
    1195                        for pk, col in zip( 
    1196                            inserted_primary_key, 
    1197                            mapper._pks_by_table[table], 
    1198                        ): 
    1199                            prop = mapper_rec._columntoproperty[col] 
    1200                            if state_dict.get(prop.key) is None: 
    1201                                state_dict[prop.key] = pk 
    1202 
    1203                        if state: 
    1204                            _postfetch( 
    1205                                mapper_rec, 
    1206                                uowtransaction, 
    1207                                table, 
    1208                                state, 
    1209                                state_dict, 
    1210                                result, 
    1211                                last_inserted_params, 
    1212                                value_params, 
    1213                                False, 
    1214                                returned_defaults, 
    1215                            ) 
    1216                        else: 
    1217                            _postfetch_bulk_save(mapper_rec, state_dict, table) 
    1218            else: 
    1219                assert not returning_is_required_anyway 
    1220 
    1221                for ( 
    1222                    state, 
    1223                    state_dict, 
    1224                    params, 
    1225                    mapper_rec, 
    1226                    connection, 
    1227                    value_params, 
    1228                    has_all_pks, 
    1229                    has_all_defaults, 
    1230                ) in records: 
    1231                    if value_params: 
    1232                        result = connection.execute( 
    1233                            statement.values(value_params), 
    1234                            params, 
    1235                            execution_options=execution_options, 
    1236                        ) 
    1237                    else: 
    1238                        result = connection.execute( 
    1239                            statement, 
    1240                            params, 
    1241                            execution_options=execution_options, 
    1242                        ) 
    1243 
    1244                    primary_key = result.inserted_primary_key 
    1245                    if primary_key is None: 
    1246                        raise orm_exc.FlushError( 
    1247                            "Single-row INSERT statement for %s " 
    1248                            "did not produce a " 
    1249                            "new primary key result " 
    1250                            "being invoked.  Ensure there are no triggers or " 
    1251                            "special driver issues preventing INSERT from " 
    1252                            "functioning properly." % (mapper_rec,) 
    1253                        ) 
    1254                    for pk, col in zip( 
    1255                        primary_key, mapper._pks_by_table[table] 
    1256                    ): 
    1257                        prop = mapper_rec._columntoproperty[col] 
    1258                        if ( 
    1259                            col in value_params 
    1260                            or state_dict.get(prop.key) is None 
    1261                        ): 
    1262                            state_dict[prop.key] = pk 
    1263                    if bookkeeping: 
    1264                        if state: 
    1265                            _postfetch( 
    1266                                mapper_rec, 
    1267                                uowtransaction, 
    1268                                table, 
    1269                                state, 
    1270                                state_dict, 
    1271                                result, 
    1272                                result.context.compiled_parameters[0], 
    1273                                value_params, 
    1274                                False, 
    1275                                ( 
    1276                                    result.returned_defaults 
    1277                                    if not result.context.executemany 
    1278                                    else None 
    1279                                ), 
    1280                            ) 
    1281                        else: 
    1282                            _postfetch_bulk_save(mapper_rec, state_dict, table) 
    1283 
    1284    if use_orm_insert_stmt is not None: 
    1285        if return_result is None: 
    1286            return _cursor.null_dml_result() 
    1287        else: 
    1288            return return_result 
    1289 
    1290 
    1291def _emit_post_update_statements( 
    1292    base_mapper, uowtransaction, mapper, table, update 
    1293): 
    1294    """Emit UPDATE statements corresponding to value lists collected 
    1295    by _collect_post_update_commands().""" 
    1296 
    1297    execution_options = {"compiled_cache": base_mapper._compiled_cache} 
    1298 
    1299    needs_version_id = ( 
    1300        mapper.version_id_col is not None 
    1301        and mapper.version_id_col in mapper._cols_by_table[table] 
    1302    ) 
    1303 
    1304    def update_stmt(): 
    1305        clauses = BooleanClauseList._construct_raw(operators.and_) 
    1306 
    1307        for col in mapper._pks_by_table[table]: 
    1308            clauses._append_inplace( 
    1309                col == sql.bindparam(col._label, type_=col.type) 
    1310            ) 
    1311 
    1312        if needs_version_id: 
    1313            clauses._append_inplace( 
    1314                mapper.version_id_col 
    1315                == sql.bindparam( 
    1316                    mapper.version_id_col._label, 
    1317                    type_=mapper.version_id_col.type, 
    1318                ) 
    1319            ) 
    1320 
    1321        stmt = table.update().where(clauses) 
    1322 
    1323        return stmt 
    1324 
    1325    statement = base_mapper._memo(("post_update", table), update_stmt) 
    1326 
    1327    if mapper._version_id_has_server_side_value: 
    1328        statement = statement.return_defaults(mapper.version_id_col) 
    1329 
    1330    # execute each UPDATE in the order according to the original 
    1331    # list of states to guarantee row access order, but 
    1332    # also group them into common (connection, cols) sets 
    1333    # to support executemany(). 
    1334    for key, records in groupby( 
    1335        update, 
    1336        lambda rec: (rec[3], set(rec[4])),  # connection  # parameter keys 
    1337    ): 
    1338        rows = 0 
    1339 
    1340        records = list(records) 
    1341        connection = key[0] 
    1342 
    1343        assert_singlerow = connection.dialect.supports_sane_rowcount 
    1344        assert_multirow = ( 
    1345            assert_singlerow 
    1346            and connection.dialect.supports_sane_multi_rowcount 
    1347        ) 
    1348        allow_executemany = not needs_version_id or assert_multirow 
    1349 
    1350        if not allow_executemany: 
    1351            check_rowcount = assert_singlerow 
    1352            for state, state_dict, mapper_rec, connection, params in records: 
    1353                c = connection.execute( 
    1354                    statement, params, execution_options=execution_options 
    1355                ) 
    1356 
    1357                _postfetch_post_update( 
    1358                    mapper_rec, 
    1359                    uowtransaction, 
    1360                    table, 
    1361                    state, 
    1362                    state_dict, 
    1363                    c, 
    1364                    c.context.compiled_parameters[0], 
    1365                ) 
    1366                rows += c.rowcount 
    1367        else: 
    1368            multiparams = [ 
    1369                params 
    1370                for state, state_dict, mapper_rec, conn, params in records 
    1371            ] 
    1372 
    1373            check_rowcount = assert_multirow or ( 
    1374                assert_singlerow and len(multiparams) == 1 
    1375            ) 
    1376 
    1377            c = connection.execute( 
    1378                statement, multiparams, execution_options=execution_options 
    1379            ) 
    1380 
    1381            rows += c.rowcount 
    1382            for i, ( 
    1383                state, 
    1384                state_dict, 
    1385                mapper_rec, 
    1386                connection, 
    1387                params, 
    1388            ) in enumerate(records): 
    1389                _postfetch_post_update( 
    1390                    mapper_rec, 
    1391                    uowtransaction, 
    1392                    table, 
    1393                    state, 
    1394                    state_dict, 
    1395                    c, 
    1396                    c.context.compiled_parameters[i], 
    1397                ) 
    1398 
    1399        if check_rowcount: 
    1400            if rows != len(records): 
    1401                raise orm_exc.StaleDataError( 
    1402                    "UPDATE statement on table '%s' expected to " 
    1403                    "update %d row(s); %d were matched." 
    1404                    % (table.description, len(records), rows) 
    1405                ) 
    1406 
    1407        elif needs_version_id: 
    1408            util.warn( 
    1409                "Dialect %s does not support updated rowcount " 
    1410                "- versioning cannot be verified." 
    1411                % c.dialect.dialect_description 
    1412            ) 
    1413 
    1414 
    1415def _emit_delete_statements( 
    1416    base_mapper, uowtransaction, mapper, table, delete 
    1417): 
    1418    """Emit DELETE statements corresponding to value lists collected 
    1419    by _collect_delete_commands().""" 
    1420 
    1421    need_version_id = ( 
    1422        mapper.version_id_col is not None 
    1423        and mapper.version_id_col in mapper._cols_by_table[table] 
    1424    ) 
    1425 
    1426    def delete_stmt(): 
    1427        clauses = BooleanClauseList._construct_raw(operators.and_) 
    1428 
    1429        for col in mapper._pks_by_table[table]: 
    1430            clauses._append_inplace( 
    1431                col == sql.bindparam(col.key, type_=col.type) 
    1432            ) 
    1433 
    1434        if need_version_id: 
    1435            clauses._append_inplace( 
    1436                mapper.version_id_col 
    1437                == sql.bindparam( 
    1438                    mapper.version_id_col.key, type_=mapper.version_id_col.type 
    1439                ) 
    1440            ) 
    1441 
    1442        return table.delete().where(clauses) 
    1443 
    1444    statement = base_mapper._memo(("delete", table), delete_stmt) 
    1445    for connection, recs in groupby(delete, lambda rec: rec[1]):  # connection 
    1446        del_objects = [params for params, connection in recs] 
    1447 
    1448        execution_options = {"compiled_cache": base_mapper._compiled_cache} 
    1449        expected = len(del_objects) 
    1450        rows_matched = -1 
    1451        only_warn = False 
    1452 
    1453        if ( 
    1454            need_version_id 
    1455            and not connection.dialect.supports_sane_multi_rowcount 
    1456        ): 
    1457            if connection.dialect.supports_sane_rowcount: 
    1458                rows_matched = 0 
    1459                # execute deletes individually so that versioned 
    1460                # rows can be verified 
    1461                for params in del_objects: 
    1462                    c = connection.execute( 
    1463                        statement, params, execution_options=execution_options 
    1464                    ) 
    1465                    rows_matched += c.rowcount 
    1466            else: 
    1467                util.warn( 
    1468                    "Dialect %s does not support deleted rowcount " 
    1469                    "- versioning cannot be verified." 
    1470                    % connection.dialect.dialect_description 
    1471                ) 
    1472                connection.execute( 
    1473                    statement, del_objects, execution_options=execution_options 
    1474                ) 
    1475        else: 
    1476            c = connection.execute( 
    1477                statement, del_objects, execution_options=execution_options 
    1478            ) 
    1479 
    1480            if not need_version_id: 
    1481                only_warn = True 
    1482 
    1483            rows_matched = c.rowcount 
    1484 
    1485        if ( 
    1486            base_mapper.confirm_deleted_rows 
    1487            and rows_matched > -1 
    1488            and expected != rows_matched 
    1489            and ( 
    1490                connection.dialect.supports_sane_multi_rowcount 
    1491                or len(del_objects) == 1 
    1492            ) 
    1493        ): 
    1494            # TODO: why does this "only warn" if versioning is turned off, 
    1495            # whereas the UPDATE raises? 
    1496            if only_warn: 
    1497                util.warn( 
    1498                    "DELETE statement on table '%s' expected to " 
    1499                    "delete %d row(s); %d were matched.  Please set " 
    1500                    "confirm_deleted_rows=False within the mapper " 
    1501                    "configuration to prevent this warning." 
    1502                    % (table.description, expected, rows_matched) 
    1503                ) 
    1504            else: 
    1505                raise orm_exc.StaleDataError( 
    1506                    "DELETE statement on table '%s' expected to " 
    1507                    "delete %d row(s); %d were matched.  Please set " 
    1508                    "confirm_deleted_rows=False within the mapper " 
    1509                    "configuration to prevent this warning." 
    1510                    % (table.description, expected, rows_matched) 
    1511                ) 
    1512 
    1513 
    1514def _finalize_insert_update_commands(base_mapper, uowtransaction, states): 
    1515    """finalize state on states that have been inserted or updated, 
    1516    including calling after_insert/after_update events. 
    1517 
    1518    """ 
    1519    for state, state_dict, mapper, connection, has_identity in states: 
    1520        if mapper._readonly_props: 
    1521            readonly = state.unmodified_intersection( 
    1522                [ 
    1523                    p.key 
    1524                    for p in mapper._readonly_props 
    1525                    if ( 
    1526                        p.expire_on_flush 
    1527                        and (not p.deferred or p.key in state.dict) 
    1528                    ) 
    1529                    or ( 
    1530                        not p.expire_on_flush 
    1531                        and not p.deferred 
    1532                        and p.key not in state.dict 
    1533                    ) 
    1534                ] 
    1535            ) 
    1536            if readonly: 
    1537                state._expire_attributes(state.dict, readonly) 
    1538 
    1539        # if eager_defaults option is enabled, load 
    1540        # all expired cols.  Else if we have a version_id_col, make sure 
    1541        # it isn't expired. 
    1542        toload_now = [] 
    1543 
    1544        # this is specifically to emit a second SELECT for eager_defaults, 
    1545        # so only if it's set to True, not "auto" 
    1546        if base_mapper.eager_defaults is True: 
    1547            toload_now.extend( 
    1548                state._unloaded_non_object.intersection( 
    1549                    mapper._server_default_plus_onupdate_propkeys 
    1550                ) 
    1551            ) 
    1552 
    1553        if ( 
    1554            mapper.version_id_col is not None 
    1555            and mapper.version_id_generator is False 
    1556        ): 
    1557            if mapper._version_id_prop.key in state.unloaded: 
    1558                toload_now.extend([mapper._version_id_prop.key]) 
    1559 
    1560        if toload_now: 
    1561            state.key = base_mapper._identity_key_from_state(state) 
    1562            stmt = future.select(mapper).set_label_style( 
    1563                LABEL_STYLE_TABLENAME_PLUS_COL 
    1564            ) 
    1565            loading._load_on_ident( 
    1566                uowtransaction.session, 
    1567                stmt, 
    1568                state.key, 
    1569                refresh_state=state, 
    1570                only_load_props=toload_now, 
    1571            ) 
    1572 
    1573        # call after_XXX extensions 
    1574        if not has_identity: 
    1575            mapper.dispatch.after_insert(mapper, connection, state) 
    1576        else: 
    1577            mapper.dispatch.after_update(mapper, connection, state) 
    1578 
    1579        if ( 
    1580            mapper.version_id_generator is False 
    1581            and mapper.version_id_col is not None 
    1582        ): 
    1583            if state_dict[mapper._version_id_prop.key] is None: 
    1584                raise orm_exc.FlushError( 
    1585                    "Instance does not contain a non-NULL version value" 
    1586                ) 
    1587 
    1588 
    1589def _postfetch_post_update( 
    1590    mapper, uowtransaction, table, state, dict_, result, params 
    1591): 
    1592    needs_version_id = ( 
    1593        mapper.version_id_col is not None 
    1594        and mapper.version_id_col in mapper._cols_by_table[table] 
    1595    ) 
    1596 
    1597    if not uowtransaction.is_deleted(state): 
    1598        # post updating after a regular INSERT or UPDATE, do a full postfetch 
    1599        prefetch_cols = result.context.compiled.prefetch 
    1600        postfetch_cols = result.context.compiled.postfetch 
    1601    elif needs_version_id: 
    1602        # post updating before a DELETE with a version_id_col, need to 
    1603        # postfetch just version_id_col 
    1604        prefetch_cols = postfetch_cols = () 
    1605    else: 
    1606        # post updating before a DELETE without a version_id_col, 
    1607        # don't need to postfetch 
    1608        return 
    1609 
    1610    if needs_version_id: 
    1611        prefetch_cols = list(prefetch_cols) + [mapper.version_id_col] 
    1612 
    1613    refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush) 
    1614    if refresh_flush: 
    1615        load_evt_attrs = [] 
    1616 
    1617    for c in prefetch_cols: 
    1618        if c.key in params and c in mapper._columntoproperty: 
    1619            dict_[mapper._columntoproperty[c].key] = params[c.key] 
    1620            if refresh_flush: 
    1621                load_evt_attrs.append(mapper._columntoproperty[c].key) 
    1622 
    1623    if refresh_flush and load_evt_attrs: 
    1624        mapper.class_manager.dispatch.refresh_flush( 
    1625            state, uowtransaction, load_evt_attrs 
    1626        ) 
    1627 
    1628    if postfetch_cols: 
    1629        state._expire_attributes( 
    1630            state.dict, 
    1631            [ 
    1632                mapper._columntoproperty[c].key 
    1633                for c in postfetch_cols 
    1634                if c in mapper._columntoproperty 
    1635            ], 
    1636        ) 
    1637 
    1638 
    1639def _postfetch( 
    1640    mapper, 
    1641    uowtransaction, 
    1642    table, 
    1643    state, 
    1644    dict_, 
    1645    result, 
    1646    params, 
    1647    value_params, 
    1648    isupdate, 
    1649    returned_defaults, 
    1650): 
    1651    """Expire attributes in need of newly persisted database state, 
    1652    after an INSERT or UPDATE statement has proceeded for that 
    1653    state.""" 
    1654 
    1655    prefetch_cols = result.context.compiled.prefetch 
    1656    postfetch_cols = result.context.compiled.postfetch 
    1657    returning_cols = result.context.compiled.effective_returning 
    1658 
    1659    if ( 
    1660        mapper.version_id_col is not None 
    1661        and mapper.version_id_col in mapper._cols_by_table[table] 
    1662    ): 
    1663        prefetch_cols = list(prefetch_cols) + [mapper.version_id_col] 
    1664 
    1665    refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush) 
    1666    if refresh_flush: 
    1667        load_evt_attrs = [] 
    1668 
    1669    if returning_cols: 
    1670        row = returned_defaults 
    1671        if row is not None: 
    1672            for row_value, col in zip(row, returning_cols): 
    1673                # pk cols returned from insert are handled 
    1674                # distinctly, don't step on the values here 
    1675                if col.primary_key and result.context.isinsert: 
    1676                    continue 
    1677 
    1678                # note that columns can be in the "return defaults" that are 
    1679                # not mapped to this mapper, typically because they are 
    1680                # "excluded", which can be specified directly or also occurs 
    1681                # when using declarative w/ single table inheritance 
    1682                prop = mapper._columntoproperty.get(col) 
    1683                if prop: 
    1684                    dict_[prop.key] = row_value 
    1685                    if refresh_flush: 
    1686                        load_evt_attrs.append(prop.key) 
    1687 
    1688    for c in prefetch_cols: 
    1689        if c.key in params and c in mapper._columntoproperty: 
    1690            pkey = mapper._columntoproperty[c].key 
    1691 
    1692            # set prefetched value in dict and also pop from committed_state, 
    1693            # since this is new database state that replaces whatever might 
    1694            # have previously been fetched (see #10800).  this is essentially a 
    1695            # shorthand version of set_committed_value(), which could also be 
    1696            # used here directly (with more overhead) 
    1697            dict_[pkey] = params[c.key] 
    1698            state.committed_state.pop(pkey, None) 
    1699 
    1700            if refresh_flush: 
    1701                load_evt_attrs.append(pkey) 
    1702 
    1703    if refresh_flush and load_evt_attrs: 
    1704        mapper.class_manager.dispatch.refresh_flush( 
    1705            state, uowtransaction, load_evt_attrs 
    1706        ) 
    1707 
    1708    if isupdate and value_params: 
    1709        # explicitly suit the use case specified by 
    1710        # [ticket:3801], PK SQL expressions for UPDATE on non-RETURNING 
    1711        # database which are set to themselves in order to do a version bump. 
    1712        postfetch_cols.extend( 
    1713            [ 
    1714                col 
    1715                for col in value_params 
    1716                if col.primary_key and col not in returning_cols 
    1717            ] 
    1718        ) 
    1719 
    1720    if postfetch_cols: 
    1721        state._expire_attributes( 
    1722            state.dict, 
    1723            [ 
    1724                mapper._columntoproperty[c].key 
    1725                for c in postfetch_cols 
    1726                if c in mapper._columntoproperty 
    1727            ], 
    1728        ) 
    1729 
    1730    # synchronize newly inserted ids from one table to the next 
    1731    # TODO: this still goes a little too often.  would be nice to 
    1732    # have definitive list of "columns that changed" here 
    1733    for m, equated_pairs in mapper._table_to_equated[table]: 
    1734        sync._populate( 
    1735            state, 
    1736            m, 
    1737            state, 
    1738            m, 
    1739            equated_pairs, 
    1740            uowtransaction, 
    1741            mapper.passive_updates, 
    1742        ) 
    1743 
    1744 
    1745def _postfetch_bulk_save(mapper, dict_, table): 
    1746    for m, equated_pairs in mapper._table_to_equated[table]: 
    1747        sync._bulk_populate_inherit_keys(dict_, m, equated_pairs) 
    1748 
    1749 
    1750def _connections_for_states(base_mapper, uowtransaction, states): 
    1751    """Return an iterator of (state, state.dict, mapper, connection). 
    1752 
    1753    The states are sorted according to _sort_states, then paired 
    1754    with the connection they should be using for the given 
    1755    unit of work transaction. 
    1756 
    1757    """ 
    1758    # if session has a connection callable, 
    1759    # organize individual states with the connection 
    1760    # to use for update 
    1761    if uowtransaction.session.connection_callable: 
    1762        connection_callable = uowtransaction.session.connection_callable 
    1763    else: 
    1764        connection = uowtransaction.transaction.connection(base_mapper) 
    1765        connection_callable = None 
    1766 
    1767    for state in _sort_states(base_mapper, states): 
    1768        if connection_callable: 
    1769            connection = connection_callable(base_mapper, state.obj()) 
    1770 
    1771        mapper = state.manager.mapper 
    1772 
    1773        yield state, state.dict, mapper, connection 
    1774 
    1775 
    1776def _sort_states(mapper, states): 
    1777    pending = set(states) 
    1778    persistent = {s for s in pending if s.key is not None} 
    1779    pending.difference_update(persistent) 
    1780 
    1781    try: 
    1782        persistent_sorted = sorted( 
    1783            persistent, key=mapper._persistent_sortkey_fn 
    1784        ) 
    1785    except TypeError as err: 
    1786        raise sa_exc.InvalidRequestError( 
    1787            "Could not sort objects by primary key; primary key " 
    1788            "values must be sortable in Python (was: %s)" % err 
    1789        ) from err 
    1790    return ( 
    1791        sorted(pending, key=operator.attrgetter("insert_order")) 
    1792        + persistent_sorted 
    1793    )