1# Copyright The OpenTelemetry Authors
2# SPDX-License-Identifier: Apache-2.0
3
4import contextlib
5import functools
6import inspect
7from collections.abc import Callable, Iterator
8from typing import TYPE_CHECKING, Generic, TypeVar
9
10V = TypeVar("V")
11R = TypeVar("R") # Return type
12Pargs = TypeVar("Pargs") # Generic type for arguments
13Pkwargs = TypeVar("Pkwargs") # Generic type for arguments
14
15# We don't actually depend on typing_extensions but we can use it in CI with this conditional
16# import. ParamSpec can be imported directly from typing after python 3.9 is dropped
17# https://peps.python.org/pep-0612/.
18if TYPE_CHECKING:
19 from typing_extensions import ParamSpec
20
21 P = ParamSpec("P") # Generic type for all arguments
22
23
24class _AgnosticContextManager(
25 contextlib._GeneratorContextManager[R],
26 Generic[R],
27): # pylint: disable=protected-access
28 """Context manager that can decorate both async and sync functions.
29
30 This is an overridden version of the contextlib._GeneratorContextManager
31 class that will decorate async functions with an async context manager
32 to end the span AFTER the entire async function coroutine finishes.
33
34 Else it will report near zero spans durations for async functions.
35
36 We are overriding the contextlib._GeneratorContextManager class as
37 reimplementing it is a lot of code to maintain and this class (even if it's
38 marked as protected) doesn't seems like to be evolving a lot.
39
40 For more information, see:
41 https://github.com/open-telemetry/opentelemetry-python/pull/3633
42 """
43
44 def __enter__(self) -> R:
45 """Reimplementing __enter__ to avoid the type error.
46
47 The original __enter__ method returns Any type, but we want to return R.
48 """
49 del self.args, self.kwds, self.func # type: ignore
50 try:
51 return next(self.gen) # type: ignore
52 except StopIteration:
53 raise RuntimeError("generator didn't yield") from None
54
55 def __call__(self, func: V) -> V: # pyright: ignore [reportIncompatibleMethodOverride]
56 if inspect.iscoroutinefunction(func):
57
58 @functools.wraps(func) # type: ignore
59 async def async_wrapper(*args: Pargs, **kwargs: Pkwargs) -> R: # pyright: ignore [reportInvalidTypeVarUse]
60 with self._recreate_cm(): # type: ignore
61 return await func(*args, **kwargs) # type: ignore
62
63 return async_wrapper # type: ignore
64 return super().__call__(func) # type: ignore
65
66
67def _agnosticcontextmanager(
68 func: "Callable[P, Iterator[R]]",
69) -> "Callable[P, _AgnosticContextManager[R]]":
70 @functools.wraps(func)
71 def helper(*args: Pargs, **kwargs: Pkwargs) -> _AgnosticContextManager[R]: # pyright: ignore [reportInvalidTypeVarUse]
72 return _AgnosticContextManager(func, args, kwargs) # pyright: ignore [reportArgumentType]
73
74 # Ignoring the type to keep the original signature of the function
75 return helper # type: ignore[return-value]