Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/base.py: 58%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

130 statements  

1# ext/asyncio/base.py 

2# Copyright (C) 2020-2025 the SQLAlchemy authors and contributors 

3# <see AUTHORS file> 

4# 

5# This module is part of SQLAlchemy and is released under 

6# the MIT License: https://www.opensource.org/licenses/mit-license.php 

7 

8from __future__ import annotations 

9 

10import abc 

11import functools 

12from typing import Any 

13from typing import AsyncGenerator 

14from typing import AsyncIterator 

15from typing import Awaitable 

16from typing import Callable 

17from typing import ClassVar 

18from typing import Dict 

19from typing import Generator 

20from typing import Generic 

21from typing import NoReturn 

22from typing import Optional 

23from typing import overload 

24from typing import Tuple 

25from typing import TypeVar 

26import weakref 

27 

28from . import exc as async_exc 

29from ... import util 

30from ...util.typing import Literal 

31from ...util.typing import Self 

32 

33_T = TypeVar("_T", bound=Any) 

34_T_co = TypeVar("_T_co", bound=Any, covariant=True) 

35 

36 

37_PT = TypeVar("_PT", bound=Any) 

38 

39 

40class ReversibleProxy(Generic[_PT]): 

41 _proxy_objects: ClassVar[ 

42 Dict[weakref.ref[Any], weakref.ref[ReversibleProxy[Any]]] 

43 ] = {} 

44 __slots__ = ("__weakref__",) 

45 

46 @overload 

47 def _assign_proxied(self, target: _PT) -> _PT: ... 

48 

49 @overload 

50 def _assign_proxied(self, target: None) -> None: ... 

51 

52 def _assign_proxied(self, target: Optional[_PT]) -> Optional[_PT]: 

53 if target is not None: 

54 target_ref: weakref.ref[_PT] = weakref.ref( 

55 target, ReversibleProxy._target_gced 

56 ) 

57 proxy_ref = weakref.ref( 

58 self, 

59 functools.partial(ReversibleProxy._target_gced, target_ref), 

60 ) 

61 ReversibleProxy._proxy_objects[target_ref] = proxy_ref 

62 

63 return target 

64 

65 @classmethod 

66 def _target_gced( 

67 cls, 

68 ref: weakref.ref[_PT], 

69 proxy_ref: Optional[weakref.ref[Self]] = None, # noqa: U100 

70 ) -> None: 

71 cls._proxy_objects.pop(ref, None) 

72 

73 @classmethod 

74 def _regenerate_proxy_for_target( 

75 cls, target: _PT, **additional_kw: Any 

76 ) -> Self: 

77 raise NotImplementedError() 

78 

79 @overload 

80 @classmethod 

81 def _retrieve_proxy_for_target( 

82 cls, target: _PT, regenerate: Literal[True] = ..., **additional_kw: Any 

83 ) -> Self: ... 

84 

85 @overload 

86 @classmethod 

87 def _retrieve_proxy_for_target( 

88 cls, target: _PT, regenerate: bool = True, **additional_kw: Any 

89 ) -> Optional[Self]: ... 

90 

91 @classmethod 

92 def _retrieve_proxy_for_target( 

93 cls, target: _PT, regenerate: bool = True, **additional_kw: Any 

94 ) -> Optional[Self]: 

95 try: 

96 proxy_ref = cls._proxy_objects[weakref.ref(target)] 

97 except KeyError: 

98 pass 

99 else: 

100 proxy = proxy_ref() 

101 if proxy is not None: 

102 return proxy # type: ignore 

103 

104 if regenerate: 

105 return cls._regenerate_proxy_for_target(target, **additional_kw) 

106 else: 

107 return None 

108 

109 

110class StartableContext(Awaitable[_T_co], abc.ABC): 

111 __slots__ = () 

112 

113 @abc.abstractmethod 

114 async def start(self, is_ctxmanager: bool = False) -> _T_co: 

115 raise NotImplementedError() 

116 

117 def __await__(self) -> Generator[Any, Any, _T_co]: 

118 return self.start().__await__() 

119 

120 async def __aenter__(self) -> _T_co: 

121 return await self.start(is_ctxmanager=True) 

122 

123 @abc.abstractmethod 

124 async def __aexit__( 

125 self, type_: Any, value: Any, traceback: Any 

126 ) -> Optional[bool]: 

127 pass 

128 

129 def _raise_for_not_started(self) -> NoReturn: 

130 raise async_exc.AsyncContextNotStarted( 

131 "%s context has not been started and object has not been awaited." 

132 % (self.__class__.__name__) 

133 ) 

134 

135 

136class GeneratorStartableContext(StartableContext[_T_co]): 

137 __slots__ = ("gen",) 

138 

139 gen: AsyncGenerator[_T_co, Any] 

140 

141 def __init__( 

142 self, 

143 func: Callable[..., AsyncIterator[_T_co]], 

144 args: Tuple[Any, ...], 

145 kwds: Dict[str, Any], 

146 ): 

