1import threading
2from types import TracebackType
3from typing import Optional, Type
4
5from ._exceptions import ExceptionMapping, PoolTimeout, map_exceptions
6
7# Our async synchronization primatives use either 'anyio' or 'trio' depending
8# on if they're running under asyncio or trio.
9
10try:
11 import trio
12except ImportError: # pragma: nocover
13 trio = None # type: ignore
14
15try:
16 import anyio
17except ImportError: # pragma: nocover
18 anyio = None # type: ignore
19
20
21def current_async_library() -> str:
22 # Determine if we're running under trio or asyncio.
23 # See https://sniffio.readthedocs.io/en/latest/
24 try:
25 import sniffio
26 except ImportError: # pragma: nocover
27 environment = "asyncio"
28 else:
29 environment = sniffio.current_async_library()
30
31 if environment not in ("asyncio", "trio"): # pragma: nocover
32 raise RuntimeError("Running under an unsupported async environment.")
33
34 if environment == "asyncio" and anyio is None: # pragma: nocover
35 raise RuntimeError(
36 "Running with asyncio requires installation of 'httpcore[asyncio]'."
37 )
38
39 if environment == "trio" and trio is None: # pragma: nocover
40 raise RuntimeError(
41 "Running with trio requires installation of 'httpcore[trio]'."
42 )
43
44 return environment
45
46
47class AsyncLock:
48 """
49 This is a standard lock.
50
51 In the sync case `Lock` provides thread locking.
52 In the async case `AsyncLock` provides async locking.
53 """
54
55 def __init__(self) -> None:
56 self._backend = ""
57
58 def setup(self) -> None:
59 """
60 Detect if we're running under 'asyncio' or 'trio' and create
61 a lock with the correct implementation.
62 """
63 self._backend = current_async_library()
64 if self._backend == "trio":
65 self._trio_lock = trio.Lock()
66 elif self._backend == "asyncio":
67 self._anyio_lock = anyio.Lock()
68
69 async def __aenter__(self) -> "AsyncLock":
70 if not self._backend:
71 self.setup()
72
73 if self._backend == "trio":
74 await self._trio_lock.acquire()
75 elif self._backend == "asyncio":
76 await self._anyio_lock.acquire()
77
78 return self
79
80 async def __aexit__(
81 self,
82 exc_type: Optional[Type[BaseException]] = None,
83 exc_value: Optional[BaseException] = None,
84 traceback: Optional[TracebackType] = None,
85 ) -> None:
86 if self._backend == "trio":
87 self._trio_lock.release()
88 elif self._backend == "asyncio":
89 self._anyio_lock.release()
90
91
92class AsyncThreadLock:
93 """
94 This is a threading-only lock for no-I/O contexts.
95
96 In the sync case `ThreadLock` provides thread locking.
97 In the async case `AsyncThreadLock` is a no-op.
98 """
99
100 def __enter__(self) -> "AsyncThreadLock":
101 return self
102
103 def __exit__(
104 self,
105 exc_type: Optional[Type[BaseException]] = None,
106 exc_value: Optional[BaseException] = None,
107 traceback: Optional[TracebackType] = None,
108 ) -> None:
109 pass
110
111
112class AsyncEvent:
113 def __init__(self) -> None:
114 self._backend = ""
115
116 def setup(self) -> None:
117 """
118 Detect if we're running under 'asyncio' or 'trio' and create
119 a lock with the correct implementation.
120 """
121 self._backend = current_async_library()
122 if self._backend == "trio":
123 self._trio_event = trio.Event()
124 elif self._backend == "asyncio":
125 self._anyio_event = anyio.Event()
126
127 def set(self) -> None:
128 if not self._backend:
129 self.setup()
130
131 if self._backend == "trio":
132 self._trio_event.set()
133 elif self._backend == "asyncio":
134 self._anyio_event.set()
135
136 async def wait(self, timeout: Optional[float] = None) -> None:
137 if not self._backend:
138 self.setup()
139
140 if self._backend == "trio":
141 trio_exc_map: ExceptionMapping = {trio.TooSlowError: PoolTimeout}
142 timeout_or_inf = float("inf") if timeout is None else timeout
143 with map_exceptions(trio_exc_map):
144 with trio.fail_after(timeout_or_inf):
145 await self._trio_event.wait()
146 elif self._backend == "asyncio":
147 anyio_exc_map: ExceptionMapping = {TimeoutError: PoolTimeout}
148 with map_exceptions(anyio_exc_map):
149 with anyio.fail_after(timeout):
150 await self._anyio_event.wait()
151
152
153class AsyncSemaphore:
154 def __init__(self, bound: int) -> None:
155 self._bound = bound
156 self._backend = ""
157
158 def setup(self) -> None:
159 """
160 Detect if we're running under 'asyncio' or 'trio' and create
161 a semaphore with the correct implementation.
162 """
163 self._backend = current_async_library()
164 if self._backend == "trio":
165 self._trio_semaphore = trio.Semaphore(
166 initial_value=self._bound, max_value=self._bound
167 )
168 elif self._backend == "asyncio":
169 self._anyio_semaphore = anyio.Semaphore(
170 initial_value=self._bound, max_value=self._bound
171 )
172
173 async def acquire(self) -> None:
174 if not self._backend:
175 self.setup()
176
177 if self._backend == "trio":
178 await self._trio_semaphore.acquire()
179 elif self._backend == "asyncio":
180 await self._anyio_semaphore.acquire()
181
182 async def release(self) -> None:
183 if self._backend == "trio":
184 self._trio_semaphore.release()
185 elif self._backend == "asyncio":
186 self._anyio_semaphore.release()
187
188
189class AsyncShieldCancellation:
190 # For certain portions of our codebase where we're dealing with
191 # closing connections during exception handling we want to shield
192 # the operation from being cancelled.
193 #
194 # with AsyncShieldCancellation():
195 # ... # clean-up operations, shielded from cancellation.
196
197 def __init__(self) -> None:
198 """
199 Detect if we're running under 'asyncio' or 'trio' and create
200 a shielded scope with the correct implementation.
201 """
202 self._backend = current_async_library()
203
204 if self._backend == "trio":
205 self._trio_shield = trio.CancelScope(shield=True)
206 elif self._backend == "asyncio":
207 self._anyio_shield = anyio.CancelScope(shield=True)
208
209 def __enter__(self) -> "AsyncShieldCancellation":
210 if self._backend == "trio":
211 self._trio_shield.__enter__()
212 elif self._backend == "asyncio":
213 self._anyio_shield.__enter__()
214 return self
215
216 def __exit__(
217 self,
218 exc_type: Optional[Type[BaseException]] = None,
219 exc_value: Optional[BaseException] = None,
220 traceback: Optional[TracebackType] = None,
221 ) -> None:
222 if self._backend == "trio":
223 self._trio_shield.__exit__(exc_type, exc_value, traceback)
224 elif self._backend == "asyncio":
225 self._anyio_shield.__exit__(exc_type, exc_value, traceback)
226
227
228# Our thread-based synchronization primitives...
229
230
231class Lock:
232 """
233 This is a standard lock.
234
235 In the sync case `Lock` provides thread locking.
236 In the async case `AsyncLock` provides async locking.
237 """
238
239 def __init__(self) -> None:
240 self._lock = threading.Lock()
241
242 def __enter__(self) -> "Lock":
243 self._lock.acquire()
244 return self
245
246 def __exit__(
247 self,
248 exc_type: Optional[Type[BaseException]] = None,
249 exc_value: Optional[BaseException] = None,
250 traceback: Optional[TracebackType] = None,
251 ) -> None:
252 self._lock.release()
253
254
255class ThreadLock:
256 """
257 This is a threading-only lock for no-I/O contexts.
258
259 In the sync case `ThreadLock` provides thread locking.
260 In the async case `AsyncThreadLock` is a no-op.
261 """
262
263 def __init__(self) -> None:
264 self._lock = threading.Lock()
265
266 def __enter__(self) -> "ThreadLock":
267 self._lock.acquire()
268 return self
269
270 def __exit__(
271 self,
272 exc_type: Optional[Type[BaseException]] = None,
273 exc_value: Optional[BaseException] = None,
274 traceback: Optional[TracebackType] = None,
275 ) -> None:
276 self._lock.release()
277
278
279class Event:
280 def __init__(self) -> None:
281 self._event = threading.Event()
282
283 def set(self) -> None:
284 self._event.set()
285
286 def wait(self, timeout: Optional[float] = None) -> None:
287 if timeout == float("inf"): # pragma: no cover
288 timeout = None
289 if not self._event.wait(timeout=timeout):
290 raise PoolTimeout() # pragma: nocover
291
292
293class Semaphore:
294 def __init__(self, bound: int) -> None:
295 self._semaphore = threading.Semaphore(value=bound)
296
297 def acquire(self) -> None:
298 self._semaphore.acquire()
299
300 def release(self) -> None:
301 self._semaphore.release()
302
303
304class ShieldCancellation:
305 # Thread-synchronous codebases don't support cancellation semantics.
306 # We have this class because we need to mirror the async and sync
307 # cases within our package, but it's just a no-op.
308 def __enter__(self) -> "ShieldCancellation":
309 return self
310
311 def __exit__(
312 self,
313 exc_type: Optional[Type[BaseException]] = None,
314 exc_value: Optional[BaseException] = None,
315 traceback: Optional[TracebackType] = None,
316 ) -> None:
317 pass