1# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
2# For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE
3# Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt
4
5"""This module contains some base nodes that can be inherited for the different nodes.
6
7Previously these were called Mixin nodes.
8"""
9
10from __future__ import annotations
11
12import itertools
13from collections.abc import Callable, Generator, Iterator
14from functools import cached_property, lru_cache, partial
15from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
16
17from astroid import bases, nodes, util
18from astroid.const import PY310_PLUS
19from astroid.context import (
20 CallContext,
21 InferenceContext,
22 bind_context_to_node,
23)
24from astroid.exceptions import (
25 AttributeInferenceError,
26 InferenceError,
27)
28from astroid.interpreter import dunder_lookup
29from astroid.nodes.node_ng import NodeNG
30from astroid.typing import InferenceResult
31
32if TYPE_CHECKING:
33 from astroid.nodes.node_classes import LocalsDictNodeNG
34
35 GetFlowFactory = Callable[
36 [
37 InferenceResult,
38 Optional[InferenceResult],
39 Union[nodes.AugAssign, nodes.BinOp],
40 InferenceResult,
41 Optional[InferenceResult],
42 InferenceContext,
43 InferenceContext,
44 ],
45 list[partial[Generator[InferenceResult]]],
46 ]
47
48
49class Statement(NodeNG):
50 """Statement node adding a few attributes.
51
52 NOTE: This class is part of the public API of 'astroid.nodes'.
53 """
54
55 is_statement = True
56 """Whether this node indicates a statement."""
57
58 def next_sibling(self):
59 """The next sibling statement node.
60
61 :returns: The next sibling statement node.
62 :rtype: NodeNG or None
63 """
64 stmts = self.parent.child_sequence(self)
65 index = stmts.index(self)
66 try:
67 return stmts[index + 1]
68 except IndexError:
69 return None
70
71 def previous_sibling(self):
72 """The previous sibling statement.
73
74 :returns: The previous sibling statement node.
75 :rtype: NodeNG or None
76 """
77 stmts = self.parent.child_sequence(self)
78 index = stmts.index(self)
79 if index >= 1:
80 return stmts[index - 1]
81 return None
82
83
84class NoChildrenNode(NodeNG):
85 """Base nodes for nodes with no children, e.g. Pass."""
86
87 def get_children(self) -> Iterator[NodeNG]:
88 yield from ()
89
90
91class FilterStmtsBaseNode(NodeNG):
92 """Base node for statement filtering and assignment type."""
93
94 def _get_filtered_stmts(self, _, node, _stmts, mystmt: Statement | None):
95 """Method used in _filter_stmts to get statements and trigger break."""
96 if self.statement() is mystmt:
97 # original node's statement is the assignment, only keep
98 # current node (gen exp, list comp)
99 return [node], True
100 return _stmts, False
101
102 def assign_type(self):
103 return self
104
105
106class AssignTypeNode(NodeNG):
107 """Base node for nodes that can 'assign' such as AnnAssign."""
108
109 def assign_type(self):
110 return self
111
112 def _get_filtered_stmts(self, lookup_node, node, _stmts, mystmt: Statement | None):
113 """Method used in filter_stmts."""
114 if self is mystmt:
115 return _stmts, True
116 if self.statement() is mystmt:
117 # original node's statement is the assignment, only keep
118 # current node (gen exp, list comp)
119 return [node], True
120 return _stmts, False
121
122
123class ParentAssignNode(AssignTypeNode):
124 """Base node for nodes whose assign_type is determined by the parent node."""
125
126 def assign_type(self):
127 return self.parent.assign_type()
128
129
130class ImportNode(FilterStmtsBaseNode, NoChildrenNode, Statement):
131 """Base node for From and Import Nodes."""
132
133 modname: str | None
134 """The module that is being imported from.
135
136 This is ``None`` for relative imports.
137 """
138
139 names: list[tuple[str, str | None]]
140 """What is being imported from the module.
141
142 Each entry is a :class:`tuple` of the name being imported,
143 and the alias that the name is assigned to (if any).
144 """
145
146 def _infer_name(self, frame, name):
147 return name
148
149 def do_import_module(self, modname: str | None = None) -> nodes.Module:
150 """Return the ast for a module whose name is <modname> imported by <self>."""
151 mymodule = self.root()
152 level: int | None = getattr(self, "level", None) # Import has no level
153 if modname is None:
154 modname = self.modname
155 # If the module ImportNode is importing is a module with the same name
156 # as the file that contains the ImportNode we don't want to use the cache
157 # to make sure we use the import system to get the correct module.
158 if (
159 modname
160 # pylint: disable-next=no-member # pylint doesn't recognize type of mymodule
161 and mymodule.relative_to_absolute_name(modname, level) == mymodule.name
162 ):
163 use_cache = False
164 else:
165 use_cache = True
166
167 # pylint: disable-next=no-member # pylint doesn't recognize type of mymodule
168 return mymodule.import_module(
169 modname,
170 level=level,
171 relative_only=bool(level and level >= 1),
172 use_cache=use_cache,
173 )
174
175 def real_name(self, asname: str) -> str:
176 """Get name from 'as' name."""
177 for name, _asname in self.names:
178 if name == "*":
179 return asname
180 if not _asname:
181 name = name.split(".", 1)[0]
182 _asname = name
183 if asname == _asname:
184 return name
185 raise AttributeInferenceError(
186 "Could not find original name for {attribute} in {target!r}",
187 target=self,
188 attribute=asname,
189 )
190
191
192class MultiLineBlockNode(NodeNG):
193 """Base node for multi-line blocks, e.g. For and FunctionDef.
194
195 Note that this does not apply to every node with a `body` field.
196 For instance, an If node has a multi-line body, but the body of an
197 IfExpr is not multi-line, and hence cannot contain Return nodes,
198 Assign nodes, etc.
199 """
200
201 _multi_line_block_fields: ClassVar[tuple[str, ...]] = ()
202
203 @cached_property
204 def _multi_line_blocks(self):
205 return tuple(getattr(self, field) for field in self._multi_line_block_fields)
206
207 def _get_return_nodes_skip_functions(self):
208 for block in self._multi_line_blocks:
209 for child_node in block:
210 if child_node.is_function:
211 continue
212 yield from child_node._get_return_nodes_skip_functions()
213
214 def _get_yield_nodes_skip_functions(self):
215 for block in self._multi_line_blocks:
216 for child_node in block:
217 if child_node.is_function:
218 continue
219 yield from child_node._get_yield_nodes_skip_functions()
220
221 def _get_yield_nodes_skip_lambdas(self):
222 for block in self._multi_line_blocks:
223 for child_node in block:
224 if child_node.is_lambda:
225 continue
226 yield from child_node._get_yield_nodes_skip_lambdas()
227
228 @cached_property
229 def _assign_nodes_in_scope(self) -> list[nodes.Assign]:
230 children_assign_nodes = (
231 child_node._assign_nodes_in_scope
232 for block in self._multi_line_blocks
233 for child_node in block
234 )
235 return list(itertools.chain.from_iterable(children_assign_nodes))
236
237
238class MultiLineWithElseBlockNode(MultiLineBlockNode):
239 """Base node for multi-line blocks that can have else statements."""
240
241 @cached_property
242 def blockstart_tolineno(self):
243 return self.lineno
244
245 def _elsed_block_range(
246 self, lineno: int, orelse: list[nodes.NodeNG], last: int | None = None
247 ) -> tuple[int, int]:
248 """Handle block line numbers range for try/finally, for, if and while
249 statements.
250 """
251 if lineno == self.fromlineno:
252 return lineno, lineno
253 if orelse:
254 if lineno >= orelse[0].fromlineno:
255 return lineno, orelse[-1].tolineno
256 return lineno, orelse[0].fromlineno - 1
257 return lineno, last or self.tolineno
258
259
260class LookupMixIn(NodeNG):
261 """Mixin to look up a name in the right scope."""
262
263 @lru_cache # noqa
264 def lookup(self, name: str) -> tuple[LocalsDictNodeNG, list[NodeNG]]:
265 """Lookup where the given variable is assigned.
266
267 The lookup starts from self's scope. If self is not a frame itself
268 and the name is found in the inner frame locals, statements will be
269 filtered to remove ignorable statements according to self's location.
270
271 :param name: The name of the variable to find assignments for.
272
273 :returns: The scope node and the list of assignments associated to the
274 given name according to the scope where it has been found (locals,
275 globals or builtin).
276 """
277 return self.scope().scope_lookup(self, name)
278
279 def ilookup(self, name):
280 """Lookup the inferred values of the given variable.
281
282 :param name: The variable name to find values for.
283 :type name: str
284
285 :returns: The inferred values of the statements returned from
286 :meth:`lookup`.
287 :rtype: iterable
288 """
289 frame, stmts = self.lookup(name)
290 context = InferenceContext()
291 return bases._infer_stmts(stmts, context, frame)
292
293
294def _reflected_name(name) -> str:
295 return "__r" + name[2:]
296
297
298def _augmented_name(name) -> str:
299 return "__i" + name[2:]
300
301
302BIN_OP_METHOD = {
303 "+": "__add__",
304 "-": "__sub__",
305 "/": "__truediv__",
306 "//": "__floordiv__",
307 "*": "__mul__",
308 "**": "__pow__",
309 "%": "__mod__",
310 "&": "__and__",
311 "|": "__or__",
312 "^": "__xor__",
313 "<<": "__lshift__",
314 ">>": "__rshift__",
315 "@": "__matmul__",
316}
317
318REFLECTED_BIN_OP_METHOD = {
319 key: _reflected_name(value) for (key, value) in BIN_OP_METHOD.items()
320}
321AUGMENTED_OP_METHOD = {
322 key + "=": _augmented_name(value) for (key, value) in BIN_OP_METHOD.items()
323}
324
325
326class OperatorNode(NodeNG):
327 @staticmethod
328 def _filter_operation_errors(
329 infer_callable: Callable[
330 [InferenceContext | None],
331 Generator[InferenceResult | util.BadOperationMessage],
332 ],
333 context: InferenceContext | None,
334 error: type[util.BadOperationMessage],
335 ) -> Generator[InferenceResult]:
336 for result in infer_callable(context):
337 if isinstance(result, error):
338 # For the sake of .infer(), we don't care about operation
339 # errors, which is the job of a linter. So return something
340 # which shows that we can't infer the result.
341 yield util.Uninferable
342 else:
343 yield result
344
345 @staticmethod
346 def _is_not_implemented(const) -> bool:
347 """Check if the given const node is NotImplemented."""
348 return isinstance(const, nodes.Const) and const.value is NotImplemented
349
350 @staticmethod
351 def _infer_old_style_string_formatting(
352 instance: nodes.Const, other: nodes.NodeNG, context: InferenceContext
353 ) -> tuple[util.UninferableBase | nodes.Const]:
354 """Infer the result of '"string" % ...'.
355
356 TODO: Instead of returning Uninferable we should rely
357 on the call to '%' to see if the result is actually uninferable.
358 """
359 if isinstance(other, nodes.Tuple):
360 if util.Uninferable in other.elts:
361 return (util.Uninferable,)
362 inferred_positional = [util.safe_infer(i, context) for i in other.elts]
363 if all(isinstance(i, nodes.Const) for i in inferred_positional):
364 values = tuple(i.value for i in inferred_positional)
365 else:
366 values = None
367 elif isinstance(other, nodes.Dict):
368 values: dict[Any, Any] = {}
369 for pair in other.items:
370 key = util.safe_infer(pair[0], context)
371 if not isinstance(key, nodes.Const):
372 return (util.Uninferable,)
373 value = util.safe_infer(pair[1], context)
374 if not isinstance(value, nodes.Const):
375 return (util.Uninferable,)
376 values[key.value] = value.value
377 elif isinstance(other, nodes.Const):
378 values = other.value
379 else:
380 return (util.Uninferable,)
381
382 try:
383 return (nodes.const_factory(instance.value % values),)
384 except (TypeError, KeyError, ValueError):
385 return (util.Uninferable,)
386
387 @staticmethod
388 def _invoke_binop_inference(
389 instance: InferenceResult,
390 opnode: nodes.AugAssign | nodes.BinOp,
391 op: str,
392 other: InferenceResult,
393 context: InferenceContext,
394 method_name: str,
395 ) -> Generator[InferenceResult]:
396 """Invoke binary operation inference on the given instance."""
397 methods = dunder_lookup.lookup(instance, method_name)
398 context = bind_context_to_node(context, instance)
399 method = methods[0]
400 context.callcontext.callee = method
401
402 if (
403 isinstance(instance, nodes.Const)
404 and isinstance(instance.value, str)
405 and op == "%"
406 ):
407 return iter(
408 OperatorNode._infer_old_style_string_formatting(
409 instance, other, context
410 )
411 )
412
413 try:
414 inferred = next(method.infer(context=context))
415 except StopIteration as e:
416 raise InferenceError(node=method, context=context) from e
417 if isinstance(inferred, util.UninferableBase):
418 raise InferenceError
419 if not isinstance(
420 instance,
421 (nodes.Const, nodes.Tuple, nodes.List, nodes.ClassDef, bases.Instance),
422 ):
423 raise InferenceError # pragma: no cover # Used as a failsafe
424 return instance.infer_binary_op(opnode, op, other, context, inferred)
425
426 @staticmethod
427 def _aug_op(
428 instance: InferenceResult,
429 opnode: nodes.AugAssign,
430 op: str,
431 other: InferenceResult,
432 context: InferenceContext,
433 reverse: bool = False,
434 ) -> partial[Generator[InferenceResult]]:
435 """Get an inference callable for an augmented binary operation."""
436 method_name = AUGMENTED_OP_METHOD[op]
437 return partial(
438 OperatorNode._invoke_binop_inference,
439 instance=instance,
440 op=op,
441 opnode=opnode,
442 other=other,
443 context=context,
444 method_name=method_name,
445 )
446
447 @staticmethod
448 def _bin_op(
449 instance: InferenceResult,
450 opnode: nodes.AugAssign | nodes.BinOp,
451 op: str,
452 other: InferenceResult,
453 context: InferenceContext,
454 reverse: bool = False,
455 ) -> partial[Generator[InferenceResult]]:
456 """Get an inference callable for a normal binary operation.
457
458 If *reverse* is True, then the reflected method will be used instead.
459 """
460 if reverse:
461 method_name = REFLECTED_BIN_OP_METHOD[op]
462 else:
463 method_name = BIN_OP_METHOD[op]
464 return partial(
465 OperatorNode._invoke_binop_inference,
466 instance=instance,
467 op=op,
468 opnode=opnode,
469 other=other,
470 context=context,
471 method_name=method_name,
472 )
473
474 @staticmethod
475 def _bin_op_or_union_type(
476 left: bases.UnionType | nodes.ClassDef | nodes.Const,
477 right: bases.UnionType | nodes.ClassDef | nodes.Const,
478 ) -> Generator[InferenceResult]:
479 """Create a new UnionType instance for binary or, e.g. int | str."""
480 yield bases.UnionType(left, right)
481
482 @staticmethod
483 def _get_binop_contexts(context, left, right):
484 """Get contexts for binary operations.
485
486 This will return two inference contexts, the first one
487 for x.__op__(y), the other one for y.__rop__(x), where
488 only the arguments are inversed.
489 """
490 # The order is important, since the first one should be
491 # left.__op__(right).
492 for arg in (right, left):
493 new_context = context.clone()
494 new_context.callcontext = CallContext(args=[arg])
495 new_context.boundnode = None
496 yield new_context
497
498 @staticmethod
499 def _same_type(type1, type2) -> bool:
500 """Check if type1 is the same as type2."""
501 return type1.qname() == type2.qname()
502
503 @staticmethod
504 def _get_aug_flow(
505 left: InferenceResult,
506 left_type: InferenceResult | None,
507 aug_opnode: nodes.AugAssign,
508 right: InferenceResult,
509 right_type: InferenceResult | None,
510 context: InferenceContext,
511 reverse_context: InferenceContext,
512 ) -> list[partial[Generator[InferenceResult]]]:
513 """Get the flow for augmented binary operations.
514
515 The rules are a bit messy:
516
517 * if left and right have the same type, then left.__augop__(right)
518 is first tried and then left.__op__(right).
519 * if left and right are unrelated typewise, then
520 left.__augop__(right) is tried, then left.__op__(right)
521 is tried and then right.__rop__(left) is tried.
522 * if left is a subtype of right, then left.__augop__(right)
523 is tried and then left.__op__(right).
524 * if left is a supertype of right, then left.__augop__(right)
525 is tried, then right.__rop__(left) and then
526 left.__op__(right)
527 """
528 from astroid import helpers # pylint: disable=import-outside-toplevel
529
530 bin_op = aug_opnode.op.strip("=")
531 aug_op = aug_opnode.op
532 if OperatorNode._same_type(left_type, right_type):
533 methods = [
534 OperatorNode._aug_op(left, aug_opnode, aug_op, right, context),
535 OperatorNode._bin_op(left, aug_opnode, bin_op, right, context),
536 ]
537 elif helpers.is_subtype(left_type, right_type):
538 methods = [
539 OperatorNode._aug_op(left, aug_opnode, aug_op, right, context),
540 OperatorNode._bin_op(left, aug_opnode, bin_op, right, context),
541 ]
542 elif helpers.is_supertype(left_type, right_type):
543 methods = [
544 OperatorNode._aug_op(left, aug_opnode, aug_op, right, context),
545 OperatorNode._bin_op(
546 right, aug_opnode, bin_op, left, reverse_context, reverse=True
547 ),
548 OperatorNode._bin_op(left, aug_opnode, bin_op, right, context),
549 ]
550 else:
551 methods = [
552 OperatorNode._aug_op(left, aug_opnode, aug_op, right, context),
553 OperatorNode._bin_op(left, aug_opnode, bin_op, right, context),
554 OperatorNode._bin_op(
555 right, aug_opnode, bin_op, left, reverse_context, reverse=True
556 ),
557 ]
558 return methods
559
560 @staticmethod
561 def _get_binop_flow(
562 left: InferenceResult,
563 left_type: InferenceResult | None,
564 binary_opnode: nodes.AugAssign | nodes.BinOp,
565 right: InferenceResult,
566 right_type: InferenceResult | None,
567 context: InferenceContext,
568 reverse_context: InferenceContext,
569 ) -> list[partial[Generator[InferenceResult]]]:
570 """Get the flow for binary operations.
571
572 The rules are a bit messy:
573
574 * if left and right have the same type, then only one
575 method will be called, left.__op__(right)
576 * if left and right are unrelated typewise, then first
577 left.__op__(right) is tried and if this does not exist
578 or returns NotImplemented, then right.__rop__(left) is tried.
579 * if left is a subtype of right, then only left.__op__(right)
580 is tried.
581 * if left is a supertype of right, then right.__rop__(left)
582 is first tried and then left.__op__(right)
583 """
584 from astroid import helpers # pylint: disable=import-outside-toplevel
585
586 op = binary_opnode.op
587 if OperatorNode._same_type(left_type, right_type):
588 methods = [OperatorNode._bin_op(left, binary_opnode, op, right, context)]
589 elif helpers.is_subtype(left_type, right_type):
590 methods = [OperatorNode._bin_op(left, binary_opnode, op, right, context)]
591 elif helpers.is_supertype(left_type, right_type):
592 methods = [
593 OperatorNode._bin_op(
594 right, binary_opnode, op, left, reverse_context, reverse=True
595 ),
596 OperatorNode._bin_op(left, binary_opnode, op, right, context),
597 ]
598 else:
599 methods = [
600 OperatorNode._bin_op(left, binary_opnode, op, right, context),
601 OperatorNode._bin_op(
602 right, binary_opnode, op, left, reverse_context, reverse=True
603 ),
604 ]
605
606 # pylint: disable = too-many-boolean-expressions
607 if (
608 PY310_PLUS
609 and op == "|"
610 and (
611 isinstance(left, (bases.UnionType, nodes.ClassDef))
612 or (isinstance(left, nodes.Const) and left.value is None)
613 )
614 and (
615 isinstance(right, (bases.UnionType, nodes.ClassDef))
616 or (isinstance(right, nodes.Const) and right.value is None)
617 )
618 ):
619 methods.extend([partial(OperatorNode._bin_op_or_union_type, left, right)])
620 return methods
621
622 @staticmethod
623 def _infer_binary_operation(
624 left: InferenceResult,
625 right: InferenceResult,
626 binary_opnode: nodes.AugAssign | nodes.BinOp,
627 context: InferenceContext,
628 flow_factory: GetFlowFactory,
629 ) -> Generator[InferenceResult | util.BadBinaryOperationMessage]:
630 """Infer a binary operation between a left operand and a right operand.
631
632 This is used by both normal binary operations and augmented binary
633 operations, the only difference is the flow factory used.
634 """
635 from astroid import helpers # pylint: disable=import-outside-toplevel
636
637 context, reverse_context = OperatorNode._get_binop_contexts(
638 context, left, right
639 )
640 left_type = helpers.object_type(left)
641 right_type = helpers.object_type(right)
642 methods = flow_factory(
643 left, left_type, binary_opnode, right, right_type, context, reverse_context
644 )
645 for method in methods:
646 try:
647 results = list(method())
648 except AttributeError:
649 continue
650 except AttributeInferenceError:
651 continue
652 except InferenceError:
653 yield util.Uninferable
654 return
655 else:
656 if any(isinstance(result, util.UninferableBase) for result in results):
657 yield util.Uninferable
658 return
659
660 if all(map(OperatorNode._is_not_implemented, results)):
661 continue
662 not_implemented = sum(
663 1 for result in results if OperatorNode._is_not_implemented(result)
664 )
665 if not_implemented and not_implemented != len(results):
666 # Can't infer yet what this is.
667 yield util.Uninferable
668 return
669
670 yield from results
671 return
672
673 # The operation doesn't seem to be supported so let the caller know about it
674 yield util.BadBinaryOperationMessage(left_type, binary_opnode.op, right_type)