1# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
2# For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE
3# Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt
4
5"""A few useful function/method decorators."""
6
7from __future__ import annotations
8
9import functools
10import inspect
11import sys
12import warnings
13from collections.abc import Callable, Generator
14from typing import TypeVar
15
16from astroid import util
17from astroid.context import InferenceContext
18from astroid.exceptions import InferenceError
19from astroid.typing import InferenceResult
20
21if sys.version_info >= (3, 10):
22 from typing import ParamSpec
23else:
24 from typing_extensions import ParamSpec
25
26_R = TypeVar("_R")
27_P = ParamSpec("_P")
28
29
30def path_wrapper(func):
31 """Return the given infer function wrapped to handle the path.
32
33 Used to stop inference if the node has already been looked
34 at for a given `InferenceContext` to prevent infinite recursion
35 """
36
37 @functools.wraps(func)
38 def wrapped(
39 node, context: InferenceContext | None = None, _func=func, **kwargs
40 ) -> Generator:
41 """Wrapper function handling context."""
42 if context is None:
43 context = InferenceContext()
44 if context.push(node):
45 return
46
47 yielded = set()
48
49 for res in _func(node, context, **kwargs):
50 # unproxy only true instance, not const, tuple, dict...
51 if res.__class__.__name__ == "Instance":
52 ares = res._proxied
53 else:
54 ares = res
55 if ares not in yielded:
56 yield res
57 yielded.add(ares)
58
59 return wrapped
60
61
62def yes_if_nothing_inferred(
63 func: Callable[_P, Generator[InferenceResult]]
64) -> Callable[_P, Generator[InferenceResult]]:
65 def inner(*args: _P.args, **kwargs: _P.kwargs) -> Generator[InferenceResult]:
66 generator = func(*args, **kwargs)
67
68 try:
69 yield next(generator)
70 except StopIteration:
71 # generator is empty
72 yield util.Uninferable
73 return
74
75 yield from generator
76
77 return inner
78
79
80def raise_if_nothing_inferred(
81 func: Callable[_P, Generator[InferenceResult]],
82) -> Callable[_P, Generator[InferenceResult]]:
83 def inner(*args: _P.args, **kwargs: _P.kwargs) -> Generator[InferenceResult]:
84 generator = func(*args, **kwargs)
85 try:
86 yield next(generator)
87 except StopIteration as error:
88 # generator is empty
89 if error.args:
90 raise InferenceError(**error.args[0]) from error
91 raise InferenceError(
92 "StopIteration raised without any error information."
93 ) from error
94 except RecursionError as error:
95 raise InferenceError(
96 f"RecursionError raised with limit {sys.getrecursionlimit()}."
97 ) from error
98
99 yield from generator
100
101 return inner
102
103
104# Expensive decorators only used to emit Deprecation warnings.
105# If no other than the default DeprecationWarning are enabled,
106# fall back to passthrough implementations.
107if util.check_warnings_filter(): # noqa: C901
108
109 def deprecate_default_argument_values(
110 astroid_version: str = "3.0", **arguments: str
111 ) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
112 """Decorator which emits a DeprecationWarning if any arguments specified
113 are None or not passed at all.
114
115 Arguments should be a key-value mapping, with the key being the argument to check
116 and the value being a type annotation as string for the value of the argument.
117
118 To improve performance, only used when DeprecationWarnings other than
119 the default one are enabled.
120 """
121 # Helpful links
122 # Decorator for DeprecationWarning: https://stackoverflow.com/a/49802489
123 # Typing of stacked decorators: https://stackoverflow.com/a/68290080
124
125 def deco(func: Callable[_P, _R]) -> Callable[_P, _R]:
126 """Decorator function."""
127
128 @functools.wraps(func)
129 def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
130 """Emit DeprecationWarnings if conditions are met."""
131
132 keys = list(inspect.signature(func).parameters.keys())
133 for arg, type_annotation in arguments.items():
134 try:
135 index = keys.index(arg)
136 except ValueError:
137 raise ValueError(
138 f"Can't find argument '{arg}' for '{args[0].__class__.__qualname__}'"
139 ) from None
140 # pylint: disable = too-many-boolean-expressions
141 if (
142 # Check kwargs
143 # - if found, check it's not None
144 (arg in kwargs and kwargs[arg] is None)
145 # Check args
146 # - make sure not in kwargs
147 # - len(args) needs to be long enough, if too short
148 # arg can't be in args either
149 # - args[index] should not be None
150 or (
151 arg not in kwargs
152 and (
153 index == -1
154 or len(args) <= index
155 or (len(args) > index and args[index] is None)
156 )
157 )
158 ):
159 warnings.warn(
160 f"'{arg}' will be a required argument for "
161 f"'{args[0].__class__.__qualname__}.{func.__name__}'"
162 f" in astroid {astroid_version} "
163 f"('{arg}' should be of type: '{type_annotation}')",
164 DeprecationWarning,
165 stacklevel=2,
166 )
167 return func(*args, **kwargs)
168
169 return wrapper
170
171 return deco
172
173 def deprecate_arguments(
174 astroid_version: str = "3.0", **arguments: str
175 ) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
176 """Decorator which emits a DeprecationWarning if any arguments specified
177 are passed.
178
179 Arguments should be a key-value mapping, with the key being the argument to check
180 and the value being a string that explains what to do instead of passing the argument.
181
182 To improve performance, only used when DeprecationWarnings other than
183 the default one are enabled.
184 """
185
186 def deco(func: Callable[_P, _R]) -> Callable[_P, _R]:
187 @functools.wraps(func)
188 def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
189 keys = list(inspect.signature(func).parameters.keys())
190 for arg, note in arguments.items():
191 try:
192 index = keys.index(arg)
193 except ValueError:
194 raise ValueError(
195 f"Can't find argument '{arg}' for '{args[0].__class__.__qualname__}'"
196 ) from None
197 if arg in kwargs or len(args) > index:
198 warnings.warn(
199 f"The argument '{arg}' for "
200 f"'{args[0].__class__.__qualname__}.{func.__name__}' is deprecated "
201 f"and will be removed in astroid {astroid_version} ({note})",
202 DeprecationWarning,
203 stacklevel=2,
204 )
205 return func(*args, **kwargs)
206
207 return wrapper
208
209 return deco
210
211else:
212
213 def deprecate_default_argument_values(
214 astroid_version: str = "3.0", **arguments: str
215 ) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
216 """Passthrough decorator to improve performance if DeprecationWarnings are
217 disabled.
218 """
219
220 def deco(func: Callable[_P, _R]) -> Callable[_P, _R]:
221 """Decorator function."""
222 return func
223
224 return deco
225
226 def deprecate_arguments(
227 astroid_version: str = "3.0", **arguments: str
228 ) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
229 """Passthrough decorator to improve performance if DeprecationWarnings are
230 disabled.
231 """
232
233 def deco(func: Callable[_P, _R]) -> Callable[_P, _R]:
234 """Decorator function."""
235 return func
236
237 return deco