1# orm/state_changes.py 
    2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors 
    3# <see AUTHORS file> 
    4# 
    5# This module is part of SQLAlchemy and is released under 
    6# the MIT License: https://www.opensource.org/licenses/mit-license.php 
    7 
    8"""State tracking utilities used by :class:`_orm.Session`.""" 
    9 
    10from __future__ import annotations 
    11 
    12import contextlib 
    13from enum import Enum 
    14from typing import Any 
    15from typing import Callable 
    16from typing import cast 
    17from typing import Iterator 
    18from typing import Literal 
    19from typing import NoReturn 
    20from typing import Optional 
    21from typing import Tuple 
    22from typing import TypeVar 
    23from typing import Union 
    24 
    25from .. import exc as sa_exc 
    26from .. import util 
    27 
    28_F = TypeVar("_F", bound=Callable[..., Any]) 
    29 
    30 
    31class _StateChangeState(Enum): 
    32    pass 
    33 
    34 
    35class _StateChangeStates(_StateChangeState): 
    36    ANY = 1 
    37    NO_CHANGE = 2 
    38    CHANGE_IN_PROGRESS = 3 
    39 
    40 
    41class _StateChange: 
    42    """Supplies state assertion decorators. 
    43 
    44    The current use case is for the :class:`_orm.SessionTransaction` class. The 
    45    :class:`_StateChange` class itself is agnostic of the 
    46    :class:`_orm.SessionTransaction` class so could in theory be generalized 
    47    for other systems as well. 
    48 
    49    """ 
    50 
    51    _next_state: _StateChangeState = _StateChangeStates.ANY 
    52    _state: _StateChangeState = _StateChangeStates.NO_CHANGE 
    53    _current_fn: Optional[Callable[..., Any]] = None 
    54 
    55    def _raise_for_prerequisite_state( 
    56        self, operation_name: str, state: _StateChangeState 
    57    ) -> NoReturn: 
    58        raise sa_exc.IllegalStateChangeError( 
    59            f"Can't run operation '{operation_name}()' when Session " 
    60            f"is in state {state!r}", 
    61            code="isce", 
    62        ) 
    63 
    64    @classmethod 
    65    def declare_states( 
    66        cls, 
    67        prerequisite_states: Union[ 
    68            Literal[_StateChangeStates.ANY], Tuple[_StateChangeState, ...] 
    69        ], 
    70        moves_to: _StateChangeState, 
    71    ) -> Callable[[_F], _F]: 
    72        """Method decorator declaring valid states. 
    73 
    74        :param prerequisite_states: sequence of acceptable prerequisite 
    75         states.   Can be the single constant _State.ANY to indicate no 
    76         prerequisite state 
    77 
    78        :param moves_to: the expected state at the end of the method, assuming 
    79         no exceptions raised.   Can be the constant _State.NO_CHANGE to 
    80         indicate state should not change at the end of the method. 
    81 
    82        """ 
    83        assert prerequisite_states, "no prequisite states sent" 
    84        has_prerequisite_states = ( 
    85            prerequisite_states is not _StateChangeStates.ANY 
    86        ) 
    87 
    88        prerequisite_state_collection = cast( 
    89            "Tuple[_StateChangeState, ...]", prerequisite_states 
    90        ) 
    91        expect_state_change = moves_to is not _StateChangeStates.NO_CHANGE 
    92 
    93        @util.decorator 
    94        def _go(fn: _F, self: Any, *arg: Any, **kw: Any) -> Any: 
    95            current_state = self._state 
    96 
    97            if ( 
    98                has_prerequisite_states 
    99                and current_state not in prerequisite_state_collection 
    100            ): 
    101                self._raise_for_prerequisite_state(fn.__name__, current_state) 
    102 
    103            next_state = self._next_state 
    104            existing_fn = self._current_fn 
    105            expect_state = moves_to if expect_state_change else current_state 
    106 
    107            if ( 
    108                # destination states are restricted 
    109                next_state is not _StateChangeStates.ANY 
    110                # method seeks to change state 
    111                and expect_state_change 
    112                # destination state incorrect 
    113                and next_state is not expect_state 
    114            ): 
    115                if existing_fn and next_state in ( 
    116                    _StateChangeStates.NO_CHANGE, 
    117                    _StateChangeStates.CHANGE_IN_PROGRESS, 
    118                ): 
    119                    raise sa_exc.IllegalStateChangeError( 
    120                        f"Method '{fn.__name__}()' can't be called here; " 
    121                        f"method '{existing_fn.__name__}()' is already " 
    122                        f"in progress and this would cause an unexpected " 
    123                        f"state change to {moves_to!r}", 
    124                        code="isce", 
    125                    ) 
    126                else: 
    127                    raise sa_exc.IllegalStateChangeError( 
    128                        f"Cant run operation '{fn.__name__}()' here; " 
    129                        f"will move to state {moves_to!r} where we are " 
    130                        f"expecting {next_state!r}", 
    131                        code="isce", 
    132                    ) 
    133 
    134            self._current_fn = fn 
    135            self._next_state = _StateChangeStates.CHANGE_IN_PROGRESS 
    136            try: 
    137                ret_value = fn(self, *arg, **kw) 
    138            except: 
    139                raise 
    140            else: 
    141                if self._state is expect_state: 
    142                    return ret_value 
    143 
    144                if self._state is current_state: 
    145                    raise sa_exc.IllegalStateChangeError( 
    146                        f"Method '{fn.__name__}()' failed to " 
    147                        "change state " 
    148                        f"to {moves_to!r} as expected", 
    149                        code="isce", 
    150                    ) 
    151                elif existing_fn: 
    152                    raise sa_exc.IllegalStateChangeError( 
    153                        f"While method '{existing_fn.__name__}()' was " 
    154                        "running, " 
    155                        f"method '{fn.__name__}()' caused an " 
    156                        "unexpected " 
    157                        f"state change to {self._state!r}", 
    158                        code="isce", 
    159                    ) 
    160                else: 
    161                    raise sa_exc.IllegalStateChangeError( 
    162                        f"Method '{fn.__name__}()' caused an unexpected " 
    163                        f"state change to {self._state!r}", 
    164                        code="isce", 
    165                    ) 
    166 
    167            finally: 
    168                self._next_state = next_state 
    169                self._current_fn = existing_fn 
    170 
    171        return _go 
    172 
    173    @contextlib.contextmanager 
    174    def _expect_state(self, expected: _StateChangeState) -> Iterator[Any]: 
    175        """called within a method that changes states. 
    176 
    177        method must also use the ``@declare_states()`` decorator. 
    178 
    179        """ 
    180        assert self._next_state is _StateChangeStates.CHANGE_IN_PROGRESS, ( 
    181            "Unexpected call to _expect_state outside of " 
    182            "state-changing method" 
    183        ) 
    184 
    185        self._next_state = expected 
    186        try: 
    187            yield 
    188        except: 
    189            raise 
    190        else: 
    191            if self._state is not expected: 
    192                raise sa_exc.IllegalStateChangeError( 
    193                    f"Unexpected state change to {self._state!r}", code="isce" 
    194                ) 
    195        finally: 
    196            self._next_state = _StateChangeStates.CHANGE_IN_PROGRESS