Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/httpcore/_synchronization.py: 38%

116 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-25 06:38 +0000

1import threading 

2from types import TracebackType 

3from typing import Optional, Type 

4 

5import sniffio 

6 

7from ._exceptions import ExceptionMapping, PoolTimeout, map_exceptions 

8 

9# Our async synchronization primatives use either 'anyio' or 'trio' depending 

10# on if they're running under asyncio or trio. 

11 

12try: 

13 import trio 

14except ImportError: # pragma: nocover 

15 trio = None # type: ignore 

16 

17try: 

18 import anyio 

19except ImportError: # pragma: nocover 

20 anyio = None # type: ignore 

21 

22 

23class AsyncLock: 

24 def __init__(self) -> None: 

25 self._backend = "" 

26 

27 def setup(self) -> None: 

28 """ 

29 Detect if we're running under 'asyncio' or 'trio' and create 

30 a lock with the correct implementation. 

31 """ 

32 self._backend = sniffio.current_async_library() 

33 if self._backend == "trio": 

34 if trio is None: # pragma: nocover 

35 raise RuntimeError( 

36 "Running under trio, requires the 'trio' package to be installed." 

37 ) 

38 self._trio_lock = trio.Lock() 

39 else: 

40 if anyio is None: # pragma: nocover 

41 raise RuntimeError( 

42 "Running under asyncio requires the 'anyio' package to be installed." 

43 ) 

44 self._anyio_lock = anyio.Lock() 

45 

46 async def __aenter__(self) -> "AsyncLock": 

47 if not self._backend: 

48 self.setup() 

49 

50 if self._backend == "trio": 

51 await self._trio_lock.acquire() 

52 else: 

53 await self._anyio_lock.acquire() 

54 

55 return self 

56 

57 async def __aexit__( 

58 self, 

59 exc_type: Optional[Type[BaseException]] = None, 

60 exc_value: Optional[BaseException] = None, 

61 traceback: Optional[TracebackType] = None, 

62 ) -> None: 

63 if self._backend == "trio": 

64 self._trio_lock.release() 

65 else: 

66 self._anyio_lock.release() 

67 

68 

69class AsyncEvent: 

70 def __init__(self) -> None: 

71 self._backend = "" 

72 

73 def setup(self) -> None: 

74 """ 

75 Detect if we're running under 'asyncio' or 'trio' and create 

76 a lock with the correct implementation. 

77 """ 

78 self._backend = sniffio.current_async_library() 

79 if self._backend == "trio": 

80 if trio is None: # pragma: nocover 

81 raise RuntimeError( 

82 "Running under trio requires the 'trio' package to be installed." 

83 ) 

84 self._trio_event = trio.Event() 

85 else: 

86 if anyio is None: # pragma: nocover 

87 raise RuntimeError( 

88 "Running under asyncio requires the 'anyio' package to be installed." 

89 ) 

90 self._anyio_event = anyio.Event() 

91 

92 def set(self) -> None: 

93 if not self._backend: 

94 self.setup() 

95 

96 if self._backend == "trio": 

97 self._trio_event.set() 

98 else: 

99 self._anyio_event.set() 

100 

101 async def wait(self, timeout: Optional[float] = None) -> None: 

102 if not self._backend: 

103 self.setup() 

104 

105 if self._backend == "trio": 

106 if trio is None: # pragma: nocover 

107 raise RuntimeError( 

108 "Running under trio requires the 'trio' package to be installed." 

109 ) 

110 

111 trio_exc_map: ExceptionMapping = {trio.TooSlowError: PoolTimeout} 

112 timeout_or_inf = float("inf") if timeout is None else timeout 

113 with map_exceptions(trio_exc_map): 

114 with trio.fail_after(timeout_or_inf): 

115 await self._trio_event.wait() 

116 else: 

117 if anyio is None: # pragma: nocover 

118 raise RuntimeError( 

119 "Running under asyncio requires the 'anyio' package to be installed." 

120 ) 

121 

122 anyio_exc_map: ExceptionMapping = {TimeoutError: PoolTimeout} 

123 with map_exceptions(anyio_exc_map): 

124 with anyio.fail_after(timeout): 

125 await self._anyio_event.wait() 

126 

127 

128class AsyncSemaphore: 

129 def __init__(self, bound: int) -> None: 

130 self._bound = bound 

131 self._backend = "" 

132 

133 def setup(self) -> None: 

134 """ 

135 Detect if we're running under 'asyncio' or 'trio' and create 

136 a semaphore with the correct implementation. 

137 """ 

138 self._backend = sniffio.current_async_library() 

