1# ext/asyncio/base.py
2# Copyright (C) 2020-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7
8from __future__ import annotations
9
10import abc
11import functools
12from typing import Any
13from typing import AsyncGenerator
14from typing import AsyncIterator
15from typing import Awaitable
16from typing import Callable
17from typing import ClassVar
18from typing import Dict
19from typing import Generator
20from typing import Generic
21from typing import NoReturn
22from typing import Optional
23from typing import overload
24from typing import Tuple
25from typing import TypeVar
26import weakref
27
28from . import exc as async_exc
29from ... import util
30from ...util.typing import Literal
31from ...util.typing import Self
32
33_T = TypeVar("_T", bound=Any)
34_T_co = TypeVar("_T_co", bound=Any, covariant=True)
35
36
37_PT = TypeVar("_PT", bound=Any)
38
39
40class ReversibleProxy(Generic[_PT]):
41 _proxy_objects: ClassVar[
42 Dict[weakref.ref[Any], weakref.ref[ReversibleProxy[Any]]]
43 ] = {}
44 __slots__ = ("__weakref__",)
45
46 @overload
47 def _assign_proxied(self, target: _PT) -> _PT: ...
48
49 @overload
50 def _assign_proxied(self, target: None) -> None: ...
51
52 def _assign_proxied(self, target: Optional[_PT]) -> Optional[_PT]:
53 if target is not None:
54 target_ref: weakref.ref[_PT] = weakref.ref(
55 target, ReversibleProxy._target_gced
56 )
57 proxy_ref = weakref.ref(
58 self,
59 functools.partial(ReversibleProxy._target_gced, target_ref),
60 )
61 ReversibleProxy._proxy_objects[target_ref] = proxy_ref
62
63 return target
64
65 @classmethod
66 def _target_gced(
67 cls,
68 ref: weakref.ref[_PT],
69 proxy_ref: Optional[weakref.ref[Self]] = None, # noqa: U100
70 ) -> None:
71 cls._proxy_objects.pop(ref, None)
72
73 @classmethod
74 def _regenerate_proxy_for_target(
75 cls, target: _PT, **additional_kw: Any
76 ) -> Self:
77 raise NotImplementedError()
78
79 @overload
80 @classmethod
81 def _retrieve_proxy_for_target(
82 cls, target: _PT, regenerate: Literal[True] = ..., **additional_kw: Any
83 ) -> Self: ...
84
85 @overload
86 @classmethod
87 def _retrieve_proxy_for_target(
88 cls, target: _PT, regenerate: bool = True, **additional_kw: Any
89 ) -> Optional[Self]: ...
90
91 @classmethod
92 def _retrieve_proxy_for_target(
93 cls, target: _PT, regenerate: bool = True, **additional_kw: Any
94 ) -> Optional[Self]:
95 try:
96 proxy_ref = cls._proxy_objects[weakref.ref(target)]
97 except KeyError:
98 pass
99 else:
100 proxy = proxy_ref()
101 if proxy is not None:
102 return proxy # type: ignore
103
104 if regenerate:
105 return cls._regenerate_proxy_for_target(target, **additional_kw)
106 else:
107 return None
108
109
110class StartableContext(Awaitable[_T_co], abc.ABC):
111 __slots__ = ()
112
113 @abc.abstractmethod
114 async def start(self, is_ctxmanager: bool = False) -> _T_co:
115 raise NotImplementedError()
116
117 def __await__(self) -> Generator[Any, Any, _T_co]:
118 return self.start().__await__()
119
120 async def __aenter__(self) -> _T_co:
121 return await self.start(is_ctxmanager=True)
122
123 @abc.abstractmethod
124 async def __aexit__(
125 self, type_: Any, value: Any, traceback: Any
126 ) -> Optional[bool]:
127 pass
128
129 def _raise_for_not_started(self) -> NoReturn:
130 raise async_exc.AsyncContextNotStarted(
131 "%s context has not been started and object has not been awaited."
132 % (self.__class__.__name__)
133 )
134
135
136class GeneratorStartableContext(StartableContext[_T_co]):
137 __slots__ = ("gen",)
138
139 gen: AsyncGenerator[_T_co, Any]
140
141 def __init__(
142 self,
143 func: Callable[..., AsyncIterator[_T_co]],
144 args: Tuple[Any, ...],
145 kwds: Dict[str, Any],
146 ):
147 self.gen = func(*args, **kwds) # type: ignore
148
149 async def start(self, is_ctxmanager: bool = False) -> _T_co:
150 try:
151 start_value = await util.anext_(self.gen)
152 except StopAsyncIteration:
153 raise RuntimeError("generator didn't yield") from None
154
155 # if not a context manager, then interrupt the generator, don't
156 # let it complete. this step is technically not needed, as the
157 # generator will close in any case at gc time. not clear if having
158 # this here is a good idea or not (though it helps for clarity IMO)
159 if not is_ctxmanager:
160 await self.gen.aclose()
161
162 return start_value
163
164 async def __aexit__(
165 self, typ: Any, value: Any, traceback: Any
166 ) -> Optional[bool]:
167 # vendored from contextlib.py
168 if typ is None:
169 try:
170 await util.anext_(self.gen)
171 except StopAsyncIteration:
172 return False
173 else:
174 raise RuntimeError("generator didn't stop")
175 else:
176 if value is None:
177 # Need to force instantiation so we can reliably
178 # tell if we get the same exception back
179 value = typ()
180 try:
181 await self.gen.athrow(value)
182 except StopAsyncIteration as exc:
183 # Suppress StopIteration *unless* it's the same exception that
184 # was passed to throw(). This prevents a StopIteration
185 # raised inside the "with" statement from being suppressed.
186 return exc is not value
187 except RuntimeError as exc:
188 # Don't re-raise the passed in exception. (issue27122)
189 if exc is value:
190 return False
191 # Avoid suppressing if a Stop(Async)Iteration exception
192 # was passed to athrow() and later wrapped into a RuntimeError
193 # (see PEP 479 for sync generators; async generators also
194 # have this behavior). But do this only if the exception
195 # wrapped
196 # by the RuntimeError is actully Stop(Async)Iteration (see
197 # issue29692).
198 if (
199 isinstance(value, (StopIteration, StopAsyncIteration))
200 and exc.__cause__ is value
201 ):
202 return False
203 raise
204 except BaseException as exc:
205 # only re-raise if it's *not* the exception that was
206 # passed to throw(), because __exit__() must not raise
207 # an exception unless __exit__() itself failed. But throw()
208 # has to raise the exception to signal propagation, so this
209 # fixes the impedance mismatch between the throw() protocol
210 # and the __exit__() protocol.
211 if exc is not value:
212 raise
213 return False
214 raise RuntimeError("generator didn't stop after athrow()")
215
216
217def asyncstartablecontext(
218 func: Callable[..., AsyncIterator[_T_co]],
219) -> Callable[..., GeneratorStartableContext[_T_co]]:
220 """@asyncstartablecontext decorator.
221
222 the decorated function can be called either as ``async with fn()``, **or**
223 ``await fn()``. This is decidedly different from what
224 ``@contextlib.asynccontextmanager`` supports, and the usage pattern
225 is different as well.
226
227 Typical usage:
228
229 .. sourcecode:: text
230
231 @asyncstartablecontext
232 async def some_async_generator(<arguments>):
233 <setup>
234 try:
235 yield <value>
236 except GeneratorExit:
237 # return value was awaited, no context manager is present
238 # and caller will .close() the resource explicitly
239 pass
240 else:
241 <context manager cleanup>
242
243
244 Above, ``GeneratorExit`` is caught if the function were used as an
245 ``await``. In this case, it's essential that the cleanup does **not**
246 occur, so there should not be a ``finally`` block.
247
248 If ``GeneratorExit`` is not invoked, this means we're in ``__aexit__``
249 and we were invoked as a context manager, and cleanup should proceed.
250
251
252 """
253
254 @functools.wraps(func)
255 def helper(*args: Any, **kwds: Any) -> GeneratorStartableContext[_T_co]:
256 return GeneratorStartableContext(func, args, kwds)
257
258 return helper
259
260
261class ProxyComparable(ReversibleProxy[_PT]):
262 __slots__ = ()
263
264 @util.ro_non_memoized_property
265 def _proxied(self) -> _PT:
266 raise NotImplementedError()
267
268 def __hash__(self) -> int:
269 return id(self)
270
271 def __eq__(self, other: Any) -> bool:
272 return (
273 isinstance(other, self.__class__)
274 and self._proxied == other._proxied
275 )
276
277 def __ne__(self, other: Any) -> bool:
278 return (
279 not isinstance(other, self.__class__)
280 or self._proxied != other._proxied
281 )