1# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
2# For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE
3# Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt
4
5"""This module contains some base nodes that can be inherited for the different nodes.
6
7Previously these were called Mixin nodes.
8"""
9
10from __future__ import annotations
11
12import itertools
13from collections.abc import Callable, Generator, Iterator
14from functools import cached_property, lru_cache, partial
15from typing import TYPE_CHECKING, Any, ClassVar
16
17from astroid import bases, nodes, util
18from astroid.context import (
19 CallContext,
20 InferenceContext,
21 bind_context_to_node,
22)
23from astroid.exceptions import (
24 AttributeInferenceError,
25 InferenceError,
26)
27from astroid.interpreter import dunder_lookup
28from astroid.nodes.node_ng import NodeNG
29from astroid.typing import InferenceResult
30
31if TYPE_CHECKING:
32 from astroid.nodes.node_classes import LocalsDictNodeNG
33
34 GetFlowFactory = Callable[
35 [
36 InferenceResult,
37 InferenceResult | None,
38 nodes.AugAssign | nodes.BinOp,
39 InferenceResult,
40 InferenceResult | None,
41 InferenceContext,
42 InferenceContext,
43 ],
44 list[partial[Generator[InferenceResult]]],
45 ]
46
47
48class Statement(NodeNG):
49 """Statement node adding a few attributes.
50
51 NOTE: This class is part of the public API of 'astroid.nodes'.
52 """
53
54 is_statement = True
55 """Whether this node indicates a statement."""
56
57 def next_sibling(self):
58 """The next sibling statement node.
59
60 :returns: The next sibling statement node.
61 :rtype: NodeNG or None
62 """
63 stmts = self.parent.child_sequence(self)
64 index = stmts.index(self)
65 try:
66 return stmts[index + 1]
67 except IndexError:
68 return None
69
70 def previous_sibling(self):
71 """The previous sibling statement.
72
73 :returns: The previous sibling statement node.
74 :rtype: NodeNG or None
75 """
76 stmts = self.parent.child_sequence(self)
77 index = stmts.index(self)
78 if index >= 1:
79 return stmts[index - 1]
80 return None
81
82
83class NoChildrenNode(NodeNG):
84 """Base nodes for nodes with no children, e.g. Pass."""
85
86 def get_children(self) -> Iterator[NodeNG]:
87 yield from ()
88
89
90class FilterStmtsBaseNode(NodeNG):
91 """Base node for statement filtering and assignment type."""
92
93 def _get_filtered_stmts(self, _, node, _stmts, mystmt: Statement | None):
94 """Method used in _filter_stmts to get statements and trigger break."""
95 if self.statement() is mystmt:
96 # original node's statement is the assignment, only keep
97 # current node (gen exp, list comp)
98 return [node], True
99 return _stmts, False
100
101 def assign_type(self):
102 return self
103
104
105class AssignTypeNode(NodeNG):
106 """Base node for nodes that can 'assign' such as AnnAssign."""
107
108 def assign_type(self):
109 return self
110
111 def _get_filtered_stmts(self, lookup_node, node, _stmts, mystmt: Statement | None):
112 """Method used in filter_stmts."""
113 if self is mystmt:
114 return _stmts, True
115 if self.statement() is mystmt:
116 # original node's statement is the assignment, only keep
117 # current node (gen exp, list comp)
118 return [node], True
119 return _stmts, False
120
121
122class ParentAssignNode(AssignTypeNode):
123 """Base node for nodes whose assign_type is determined by the parent node."""
124
125 def assign_type(self):
126 return self.parent.assign_type()
127
128
129class ImportNode(FilterStmtsBaseNode, NoChildrenNode, Statement):
130 """Base node for From and Import Nodes."""
131
132 modname: str | None
133 """The module that is being imported from.
134
135 This is ``None`` for relative imports.
136 """
137
138 names: list[tuple[str, str | None]]
139 """What is being imported from the module.
140
141 Each entry is a :class:`tuple` of the name being imported,
142 and the alias that the name is assigned to (if any).
143 """
144
145 def _infer_name(self, frame, name):
146 return name
147
148 def do_import_module(self, modname: str | None = None) -> nodes.Module:
149 """Return the ast for a module whose name is <modname> imported by <self>."""
150 mymodule = self.root()
151 level: int | None = getattr(self, "level", None) # Import has no level
152 if modname is None:
153 modname = self.modname
154 # If the module ImportNode is importing is a module with the same name
155 # as the file that contains the ImportNode we don't want to use the cache
156 # to make sure we use the import system to get the correct module.
157 if (
158 modname
159 # pylint: disable-next=no-member # pylint doesn't recognize type of mymodule
160 and mymodule.relative_to_absolute_name(modname, level) == mymodule.name
161 ):
162 use_cache = False
163 else:
164 use_cache = True
165
166 # pylint: disable-next=no-member # pylint doesn't recognize type of mymodule
167 return mymodule.import_module(
168 modname,
169 level=level,
170 relative_only=bool(level and level >= 1),
171 use_cache=use_cache,
172 )
173
174 def real_name(self, asname: str) -> str:
175 """Get name from 'as' name."""
176 for name, _asname in self.names:
177 if name == "*":
178 return asname
179 if not _asname:
180 name = name.split(".", 1)[0]
181 _asname = name
182 if asname == _asname:
183 return name
184 raise AttributeInferenceError(
185 "Could not find original name for {attribute} in {target!r}",
186 target=self,
187 attribute=asname,
188 )
189
190
191class MultiLineBlockNode(NodeNG):
192 """Base node for multi-line blocks, e.g. For and FunctionDef.
193
194 Note that this does not apply to every node with a `body` field.
195 For instance, an If node has a multi-line body, but the body of an
196 IfExpr is not multi-line, and hence cannot contain Return nodes,
197 Assign nodes, etc.
198 """
199
200 _multi_line_block_fields: ClassVar[tuple[str, ...]] = ()
201
202 @cached_property
203 def _multi_line_blocks(self):
204 return tuple(getattr(self, field) for field in self._multi_line_block_fields)
205
206 def _get_return_nodes_skip_functions(self):
207 for block in self._multi_line_blocks:
208 for child_node in block:
209 if child_node.is_function:
210 continue
211 yield from child_node._get_return_nodes_skip_functions()
212
213 def _get_yield_nodes_skip_functions(self):
214 for block in self._multi_line_blocks:
215 for child_node in block:
216 if child_node.is_function:
217 continue
218 yield from child_node._get_yield_nodes_skip_functions()
219
220 def _get_yield_nodes_skip_lambdas(self):
221 for block in self._multi_line_blocks:
222 for child_node in block:
223 if child_node.is_lambda:
224 continue
225 yield from child_node._get_yield_nodes_skip_lambdas()
226
227 @cached_property
228 def _assign_nodes_in_scope(self) -> list[nodes.Assign]:
229 children_assign_nodes = (
230 child_node._assign_nodes_in_scope
231 for block in self._multi_line_blocks
232 for child_node in block
233 )
234 return list(itertools.chain.from_iterable(children_assign_nodes))
235
236
237class MultiLineWithElseBlockNode(MultiLineBlockNode):
238 """Base node for multi-line blocks that can have else statements."""
239
240 body: list[NodeNG]
241 """The contents of the block."""
242
243 orelse: list[NodeNG]
244 """The contents of the ``else`` block."""
245
246 @cached_property
247 def blockstart_tolineno(self):
248 return self.lineno
249
250 def block_range(self, lineno: int) -> tuple[int, int]:
251 """Get a range from the given line number to where this node ends.
252
253 :param lineno: The line number to start the range at.
254
255 :returns: The range of line numbers that this node belongs to,
256 starting at the given line number.
257 """
258 if lineno < self.fromlineno:
259 return lineno, self.tolineno
260 if lineno == self.body[0].fromlineno:
261 return lineno, lineno
262 if lineno <= self.body[-1].tolineno:
263 return lineno, self.body[-1].tolineno
264 return self._elsed_block_range(lineno, self.orelse, self.body[0].fromlineno - 1)
265
266 def _elsed_block_range(
267 self, lineno: int, orelse: list[nodes.NodeNG], last: int | None = None
268 ) -> tuple[int, int]:
269 """Handle block line numbers range for try/finally, for, if and while
270 statements.
271 """
272 # If at the end of the node, return same line
273 if lineno == self.tolineno:
274 return lineno, lineno
275 if orelse:
276 # If the lineno is beyond the body of the node we check the orelse
277 if lineno >= self.body[-1].tolineno + 1:
278 # If the orelse has a scope of its own we determine the block range there
279 if isinstance(orelse[0], MultiLineWithElseBlockNode):
280 return orelse[0]._elsed_block_range(lineno, orelse[0].orelse)
281 # Return last line of orelse
282 return lineno, orelse[-1].tolineno
283 # If the lineno is within the body we take the last line of the body
284 return lineno, self.body[-1].tolineno
285 return lineno, last or self.tolineno
286
287
288class LookupMixIn(NodeNG):
289 """Mixin to look up a name in the right scope."""
290
291 @lru_cache # noqa
292 def lookup(self, name: str) -> tuple[LocalsDictNodeNG, list[NodeNG]]:
293 """Lookup where the given variable is assigned.
294
295 The lookup starts from self's scope. If self is not a frame itself
296 and the name is found in the inner frame locals, statements will be
297 filtered to remove ignorable statements according to self's location.
298
299 :param name: The name of the variable to find assignments for.
300
301 :returns: The scope node and the list of assignments associated to the
302 given name according to the scope where it has been found (locals,
303 globals or builtin).
304 """
305 return self.scope().scope_lookup(self, name)
306
307 def ilookup(self, name):
308 """Lookup the inferred values of the given variable.
309
310 :param name: The variable name to find values for.
311 :type name: str
312
313 :returns: The inferred values of the statements returned from
314 :meth:`lookup`.
315 :rtype: iterable
316 """
317 frame, stmts = self.lookup(name)
318 context = InferenceContext()
319 return bases._infer_stmts(stmts, context, frame)
320
321
322def _reflected_name(name) -> str:
323 return "__r" + name[2:]
324
325
326def _augmented_name(name) -> str:
327 return "__i" + name[2:]
328
329
330BIN_OP_METHOD = {
331 "+": "__add__",
332 "-": "__sub__",
333 "/": "__truediv__",
334 "//": "__floordiv__",
335 "*": "__mul__",
336 "**": "__pow__",
337 "%": "__mod__",
338 "&": "__and__",
339 "|": "__or__",
340 "^": "__xor__",
341 "<<": "__lshift__",
342 ">>": "__rshift__",
343 "@": "__matmul__",
344}
345
346REFLECTED_BIN_OP_METHOD = {
347 key: _reflected_name(value) for (key, value) in BIN_OP_METHOD.items()
348}
349AUGMENTED_OP_METHOD = {
350 key + "=": _augmented_name(value) for (key, value) in BIN_OP_METHOD.items()
351}
352
353
354class OperatorNode(NodeNG):
355 @staticmethod
356 def _filter_operation_errors(
357 infer_callable: Callable[
358 [InferenceContext | None],
359 Generator[InferenceResult | util.BadOperationMessage],
360 ],
361 context: InferenceContext | None,
362 error: type[util.BadOperationMessage],
363 ) -> Generator[InferenceResult]:
364 for result in infer_callable(context):
365 if isinstance(result, error):
366 # For the sake of .infer(), we don't care about operation
367 # errors, which is the job of a linter. So return something
368 # which shows that we can't infer the result.
369 yield util.Uninferable
370 else:
371 yield result
372
373 @staticmethod
374 def _is_not_implemented(const) -> bool:
375 """Check if the given const node is NotImplemented."""
376 return isinstance(const, nodes.Const) and const.value is NotImplemented
377
378 @staticmethod
379 def _infer_old_style_string_formatting(
380 instance: nodes.Const, other: nodes.NodeNG, context: InferenceContext
381 ) -> tuple[util.UninferableBase | nodes.Const]:
382 """Infer the result of '"string" % ...'.
383
384 TODO: Instead of returning Uninferable we should rely
385 on the call to '%' to see if the result is actually uninferable.
386 """
387 if isinstance(other, nodes.Tuple):
388 if util.Uninferable in other.elts:
389 return (util.Uninferable,)
390 inferred_positional = [util.safe_infer(i, context) for i in other.elts]
391 if all(isinstance(i, nodes.Const) for i in inferred_positional):
392 values = tuple(i.value for i in inferred_positional)
393 else:
394 values = None
395 elif isinstance(other, nodes.Dict):
396 values: dict[Any, Any] = {}
397 for pair in other.items:
398 key = util.safe_infer(pair[0], context)
399 if not isinstance(key, nodes.Const):
400 return (util.Uninferable,)
401 value = util.safe_infer(pair[1], context)
402 if not isinstance(value, nodes.Const):
403 return (util.Uninferable,)
404 values[key.value] = value.value
405 elif isinstance(other, nodes.Const):
406 values = other.value
407 else:
408 return (util.Uninferable,)
409
410 try:
411 return (nodes.const_factory(instance.value % values),)
412 except (TypeError, KeyError, ValueError):
413 return (util.Uninferable,)
414
415 @staticmethod
416 def _invoke_binop_inference(
417 instance: InferenceResult,
418 opnode: nodes.AugAssign | nodes.BinOp,
419 op: str,
420 other: InferenceResult,
421 context: InferenceContext,
422 method_name: str,
423 ) -> Generator[InferenceResult]:
424 """Invoke binary operation inference on the given instance."""
425 methods = dunder_lookup.lookup(instance, method_name)
426 context = bind_context_to_node(context, instance)
427 method = methods[0]
428 context.callcontext.callee = method
429
430 if (
431 isinstance(instance, nodes.Const)
432 and isinstance(instance.value, str)
433 and op == "%"
434 ):
435 return iter(
436 OperatorNode._infer_old_style_string_formatting(
437 instance, other, context
438 )
439 )
440
441 try:
442 inferred = next(method.infer(context=context))
443 except StopIteration as e:
444 raise InferenceError(node=method, context=context) from e
445 if isinstance(inferred, util.UninferableBase):
446 raise InferenceError
447 if not isinstance(
448 instance,
449 (nodes.Const, nodes.Tuple, nodes.List, nodes.ClassDef, bases.Instance),
450 ):
451 raise InferenceError # pragma: no cover # Used as a failsafe
452 return instance.infer_binary_op(opnode, op, other, context, inferred)
453
454 @staticmethod
455 def _aug_op(
456 instance: InferenceResult,
457 opnode: nodes.AugAssign,
458 op: str,
459 other: InferenceResult,
460 context: InferenceContext,
461 reverse: bool = False,
462 ) -> partial[Generator[InferenceResult]]:
463 """Get an inference callable for an augmented binary operation."""
464 method_name = AUGMENTED_OP_METHOD[op]
465 return partial(
466 OperatorNode._invoke_binop_inference,
467 instance=instance,
468 op=op,
469 opnode=opnode,
470 other=other,
471 context=context,
472 method_name=method_name,
473 )
474
475 @staticmethod
476 def _bin_op(
477 instance: InferenceResult,
478 opnode: nodes.AugAssign | nodes.BinOp,
479 op: str,
480 other: InferenceResult,
481 context: InferenceContext,
482 reverse: bool = False,
483 ) -> partial[Generator[InferenceResult]]:
484 """Get an inference callable for a normal binary operation.
485
486 If *reverse* is True, then the reflected method will be used instead.
487 """
488 if reverse:
489 method_name = REFLECTED_BIN_OP_METHOD[op]
490 else:
491 method_name = BIN_OP_METHOD[op]
492 return partial(
493 OperatorNode._invoke_binop_inference,
494 instance=instance,
495 op=op,
496 opnode=opnode,
497 other=other,
498 context=context,
499 method_name=method_name,
500 )
501
502 @staticmethod
503 def _bin_op_or_union_type(
504 left: bases.UnionType | nodes.ClassDef | nodes.Const,
505 right: bases.UnionType | nodes.ClassDef | nodes.Const,
506 ) -> Generator[InferenceResult]:
507 """Create a new UnionType instance for binary or, e.g. int | str."""
508 yield bases.UnionType(left, right)
509
510 @staticmethod
511 def _get_binop_contexts(context, left, right):
512 """Get contexts for binary operations.
513
514 This will return two inference contexts, the first one
515 for x.__op__(y), the other one for y.__rop__(x), where
516 only the arguments are inversed.
517 """
518 # The order is important, since the first one should be
519 # left.__op__(right).
520 for arg in (right, left):
521 new_context = context.clone()
522 new_context.callcontext = CallContext(args=[arg])
523 new_context.boundnode = None
524 yield new_context
525
526 @staticmethod
527 def _same_type(type1, type2) -> bool:
528 """Check if type1 is the same as type2."""
529 return type1.qname() == type2.qname()
530
531 @staticmethod
532 def _get_aug_flow(
533 left: InferenceResult,
534 left_type: InferenceResult | None,
535 aug_opnode: nodes.AugAssign,
536 right: InferenceResult,
537 right_type: InferenceResult | None,
538 context: InferenceContext,
539 reverse_context: InferenceContext,
540 ) -> list[partial[Generator[InferenceResult]]]:
541 """Get the flow for augmented binary operations.
542
543 The rules are a bit messy:
544
545 * if left and right have the same type, then left.__augop__(right)
546 is first tried and then left.__op__(right).
547 * if left and right are unrelated typewise, then
548 left.__augop__(right) is tried, then left.__op__(right)
549 is tried and then right.__rop__(left) is tried.
550 * if left is a subtype of right, then left.__augop__(right)
551 is tried and then left.__op__(right).
552 * if left is a supertype of right, then left.__augop__(right)
553 is tried, then right.__rop__(left) and then
554 left.__op__(right)
555 """
556 from astroid import helpers # pylint: disable=import-outside-toplevel
557
558 bin_op = aug_opnode.op.strip("=")
559 aug_op = aug_opnode.op
560 if OperatorNode._same_type(left_type, right_type):
561 methods = [
562 OperatorNode._aug_op(left, aug_opnode, aug_op, right, context),
563 OperatorNode._bin_op(left, aug_opnode, bin_op, right, context),
564 ]
565 elif helpers.is_subtype(left_type, right_type):
566 methods = [
567 OperatorNode._aug_op(left, aug_opnode, aug_op, right, context),
568 OperatorNode._bin_op(left, aug_opnode, bin_op, right, context),
569 ]
570 elif helpers.is_supertype(left_type, right_type):
571 methods = [
572 OperatorNode._aug_op(left, aug_opnode, aug_op, right, context),
573 OperatorNode._bin_op(
574 right, aug_opnode, bin_op, left, reverse_context, reverse=True
575 ),
576 OperatorNode._bin_op(left, aug_opnode, bin_op, right, context),
577 ]
578 else:
579 methods = [
580 OperatorNode._aug_op(left, aug_opnode, aug_op, right, context),
581 OperatorNode._bin_op(left, aug_opnode, bin_op, right, context),
582 OperatorNode._bin_op(
583 right, aug_opnode, bin_op, left, reverse_context, reverse=True
584 ),
585 ]
586 return methods
587
588 @staticmethod
589 def _get_binop_flow(
590 left: InferenceResult,
591 left_type: InferenceResult | None,
592 binary_opnode: nodes.AugAssign | nodes.BinOp,
593 right: InferenceResult,
594 right_type: InferenceResult | None,
595 context: InferenceContext,
596 reverse_context: InferenceContext,
597 ) -> list[partial[Generator[InferenceResult]]]:
598 """Get the flow for binary operations.
599
600 The rules are a bit messy:
601
602 * if left and right have the same type, then only one
603 method will be called, left.__op__(right)
604 * if left and right are unrelated typewise, then first
605 left.__op__(right) is tried and if this does not exist
606 or returns NotImplemented, then right.__rop__(left) is tried.
607 * if left is a subtype of right, then only left.__op__(right)
608 is tried.
609 * if left is a supertype of right, then right.__rop__(left)
610 is first tried and then left.__op__(right)
611 """
612 from astroid import helpers # pylint: disable=import-outside-toplevel
613
614 op = binary_opnode.op
615 if OperatorNode._same_type(left_type, right_type):
616 methods = [OperatorNode._bin_op(left, binary_opnode, op, right, context)]
617 elif helpers.is_subtype(left_type, right_type):
618 methods = [OperatorNode._bin_op(left, binary_opnode, op, right, context)]
619 elif helpers.is_supertype(left_type, right_type):
620 methods = [
621 OperatorNode._bin_op(
622 right, binary_opnode, op, left, reverse_context, reverse=True
623 ),
624 OperatorNode._bin_op(left, binary_opnode, op, right, context),
625 ]
626 else:
627 methods = [
628 OperatorNode._bin_op(left, binary_opnode, op, right, context),
629 OperatorNode._bin_op(
630 right, binary_opnode, op, left, reverse_context, reverse=True
631 ),
632 ]
633
634 # pylint: disable = too-many-boolean-expressions
635 if (
636 op == "|"
637 and (
638 isinstance(left, (bases.UnionType, nodes.ClassDef))
639 or (isinstance(left, nodes.Const) and left.value is None)
640 )
641 and (
642 isinstance(right, (bases.UnionType, nodes.ClassDef))
643 or (isinstance(right, nodes.Const) and right.value is None)
644 )
645 ):
646 methods.extend([partial(OperatorNode._bin_op_or_union_type, left, right)])
647 return methods
648
649 @staticmethod
650 def _infer_binary_operation(
651 left: InferenceResult,
652 right: InferenceResult,
653 binary_opnode: nodes.AugAssign | nodes.BinOp,
654 context: InferenceContext,
655 flow_factory: GetFlowFactory,
656 ) -> Generator[InferenceResult | util.BadBinaryOperationMessage]:
657 """Infer a binary operation between a left operand and a right operand.
658
659 This is used by both normal binary operations and augmented binary
660 operations, the only difference is the flow factory used.
661 """
662 from astroid import helpers # pylint: disable=import-outside-toplevel
663
664 context, reverse_context = OperatorNode._get_binop_contexts(
665 context, left, right
666 )
667 left_type = helpers.object_type(left)
668 right_type = helpers.object_type(right)
669 methods = flow_factory(
670 left, left_type, binary_opnode, right, right_type, context, reverse_context
671 )
672 for method in methods:
673 try:
674 results = list(method())
675 except AttributeError:
676 continue
677 except AttributeInferenceError:
678 continue
679 except InferenceError:
680 yield util.Uninferable
681 return
682 else:
683 if any(isinstance(result, util.UninferableBase) for result in results):
684 yield util.Uninferable
685 return
686
687 if all(map(OperatorNode._is_not_implemented, results)):
688 continue
689 not_implemented = sum(
690 1 for result in results if OperatorNode._is_not_implemented(result)
691 )
692 if not_implemented and not_implemented != len(results):
693 # Can't infer yet what this is.
694 yield util.Uninferable
695 return
696
697 yield from results
698 return
699
700 # The operation doesn't seem to be supported so let the caller know about it
701 yield util.BadBinaryOperationMessage(left_type, binary_opnode.op, right_type)