139 if self._backend == "trio": 

140 if trio is None: # pragma: nocover 

141 raise RuntimeError( 

142 "Running under trio requires the 'trio' package to be installed." 

143 ) 

144 

145 self._trio_semaphore = trio.Semaphore( 

146 initial_value=self._bound, max_value=self._bound 

147 ) 

148 else: 

149 if anyio is None: # pragma: nocover 

150 raise RuntimeError( 

151 "Running under asyncio requires the 'anyio' package to be installed." 

152 ) 

153 

154 self._anyio_semaphore = anyio.Semaphore( 

155 initial_value=self._bound, max_value=self._bound 

156 ) 

157 

158 async def acquire(self) -> None: 

159 if not self._backend: 

160 self.setup() 

161 

162 if self._backend == "trio": 

163 await self._trio_semaphore.acquire() 

164 else: 

165 await self._anyio_semaphore.acquire() 

166 

167 async def release(self) -> None: 

168 if self._backend == "trio": 

169 self._trio_semaphore.release() 

170 else: 

171 self._anyio_semaphore.release() 

172 

173 

174class AsyncShieldCancellation: 

175 # For certain portions of our codebase where we're dealing with 

176 # closing connections during exception handling we want to shield 

177 # the operation from being cancelled. 

178 # 

179 # with AsyncShieldCancellation(): 

180 # ... # clean-up operations, shielded from cancellation. 

181 

182 def __init__(self) -> None: 

183 """ 

184 Detect if we're running under 'asyncio' or 'trio' and create 

185 a shielded scope with the correct implementation. 

186 """ 

187 self._backend = sniffio.current_async_library() 

188 

189 if self._backend == "trio": 

190 if trio is None: # pragma: nocover 

191 raise RuntimeError( 

192 "Running under trio requires the 'trio' package to be installed." 

193 ) 

194 

195 self._trio_shield = trio.CancelScope(shield=True) 

196 else: 

197 if anyio is None: # pragma: nocover 

198 raise RuntimeError( 

199 "Running under asyncio requires the 'anyio' package to be installed." 

200 ) 

201 

202 self._anyio_shield = anyio.CancelScope(shield=True) 

203 

204 def __enter__(self) -> "AsyncShieldCancellation": 

205 if self._backend == "trio": 

206 self._trio_shield.__enter__() 

207 else: 

208 self._anyio_shield.__enter__() 

209 return self 

210 

211 def __exit__( 

212 self, 

213 exc_type: Optional[Type[BaseException]] = None, 

214 exc_value: Optional[BaseException] = None, 

215 traceback: Optional[TracebackType] = None, 

216 ) -> None: 

217 if self._backend == "trio": 

218 self._trio_shield.__exit__(exc_type, exc_value, traceback) 

219 else: 

220 self._anyio_shield.__exit__(exc_type, exc_value, traceback) 

221 

222 

223# Our thread-based synchronization primitives... 

224 

225 

226class Lock: 

227 def __init__(self) -> None: 

228 self._lock = threading.Lock() 

229 

230 def __enter__(self) -> "Lock": 

231 self._lock.acquire() 

232 return self 

233 

234 def __exit__( 

235 self, 

236 exc_type: Optional[Type[BaseException]] = None, 

237 exc_value: Optional[BaseException] = None, 

238 traceback: Optional[TracebackType] = None, 

239 ) -> None: 

240 self._lock.release() 

241 

242 

243class Event: 

244 def __init__(self) -> None: 

245 self._event = threading.Event() 

246 

247 def set(self) -> None: 

248 self._event.set() 

249 

250 def wait(self, timeout: Optional[float] = None) -> None: 

251 if not self._event.wait(timeout=timeout): 

252 raise PoolTimeout() # pragma: nocover 

253 

254 

255class Semaphore: 

256 def __init__(self, bound: int) -> None: 

257 self._semaphore = threading.Semaphore(value=bound) 

258 

259 def acquire(self) -> None: 

260 self._semaphore.acquire() 

261 

262 def release(self) -> None: 

263 self._semaphore.release() 

264 

265 

266class ShieldCancellation: 

267 # Thread-synchronous codebases don't support cancellation semantics. 

268 # We have this class because we need to mirror the async and sync 

269 # cases within our package, but it's just a no-op. 

270 def __enter__(self) -> "ShieldCancellation": 

271 return self 

272 

273 def __exit__( 

274 self, 

275 exc_type: Optional[Type[BaseException]] = None, 

276 exc_value: Optional[BaseException] = None, 

277 traceback: Optional[TracebackType] = None, 

278 ) -> None: 

279 pass