1from __future__ import annotations
2
3import sys
4from collections.abc import Sequence
5from typing import Any, Callable
6
7if sys.version_info >= (3, 10): # pragma: no cover
8 from typing import ParamSpec
9else: # pragma: no cover
10 from typing_extensions import ParamSpec
11
12from starlette._utils import is_async_callable
13from starlette.concurrency import run_in_threadpool
14
15P = ParamSpec("P")
16
17
18class BackgroundTask:
19 def __init__(self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> None:
20 self.func = func
21 self.args = args
22 self.kwargs = kwargs
23 self.is_async = is_async_callable(func)
24
25 async def __call__(self) -> None:
26 if self.is_async:
27 await self.func(*self.args, **self.kwargs)
28 else:
29 await run_in_threadpool(self.func, *self.args, **self.kwargs)
30
31
32class BackgroundTasks(BackgroundTask):
33 def __init__(self, tasks: Sequence[BackgroundTask] | None = None):
34 self.tasks = list(tasks) if tasks else []
35
36 def add_task(self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> None:
37 task = BackgroundTask(func, *args, **kwargs)
38 self.tasks.append(task)
39
40 async def __call__(self) -> None:
41 for task in self.tasks:
42 await task()