1import inspect 
    2import typing as t 
    3from functools import WRAPPER_ASSIGNMENTS 
    4from functools import wraps 
    5 
    6from .utils import _PassArg 
    7from .utils import pass_eval_context 
    8 
    9if t.TYPE_CHECKING: 
    10    import typing_extensions as te 
    11 
    12V = t.TypeVar("V") 
    13 
    14 
    15def async_variant(normal_func):  # type: ignore 
    16    def decorator(async_func):  # type: ignore 
    17        pass_arg = _PassArg.from_obj(normal_func) 
    18        need_eval_context = pass_arg is None 
    19 
    20        if pass_arg is _PassArg.environment: 
    21 
    22            def is_async(args: t.Any) -> bool: 
    23                return t.cast(bool, args[0].is_async) 
    24 
    25        else: 
    26 
    27            def is_async(args: t.Any) -> bool: 
    28                return t.cast(bool, args[0].environment.is_async) 
    29 
    30        # Take the doc and annotations from the sync function, but the 
    31        # name from the async function. Pallets-Sphinx-Themes 
    32        # build_function_directive expects __wrapped__ to point to the 
    33        # sync function. 
    34        async_func_attrs = ("__module__", "__name__", "__qualname__") 
    35        normal_func_attrs = tuple(set(WRAPPER_ASSIGNMENTS).difference(async_func_attrs)) 
    36 
    37        @wraps(normal_func, assigned=normal_func_attrs) 
    38        @wraps(async_func, assigned=async_func_attrs, updated=()) 
    39        def wrapper(*args, **kwargs):  # type: ignore 
    40            b = is_async(args) 
    41 
    42            if need_eval_context: 
    43                args = args[1:] 
    44 
    45            if b: 
    46                return async_func(*args, **kwargs) 
    47 
    48            return normal_func(*args, **kwargs) 
    49 
    50        if need_eval_context: 
    51            wrapper = pass_eval_context(wrapper) 
    52 
    53        wrapper.jinja_async_variant = True  # type: ignore[attr-defined] 
    54        return wrapper 
    55 
    56    return decorator 
    57 
    58 
    59_common_primitives = {int, float, bool, str, list, dict, tuple, type(None)} 
    60 
    61 
    62async def auto_await(value: t.Union[t.Awaitable["V"], "V"]) -> "V": 
    63    # Avoid a costly call to isawaitable 
    64    if type(value) in _common_primitives: 
    65        return t.cast("V", value) 
    66 
    67    if inspect.isawaitable(value): 
    68        return await t.cast("t.Awaitable[V]", value) 
    69 
    70    return value 
    71 
    72 
    73class _IteratorToAsyncIterator(t.Generic[V]): 
    74    def __init__(self, iterator: "t.Iterator[V]"): 
    75        self._iterator = iterator 
    76 
    77    def __aiter__(self) -> "te.Self": 
    78        return self 
    79 
    80    async def __anext__(self) -> V: 
    81        try: 
    82            return next(self._iterator) 
    83        except StopIteration as e: 
    84            raise StopAsyncIteration(e.value) from e 
    85 
    86 
    87def auto_aiter( 
    88    iterable: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", 
    89) -> "t.AsyncIterator[V]": 
    90    if hasattr(iterable, "__aiter__"): 
    91        return iterable.__aiter__() 
    92    else: 
    93        return _IteratorToAsyncIterator(iter(iterable)) 
    94 
    95 
    96async def auto_to_list( 
    97    value: "t.Union[t.AsyncIterable[V], t.Iterable[V]]", 
    98) -> t.List["V"]: 
    99    return [x async for x in auto_aiter(value)]