1# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
2# For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE
3# Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt
4
5"""This module contains some base nodes that can be inherited for the different nodes.
6
7Previously these were called Mixin nodes.
8"""
9
10from __future__ import annotations
11
12import itertools
13from collections.abc import Callable, Generator, Iterator
14from functools import cached_property, lru_cache, partial
15from typing import TYPE_CHECKING, Any, ClassVar
16
17from astroid import bases, nodes, util
18from astroid.context import (
19 CallContext,
20 InferenceContext,
21 bind_context_to_node,
22)
23from astroid.exceptions import (
24 AttributeInferenceError,
25 InferenceError,
26)
27from astroid.interpreter import dunder_lookup
28from astroid.nodes.node_ng import NodeNG
29from astroid.typing import InferenceResult
30
31if TYPE_CHECKING:
32 from astroid.nodes.node_classes import LocalsDictNodeNG
33
34 GetFlowFactory = Callable[
35 [
36 InferenceResult,
37 InferenceResult | None,
38 nodes.AugAssign | nodes.BinOp,
39 InferenceResult,
40 InferenceResult | None,
41 InferenceContext,
42 InferenceContext,
43 ],
44 list[partial[Generator[InferenceResult]]],
45 ]
46
47
48class Statement(NodeNG):
49 """Statement node adding a few attributes.
50
51 NOTE: This class is part of the public API of 'astroid.nodes'.
52 """
53
54 is_statement = True
55 """Whether this node indicates a statement."""
56
57 def next_sibling(self):
58 """The next sibling statement node.
59
60 :returns: The next sibling statement node.
61 :rtype: NodeNG or None
62 """
63 stmts = self.parent.child_sequence(self)
64 index = stmts.index(self)
65 try:
66 return stmts[index + 1]
67 except IndexError:
68 return None
69
70 def previous_sibling(self):
71 """The previous sibling statement.
72
73 :returns: The previous sibling statement node.
74 :rtype: NodeNG or None
75 """
76 stmts = self.parent.child_sequence(self)
77 index = stmts.index(self)
78 if index >= 1:
79 return stmts[index - 1]
80 return None
81
82
83class NoChildrenNode(NodeNG):
84 """Base nodes for nodes with no children, e.g. Pass."""
85
86 def get_children(self) -> Iterator[NodeNG]:
87 yield from ()
88
89
90class FilterStmtsBaseNode(NodeNG):
91 """Base node for statement filtering and assignment type."""
92
93 def _get_filtered_stmts(self, _, node, _stmts, mystmt: Statement | None):
94 """Method used in _filter_stmts to get statements and trigger break."""
95 if self.statement() is mystmt:
96 # original node's statement is the assignment, only keep
97 # current node (gen exp, list comp)
98 return [node], True
99 return _stmts, False
100
101 def assign_type(self):
102 return self
103
104
105class AssignTypeNode(NodeNG):
106 """Base node for nodes that can 'assign' such as AnnAssign."""
107
108 def assign_type(self):
109 return self
110
111 def _get_filtered_stmts(self, lookup_node, node, _stmts, mystmt: Statement | None):
112 """Method used in filter_stmts."""
113 if self is mystmt:
114 return _stmts, True
115 if self.statement() is mystmt:
116 # original node's statement is the assignment, only keep
117 # current node (gen exp, list comp)
118 return [node], True
119 return _stmts, False
120
121
122class ParentAssignNode(AssignTypeNode):
123 """Base node for nodes whose assign_type is determined by the parent node."""
124
125 def assign_type(self):
126 return self.parent.assign_type()
127
128
129class ImportNode(FilterStmtsBaseNode, NoChildrenNode, Statement):
130 """Base node for From and Import Nodes."""
131
132 modname: str | None
133 """The module that is being imported from.
134
135 This is ``None`` for relative imports.
136 """
137
138 names: list[tuple[str, str | None]]
139 """What is being imported from the module.
140
141 Each entry is a :class:`tuple` of the name being imported,
142 and the alias that the name is assigned to (if any).
143 """
144
145 def _infer_name(self, frame, name):
146 return name
147
148 def do_import_module(self, modname: str | None = None) -> nodes.Module:
149 """Return the ast for a module whose name is <modname> imported by <self>."""
150 mymodule = self.root()
151 level: int | None = getattr(self, "level", None) # Import has no level
152 if modname is None:
153 modname = self.modname
154 # If the module ImportNode is importing is a module with the same name
155 # as the file that contains the ImportNode we don't want to use the cache
156 # to make sure we use the import system to get the correct module.
157 if (
158 modname
159 # pylint: disable-next=no-member # pylint doesn't recognize type of mymodule
160 and mymodule.relative_to_absolute_name(modname, level) == mymodule.name
161 ):
162 use_cache = False
163 else:
164 use_cache = True
165
166 # pylint: disable-next=no-member # pylint doesn't recognize type of mymodule
167 return mymodule.import_module(
168 modname,
169 level=level,
170 relative_only=bool(level and level >= 1),
171 use_cache=use_cache,
172 )
173
174 def real_name(self, asname: str) -> str:
175 """Get name from 'as' name."""
176 for name, _asname in self.names:
177 if name == "*":
178 return asname
179 if not _asname:
180 name = name.split(".", 1)[0]
181 _asname = name
182 if asname == _asname:
183 return name
184 raise AttributeInferenceError(
185 "Could not find original name for {attribute} in {target!r}",
186 target=self,
187 attribute=asname,
188 )
189
190
191class MultiLineBlockNode(NodeNG):
192 """Base node for multi-line blocks, e.g. For and FunctionDef.
193
194 Note that this does not apply to every node with a `body` field.
195 For instance, an If node has a multi-line body, but the body of an
196 IfExpr is not multi-line, and hence cannot contain Return nodes,
197 Assign nodes, etc.
198 """
199
200 _multi_line_block_fields: ClassVar[tuple[str, ...]] = ()
201
202 @cached_property
203 def _multi_line_blocks(self):
204 return tuple(getattr(self, field) for field in self._multi_line_block_fields)
205
206 def _get_return_nodes_skip_functions(self):
207 for block in self._multi_line_blocks:
208 for child_node in block:
209 if child_node.is_function:
210 continue
211 yield from child_node._get_return_nodes_skip_functions()
212
213 def _get_yield_nodes_skip_functions(self):
214 for block in self._multi_line_blocks:
215 for child_node in block:
216 if child_node.is_function:
217 continue
218 yield from child_node._get_yield_nodes_skip_functions()
219
220 def _get_yield_nodes_skip_lambdas(self):
221 for block in self._multi_line_blocks:
222 for child_node in block:
223 if child_node.is_lambda:
224 continue
225 yield from child_node._get_yield_nodes_skip_lambdas()
226
227 @cached_property
228 def _assign_nodes_in_scope(self) -> list[nodes.Assign]:
229 children_assign_nodes = (
230 child_node._assign_nodes_in_scope
231 for block in self._multi_line_blocks
232 for child_node in block
233 )
234 return list(itertools.chain.from_iterable(children_assign_nodes))
235
236
237class MultiLineWithElseBlockNode(MultiLineBlockNode):
238 """Base node for multi-line blocks that can have else statements."""
239
240 @cached_property
241 def blockstart_tolineno(self):
242 return self.lineno
243
244 def _elsed_block_range(
245 self, lineno: int, orelse: list[nodes.NodeNG], last: int | None = None
246 ) -> tuple[int, int]:
247 """Handle block line numbers range for try/finally, for, if and while
248 statements.
249 """
250 if lineno == self.fromlineno:
251 return lineno, lineno
252 if orelse:
253 if lineno >= orelse[0].fromlineno:
254 return lineno, orelse[-1].tolineno
255 return lineno, orelse[0].fromlineno - 1
256 return lineno, last or self.tolineno
257
258
259class LookupMixIn(NodeNG):
260 """Mixin to look up a name in the right scope."""
261
262 @lru_cache # noqa
263 def lookup(self, name: str) -> tuple[LocalsDictNodeNG, list[NodeNG]]:
264 """Lookup where the given variable is assigned.
265
266 The lookup starts from self's scope. If self is not a frame itself
267 and the name is found in the inner frame locals, statements will be
268 filtered to remove ignorable statements according to self's location.
269
270 :param name: The name of the variable to find assignments for.
271
272 :returns: The scope node and the list of assignments associated to the
273 given name according to the scope where it has been found (locals,
274 globals or builtin).
275 """
276 return self.scope().scope_lookup(self, name)
277
278 def ilookup(self, name):
279 """Lookup the inferred values of the given variable.
280
281 :param name: The variable name to find values for.
282 :type name: str
283
284 :returns: The inferred values of the statements returned from
285 :meth:`lookup`.
286 :rtype: iterable
287 """
288 frame, stmts = self.lookup(name)
289 context = InferenceContext()
290 return bases._infer_stmts(stmts, context, frame)
291
292
293def _reflected_name(name) -> str:
294 return "__r" + name[2:]
295
296
297def _augmented_name(name) -> str:
298 return "__i" + name[2:]
299
300
301BIN_OP_METHOD = {
302 "+": "__add__",
303 "-": "__sub__",
304 "/": "__truediv__",
305 "//": "__floordiv__",
306 "*": "__mul__",
307 "**": "__pow__",
308 "%": "__mod__",
309 "&": "__and__",
310 "|": "__or__",
311 "^": "__xor__",
312 "<<": "__lshift__",
313 ">>": "__rshift__",
314 "@": "__matmul__",
315}
316
317REFLECTED_BIN_OP_METHOD = {
318 key: _reflected_name(value) for (key, value) in BIN_OP_METHOD.items()
319}
320AUGMENTED_OP_METHOD = {
321 key + "=": _augmented_name(value) for (key, value) in BIN_OP_METHOD.items()
322}
323
324
325class OperatorNode(NodeNG):
326 @staticmethod
327 def _filter_operation_errors(
328 infer_callable: Callable[
329 [InferenceContext | None],
330 Generator[InferenceResult | util.BadOperationMessage],
331 ],
332 context: InferenceContext | None,
333 error: type[util.BadOperationMessage],
334 ) -> Generator[InferenceResult]:
335 for result in infer_callable(context):
336 if isinstance(result, error):
337 # For the sake of .infer(), we don't care about operation
338 # errors, which is the job of a linter. So return something
339 # which shows that we can't infer the result.
340 yield util.Uninferable
341 else:
342 yield result
343
344 @staticmethod
345 def _is_not_implemented(const) -> bool:
346 """Check if the given const node is NotImplemented."""
347 return isinstance(const, nodes.Const) and const.value is NotImplemented
348
349 @staticmethod
350 def _infer_old_style_string_formatting(
351 instance: nodes.Const, other: nodes.NodeNG, context: InferenceContext
352 ) -> tuple[util.UninferableBase | nodes.Const]:
353 """Infer the result of '"string" % ...'.
354
355 TODO: Instead of returning Uninferable we should rely
356 on the call to '%' to see if the result is actually uninferable.
357 """
358 if isinstance(other, nodes.Tuple):
359 if util.Uninferable in other.elts:
360 return (util.Uninferable,)
361 inferred_positional = [util.safe_infer(i, context) for i in other.elts]
362 if all(isinstance(i, nodes.Const) for i in inferred_positional):
363 values = tuple(i.value for i in inferred_positional)
364 else:
365 values = None
366 elif isinstance(other, nodes.Dict):
367 values: dict[Any, Any] = {}
368 for pair in other.items:
369 key = util.safe_infer(pair[0], context)
370 if not isinstance(key, nodes.Const):
371 return (util.Uninferable,)
372 value = util.safe_infer(pair[1], context)
373 if not isinstance(value, nodes.Const):
374 return (util.Uninferable,)
375 values[key.value] = value.value
376 elif isinstance(other, nodes.Const):
377 values = other.value
378 else:
379 return (util.Uninferable,)
380
381 try:
382 return (nodes.const_factory(instance.value % values),)
383 except (TypeError, KeyError, ValueError):
384 return (util.Uninferable,)
385
386 @staticmethod
387 def _invoke_binop_inference(
388 instance: InferenceResult,
389 opnode: nodes.AugAssign | nodes.BinOp,
390 op: str,
391 other: InferenceResult,
392 context: InferenceContext,
393 method_name: str,
394 ) -> Generator[InferenceResult]:
395 """Invoke binary operation inference on the given instance."""
396 methods = dunder_lookup.lookup(instance, method_name)
397 context = bind_context_to_node(context, instance)
398 method = methods[0]
399 context.callcontext.callee = method
400
401 if (
402 isinstance(instance, nodes.Const)
403 and isinstance(instance.value, str)
404 and op == "%"
405 ):
406 return iter(
407 OperatorNode._infer_old_style_string_formatting(
408 instance, other, context
409 )
410 )
411
412 try:
413 inferred = next(method.infer(context=context))
414 except StopIteration as e:
415 raise InferenceError(node=method, context=context) from e
416 if isinstance(inferred, util.UninferableBase):
417 raise InferenceError
418 if not isinstance(
419 instance,
420 (nodes.Const, nodes.Tuple, nodes.List, nodes.ClassDef, bases.Instance),
421 ):
422 raise InferenceError # pragma: no cover # Used as a failsafe
423 return instance.infer_binary_op(opnode, op, other, context, inferred)
424
425 @staticmethod
426 def _aug_op(
427 instance: InferenceResult,
428 opnode: nodes.AugAssign,
429 op: str,
430 other: InferenceResult,
431 context: InferenceContext,
432 reverse: bool = False,
433 ) -> partial[Generator[InferenceResult]]:
434 """Get an inference callable for an augmented binary operation."""
435 method_name = AUGMENTED_OP_METHOD[op]
436 return partial(
437 OperatorNode._invoke_binop_inference,
438 instance=instance,
439 op=op,
440 opnode=opnode,
441 other=other,
442 context=context,
443 method_name=method_name,
444 )
445
446 @staticmethod
447 def _bin_op(
448 instance: InferenceResult,
449 opnode: nodes.AugAssign | nodes.BinOp,
450 op: str,
451 other: InferenceResult,
452 context: InferenceContext,
453 reverse: bool = False,
454 ) -> partial[Generator[InferenceResult]]:
455 """Get an inference callable for a normal binary operation.
456
457 If *reverse* is True, then the reflected method will be used instead.
458 """
459 if reverse:
460 method_name = REFLECTED_BIN_OP_METHOD[op]
461 else:
462 method_name = BIN_OP_METHOD[op]
463 return partial(
464 OperatorNode._invoke_binop_inference,
465 instance=instance,
466 op=op,
467 opnode=opnode,
468 other=other,
469 context=context,
470 method_name=method_name,
471 )
472
473 @staticmethod
474 def _bin_op_or_union_type(
475 left: bases.UnionType | nodes.ClassDef | nodes.Const,
476 right: bases.UnionType | nodes.ClassDef | nodes.Const,
477 ) -> Generator[InferenceResult]:
478 """Create a new UnionType instance for binary or, e.g. int | str."""
479 yield bases.UnionType(left, right)
480
481 @staticmethod
482 def _get_binop_contexts(context, left, right):
483 """Get contexts for binary operations.
484
485 This will return two inference contexts, the first one
486 for x.__op__(y), the other one for y.__rop__(x), where
487 only the arguments are inversed.
488 """
489 # The order is important, since the first one should be
490 # left.__op__(right).
491 for arg in (right, left):
492 new_context = context.clone()
493 new_context.callcontext = CallContext(args=[arg])
494 new_context.boundnode = None
495 yield new_context
496
497 @staticmethod
498 def _same_type(type1, type2) -> bool:
499 """Check if type1 is the same as type2."""
500 return type1.qname() == type2.qname()
501
502 @staticmethod
503 def _get_aug_flow(
504 left: InferenceResult,
505 left_type: InferenceResult | None,
506 aug_opnode: nodes.AugAssign,
507 right: InferenceResult,
508 right_type: InferenceResult | None,
509 context: InferenceContext,
510 reverse_context: InferenceContext,
511 ) -> list[partial[Generator[InferenceResult]]]:
512 """Get the flow for augmented binary operations.
513
514 The rules are a bit messy:
515
516 * if left and right have the same type, then left.__augop__(right)
517 is first tried and then left.__op__(right).
518 * if left and right are unrelated typewise, then
519 left.__augop__(right) is tried, then left.__op__(right)
520 is tried and then right.__rop__(left) is tried.
521 * if left is a subtype of right, then left.__augop__(right)
522 is tried and then left.__op__(right).
523 * if left is a supertype of right, then left.__augop__(right)
524 is tried, then right.__rop__(left) and then
525 left.__op__(right)
526 """
527 from astroid import helpers # pylint: disable=import-outside-toplevel
528
529 bin_op = aug_opnode.op.strip("=")
530 aug_op = aug_opnode.op
531 if OperatorNode._same_type(left_type, right_type):
532 methods = [
533 OperatorNode._aug_op(left, aug_opnode, aug_op, right, context),
534 OperatorNode._bin_op(left, aug_opnode, bin_op, right, context),
535 ]
536 elif helpers.is_subtype(left_type, right_type):
537 methods = [
538 OperatorNode._aug_op(left, aug_opnode, aug_op, right, context),
539 OperatorNode._bin_op(left, aug_opnode, bin_op, right, context),
540 ]
541 elif helpers.is_supertype(left_type, right_type):
542 methods = [
543 OperatorNode._aug_op(left, aug_opnode, aug_op, right, context),
544 OperatorNode._bin_op(
545 right, aug_opnode, bin_op, left, reverse_context, reverse=True
546 ),
547 OperatorNode._bin_op(left, aug_opnode, bin_op, right, context),
548 ]
549 else:
550 methods = [
551 OperatorNode._aug_op(left, aug_opnode, aug_op, right, context),
552 OperatorNode._bin_op(left, aug_opnode, bin_op, right, context),
553 OperatorNode._bin_op(
554 right, aug_opnode, bin_op, left, reverse_context, reverse=True
555 ),
556 ]
557 return methods
558
559 @staticmethod
560 def _get_binop_flow(
561 left: InferenceResult,
562 left_type: InferenceResult | None,
563 binary_opnode: nodes.AugAssign | nodes.BinOp,
564 right: InferenceResult,
565 right_type: InferenceResult | None,
566 context: InferenceContext,
567 reverse_context: InferenceContext,
568 ) -> list[partial[Generator[InferenceResult]]]:
569 """Get the flow for binary operations.
570
571 The rules are a bit messy:
572
573 * if left and right have the same type, then only one
574 method will be called, left.__op__(right)
575 * if left and right are unrelated typewise, then first
576 left.__op__(right) is tried and if this does not exist
577 or returns NotImplemented, then right.__rop__(left) is tried.
578 * if left is a subtype of right, then only left.__op__(right)
579 is tried.
580 * if left is a supertype of right, then right.__rop__(left)
581 is first tried and then left.__op__(right)
582 """
583 from astroid import helpers # pylint: disable=import-outside-toplevel
584
585 op = binary_opnode.op
586 if OperatorNode._same_type(left_type, right_type):
587 methods = [OperatorNode._bin_op(left, binary_opnode, op, right, context)]
588 elif helpers.is_subtype(left_type, right_type):
589 methods = [OperatorNode._bin_op(left, binary_opnode, op, right, context)]
590 elif helpers.is_supertype(left_type, right_type):
591 methods = [
592 OperatorNode._bin_op(
593 right, binary_opnode, op, left, reverse_context, reverse=True
594 ),
595 OperatorNode._bin_op(left, binary_opnode, op, right, context),
596 ]
597 else:
598 methods = [
599 OperatorNode._bin_op(left, binary_opnode, op, right, context),
600 OperatorNode._bin_op(
601 right, binary_opnode, op, left, reverse_context, reverse=True
602 ),
603 ]
604
605 # pylint: disable = too-many-boolean-expressions
606 if (
607 op == "|"
608 and (
609 isinstance(left, (bases.UnionType, nodes.ClassDef))
610 or (isinstance(left, nodes.Const) and left.value is None)
611 )
612 and (
613 isinstance(right, (bases.UnionType, nodes.ClassDef))
614 or (isinstance(right, nodes.Const) and right.value is None)
615 )
616 ):
617 methods.extend([partial(OperatorNode._bin_op_or_union_type, left, right)])
618 return methods
619
620 @staticmethod
621 def _infer_binary_operation(
622 left: InferenceResult,
623 right: InferenceResult,
624 binary_opnode: nodes.AugAssign | nodes.BinOp,
625 context: InferenceContext,
626 flow_factory: GetFlowFactory,
627 ) -> Generator[InferenceResult | util.BadBinaryOperationMessage]:
628 """Infer a binary operation between a left operand and a right operand.
629
630 This is used by both normal binary operations and augmented binary
631 operations, the only difference is the flow factory used.
632 """
633 from astroid import helpers # pylint: disable=import-outside-toplevel
634
635 context, reverse_context = OperatorNode._get_binop_contexts(
636 context, left, right
637 )
638 left_type = helpers.object_type(left)
639 right_type = helpers.object_type(right)
640 methods = flow_factory(
641 left, left_type, binary_opnode, right, right_type, context, reverse_context
642 )
643 for method in methods:
644 try:
645 results = list(method())
646 except AttributeError:
647 continue
648 except AttributeInferenceError:
649 continue
650 except InferenceError:
651 yield util.Uninferable
652 return
653 else:
654 if any(isinstance(result, util.UninferableBase) for result in results):
655 yield util.Uninferable
656 return
657
658 if all(map(OperatorNode._is_not_implemented, results)):
659 continue
660 not_implemented = sum(
661 1 for result in results if OperatorNode._is_not_implemented(result)
662 )
663 if not_implemented and not_implemented != len(results):
664 # Can't infer yet what this is.
665 yield util.Uninferable
666 return
667
668 yield from results
669 return
670
671 # The operation doesn't seem to be supported so let the caller know about it
672 yield util.BadBinaryOperationMessage(left_type, binary_opnode.op, right_type)