1# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
2# For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE
3# Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt
4
5from __future__ import annotations
6
7from astroid import nodes
8from astroid.bases import Instance
9from astroid.context import CallContext, InferenceContext
10from astroid.exceptions import InferenceError, NoDefault
11from astroid.typing import InferenceResult
12from astroid.util import Uninferable, UninferableBase, safe_infer
13
14
15class CallSite:
16 """Class for understanding arguments passed into a call site.
17
18 It needs a call context, which contains the arguments and the
19 keyword arguments that were passed into a given call site.
20 In order to infer what an argument represents, call :meth:`infer_argument`
21 with the corresponding function node and the argument name.
22
23 :param callcontext:
24 An instance of :class:`astroid.context.CallContext`, that holds
25 the arguments for the call site.
26 :param argument_context_map:
27 Additional contexts per node, passed in from :attr:`astroid.context.Context.extra_context`
28 :param context:
29 An instance of :class:`astroid.context.Context`.
30 """
31
32 def __init__(
33 self,
34 callcontext: CallContext,
35 argument_context_map=None,
36 context: InferenceContext | None = None,
37 ):
38 if argument_context_map is None:
39 argument_context_map = {}
40 self.argument_context_map = argument_context_map
41 args = callcontext.args
42 keywords = callcontext.keywords
43 self.duplicated_keywords: set[str] = set()
44 self._unpacked_args = self._unpack_args(args, context=context)
45 self._unpacked_kwargs = self._unpack_keywords(keywords, context=context)
46
47 self.positional_arguments = [
48 arg for arg in self._unpacked_args if not isinstance(arg, UninferableBase)
49 ]
50 self.keyword_arguments = {
51 key: value
52 for key, value in self._unpacked_kwargs.items()
53 if not isinstance(value, UninferableBase)
54 }
55
56 @classmethod
57 def from_call(cls, call_node: nodes.Call, context: InferenceContext | None = None):
58 """Get a CallSite object from the given Call node.
59
60 context will be used to force a single inference path.
61 """
62
63 # Determine the callcontext from the given `context` object if any.
64 context = context or InferenceContext()
65 callcontext = CallContext(call_node.args, call_node.keywords)
66 return cls(callcontext, context=context)
67
68 def has_invalid_arguments(self) -> bool:
69 """Check if in the current CallSite were passed *invalid* arguments.
70
71 This can mean multiple things. For instance, if an unpacking
72 of an invalid object was passed, then this method will return True.
73 Other cases can be when the arguments can't be inferred by astroid,
74 for example, by passing objects which aren't known statically.
75 """
76 return len(self.positional_arguments) != len(self._unpacked_args)
77
78 def has_invalid_keywords(self) -> bool:
79 """Check if in the current CallSite were passed *invalid* keyword arguments.
80
81 For instance, unpacking a dictionary with integer keys is invalid
82 (**{1:2}), because the keys must be strings, which will make this
83 method to return True. Other cases where this might return True if
84 objects which can't be inferred were passed.
85 """
86 return len(self.keyword_arguments) != len(self._unpacked_kwargs)
87
88 def _unpack_keywords(
89 self,
90 keywords: list[tuple[str | None, nodes.NodeNG]],
91 context: InferenceContext | None = None,
92 ) -> dict[str | None, InferenceResult]:
93 values: dict[str | None, InferenceResult] = {}
94 context = context or InferenceContext()
95 context.extra_context = self.argument_context_map
96 for name, value in keywords:
97 if name is None:
98 # Then it's an unpacking operation (**)
99 inferred = safe_infer(value, context=context)
100 if not isinstance(inferred, nodes.Dict):
101 # Not something we can work with.
102 values[name] = Uninferable
103 continue
104
105 for dict_key, dict_value in inferred.items:
106 dict_key = safe_infer(dict_key, context=context)
107 if not isinstance(dict_key, nodes.Const):
108 values[name] = Uninferable
109 continue
110 if not isinstance(dict_key.value, str):
111 values[name] = Uninferable
112 continue
113 if dict_key.value in values:
114 # The name is already in the dictionary
115 values[dict_key.value] = Uninferable
116 self.duplicated_keywords.add(dict_key.value)
117 continue
118 values[dict_key.value] = dict_value
119 else:
120 values[name] = value
121 return values
122
123 def _unpack_args(self, args, context: InferenceContext | None = None):
124 values = []
125 context = context or InferenceContext()
126 context.extra_context = self.argument_context_map
127 for arg in args:
128 if isinstance(arg, nodes.Starred):
129 inferred = safe_infer(arg.value, context=context)
130 if isinstance(inferred, UninferableBase):
131 values.append(Uninferable)
132 continue
133 if not hasattr(inferred, "elts"):
134 values.append(Uninferable)
135 continue
136 values.extend(inferred.elts)
137 else:
138 values.append(arg)
139 return values
140
141 def infer_argument(
142 self, funcnode: InferenceResult, name: str, context: InferenceContext
143 ): # noqa: C901
144 """Infer a function argument value according to the call context."""
145 # pylint: disable = too-many-branches
146
147 if not isinstance(funcnode, (nodes.FunctionDef, nodes.Lambda)):
148 raise InferenceError(
149 f"Can not infer function argument value for non-function node {funcnode!r}.",
150 call_site=self,
151 func=funcnode,
152 arg=name,
153 context=context,
154 )
155
156 if name in self.duplicated_keywords:
157 raise InferenceError(
158 "The arguments passed to {func!r} have duplicate keywords.",
159 call_site=self,
160 func=funcnode,
161 arg=name,
162 context=context,
163 )
164
165 # Look into the keywords first, maybe it's already there.
166 try:
167 return self.keyword_arguments[name].infer(context)
168 except KeyError:
169 pass
170
171 # Too many arguments given and no variable arguments.
172 if len(self.positional_arguments) > len(funcnode.args.args):
173 if not funcnode.args.vararg and not funcnode.args.posonlyargs:
174 raise InferenceError(
175 "Too many positional arguments "
176 "passed to {func!r} that does "
177 "not have *args.",
178 call_site=self,
179 func=funcnode,
180 arg=name,
181 context=context,
182 )
183
184 positional = self.positional_arguments[: len(funcnode.args.args)]
185 vararg = self.positional_arguments[len(funcnode.args.args) :]
186
187 # preserving previous behavior, when vararg and kwarg were not included in find_argname results
188 if name in [funcnode.args.vararg, funcnode.args.kwarg]:
189 argindex = None
190 else:
191 argindex = funcnode.args.find_argname(name)[0]
192
193 kwonlyargs = {arg.name for arg in funcnode.args.kwonlyargs}
194 kwargs = {
195 key: value
196 for key, value in self.keyword_arguments.items()
197 if key not in kwonlyargs
198 }
199 # If there are too few positionals compared to
200 # what the function expects to receive, check to see
201 # if the missing positional arguments were passed
202 # as keyword arguments and if so, place them into the
203 # positional args list.
204 if len(positional) < len(funcnode.args.args):
205 for func_arg in funcnode.args.args:
206 if func_arg.name in kwargs:
207 arg = kwargs.pop(func_arg.name)
208 positional.append(arg)
209
210 if argindex is not None:
211 boundnode = context.boundnode
212 # 2. first argument of instance/class method
213 if argindex == 0 and funcnode.type in {"method", "classmethod"}:
214 # context.boundnode is None when an instance method is called with
215 # the class, e.g. MyClass.method(obj, ...). In this case, self
216 # is the first argument.
217 if boundnode is None and funcnode.type == "method" and positional:
218 return positional[0].infer(context=context)
219 if boundnode is None:
220 # XXX can do better ?
221 boundnode = funcnode.parent.frame()
222
223 if isinstance(boundnode, nodes.ClassDef):
224 # Verify that we're accessing a method
225 # of the metaclass through a class, as in
226 # `cls.metaclass_method`. In this case, the
227 # first argument is always the class.
228 method_scope = funcnode.parent.scope()
229 if method_scope is boundnode.metaclass(context=context):
230 return iter((boundnode,))
231
232 if funcnode.type == "method":
233 if not isinstance(boundnode, Instance):
234 boundnode = boundnode.instantiate_class()
235 return iter((boundnode,))
236 if funcnode.type == "classmethod":
237 return iter((boundnode,))
238 # if we have a method, extract one position
239 # from the index, so we'll take in account
240 # the extra parameter represented by `self` or `cls`
241 if funcnode.type in {"method", "classmethod"} and boundnode:
242 argindex -= 1
243 # 2. search arg index
244 try:
245 return self.positional_arguments[argindex].infer(context)
246 except IndexError:
247 pass
248
249 if funcnode.args.kwarg == name:
250 # It wants all the keywords that were passed into
251 # the call site.
252 if self.has_invalid_keywords():
253 raise InferenceError(
254 "Inference failed to find values for all keyword arguments "
255 "to {func!r}: {unpacked_kwargs!r} doesn't correspond to "
256 "{keyword_arguments!r}.",
257 keyword_arguments=self.keyword_arguments,
258 unpacked_kwargs=self._unpacked_kwargs,
259 call_site=self,
260 func=funcnode,
261 arg=name,
262 context=context,
263 )
264 kwarg = nodes.Dict(
265 lineno=funcnode.args.lineno,
266 col_offset=funcnode.args.col_offset,
267 parent=funcnode.args,
268 end_lineno=funcnode.args.end_lineno,
269 end_col_offset=funcnode.args.end_col_offset,
270 )
271 kwarg.postinit(
272 [(nodes.const_factory(key), value) for key, value in kwargs.items()]
273 )
274 return iter((kwarg,))
275 if funcnode.args.vararg == name:
276 # It wants all the args that were passed into
277 # the call site.
278 if self.has_invalid_arguments():
279 raise InferenceError(
280 "Inference failed to find values for all positional "
281 "arguments to {func!r}: {unpacked_args!r} doesn't "
282 "correspond to {positional_arguments!r}.",
283 positional_arguments=self.positional_arguments,
284 unpacked_args=self._unpacked_args,
285 call_site=self,
286 func=funcnode,
287 arg=name,
288 context=context,
289 )
290 args = nodes.Tuple(
291 lineno=funcnode.args.lineno,
292 col_offset=funcnode.args.col_offset,
293 parent=funcnode.args,
294 )
295 args.postinit(vararg)
296 return iter((args,))
297
298 # Check if it's a default parameter.
299 try:
300 return funcnode.args.default_value(name).infer(context)
301 except NoDefault:
302 pass
303 raise InferenceError(
304 "No value found for argument {arg} to {func!r}",
305 call_site=self,
306 func=funcnode,
307 arg=name,
308 context=context,
309 )