1from __future__ import annotations
2
3import threading
4import types
5
6from ._exceptions import ExceptionMapping, PoolTimeout, map_exceptions
7
8# Our async synchronization primatives use either 'anyio' or 'trio' depending
9# on if they're running under asyncio or trio.
10
11try:
12 import trio
13except (ImportError, NotImplementedError): # pragma: nocover
14 trio = None # type: ignore
15
16try:
17 import anyio
18except ImportError: # pragma: nocover
19 anyio = None # type: ignore
20
21
22def current_async_library() -> str:
23 # Determine if we're running under trio or asyncio.
24 # See https://sniffio.readthedocs.io/en/latest/
25 try:
26 import sniffio
27 except ImportError: # pragma: nocover
28 environment = "asyncio"
29 else:
30 environment = sniffio.current_async_library()
31
32 if environment not in ("asyncio", "trio"): # pragma: nocover
33 raise RuntimeError("Running under an unsupported async environment.")
34
35 if environment == "asyncio" and anyio is None: # pragma: nocover
36 raise RuntimeError(
37 "Running with asyncio requires installation of 'httpcore[asyncio]'."
38 )
39
40 if environment == "trio" and trio is None: # pragma: nocover
41 raise RuntimeError(
42 "Running with trio requires installation of 'httpcore[trio]'."
43 )
44
45 return environment
46
47
48class AsyncLock:
49 """
50 This is a standard lock.
51
52 In the sync case `Lock` provides thread locking.
53 In the async case `AsyncLock` provides async locking.
54 """
55
56 def __init__(self) -> None:
57 self._backend = ""
58
59 def setup(self) -> None:
60 """
61 Detect if we're running under 'asyncio' or 'trio' and create
62 a lock with the correct implementation.
63 """
64 self._backend = current_async_library()
65 if self._backend == "trio":
66 self._trio_lock = trio.Lock()
67 elif self._backend == "asyncio":
68 self._anyio_lock = anyio.Lock()
69
70 async def __aenter__(self) -> AsyncLock:
71 if not self._backend:
72 self.setup()
73
74 if self._backend == "trio":
75 await self._trio_lock.acquire()
76 elif self._backend == "asyncio":
77 await self._anyio_lock.acquire()
78
79 return self
80
81 async def __aexit__(
82 self,
83 exc_type: type[BaseException] | None = None,
84 exc_value: BaseException | None = None,
85 traceback: types.TracebackType | None = None,
86 ) -> None:
87 if self._backend == "trio":
88 self._trio_lock.release()
89 elif self._backend == "asyncio":
90 self._anyio_lock.release()
91
92
93class AsyncThreadLock:
94 """
95 This is a threading-only lock for no-I/O contexts.
96
97 In the sync case `ThreadLock` provides thread locking.
98 In the async case `AsyncThreadLock` is a no-op.
99 """
100
101 def __enter__(self) -> AsyncThreadLock:
102 return self
103
104 def __exit__(
105 self,
106 exc_type: type[BaseException] | None = None,
107 exc_value: BaseException | None = None,
108 traceback: types.TracebackType | None = None,
109 ) -> None:
110 pass
111
112
113class AsyncEvent:
114 def __init__(self) -> None:
115 self._backend = ""
116
117 def setup(self) -> None:
118 """
119 Detect if we're running under 'asyncio' or 'trio' and create
120 a lock with the correct implementation.
121 """
122 self._backend = current_async_library()
123 if self._backend == "trio":
124 self._trio_event = trio.Event()
125 elif self._backend == "asyncio":
126 self._anyio_event = anyio.Event()
127
128 def set(self) -> None:
129 if not self._backend:
130 self.setup()
131
132 if self._backend == "trio":
133 self._trio_event.set()
134 elif self._backend == "asyncio":
135 self._anyio_event.set()
136
137 async def wait(self, timeout: float | None = None) -> None:
138 if not self._backend:
139 self.setup()
140
141 if self._backend == "trio":
142 trio_exc_map: ExceptionMapping = {trio.TooSlowError: PoolTimeout}
143 timeout_or_inf = float("inf") if timeout is None else timeout
144 with map_exceptions(trio_exc_map):
145 with trio.fail_after(timeout_or_inf):
146 await self._trio_event.wait()
147 elif self._backend == "asyncio":
148 anyio_exc_map: ExceptionMapping = {TimeoutError: PoolTimeout}
149 with map_exceptions(anyio_exc_map):
150 with anyio.fail_after(timeout):
151 await self._anyio_event.wait()
152
153
154class AsyncSemaphore:
155 def __init__(self, bound: int) -> None:
156 self._bound = bound
157 self._backend = ""
158
159 def setup(self) -> None:
160 """
161 Detect if we're running under 'asyncio' or 'trio' and create
162 a semaphore with the correct implementation.
163 """
164 self._backend = current_async_library()
165 if self._backend == "trio":
166 self._trio_semaphore = trio.Semaphore(
167 initial_value=self._bound, max_value=self._bound
168 )
169 elif self._backend == "asyncio":
170 self._anyio_semaphore = anyio.Semaphore(
171 initial_value=self._bound, max_value=self._bound
172 )
173
174 async def acquire(self) -> None:
175 if not self._backend:
176 self.setup()
177
178 if self._backend == "trio":
179 await self._trio_semaphore.acquire()
180 elif self._backend == "asyncio":
181 await self._anyio_semaphore.acquire()
182
183 async def release(self) -> None:
184 if self._backend == "trio":
185 self._trio_semaphore.release()
186 elif self._backend == "asyncio":
187 self._anyio_semaphore.release()
188
189
190class AsyncShieldCancellation:
191 # For certain portions of our codebase where we're dealing with
192 # closing connections during exception handling we want to shield
193 # the operation from being cancelled.
194 #
195 # with AsyncShieldCancellation():
196 # ... # clean-up operations, shielded from cancellation.
197
198 def __init__(self) -> None:
199 """
200 Detect if we're running under 'asyncio' or 'trio' and create
201 a shielded scope with the correct implementation.
202 """
203 self._backend = current_async_library()
204
205 if self._backend == "trio":
206 self._trio_shield = trio.CancelScope(shield=True)
207 elif self._backend == "asyncio":
208 self._anyio_shield = anyio.CancelScope(shield=True)
209
210 def __enter__(self) -> AsyncShieldCancellation:
211 if self._backend == "trio":
212 self._trio_shield.__enter__()
213 elif self._backend == "asyncio":
214 self._anyio_shield.__enter__()
215 return self
216
217 def __exit__(
218 self,
219 exc_type: type[BaseException] | None = None,
220 exc_value: BaseException | None = None,
221 traceback: types.TracebackType | None = None,
222 ) -> None:
223 if self._backend == "trio":
224 self._trio_shield.__exit__(exc_type, exc_value, traceback)
225 elif self._backend == "asyncio":
226 self._anyio_shield.__exit__(exc_type, exc_value, traceback)
227
228
229# Our thread-based synchronization primitives...
230
231
232class Lock:
233 """
234 This is a standard lock.
235
236 In the sync case `Lock` provides thread locking.
237 In the async case `AsyncLock` provides async locking.
238 """
239
240 def __init__(self) -> None:
241 self._lock = threading.Lock()
242
243 def __enter__(self) -> Lock:
244 self._lock.acquire()
245 return self
246
247 def __exit__(
248 self,
249 exc_type: type[BaseException] | None = None,
250 exc_value: BaseException | None = None,
251 traceback: types.TracebackType | None = None,
252 ) -> None:
253 self._lock.release()
254
255
256class ThreadLock:
257 """
258 This is a threading-only lock for no-I/O contexts.
259
260 In the sync case `ThreadLock` provides thread locking.
261 In the async case `AsyncThreadLock` is a no-op.
262 """
263
264 def __init__(self) -> None:
265 self._lock = threading.Lock()
266
267 def __enter__(self) -> ThreadLock:
268 self._lock.acquire()
269 return self
270
271 def __exit__(
272 self,
273 exc_type: type[BaseException] | None = None,
274 exc_value: BaseException | None = None,
275 traceback: types.TracebackType | None = None,
276 ) -> None:
277 self._lock.release()
278
279
280class Event:
281 def __init__(self) -> None:
282 self._event = threading.Event()
283
284 def set(self) -> None:
285 self._event.set()
286
287 def wait(self, timeout: float | None = None) -> None:
288 if timeout == float("inf"): # pragma: no cover
289 timeout = None
290 if not self._event.wait(timeout=timeout):
291 raise PoolTimeout() # pragma: nocover
292
293
294class Semaphore:
295 def __init__(self, bound: int) -> None:
296 self._semaphore = threading.Semaphore(value=bound)
297
298 def acquire(self) -> None:
299 self._semaphore.acquire()
300
301 def release(self) -> None:
302 self._semaphore.release()
303
304
305class ShieldCancellation:
306 # Thread-synchronous codebases don't support cancellation semantics.
307 # We have this class because we need to mirror the async and sync
308 # cases within our package, but it's just a no-op.
309 def __enter__(self) -> ShieldCancellation:
310 return self
311
312 def __exit__(
313 self,
314 exc_type: type[BaseException] | None = None,
315 exc_value: BaseException | None = None,
316 traceback: types.TracebackType | None = None,
317 ) -> None:
318 pass