1# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
2# For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE
3# Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt
4
5"""Classes representing different types of constraints on inference values."""
6from __future__ import annotations
7
8import sys
9from abc import ABC, abstractmethod
10from collections.abc import Iterator
11from typing import TYPE_CHECKING, Union
12
13from astroid import nodes, util
14from astroid.typing import InferenceResult
15
16if sys.version_info >= (3, 11):
17 from typing import Self
18else:
19 from typing_extensions import Self
20
21if TYPE_CHECKING:
22 from astroid import bases
23
24_NameNodes = Union[nodes.AssignAttr, nodes.Attribute, nodes.AssignName, nodes.Name]
25
26
27class Constraint(ABC):
28 """Represents a single constraint on a variable."""
29
30 def __init__(self, node: nodes.NodeNG, negate: bool) -> None:
31 self.node = node
32 """The node that this constraint applies to."""
33 self.negate = negate
34 """True if this constraint is negated. E.g., "is not" instead of "is"."""
35
36 @classmethod
37 @abstractmethod
38 def match(
39 cls, node: _NameNodes, expr: nodes.NodeNG, negate: bool = False
40 ) -> Self | None:
41 """Return a new constraint for node matched from expr, if expr matches
42 the constraint pattern.
43
44 If negate is True, negate the constraint.
45 """
46
47 @abstractmethod
48 def satisfied_by(self, inferred: InferenceResult) -> bool:
49 """Return True if this constraint is satisfied by the given inferred value."""
50
51
52class NoneConstraint(Constraint):
53 """Represents an "is None" or "is not None" constraint."""
54
55 CONST_NONE: nodes.Const = nodes.Const(None)
56
57 @classmethod
58 def match(
59 cls, node: _NameNodes, expr: nodes.NodeNG, negate: bool = False
60 ) -> Self | None:
61 """Return a new constraint for node matched from expr, if expr matches
62 the constraint pattern.
63
64 Negate the constraint based on the value of negate.
65 """
66 if isinstance(expr, nodes.Compare) and len(expr.ops) == 1:
67 left = expr.left
68 op, right = expr.ops[0]
69 if op in {"is", "is not"} and (
70 _matches(left, node) and _matches(right, cls.CONST_NONE)
71 ):
72 negate = (op == "is" and negate) or (op == "is not" and not negate)
73 return cls(node=node, negate=negate)
74
75 return None
76
77 def satisfied_by(self, inferred: InferenceResult) -> bool:
78 """Return True if this constraint is satisfied by the given inferred value."""
79 # Assume true if uninferable
80 if isinstance(inferred, util.UninferableBase):
81 return True
82
83 # Return the XOR of self.negate and matches(inferred, self.CONST_NONE)
84 return self.negate ^ _matches(inferred, self.CONST_NONE)
85
86
87def get_constraints(
88 expr: _NameNodes, frame: nodes.LocalsDictNodeNG
89) -> dict[nodes.If, set[Constraint]]:
90 """Returns the constraints for the given expression.
91
92 The returned dictionary maps the node where the constraint was generated to the
93 corresponding constraint(s).
94
95 Constraints are computed statically by analysing the code surrounding expr.
96 Currently this only supports constraints generated from if conditions.
97 """
98 current_node: nodes.NodeNG | None = expr
99 constraints_mapping: dict[nodes.If, set[Constraint]] = {}
100 while current_node is not None and current_node is not frame:
101 parent = current_node.parent
102 if isinstance(parent, nodes.If):
103 branch, _ = parent.locate_child(current_node)
104 constraints: set[Constraint] | None = None
105 if branch == "body":
106 constraints = set(_match_constraint(expr, parent.test))
107 elif branch == "orelse":
108 constraints = set(_match_constraint(expr, parent.test, invert=True))
109
110 if constraints:
111 constraints_mapping[parent] = constraints
112 current_node = parent
113
114 return constraints_mapping
115
116
117ALL_CONSTRAINT_CLASSES = frozenset((NoneConstraint,))
118"""All supported constraint types."""
119
120
121def _matches(node1: nodes.NodeNG | bases.Proxy, node2: nodes.NodeNG) -> bool:
122 """Returns True if the two nodes match."""
123 if isinstance(node1, nodes.Name) and isinstance(node2, nodes.Name):
124 return node1.name == node2.name
125 if isinstance(node1, nodes.Attribute) and isinstance(node2, nodes.Attribute):
126 return node1.attrname == node2.attrname and _matches(node1.expr, node2.expr)
127 if isinstance(node1, nodes.Const) and isinstance(node2, nodes.Const):
128 return node1.value == node2.value
129
130 return False
131
132
133def _match_constraint(
134 node: _NameNodes, expr: nodes.NodeNG, invert: bool = False
135) -> Iterator[Constraint]:
136 """Yields all constraint patterns for node that match."""
137 for constraint_cls in ALL_CONSTRAINT_CLASSES:
138 constraint = constraint_cls.match(node, expr, invert)
139 if constraint:
140 yield constraint