1# Copyright The OpenTelemetry Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import asyncio
16import contextlib
17import functools
18from typing import TYPE_CHECKING, Callable, Generic, Iterator, TypeVar
19
20V = TypeVar("V")
21R = TypeVar("R") # Return type
22Pargs = TypeVar("Pargs") # Generic type for arguments
23Pkwargs = TypeVar("Pkwargs") # Generic type for arguments
24
25# We don't actually depend on typing_extensions but we can use it in CI with this conditional
26# import. ParamSpec can be imported directly from typing after python 3.9 is dropped
27# https://peps.python.org/pep-0612/.
28if TYPE_CHECKING:
29 from typing_extensions import ParamSpec
30
31 P = ParamSpec("P") # Generic type for all arguments
32
33
34class _AgnosticContextManager(
35 contextlib._GeneratorContextManager[R],
36 Generic[R],
37): # pylint: disable=protected-access
38 """Context manager that can decorate both async and sync functions.
39
40 This is an overridden version of the contextlib._GeneratorContextManager
41 class that will decorate async functions with an async context manager
42 to end the span AFTER the entire async function coroutine finishes.
43
44 Else it will report near zero spans durations for async functions.
45
46 We are overriding the contextlib._GeneratorContextManager class as
47 reimplementing it is a lot of code to maintain and this class (even if it's
48 marked as protected) doesn't seems like to be evolving a lot.
49
50 For more information, see:
51 https://github.com/open-telemetry/opentelemetry-python/pull/3633
52 """
53
54 def __enter__(self) -> R:
55 """Reimplementing __enter__ to avoid the type error.
56
57 The original __enter__ method returns Any type, but we want to return R.
58 """
59 del self.args, self.kwds, self.func # type: ignore
60 try:
61 return next(self.gen) # type: ignore
62 except StopIteration:
63 raise RuntimeError("generator didn't yield") from None
64
65 def __call__(self, func: V) -> V: # pyright: ignore [reportIncompatibleMethodOverride]
66 if asyncio.iscoroutinefunction(func):
67
68 @functools.wraps(func) # type: ignore
69 async def async_wrapper(*args: Pargs, **kwargs: Pkwargs) -> R: # pyright: ignore [reportInvalidTypeVarUse]
70 with self._recreate_cm(): # type: ignore
71 return await func(*args, **kwargs) # type: ignore
72
73 return async_wrapper # type: ignore
74 return super().__call__(func) # type: ignore
75
76
77def _agnosticcontextmanager(
78 func: "Callable[P, Iterator[R]]",
79) -> "Callable[P, _AgnosticContextManager[R]]":
80 @functools.wraps(func)
81 def helper(*args: Pargs, **kwargs: Pkwargs) -> _AgnosticContextManager[R]: # pyright: ignore [reportInvalidTypeVarUse]
82 return _AgnosticContextManager(func, args, kwargs) # pyright: ignore [reportArgumentType]
83
84 # Ignoring the type to keep the original signature of the function
85 return helper # type: ignore[return-value]