1from __future__ import annotations
2
3import sys
4import typing
5
6if sys.version_info >= (3, 10): # pragma: no cover
7 from typing import ParamSpec
8else: # pragma: no cover
9 from typing_extensions import ParamSpec
10
11from starlette._utils import is_async_callable
12from starlette.concurrency import run_in_threadpool
13
14P = ParamSpec("P")
15
16
17class BackgroundTask:
18 def __init__(
19 self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs
20 ) -> None:
21 self.func = func
22 self.args = args
23 self.kwargs = kwargs
24 self.is_async = is_async_callable(func)
25
26 async def __call__(self) -> None:
27 if self.is_async:
28 await self.func(*self.args, **self.kwargs)
29 else:
30 await run_in_threadpool(self.func, *self.args, **self.kwargs)
31
32
33class BackgroundTasks(BackgroundTask):
34 def __init__(self, tasks: typing.Sequence[BackgroundTask] | None = None):
35 self.tasks = list(tasks) if tasks else []
36
37 def add_task(
38 self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs
39 ) -> None:
40 task = BackgroundTask(func, *args, **kwargs)
41 self.tasks.append(task)
42
43 async def __call__(self) -> None:
44 for task in self.tasks:
45 await task()