1# orm/state_changes.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7
8"""State tracking utilities used by :class:`_orm.Session`."""
9
10from __future__ import annotations
11
12import contextlib
13from enum import Enum
14from typing import Any
15from typing import Callable
16from typing import cast
17from typing import Iterator
18from typing import NoReturn
19from typing import Optional
20from typing import Tuple
21from typing import TypeVar
22from typing import Union
23
24from .. import exc as sa_exc
25from .. import util
26from ..util.typing import Literal
27
28_F = TypeVar("_F", bound=Callable[..., Any])
29
30
31class _StateChangeState(Enum):
32 pass
33
34
35class _StateChangeStates(_StateChangeState):
36 ANY = 1
37 NO_CHANGE = 2
38 CHANGE_IN_PROGRESS = 3
39
40
41class _StateChange:
42 """Supplies state assertion decorators.
43
44 The current use case is for the :class:`_orm.SessionTransaction` class. The
45 :class:`_StateChange` class itself is agnostic of the
46 :class:`_orm.SessionTransaction` class so could in theory be generalized
47 for other systems as well.
48
49 """
50
51 _next_state: _StateChangeState = _StateChangeStates.ANY
52 _state: _StateChangeState = _StateChangeStates.NO_CHANGE
53 _current_fn: Optional[Callable[..., Any]] = None
54
55 def _raise_for_prerequisite_state(
56 self, operation_name: str, state: _StateChangeState
57 ) -> NoReturn:
58 raise sa_exc.IllegalStateChangeError(
59 f"Can't run operation '{operation_name}()' when Session "
60 f"is in state {state!r}",
61 code="isce",
62 )
63
64 @classmethod
65 def declare_states(
66 cls,
67 prerequisite_states: Union[
68 Literal[_StateChangeStates.ANY], Tuple[_StateChangeState, ...]
69 ],
70 moves_to: _StateChangeState,
71 ) -> Callable[[_F], _F]:
72 """Method decorator declaring valid states.
73
74 :param prerequisite_states: sequence of acceptable prerequisite
75 states. Can be the single constant _State.ANY to indicate no
76 prerequisite state
77
78 :param moves_to: the expected state at the end of the method, assuming
79 no exceptions raised. Can be the constant _State.NO_CHANGE to
80 indicate state should not change at the end of the method.
81
82 """
83 assert prerequisite_states, "no prequisite states sent"
84 has_prerequisite_states = (
85 prerequisite_states is not _StateChangeStates.ANY
86 )
87
88 prerequisite_state_collection = cast(
89 "Tuple[_StateChangeState, ...]", prerequisite_states
90 )
91 expect_state_change = moves_to is not _StateChangeStates.NO_CHANGE
92
93 @util.decorator
94 def _go(fn: _F, self: Any, *arg: Any, **kw: Any) -> Any:
95 current_state = self._state
96
97 if (
98 has_prerequisite_states
99 and current_state not in prerequisite_state_collection
100 ):
101 self._raise_for_prerequisite_state(fn.__name__, current_state)
102
103 next_state = self._next_state
104 existing_fn = self._current_fn
105 expect_state = moves_to if expect_state_change else current_state
106
107 if (
108 # destination states are restricted
109 next_state is not _StateChangeStates.ANY
110 # method seeks to change state
111 and expect_state_change
112 # destination state incorrect
113 and next_state is not expect_state
114 ):
115 if existing_fn and next_state in (
116 _StateChangeStates.NO_CHANGE,
117 _StateChangeStates.CHANGE_IN_PROGRESS,
118 ):
119 raise sa_exc.IllegalStateChangeError(
120 f"Method '{fn.__name__}()' can't be called here; "
121 f"method '{existing_fn.__name__}()' is already "
122 f"in progress and this would cause an unexpected "
123 f"state change to {moves_to!r}",
124 code="isce",
125 )
126 else:
127 raise sa_exc.IllegalStateChangeError(
128 f"Cant run operation '{fn.__name__}()' here; "
129 f"will move to state {moves_to!r} where we are "
130 f"expecting {next_state!r}",
131 code="isce",
132 )
133
134 self._current_fn = fn
135 self._next_state = _StateChangeStates.CHANGE_IN_PROGRESS
136 try:
137 ret_value = fn(self, *arg, **kw)
138 except:
139 raise
140 else:
141 if self._state is expect_state:
142 return ret_value
143
144 if self._state is current_state:
145 raise sa_exc.IllegalStateChangeError(
146 f"Method '{fn.__name__}()' failed to "
147 "change state "
148 f"to {moves_to!r} as expected",
149 code="isce",
150 )
151 elif existing_fn:
152 raise sa_exc.IllegalStateChangeError(
153 f"While method '{existing_fn.__name__}()' was "
154 "running, "
155 f"method '{fn.__name__}()' caused an "
156 "unexpected "
157 f"state change to {self._state!r}",
158 code="isce",
159 )
160 else:
161 raise sa_exc.IllegalStateChangeError(
162 f"Method '{fn.__name__}()' caused an unexpected "
163 f"state change to {self._state!r}",
164 code="isce",
165 )
166
167 finally:
168 self._next_state = next_state
169 self._current_fn = existing_fn
170
171 return _go
172
173 @contextlib.contextmanager
174 def _expect_state(self, expected: _StateChangeState) -> Iterator[Any]:
175 """called within a method that changes states.
176
177 method must also use the ``@declare_states()`` decorator.
178
179 """
180 assert self._next_state is _StateChangeStates.CHANGE_IN_PROGRESS, (
181 "Unexpected call to _expect_state outside of "
182 "state-changing method"
183 )
184
185 self._next_state = expected
186 try:
187 yield
188 except:
189 raise
190 else:
191 if self._state is not expected:
192 raise sa_exc.IllegalStateChangeError(
193 f"Unexpected state change to {self._state!r}", code="isce"
194 )
195 finally:
196 self._next_state = _StateChangeStates.CHANGE_IN_PROGRESS