147 self.gen = func(*args, **kwds) # type: ignore 

148 

149 async def start(self, is_ctxmanager: bool = False) -> _T_co: 

150 try: 

151 start_value = await util.anext_(self.gen) 

152 except StopAsyncIteration: 

153 raise RuntimeError("generator didn't yield") from None 

154 

155 # if not a context manager, then interrupt the generator, don't 

156 # let it complete. this step is technically not needed, as the 

157 # generator will close in any case at gc time. not clear if having 

158 # this here is a good idea or not (though it helps for clarity IMO) 

159 if not is_ctxmanager: 

160 await self.gen.aclose() 

161 

162 return start_value 

163 

164 async def __aexit__( 

165 self, typ: Any, value: Any, traceback: Any 

166 ) -> Optional[bool]: 

167 # vendored from contextlib.py 

168 if typ is None: 

169 try: 

170 await util.anext_(self.gen) 

171 except StopAsyncIteration: 

172 return False 

173 else: 

174 raise RuntimeError("generator didn't stop") 

175 else: 

176 if value is None: 

177 # Need to force instantiation so we can reliably 

178 # tell if we get the same exception back 

179 value = typ() 

180 try: 

181 await self.gen.athrow(value) 

182 except StopAsyncIteration as exc: 

183 # Suppress StopIteration *unless* it's the same exception that 

184 # was passed to throw(). This prevents a StopIteration 

185 # raised inside the "with" statement from being suppressed. 

186 return exc is not value 

187 except RuntimeError as exc: 

188 # Don't re-raise the passed in exception. (issue27122) 

189 if exc is value: 

190 return False 

191 # Avoid suppressing if a Stop(Async)Iteration exception 

192 # was passed to athrow() and later wrapped into a RuntimeError 

193 # (see PEP 479 for sync generators; async generators also 

194 # have this behavior). But do this only if the exception 

195 # wrapped 

196 # by the RuntimeError is actully Stop(Async)Iteration (see 

197 # issue29692). 

198 if ( 

199 isinstance(value, (StopIteration, StopAsyncIteration)) 

200 and exc.__cause__ is value 

201 ): 

202 return False 

203 raise 

204 except BaseException as exc: 

205 # only re-raise if it's *not* the exception that was 

206 # passed to throw(), because __exit__() must not raise 

207 # an exception unless __exit__() itself failed. But throw() 

208 # has to raise the exception to signal propagation, so this 

209 # fixes the impedance mismatch between the throw() protocol 

210 # and the __exit__() protocol. 

211 if exc is not value: 

212 raise 

213 return False 

214 raise RuntimeError("generator didn't stop after athrow()") 

215 

216 

217def asyncstartablecontext( 

218 func: Callable[..., AsyncIterator[_T_co]], 

219) -> Callable[..., GeneratorStartableContext[_T_co]]: 

220 """@asyncstartablecontext decorator. 

221 

222 the decorated function can be called either as ``async with fn()``, **or** 

223 ``await fn()``. This is decidedly different from what 

224 ``@contextlib.asynccontextmanager`` supports, and the usage pattern 

225 is different as well. 

226 

227 Typical usage: 

228 

229 .. sourcecode:: text 

230 

231 @asyncstartablecontext 

232 async def some_async_generator(<arguments>): 

233 <setup> 

234 try: 

235 yield <value> 

236 except GeneratorExit: 

237 # return value was awaited, no context manager is present 

238 # and caller will .close() the resource explicitly 

239 pass 

240 else: 

241 <context manager cleanup> 

242 

243 

244 Above, ``GeneratorExit`` is caught if the function were used as an 

245 ``await``. In this case, it's essential that the cleanup does **not** 

246 occur, so there should not be a ``finally`` block. 

247 

248 If ``GeneratorExit`` is not invoked, this means we're in ``__aexit__`` 

249 and we were invoked as a context manager, and cleanup should proceed. 

250 

251 

252 """ 

253 

254 @functools.wraps(func) 

255 def helper(*args: Any, **kwds: Any) -> GeneratorStartableContext[_T_co]: 

256 return GeneratorStartableContext(func, args, kwds) 

257 

258 return helper 

259 

260 

261class ProxyComparable(ReversibleProxy[_PT]): 

262 __slots__ = () 

263 

264 @util.ro_non_memoized_property 

265 def _proxied(self) -> _PT: 

266 raise NotImplementedError() 

267 

268 def __hash__(self) -> int: 

269 return id(self) 

270 

271 def __eq__(self, other: Any) -> bool: 

272 return ( 

273 isinstance(other, self.__class__) 

274 and self._proxied == other._proxied 

275 ) 

276 

277 def __ne__(self, other: Any) -> bool: 

278 return ( 

279 not isinstance(other, self.__class__) 

280 or self._proxied != other._proxied 

281 )