1from __future__ import annotations
2
3import math
4from collections import deque
5from dataclasses import dataclass
6from types import TracebackType
7
8from sniffio import AsyncLibraryNotFoundError
9
10from ..lowlevel import cancel_shielded_checkpoint, checkpoint, checkpoint_if_cancelled
11from ._eventloop import get_async_backend
12from ._exceptions import BusyResourceError, WouldBlock
13from ._tasks import CancelScope
14from ._testing import TaskInfo, get_current_task
15
16
17@dataclass(frozen=True)
18class EventStatistics:
19 """
20 :ivar int tasks_waiting: number of tasks waiting on :meth:`~.Event.wait`
21 """
22
23 tasks_waiting: int
24
25
26@dataclass(frozen=True)
27class CapacityLimiterStatistics:
28 """
29 :ivar int borrowed_tokens: number of tokens currently borrowed by tasks
30 :ivar float total_tokens: total number of available tokens
31 :ivar tuple borrowers: tasks or other objects currently holding tokens borrowed from
32 this limiter
33 :ivar int tasks_waiting: number of tasks waiting on
34 :meth:`~.CapacityLimiter.acquire` or
35 :meth:`~.CapacityLimiter.acquire_on_behalf_of`
36 """
37
38 borrowed_tokens: int
39 total_tokens: float
40 borrowers: tuple[object, ...]
41 tasks_waiting: int
42
43
44@dataclass(frozen=True)
45class LockStatistics:
46 """
47 :ivar bool locked: flag indicating if this lock is locked or not
48 :ivar ~anyio.TaskInfo owner: task currently holding the lock (or ``None`` if the
49 lock is not held by any task)
50 :ivar int tasks_waiting: number of tasks waiting on :meth:`~.Lock.acquire`
51 """
52
53 locked: bool
54 owner: TaskInfo | None
55 tasks_waiting: int
56
57
58@dataclass(frozen=True)
59class ConditionStatistics:
60 """
61 :ivar int tasks_waiting: number of tasks blocked on :meth:`~.Condition.wait`
62 :ivar ~anyio.LockStatistics lock_statistics: statistics of the underlying
63 :class:`~.Lock`
64 """
65
66 tasks_waiting: int
67 lock_statistics: LockStatistics
68
69
70@dataclass(frozen=True)
71class SemaphoreStatistics:
72 """
73 :ivar int tasks_waiting: number of tasks waiting on :meth:`~.Semaphore.acquire`
74
75 """
76
77 tasks_waiting: int
78
79
80class Event:
81 def __new__(cls) -> Event:
82 try:
83 return get_async_backend().create_event()
84 except AsyncLibraryNotFoundError:
85 return EventAdapter()
86
87 def set(self) -> None:
88 """Set the flag, notifying all listeners."""
89 raise NotImplementedError
90
91 def is_set(self) -> bool:
92 """Return ``True`` if the flag is set, ``False`` if not."""
93 raise NotImplementedError
94
95 async def wait(self) -> None:
96 """
97 Wait until the flag has been set.
98
99 If the flag has already been set when this method is called, it returns
100 immediately.
101
102 """
103 raise NotImplementedError
104
105 def statistics(self) -> EventStatistics:
106 """Return statistics about the current state of this event."""
107 raise NotImplementedError
108
109
110class EventAdapter(Event):
111 _internal_event: Event | None = None
112
113 def __new__(cls) -> EventAdapter:
114 return object.__new__(cls)
115
116 @property
117 def _event(self) -> Event:
118 if self._internal_event is None:
119 self._internal_event = get_async_backend().create_event()
120
121 return self._internal_event
122
123 def set(self) -> None:
124 self._event.set()
125
126 def is_set(self) -> bool:
127 return self._internal_event is not None and self._internal_event.is_set()
128
129 async def wait(self) -> None:
130 await self._event.wait()
131
132 def statistics(self) -> EventStatistics:
133 if self._internal_event is None:
134 return EventStatistics(tasks_waiting=0)
135
136 return self._internal_event.statistics()
137
138
139class Lock:
140 _owner_task: TaskInfo | None = None
141
142 def __init__(self) -> None:
143 self._waiters: deque[tuple[TaskInfo, Event]] = deque()
144
145 async def __aenter__(self) -> None:
146 await self.acquire()
147
148 async def __aexit__(
149 self,
150 exc_type: type[BaseException] | None,
151 exc_val: BaseException | None,
152 exc_tb: TracebackType | None,
153 ) -> None:
154 self.release()
155
156 async def acquire(self) -> None:
157 """Acquire the lock."""
158 await checkpoint_if_cancelled()
159 try:
160 self.acquire_nowait()
161 except WouldBlock:
162 task = get_current_task()
163 event = Event()
164 token = task, event
165 self._waiters.append(token)
166 try:
167 await event.wait()
168 except BaseException:
169 if not event.is_set():
170 self._waiters.remove(token)
171 elif self._owner_task == task:
172 self.release()
173
174 raise
175
176 assert self._owner_task == task
177 else:
178 try:
179 await cancel_shielded_checkpoint()
180 except BaseException:
181 self.release()
182 raise
183
184 def acquire_nowait(self) -> None:
185 """
186 Acquire the lock, without blocking.
187
188 :raises ~anyio.WouldBlock: if the operation would block
189
190 """
191 task = get_current_task()
192 if self._owner_task == task:
193 raise RuntimeError("Attempted to acquire an already held Lock")
194
195 if self._owner_task is not None:
196 raise WouldBlock
197
198 self._owner_task = task
199
200 def release(self) -> None:
201 """Release the lock."""
202 if self._owner_task != get_current_task():
203 raise RuntimeError("The current task is not holding this lock")
204
205 if self._waiters:
206 self._owner_task, event = self._waiters.popleft()
207 event.set()
208 else:
209 del self._owner_task
210
211 def locked(self) -> bool:
212 """Return True if the lock is currently held."""
213 return self._owner_task is not None
214
215 def statistics(self) -> LockStatistics:
216 """
217 Return statistics about the current state of this lock.
218
219 .. versionadded:: 3.0
220 """
221 return LockStatistics(self.locked(), self._owner_task, len(self._waiters))
222
223
224class Condition:
225 _owner_task: TaskInfo | None = None
226
227 def __init__(self, lock: Lock | None = None):
228 self._lock = lock or Lock()
229 self._waiters: deque[Event] = deque()
230
231 async def __aenter__(self) -> None:
232 await self.acquire()
233
234 async def __aexit__(
235 self,
236 exc_type: type[BaseException] | None,
237 exc_val: BaseException | None,
238 exc_tb: TracebackType | None,
239 ) -> None:
240 self.release()
241
242 def _check_acquired(self) -> None:
243 if self._owner_task != get_current_task():
244 raise RuntimeError("The current task is not holding the underlying lock")
245
246 async def acquire(self) -> None:
247 """Acquire the underlying lock."""
248 await self._lock.acquire()
249 self._owner_task = get_current_task()
250
251 def acquire_nowait(self) -> None:
252 """
253 Acquire the underlying lock, without blocking.
254
255 :raises ~anyio.WouldBlock: if the operation would block
256
257 """
258 self._lock.acquire_nowait()
259 self._owner_task = get_current_task()
260
261 def release(self) -> None:
262 """Release the underlying lock."""
263 self._lock.release()
264
265 def locked(self) -> bool:
266 """Return True if the lock is set."""
267 return self._lock.locked()
268
269 def notify(self, n: int = 1) -> None:
270 """Notify exactly n listeners."""
271 self._check_acquired()
272 for _ in range(n):
273 try:
274 event = self._waiters.popleft()
275 except IndexError:
276 break
277
278 event.set()
279
280 def notify_all(self) -> None:
281 """Notify all the listeners."""
282 self._check_acquired()
283 for event in self._waiters:
284 event.set()
285
286 self._waiters.clear()
287
288 async def wait(self) -> None:
289 """Wait for a notification."""
290 await checkpoint()
291 event = Event()
292 self._waiters.append(event)
293 self.release()
294 try:
295 await event.wait()
296 except BaseException:
297 if not event.is_set():
298 self._waiters.remove(event)
299
300 raise
301 finally:
302 with CancelScope(shield=True):
303 await self.acquire()
304
305 def statistics(self) -> ConditionStatistics:
306 """
307 Return statistics about the current state of this condition.
308
309 .. versionadded:: 3.0
310 """
311 return ConditionStatistics(len(self._waiters), self._lock.statistics())
312
313
314class Semaphore:
315 def __init__(self, initial_value: int, *, max_value: int | None = None):
316 if not isinstance(initial_value, int):
317 raise TypeError("initial_value must be an integer")
318 if initial_value < 0:
319 raise ValueError("initial_value must be >= 0")
320 if max_value is not None:
321 if not isinstance(max_value, int):
322 raise TypeError("max_value must be an integer or None")
323 if max_value < initial_value:
324 raise ValueError(
325 "max_value must be equal to or higher than initial_value"
326 )
327
328 self._value = initial_value
329 self._max_value = max_value
330 self._waiters: deque[Event] = deque()
331
332 async def __aenter__(self) -> Semaphore:
333 await self.acquire()
334 return self
335
336 async def __aexit__(
337 self,
338 exc_type: type[BaseException] | None,
339 exc_val: BaseException | None,
340 exc_tb: TracebackType | None,
341 ) -> None:
342 self.release()
343
344 async def acquire(self) -> None:
345 """Decrement the semaphore value, blocking if necessary."""
346 await checkpoint_if_cancelled()
347 try:
348 self.acquire_nowait()
349 except WouldBlock:
350 event = Event()
351 self._waiters.append(event)
352 try:
353 await event.wait()
354 except BaseException:
355 if not event.is_set():
356 self._waiters.remove(event)
357 else:
358 self.release()
359
360 raise
361 else:
362 try:
363 await cancel_shielded_checkpoint()
364 except BaseException:
365 self.release()
366 raise
367
368 def acquire_nowait(self) -> None:
369 """
370 Acquire the underlying lock, without blocking.
371
372 :raises ~anyio.WouldBlock: if the operation would block
373
374 """
375 if self._value == 0:
376 raise WouldBlock
377
378 self._value -= 1
379
380 def release(self) -> None:
381 """Increment the semaphore value."""
382 if self._max_value is not None and self._value == self._max_value:
383 raise ValueError("semaphore released too many times")
384
385 if self._waiters:
386 self._waiters.popleft().set()
387 else:
388 self._value += 1
389
390 @property
391 def value(self) -> int:
392 """The current value of the semaphore."""
393 return self._value
394
395 @property
396 def max_value(self) -> int | None:
397 """The maximum value of the semaphore."""
398 return self._max_value
399
400 def statistics(self) -> SemaphoreStatistics:
401 """
402 Return statistics about the current state of this semaphore.
403
404 .. versionadded:: 3.0
405 """
406 return SemaphoreStatistics(len(self._waiters))
407
408
409class CapacityLimiter:
410 def __new__(cls, total_tokens: float) -> CapacityLimiter:
411 try:
412 return get_async_backend().create_capacity_limiter(total_tokens)
413 except AsyncLibraryNotFoundError:
414 return CapacityLimiterAdapter(total_tokens)
415
416 async def __aenter__(self) -> None:
417 raise NotImplementedError
418
419 async def __aexit__(
420 self,
421 exc_type: type[BaseException] | None,
422 exc_val: BaseException | None,
423 exc_tb: TracebackType | None,
424 ) -> bool | None:
425 raise NotImplementedError
426
427 @property
428 def total_tokens(self) -> float:
429 """
430 The total number of tokens available for borrowing.
431
432 This is a read-write property. If the total number of tokens is increased, the
433 proportionate number of tasks waiting on this limiter will be granted their
434 tokens.
435
436 .. versionchanged:: 3.0
437 The property is now writable.
438
439 """
440 raise NotImplementedError
441
442 @total_tokens.setter
443 def total_tokens(self, value: float) -> None:
444 raise NotImplementedError
445
446 @property
447 def borrowed_tokens(self) -> int:
448 """The number of tokens that have currently been borrowed."""
449 raise NotImplementedError
450
451 @property
452 def available_tokens(self) -> float:
453 """The number of tokens currently available to be borrowed"""
454 raise NotImplementedError
455
456 def acquire_nowait(self) -> None:
457 """
458 Acquire a token for the current task without waiting for one to become
459 available.
460
461 :raises ~anyio.WouldBlock: if there are no tokens available for borrowing
462
463 """
464 raise NotImplementedError
465
466 def acquire_on_behalf_of_nowait(self, borrower: object) -> None:
467 """
468 Acquire a token without waiting for one to become available.
469
470 :param borrower: the entity borrowing a token
471 :raises ~anyio.WouldBlock: if there are no tokens available for borrowing
472
473 """
474 raise NotImplementedError
475
476 async def acquire(self) -> None:
477 """
478 Acquire a token for the current task, waiting if necessary for one to become
479 available.
480
481 """
482 raise NotImplementedError
483
484 async def acquire_on_behalf_of(self, borrower: object) -> None:
485 """
486 Acquire a token, waiting if necessary for one to become available.
487
488 :param borrower: the entity borrowing a token
489
490 """
491 raise NotImplementedError
492
493 def release(self) -> None:
494 """
495 Release the token held by the current task.
496
497 :raises RuntimeError: if the current task has not borrowed a token from this
498 limiter.
499
500 """
501 raise NotImplementedError
502
503 def release_on_behalf_of(self, borrower: object) -> None:
504 """
505 Release the token held by the given borrower.
506
507 :raises RuntimeError: if the borrower has not borrowed a token from this
508 limiter.
509
510 """
511 raise NotImplementedError
512
513 def statistics(self) -> CapacityLimiterStatistics:
514 """
515 Return statistics about the current state of this limiter.
516
517 .. versionadded:: 3.0
518
519 """
520 raise NotImplementedError
521
522
523class CapacityLimiterAdapter(CapacityLimiter):
524 _internal_limiter: CapacityLimiter | None = None
525
526 def __new__(cls, total_tokens: float) -> CapacityLimiterAdapter:
527 return object.__new__(cls)
528
529 def __init__(self, total_tokens: float) -> None:
530 self.total_tokens = total_tokens
531
532 @property
533 def _limiter(self) -> CapacityLimiter:
534 if self._internal_limiter is None:
535 self._internal_limiter = get_async_backend().create_capacity_limiter(
536 self._total_tokens
537 )
538
539 return self._internal_limiter
540
541 async def __aenter__(self) -> None:
542 await self._limiter.__aenter__()
543
544 async def __aexit__(
545 self,
546 exc_type: type[BaseException] | None,
547 exc_val: BaseException | None,
548 exc_tb: TracebackType | None,
549 ) -> bool | None:
550 return await self._limiter.__aexit__(exc_type, exc_val, exc_tb)
551
552 @property
553 def total_tokens(self) -> float:
554 if self._internal_limiter is None:
555 return self._total_tokens
556
557 return self._internal_limiter.total_tokens
558
559 @total_tokens.setter
560 def total_tokens(self, value: float) -> None:
561 if not isinstance(value, int) and value is not math.inf:
562 raise TypeError("total_tokens must be an int or math.inf")
563 elif value < 1:
564 raise ValueError("total_tokens must be >= 1")
565
566 if self._internal_limiter is None:
567 self._total_tokens = value
568 return
569
570 self._limiter.total_tokens = value
571
572 @property
573 def borrowed_tokens(self) -> int:
574 if self._internal_limiter is None:
575 return 0
576
577 return self._internal_limiter.borrowed_tokens
578
579 @property
580 def available_tokens(self) -> float:
581 if self._internal_limiter is None:
582 return self._total_tokens
583
584 return self._internal_limiter.available_tokens
585
586 def acquire_nowait(self) -> None:
587 self._limiter.acquire_nowait()
588
589 def acquire_on_behalf_of_nowait(self, borrower: object) -> None:
590 self._limiter.acquire_on_behalf_of_nowait(borrower)
591
592 async def acquire(self) -> None:
593 await self._limiter.acquire()
594
595 async def acquire_on_behalf_of(self, borrower: object) -> None:
596 await self._limiter.acquire_on_behalf_of(borrower)
597
598 def release(self) -> None:
599 self._limiter.release()
600
601 def release_on_behalf_of(self, borrower: object) -> None:
602 self._limiter.release_on_behalf_of(borrower)
603
604 def statistics(self) -> CapacityLimiterStatistics:
605 if self._internal_limiter is None:
606 return CapacityLimiterStatistics(
607 borrowed_tokens=0,
608 total_tokens=self.total_tokens,
609 borrowers=(),
610 tasks_waiting=0,
611 )
612
613 return self._internal_limiter.statistics()
614
615
616class ResourceGuard:
617 """
618 A context manager for ensuring that a resource is only used by a single task at a
619 time.
620
621 Entering this context manager while the previous has not exited it yet will trigger
622 :exc:`BusyResourceError`.
623
624 :param action: the action to guard against (visible in the :exc:`BusyResourceError`
625 when triggered, e.g. "Another task is already {action} this resource")
626
627 .. versionadded:: 4.1
628 """
629
630 __slots__ = "action", "_guarded"
631
632 def __init__(self, action: str = "using"):
633 self.action: str = action
634 self._guarded = False
635
636 def __enter__(self) -> None:
637 if self._guarded:
638 raise BusyResourceError(self.action)
639
640 self._guarded = True
641
642 def __exit__(
643 self,
644 exc_type: type[BaseException] | None,
645 exc_val: BaseException | None,
646 exc_tb: TracebackType | None,
647 ) -> bool | None:
648 self._guarded = False
649 return None