1# orm/state_changes.py
2# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7
8"""State tracking utilities used by :class:`_orm.Session`.
9
10"""
11
12from __future__ import annotations
13
14import contextlib
15from enum import Enum
16from typing import Any
17from typing import Callable
18from typing import cast
19from typing import Iterator
20from typing import NoReturn
21from typing import Optional
22from typing import Tuple
23from typing import TypeVar
24from typing import Union
25
26from .. import exc as sa_exc
27from .. import util
28from ..util.typing import Literal
29
30_F = TypeVar("_F", bound=Callable[..., Any])
31
32
33class _StateChangeState(Enum):
34 pass
35
36
37class _StateChangeStates(_StateChangeState):
38 ANY = 1
39 NO_CHANGE = 2
40 CHANGE_IN_PROGRESS = 3
41
42
43class _StateChange:
44 """Supplies state assertion decorators.
45
46 The current use case is for the :class:`_orm.SessionTransaction` class. The
47 :class:`_StateChange` class itself is agnostic of the
48 :class:`_orm.SessionTransaction` class so could in theory be generalized
49 for other systems as well.
50
51 """
52
53 _next_state: _StateChangeState = _StateChangeStates.ANY
54 _state: _StateChangeState = _StateChangeStates.NO_CHANGE
55 _current_fn: Optional[Callable[..., Any]] = None
56
57 def _raise_for_prerequisite_state(
58 self, operation_name: str, state: _StateChangeState
59 ) -> NoReturn:
60 raise sa_exc.IllegalStateChangeError(
61 f"Can't run operation '{operation_name}()' when Session "
62 f"is in state {state!r}",
63 code="isce",
64 )
65
66 @classmethod
67 def declare_states(
68 cls,
69 prerequisite_states: Union[
70 Literal[_StateChangeStates.ANY], Tuple[_StateChangeState, ...]
71 ],
72 moves_to: _StateChangeState,
73 ) -> Callable[[_F], _F]:
74 """Method decorator declaring valid states.
75
76 :param prerequisite_states: sequence of acceptable prerequisite
77 states. Can be the single constant _State.ANY to indicate no
78 prerequisite state
79
80 :param moves_to: the expected state at the end of the method, assuming
81 no exceptions raised. Can be the constant _State.NO_CHANGE to
82 indicate state should not change at the end of the method.
83
84 """
85 assert prerequisite_states, "no prequisite states sent"
86 has_prerequisite_states = (
87 prerequisite_states is not _StateChangeStates.ANY
88 )
89
90 prerequisite_state_collection = cast(
91 "Tuple[_StateChangeState, ...]", prerequisite_states
92 )
93 expect_state_change = moves_to is not _StateChangeStates.NO_CHANGE
94
95 @util.decorator
96 def _go(fn: _F, self: Any, *arg: Any, **kw: Any) -> Any:
97 current_state = self._state
98
99 if (
100 has_prerequisite_states
101 and current_state not in prerequisite_state_collection
102 ):
103 self._raise_for_prerequisite_state(fn.__name__, current_state)
104
105 next_state = self._next_state
106 existing_fn = self._current_fn
107 expect_state = moves_to if expect_state_change else current_state
108
109 if (
110 # destination states are restricted
111 next_state is not _StateChangeStates.ANY
112 # method seeks to change state
113 and expect_state_change
114 # destination state incorrect
115 and next_state is not expect_state
116 ):
117 if existing_fn and next_state in (
118 _StateChangeStates.NO_CHANGE,
119 _StateChangeStates.CHANGE_IN_PROGRESS,
120 ):
121 raise sa_exc.IllegalStateChangeError(
122 f"Method '{fn.__name__}()' can't be called here; "
123 f"method '{existing_fn.__name__}()' is already "
124 f"in progress and this would cause an unexpected "
125 f"state change to {moves_to!r}",
126 code="isce",
127 )
128 else:
129 raise sa_exc.IllegalStateChangeError(
130 f"Cant run operation '{fn.__name__}()' here; "
131 f"will move to state {moves_to!r} where we are "
132 f"expecting {next_state!r}",
133 code="isce",
134 )
135
136 self._current_fn = fn
137 self._next_state = _StateChangeStates.CHANGE_IN_PROGRESS
138 try:
139 ret_value = fn(self, *arg, **kw)
140 except:
141 raise
142 else:
143 if self._state is expect_state:
144 return ret_value
145
146 if self._state is current_state:
147 raise sa_exc.IllegalStateChangeError(
148 f"Method '{fn.__name__}()' failed to "
149 "change state "
150 f"to {moves_to!r} as expected",
151 code="isce",
152 )
153 elif existing_fn:
154 raise sa_exc.IllegalStateChangeError(
155 f"While method '{existing_fn.__name__}()' was "
156 "running, "
157 f"method '{fn.__name__}()' caused an "
158 "unexpected "
159 f"state change to {self._state!r}",
160 code="isce",
161 )
162 else:
163 raise sa_exc.IllegalStateChangeError(
164 f"Method '{fn.__name__}()' caused an unexpected "
165 f"state change to {self._state!r}",
166 code="isce",
167 )
168
169 finally:
170 self._next_state = next_state
171 self._current_fn = existing_fn
172
173 return _go
174
175 @contextlib.contextmanager
176 def _expect_state(self, expected: _StateChangeState) -> Iterator[Any]:
177 """called within a method that changes states.
178
179 method must also use the ``@declare_states()`` decorator.
180
181 """
182 assert self._next_state is _StateChangeStates.CHANGE_IN_PROGRESS, (
183 "Unexpected call to _expect_state outside of "
184 "state-changing method"
185 )
186
187 self._next_state = expected
188 try:
189 yield
190 except:
191 raise
192 else:
193 if self._state is not expected:
194 raise sa_exc.IllegalStateChangeError(
195 f"Unexpected state change to {self._state!r}", code="isce"
196 )
197 finally:
198 self._next_state = _StateChangeStates.CHANGE_IN_PROGRESS