1# Copyright (c) Meta Platforms, Inc. and affiliates.
2#
3# This source code is licensed under the MIT license found in the
4# LICENSE file in the root directory of this source tree.
5
6
7import abc
8import builtins
9from collections import defaultdict
10from contextlib import contextmanager, ExitStack
11from dataclasses import dataclass
12from enum import auto, Enum
13from typing import (
14 Collection,
15 Dict,
16 Iterator,
17 List,
18 Mapping,
19 MutableMapping,
20 Optional,
21 Set,
22 Tuple,
23 Type,
24 Union,
25)
26
27import libcst as cst
28from libcst import ensure_type
29from libcst._add_slots import add_slots
30from libcst.helpers import get_full_name_for_node
31from libcst.metadata.base_provider import BatchableMetadataProvider
32from libcst.metadata.expression_context_provider import (
33 ExpressionContext,
34 ExpressionContextProvider,
35)
36
37# Comprehensions are handled separately in _visit_comp_alike due to
38# the complexity of the semantics
39_ASSIGNMENT_LIKE_NODES = (
40 cst.AnnAssign,
41 cst.AsName,
42 cst.Assign,
43 cst.AugAssign,
44 cst.ClassDef,
45 cst.CompFor,
46 cst.FunctionDef,
47 cst.Global,
48 cst.Import,
49 cst.ImportFrom,
50 cst.NamedExpr,
51 cst.Nonlocal,
52 cst.Parameters,
53 cst.WithItem,
54 cst.TypeVar,
55 cst.TypeAlias,
56 cst.TypeVarTuple,
57 cst.ParamSpec,
58)
59
60
61@add_slots
62@dataclass(frozen=False)
63class Access:
64 """
65 An Access records an access of an assignment.
66
67 .. note::
68 This scope analysis only analyzes access via a :class:`~libcst.Name` or a :class:`~libcst.Name`
69 node embedded in other node like :class:`~libcst.Call` or :class:`~libcst.Attribute`.
70 It doesn't support type annontation using :class:`~libcst.SimpleString` literal for forward
71 references. E.g. in this example, the ``"Tree"`` isn't parsed as an access::
72
73 class Tree:
74 def __new__(cls) -> "Tree":
75 ...
76 """
77
78 #: The node of the access. A name is an access when the expression context is
79 #: :attr:`ExpressionContext.LOAD`. This is usually the name node representing the
80 #: access, except for: 1) dotted imports, when it might be the attribute that
81 #: represents the most specific part of the imported symbol; and 2) string
82 #: annotations, when it is the entire string literal
83 node: Union[cst.Name, cst.Attribute, cst.BaseString]
84
85 #: The scope of the access. Note that a access could be in a child scope of its
86 #: assignment.
87 scope: "Scope"
88
89 is_annotation: bool
90
91 is_type_hint: bool
92
93 __assignments: Set["BaseAssignment"]
94 __index: int
95
96 def __init__(
97 self, node: cst.Name, scope: "Scope", is_annotation: bool, is_type_hint: bool
98 ) -> None:
99 self.node = node
100 self.scope = scope
101 self.is_annotation = is_annotation
102 self.is_type_hint = is_type_hint
103 self.__assignments = set()
104 self.__index = scope._assignment_count
105
106 def __hash__(self) -> int:
107 return id(self)
108
109 @property
110 def referents(self) -> Collection["BaseAssignment"]:
111 """Return all assignments of the access."""
112 return self.__assignments
113
114 @property
115 def _index(self) -> int:
116 return self.__index
117
118 def record_assignment(self, assignment: "BaseAssignment") -> None:
119 if assignment.scope != self.scope or assignment._index < self.__index:
120 self.__assignments.add(assignment)
121
122 def record_assignments(self, name: str) -> None:
123 assignments = self.scope._resolve_scope_for_access(name, self.scope)
124 # filter out assignments that happened later than this access
125 previous_assignments = {
126 assignment
127 for assignment in assignments
128 if assignment.scope != self.scope or assignment._index < self.__index
129 }
130 if not previous_assignments and assignments and self.scope.parent != self.scope:
131 previous_assignments = self.scope.parent._resolve_scope_for_access(
132 name, self.scope
133 )
134 self.__assignments |= previous_assignments
135
136
137class QualifiedNameSource(Enum):
138 IMPORT = auto()
139 BUILTIN = auto()
140 LOCAL = auto()
141
142
143@add_slots
144@dataclass(frozen=True)
145class QualifiedName:
146 #: Qualified name, e.g. ``a.b.c`` or ``fn.<locals>.var``.
147 name: str
148
149 #: Source of the name, either :attr:`QualifiedNameSource.IMPORT`, :attr:`QualifiedNameSource.BUILTIN`
150 #: or :attr:`QualifiedNameSource.LOCAL`.
151 source: QualifiedNameSource
152
153
154class BaseAssignment(abc.ABC):
155 """Abstract base class of :class:`Assignment` and :class:`BuitinAssignment`."""
156
157 #: The name of assignment.
158 name: str
159
160 #: The scope associates to assignment.
161 scope: "Scope"
162 __accesses: Set[Access]
163
164 def __init__(self, name: str, scope: "Scope") -> None:
165 self.name = name
166 self.scope = scope
167 self.__accesses = set()
168
169 def record_access(self, access: Access) -> None:
170 if access.scope != self.scope or self._index < access._index:
171 self.__accesses.add(access)
172
173 def record_accesses(self, accesses: Set[Access]) -> None:
174 later_accesses = {
175 access
176 for access in accesses
177 if access.scope != self.scope or self._index < access._index
178 }
179 self.__accesses |= later_accesses
180 earlier_accesses = accesses - later_accesses
181 if earlier_accesses and self.scope.parent != self.scope:
182 # Accesses "earlier" than the relevant assignment should be attached
183 # to assignments of the same name in the parent
184 for shadowed_assignment in self.scope.parent[self.name]:
185 shadowed_assignment.record_accesses(earlier_accesses)
186
187 @property
188 def references(self) -> Collection[Access]:
189 """Return all accesses of the assignment."""
190 # we don't want to publicly expose the mutable version of this
191 return self.__accesses
192
193 def __hash__(self) -> int:
194 return id(self)
195
196 @property
197 def _index(self) -> int:
198 """Return an integer that represents the order of assignments in `scope`"""
199 return -1
200
201 @abc.abstractmethod
202 def get_qualified_names_for(self, full_name: str) -> Set[QualifiedName]: ...
203
204
205class Assignment(BaseAssignment):
206 """An assignment records the name, CSTNode and its accesses."""
207
208 #: The node of assignment, it could be a :class:`~libcst.Import`, :class:`~libcst.ImportFrom`,
209 #: :class:`~libcst.Name`, :class:`~libcst.FunctionDef`, or :class:`~libcst.ClassDef`.
210 node: cst.CSTNode
211 __index: int
212
213 def __init__(
214 self, name: str, scope: "Scope", node: cst.CSTNode, index: int
215 ) -> None:
216 self.node = node
217 self.__index = index
218 super().__init__(name, scope)
219
220 @property
221 def _index(self) -> int:
222 return self.__index
223
224 def get_qualified_names_for(self, full_name: str) -> Set[QualifiedName]:
225 return {
226 QualifiedName(
227 (
228 f"{self.scope._name_prefix}.{full_name}"
229 if self.scope._name_prefix
230 else full_name
231 ),
232 QualifiedNameSource.LOCAL,
233 )
234 }
235
236
237# even though we don't override the constructor.
238class BuiltinAssignment(BaseAssignment):
239 """
240 A BuiltinAssignment represents an value provide by Python as a builtin, including
241 `functions <https://docs.python.org/3/library/functions.html>`_,
242 `constants <https://docs.python.org/3/library/constants.html>`_, and
243 `types <https://docs.python.org/3/library/stdtypes.html>`_.
244 """
245
246 def get_qualified_names_for(self, full_name: str) -> Set[QualifiedName]:
247 return {QualifiedName(f"builtins.{self.name}", QualifiedNameSource.BUILTIN)}
248
249
250class ImportAssignment(Assignment):
251 """An assignment records the import node and it's alias"""
252
253 as_name: cst.CSTNode
254
255 def __init__(
256 self,
257 name: str,
258 scope: "Scope",
259 node: cst.CSTNode,
260 index: int,
261 as_name: cst.CSTNode,
262 ) -> None:
263 super().__init__(name, scope, node, index)
264 self.as_name = as_name
265
266 def get_module_name_for_import(self) -> str:
267 module = ""
268 if isinstance(self.node, cst.ImportFrom):
269 module_attr = self.node.module
270 relative = self.node.relative
271 if module_attr:
272 module = get_full_name_for_node(module_attr) or ""
273 if relative:
274 module = "." * len(relative) + module
275 return module
276
277 def get_qualified_names_for(self, full_name: str) -> Set[QualifiedName]:
278 module = self.get_module_name_for_import()
279 results = set()
280 assert isinstance(self.node, (cst.ImportFrom, cst.Import))
281 import_names = self.node.names
282 if not isinstance(import_names, cst.ImportStar):
283 for name in import_names:
284 real_name = get_full_name_for_node(name.name)
285 if not real_name:
286 continue
287 # real_name can contain `.` for dotted imports
288 # for these we want to find the longest prefix that matches full_name
289 parts = real_name.split(".")
290 real_names = [".".join(parts[:i]) for i in range(len(parts), 0, -1)]
291 for real_name in real_names:
292 as_name = real_name
293 if module and module.endswith("."):
294 # from . import a
295 # real_name should be ".a"
296 real_name = f"{module}{real_name}"
297 elif module:
298 real_name = f"{module}.{real_name}"
299 if name and name.asname:
300 eval_alias = name.evaluated_alias
301 if eval_alias is not None:
302 as_name = eval_alias
303 if full_name.startswith(as_name):
304 remaining_name = full_name.split(as_name, 1)[1]
305 if remaining_name and not remaining_name.startswith("."):
306 continue
307 remaining_name = remaining_name.lstrip(".")
308 results.add(
309 QualifiedName(
310 (
311 f"{real_name}.{remaining_name}"
312 if remaining_name
313 else real_name
314 ),
315 QualifiedNameSource.IMPORT,
316 )
317 )
318 break
319 return results
320
321
322class Assignments:
323 """A container to provide all assignments in a scope."""
324
325 def __init__(self, assignments: Mapping[str, Collection[BaseAssignment]]) -> None:
326 self._assignments = assignments
327
328 def __iter__(self) -> Iterator[BaseAssignment]:
329 """Iterate through all assignments by ``for i in scope.assignments``."""
330 for assignments in self._assignments.values():
331 for assignment in assignments:
332 yield assignment
333
334 def __getitem__(self, node: Union[str, cst.CSTNode]) -> Collection[BaseAssignment]:
335 """Get assignments given a name str or :class:`~libcst.CSTNode` by ``scope.assignments[node]``"""
336 name = get_full_name_for_node(node)
337 return set(self._assignments[name]) if name in self._assignments else set()
338
339 def __contains__(self, node: Union[str, cst.CSTNode]) -> bool:
340 """Check if a name str or :class:`~libcst.CSTNode` has any assignment by ``node in scope.assignments``"""
341 return len(self[node]) > 0
342
343
344class Accesses:
345 """A container to provide all accesses in a scope."""
346
347 def __init__(self, accesses: Mapping[str, Collection[Access]]) -> None:
348 self._accesses = accesses
349
350 def __iter__(self) -> Iterator[Access]:
351 """Iterate through all accesses by ``for i in scope.accesses``."""
352 for accesses in self._accesses.values():
353 for access in accesses:
354 yield access
355
356 def __getitem__(self, node: Union[str, cst.CSTNode]) -> Collection[Access]:
357 """Get accesses given a name str or :class:`~libcst.CSTNode` by ``scope.accesses[node]``"""
358 name = get_full_name_for_node(node)
359 return self._accesses[name] if name in self._accesses else set()
360
361 def __contains__(self, node: Union[str, cst.CSTNode]) -> bool:
362 """Check if a name str or :class:`~libcst.CSTNode` has any access by ``node in scope.accesses``"""
363 return len(self[node]) > 0
364
365
366class Scope(abc.ABC):
367 """
368 Base class of all scope classes. Scope object stores assignments from imports,
369 variable assignments, function definition or class definition.
370 A scope has a parent scope which represents the inheritance relationship. That means
371 an assignment in parent scope is viewable to the child scope and the child scope may
372 overwrites the assignment by using the same name.
373
374 Use ``name in scope`` to check whether a name is viewable in the scope.
375 Use ``scope[name]`` to retrieve all viewable assignments in the scope.
376
377 .. note::
378 This scope analysis module only analyzes local variable names and it doesn't handle
379 attribute names; for example, given ``a.b.c = 1``, local variable name ``a`` is recorded
380 as an assignment instead of ``c`` or ``a.b.c``. To analyze the assignment/access of
381 arbitrary object attributes, we leave the job to type inference metadata provider
382 coming in the future.
383 """
384
385 #: Parent scope. Note the parent scope of a GlobalScope is itself.
386 parent: "Scope"
387
388 #: Refers to the GlobalScope.
389 globals: "GlobalScope"
390 _assignments: MutableMapping[str, Set[BaseAssignment]]
391 _assignment_count: int
392 _accesses_by_name: MutableMapping[str, Set[Access]]
393 _accesses_by_node: MutableMapping[cst.CSTNode, Set[Access]]
394 _name_prefix: str
395
396 def __init__(self, parent: "Scope") -> None:
397 super().__init__()
398 self.parent = parent
399 self.globals = parent.globals
400 self._assignments = defaultdict(set)
401 self._assignment_count = 0
402 self._accesses_by_name = defaultdict(set)
403 self._accesses_by_node = defaultdict(set)
404 self._name_prefix = ""
405
406 def record_assignment(self, name: str, node: cst.CSTNode) -> None:
407 target = self._find_assignment_target(name)
408 target._assignments[name].add(
409 Assignment(
410 name=name, scope=target, node=node, index=target._assignment_count
411 )
412 )
413
414 def record_import_assignment(
415 self, name: str, node: cst.CSTNode, as_name: cst.CSTNode
416 ) -> None:
417 target = self._find_assignment_target(name)
418 target._assignments[name].add(
419 ImportAssignment(
420 name=name,
421 scope=target,
422 node=node,
423 as_name=as_name,
424 index=target._assignment_count,
425 )
426 )
427
428 def _find_assignment_target(self, name: str) -> "Scope":
429 return self
430
431 def record_access(self, name: str, access: Access) -> None:
432 self._accesses_by_name[name].add(access)
433 self._accesses_by_node[access.node].add(access)
434
435 def _is_visible_from_children(self, from_scope: "Scope") -> bool:
436 """Returns if the assignments in this scope can be accessed from children.
437
438 This is normally True, except for class scopes::
439
440 def outer_fn():
441 v = ... # outer_fn's declaration
442 class InnerCls:
443 v = ... # shadows outer_fn's declaration
444 class InnerInnerCls:
445 v = ... # shadows all previous declarations of v
446 def inner_fn():
447 nonlocal v
448 v = ... # this refers to outer_fn's declaration
449 # and not to any of the inner classes' as those are
450 # hidden from their children.
451 """
452 return True
453
454 def _next_visible_parent(
455 self, from_scope: "Scope", first: Optional["Scope"] = None
456 ) -> "Scope":
457 parent = first if first is not None else self.parent
458 while not parent._is_visible_from_children(from_scope):
459 parent = parent.parent
460 return parent
461
462 @abc.abstractmethod
463 def __contains__(self, name: str) -> bool:
464 """Check if the name str exist in current scope by ``name in scope``."""
465 ...
466
467 def __getitem__(self, name: str) -> Set[BaseAssignment]:
468 """
469 Get assignments given a name str by ``scope[name]``.
470
471 .. note::
472 *Why does it return a list of assignments given a name instead of just one assignment?*
473
474 Many programming languages differentiate variable declaration and assignment.
475 Further, those programming languages often disallow duplicate declarations within
476 the same scope, and will often hoist the declaration (without its assignment) to
477 the top of the scope. These design decisions make static analysis much easier,
478 because it's possible to match a name against its single declaration for a given scope.
479
480 As an example, the following code would be valid in JavaScript::
481
482 function fn() {
483 console.log(value); // value is defined here, because the declaration is hoisted, but is currently 'undefined'.
484 var value = 5; // A function-scoped declaration.
485 }
486 fn(); // prints 'undefined'.
487
488 In contrast, Python's declaration and assignment are identical and are not hoisted::
489
490 if conditional_value:
491 value = 5
492 elif other_conditional_value:
493 value = 10
494 print(value) # possibly valid, depending on conditional execution
495
496 This code may throw a ``NameError`` if both conditional values are falsy.
497 It also means that depending on the codepath taken, the original declaration
498 could come from either ``value = ...`` assignment node.
499 As a result, instead of returning a single declaration,
500 we're forced to return a collection of all of the assignments we think could have
501 defined a given name by the time a piece of code is executed.
502 For the above example, value would resolve to a set of both assignments.
503 """
504 return self._resolve_scope_for_access(name, self)
505
506 @abc.abstractmethod
507 def _resolve_scope_for_access(
508 self, name: str, from_scope: "Scope"
509 ) -> Set[BaseAssignment]: ...
510
511 def __hash__(self) -> int:
512 return id(self)
513
514 @abc.abstractmethod
515 def record_global_overwrite(self, name: str) -> None: ...
516
517 @abc.abstractmethod
518 def record_nonlocal_overwrite(self, name: str) -> None: ...
519
520 def get_qualified_names_for(
521 self, node: Union[str, cst.CSTNode]
522 ) -> Collection[QualifiedName]:
523 """Get all :class:`~libcst.metadata.QualifiedName` in current scope given a
524 :class:`~libcst.CSTNode`.
525 The source of a qualified name can be either :attr:`QualifiedNameSource.IMPORT`,
526 :attr:`QualifiedNameSource.BUILTIN` or :attr:`QualifiedNameSource.LOCAL`.
527 Given the following example, ``c`` has qualified name ``a.b.c`` with source ``IMPORT``,
528 ``f`` has qualified name ``Cls.f`` with source ``LOCAL``, ``a`` has qualified name
529 ``Cls.f.<locals>.a``, ``i`` has qualified name ``Cls.f.<locals>.<comprehension>.i``,
530 and the builtin ``int`` has qualified name ``builtins.int`` with source ``BUILTIN``::
531
532 from a.b import c
533 class Cls:
534 def f(self) -> "c":
535 c()
536 a = int("1")
537 [i for i in c()]
538
539 We extends `PEP-3155 <https://www.python.org/dev/peps/pep-3155/>`_
540 (defines ``__qualname__`` for class and function only; function namespace is followed
541 by a ``<locals>``) to provide qualified name for all :class:`~libcst.CSTNode`
542 recorded by :class:`~libcst.metadata.Assignment` and :class:`~libcst.metadata.Access`.
543 The namespace of a comprehension (:class:`~libcst.ListComp`, :class:`~libcst.SetComp`,
544 :class:`~libcst.DictComp`) is represented with ``<comprehension>``.
545
546 An imported name may be used for type annotation with :class:`~libcst.SimpleString` and
547 currently resolving the qualified given :class:`~libcst.SimpleString` is not supported
548 considering it could be a complex type annotation in the string which is hard to
549 resolve, e.g. ``List[Union[int, str]]``.
550 """
551 # if this node is an access we know the assignment and we can use that name
552 node_accesses = (
553 self._accesses_by_node.get(node) if isinstance(node, cst.CSTNode) else None
554 )
555 if node_accesses:
556 return {
557 qname
558 for access in node_accesses
559 for referent in access.referents
560 for qname in referent.get_qualified_names_for(referent.name)
561 }
562
563 full_name = get_full_name_for_node(node)
564 if full_name is None:
565 return set()
566
567 assignments = set()
568 prefix = full_name
569 while prefix:
570 if prefix in self:
571 assignments = self[prefix]
572 break
573 idx = prefix.rfind(".")
574 prefix = None if idx == -1 else prefix[:idx]
575
576 if not isinstance(node, str):
577 for assignment in assignments:
578 if isinstance(assignment, Assignment) and _is_assignment(
579 node, assignment.node
580 ):
581 return assignment.get_qualified_names_for(full_name)
582
583 results = set()
584 for assignment in assignments:
585 results |= assignment.get_qualified_names_for(full_name)
586 return results
587
588 @property
589 def assignments(self) -> Assignments:
590 """Return an :class:`~libcst.metadata.Assignments` contains all assignmens in current scope."""
591 return Assignments(self._assignments)
592
593 @property
594 def accesses(self) -> Accesses:
595 """Return an :class:`~libcst.metadata.Accesses` contains all accesses in current scope."""
596 return Accesses(self._accesses_by_name)
597
598
599class BuiltinScope(Scope):
600 """
601 A BuiltinScope represents python builtin declarations. See https://docs.python.org/3/library/builtins.html
602 """
603
604 def __init__(self, globals: Scope) -> None:
605 self.globals: Scope = globals # must be defined before Scope.__init__ is called
606 super().__init__(parent=self)
607
608 def __contains__(self, name: str) -> bool:
609 return hasattr(builtins, name)
610
611 def _resolve_scope_for_access(
612 self, name: str, from_scope: "Scope"
613 ) -> Set[BaseAssignment]:
614 if name in self._assignments:
615 return self._assignments[name]
616 if hasattr(builtins, name):
617 # note - we only see the builtin assignments during the deferred
618 # access resolution. unfortunately that means we have to create the
619 # assignment here, which can cause the set to mutate during iteration
620 self._assignments[name].add(BuiltinAssignment(name, self))
621 return self._assignments[name]
622 return set()
623
624 def record_global_overwrite(self, name: str) -> None:
625 raise NotImplementedError("global overwrite in builtin scope are not allowed")
626
627 def record_nonlocal_overwrite(self, name: str) -> None:
628 raise NotImplementedError("declarations in builtin scope are not allowed")
629
630 def _find_assignment_target(self, name: str) -> "Scope":
631 raise NotImplementedError("assignments in builtin scope are not allowed")
632
633
634class GlobalScope(Scope):
635 """
636 A GlobalScope is the scope of module. All module level assignments are recorded in GlobalScope.
637 """
638
639 def __init__(self) -> None:
640 super().__init__(parent=BuiltinScope(self))
641
642 def __contains__(self, name: str) -> bool:
643 if name in self._assignments:
644 return len(self._assignments[name]) > 0
645 return name in self._next_visible_parent(self)
646
647 def _resolve_scope_for_access(
648 self, name: str, from_scope: "Scope"
649 ) -> Set[BaseAssignment]:
650 if name in self._assignments:
651 return self._assignments[name]
652
653 parent = self._next_visible_parent(from_scope)
654 return parent[name]
655
656 def record_global_overwrite(self, name: str) -> None:
657 pass
658
659 def record_nonlocal_overwrite(self, name: str) -> None:
660 raise NotImplementedError("nonlocal declaration not allowed at module level")
661
662
663class LocalScope(Scope, abc.ABC):
664 _scope_overwrites: Dict[str, Scope]
665
666 #: Name of function. Used as qualified name.
667 name: Optional[str]
668
669 #: The :class:`~libcst.CSTNode` node defines the current scope.
670 node: cst.CSTNode
671
672 def __init__(
673 self, parent: Scope, node: cst.CSTNode, name: Optional[str] = None
674 ) -> None:
675 super().__init__(parent)
676 self.name = name
677 self.node = node
678 self._scope_overwrites = {}
679 # pyre-fixme[4]: Attribute `_name_prefix` of class `LocalScope` has type `str` but no type is specified.
680 self._name_prefix = self._make_name_prefix()
681
682 def record_global_overwrite(self, name: str) -> None:
683 self._scope_overwrites[name] = self.globals
684
685 def record_nonlocal_overwrite(self, name: str) -> None:
686 self._scope_overwrites[name] = self.parent
687
688 def _find_assignment_target(self, name: str) -> "Scope":
689 if name in self._scope_overwrites:
690 scope = self._scope_overwrites[name]
691 return self._next_visible_parent(self, scope)._find_assignment_target(name)
692 else:
693 return super()._find_assignment_target(name)
694
695 def __contains__(self, name: str) -> bool:
696 if name in self._scope_overwrites:
697 return name in self._scope_overwrites[name]
698 if name in self._assignments:
699 return len(self._assignments[name]) > 0
700 return name in self._next_visible_parent(self)
701
702 def _resolve_scope_for_access(
703 self, name: str, from_scope: "Scope"
704 ) -> Set[BaseAssignment]:
705 if name in self._scope_overwrites:
706 scope = self._scope_overwrites[name]
707 return self._next_visible_parent(
708 from_scope, scope
709 )._resolve_scope_for_access(name, from_scope)
710 if name in self._assignments:
711 return self._assignments[name]
712 else:
713 return self._next_visible_parent(from_scope)._resolve_scope_for_access(
714 name, from_scope
715 )
716
717 def _make_name_prefix(self) -> str:
718 # filter falsey strings out
719 return ".".join(filter(None, [self.parent._name_prefix, self.name, "<locals>"]))
720
721
722# even though we don't override the constructor.
723class FunctionScope(LocalScope):
724 """
725 When a function is defined, it creates a FunctionScope.
726 """
727
728 pass
729
730
731# even though we don't override the constructor.
732class ClassScope(LocalScope):
733 """
734 When a class is defined, it creates a ClassScope.
735 """
736
737 def _is_visible_from_children(self, from_scope: "Scope") -> bool:
738 return from_scope.parent is self and isinstance(from_scope, AnnotationScope)
739
740 def _make_name_prefix(self) -> str:
741 # filter falsey strings out
742 return ".".join(filter(None, [self.parent._name_prefix, self.name]))
743
744
745# even though we don't override the constructor.
746class ComprehensionScope(LocalScope):
747 """
748 Comprehensions and generator expressions create their own scope. For example, in
749
750 [i for i in range(10)]
751
752 The variable ``i`` is only viewable within the ComprehensionScope.
753 """
754
755 # TODO: Assignment expressions (Python 3.8) will complicate ComprehensionScopes,
756 # and will require us to handle such assignments as non-local.
757 # https://www.python.org/dev/peps/pep-0572/#scope-of-the-target
758
759 def _make_name_prefix(self) -> str:
760 # filter falsey strings out
761 return ".".join(filter(None, [self.parent._name_prefix, "<comprehension>"]))
762
763
764class AnnotationScope(LocalScope):
765 """
766 Scopes used for type aliases and type parameters as defined by PEP-695.
767
768 These scopes are created for type parameters using the special syntax, as well as
769 type aliases. See https://peps.python.org/pep-0695/#scoping-behavior for more.
770 """
771
772 def _make_name_prefix(self) -> str:
773 # these scopes are transparent for the purposes of qualified names
774 return self.parent._name_prefix
775
776
777# Generates dotted names from an Attribute or Name node:
778# Attribute(value=Name(value="a"), attr=Name(value="b")) -> ("a.b", "a")
779# each string has the corresponding CSTNode attached to it
780def _gen_dotted_names(
781 node: Union[cst.Attribute, cst.Name],
782) -> Iterator[Tuple[str, Union[cst.Attribute, cst.Name]]]:
783 if isinstance(node, cst.Name):
784 yield node.value, node
785 else:
786 value = node.value
787 if isinstance(value, cst.Call):
788 value = value.func
789 if isinstance(value, (cst.Attribute, cst.Name)):
790 name_values = _gen_dotted_names(value)
791 try:
792 next_name, next_node = next(name_values)
793 except StopIteration:
794 return
795 else:
796 yield next_name, next_node
797 yield from name_values
798 elif isinstance(value, (cst.Attribute, cst.Name)):
799 name_values = _gen_dotted_names(value)
800 try:
801 next_name, next_node = next(name_values)
802 except StopIteration:
803 return
804 else:
805 yield f"{next_name}.{node.attr.value}", node
806 yield next_name, next_node
807 yield from name_values
808
809
810def _is_assignment(node: cst.CSTNode, assignment_node: cst.CSTNode) -> bool:
811 """
812 Returns true if ``node`` is part of the assignment at ``assignment_node``.
813
814 Normally this is just a simple identity check, except for imports where the
815 assignment is attached to the entire import statement but we are interested in
816 ``Name`` nodes inside the statement.
817 """
818 if node is assignment_node:
819 return True
820 if isinstance(assignment_node, (cst.Import, cst.ImportFrom)):
821 aliases = assignment_node.names
822 if isinstance(aliases, cst.ImportStar):
823 return False
824 for alias in aliases:
825 if alias.name is node:
826 return True
827 asname = alias.asname
828 if asname is not None:
829 if asname.name is node:
830 return True
831 return False
832
833
834@dataclass(frozen=True)
835class DeferredAccess:
836 access: Access
837 enclosing_attribute: Optional[cst.Attribute]
838 enclosing_string_annotation: Optional[cst.BaseString]
839
840
841class ScopeVisitor(cst.CSTVisitor):
842 # since it's probably not useful. That can makes this visitor cleaner.
843 def __init__(self, provider: "ScopeProvider") -> None:
844 super().__init__()
845 self.provider: ScopeProvider = provider
846 self.scope: Scope = GlobalScope()
847 self.__deferred_accesses: List[DeferredAccess] = []
848 self.__top_level_attribute_stack: List[Optional[cst.Attribute]] = [None]
849 self.__in_annotation_stack: List[bool] = [False]
850 self.__in_type_hint_stack: List[bool] = [False]
851 self.__in_ignored_subscript: Set[cst.Subscript] = set()
852 self.__last_string_annotation: Optional[cst.BaseString] = None
853 self.__ignore_annotation: int = 0
854
855 @contextmanager
856 def _new_scope(
857 self, kind: Type[LocalScope], node: cst.CSTNode, name: Optional[str] = None
858 ) -> Iterator[None]:
859 parent_scope = self.scope
860 self.scope = kind(parent_scope, node, name)
861 try:
862 yield
863 finally:
864 self.scope = parent_scope
865
866 @contextmanager
867 def _switch_scope(self, scope: Scope) -> Iterator[None]:
868 current_scope = self.scope
869 self.scope = scope
870 try:
871 yield
872 finally:
873 self.scope = current_scope
874
875 def _visit_import_alike(self, node: Union[cst.Import, cst.ImportFrom]) -> bool:
876 names = node.names
877 if isinstance(names, cst.ImportStar):
878 return False
879
880 # make sure node.names is Sequence[ImportAlias]
881 for name in names:
882 self.provider.set_metadata(name, self.scope)
883 asname = name.asname
884 if asname is not None:
885 name_values = _gen_dotted_names(cst.ensure_type(asname.name, cst.Name))
886 import_node_asname = asname.name
887 else:
888 name_values = _gen_dotted_names(name.name)
889 import_node_asname = name.name
890
891 for name_value, _ in name_values:
892 self.scope.record_import_assignment(
893 name_value, node, import_node_asname
894 )
895 return False
896
897 def visit_Import(self, node: cst.Import) -> Optional[bool]:
898 return self._visit_import_alike(node)
899
900 def visit_ImportFrom(self, node: cst.ImportFrom) -> Optional[bool]:
901 return self._visit_import_alike(node)
902
903 def visit_Attribute(self, node: cst.Attribute) -> Optional[bool]:
904 if self.__top_level_attribute_stack[-1] is None:
905 self.__top_level_attribute_stack[-1] = node
906 node.value.visit(self) # explicitly not visiting attr
907 if self.__top_level_attribute_stack[-1] is node:
908 self.__top_level_attribute_stack[-1] = None
909 return False
910
911 def visit_Call(self, node: cst.Call) -> Optional[bool]:
912 self.__top_level_attribute_stack.append(None)
913 self.__in_type_hint_stack.append(False)
914 qnames = {qn.name for qn in self.scope.get_qualified_names_for(node)}
915 if "typing.NewType" in qnames or "typing.TypeVar" in qnames:
916 node.func.visit(self)
917 self.__in_type_hint_stack[-1] = True
918 for arg in node.args[1:]:
919 arg.visit(self)
920 return False
921 if "typing.cast" in qnames:
922 node.func.visit(self)
923 if len(node.args) > 0:
924 self.__in_type_hint_stack.append(True)
925 node.args[0].visit(self)
926 self.__in_type_hint_stack.pop()
927 for arg in node.args[1:]:
928 arg.visit(self)
929 return False
930 return True
931
932 def leave_Call(self, original_node: cst.Call) -> None:
933 self.__top_level_attribute_stack.pop()
934 self.__in_type_hint_stack.pop()
935
936 def visit_Annotation(self, node: cst.Annotation) -> Optional[bool]:
937 self.__in_annotation_stack.append(True)
938
939 def leave_Annotation(self, original_node: cst.Annotation) -> None:
940 self.__in_annotation_stack.pop()
941
942 def visit_SimpleString(self, node: cst.SimpleString) -> Optional[bool]:
943 self._handle_string_annotation(node)
944 return False
945
946 def visit_ConcatenatedString(self, node: cst.ConcatenatedString) -> Optional[bool]:
947 return not self._handle_string_annotation(node)
948
949 def _handle_string_annotation(
950 self, node: Union[cst.SimpleString, cst.ConcatenatedString]
951 ) -> bool:
952 """Returns whether it successfully handled the string annotation"""
953 if (
954 self.__in_type_hint_stack[-1] or self.__in_annotation_stack[-1]
955 ) and not self.__in_ignored_subscript:
956 value = node.evaluated_value
957 if value:
958 top_level_annotation = self.__last_string_annotation is None
959 if top_level_annotation:
960 self.__last_string_annotation = node
961 try:
962 mod = cst.parse_module(value)
963 mod.visit(self)
964 except cst.ParserSyntaxError:
965 # swallow string annotation parsing errors
966 # this is the same behavior as cPython
967 pass
968 if top_level_annotation:
969 self.__last_string_annotation = None
970 return True
971 return False
972
973 def visit_Subscript(self, node: cst.Subscript) -> Optional[bool]:
974 in_type_hint = False
975 if isinstance(node.value, cst.Name):
976 qnames = {qn.name for qn in self.scope.get_qualified_names_for(node.value)}
977 if any(qn.startswith(("typing.", "typing_extensions.")) for qn in qnames):
978 in_type_hint = True
979 if "typing.Literal" in qnames or "typing_extensions.Literal" in qnames:
980 self.__in_ignored_subscript.add(node)
981
982 self.__in_type_hint_stack.append(in_type_hint)
983 return True
984
985 def leave_Subscript(self, original_node: cst.Subscript) -> None:
986 self.__in_type_hint_stack.pop()
987 self.__in_ignored_subscript.discard(original_node)
988
989 def visit_Name(self, node: cst.Name) -> Optional[bool]:
990 # not all Name have ExpressionContext
991 context = self.provider.get_metadata(ExpressionContextProvider, node, None)
992 if context == ExpressionContext.STORE:
993 self.scope.record_assignment(node.value, node)
994 elif context in (ExpressionContext.LOAD, ExpressionContext.DEL, None):
995 access = Access(
996 node,
997 self.scope,
998 is_annotation=bool(
999 self.__in_annotation_stack[-1] and not self.__ignore_annotation
1000 ),
1001 is_type_hint=bool(self.__in_type_hint_stack[-1]),
1002 )
1003 self.__deferred_accesses.append(
1004 DeferredAccess(
1005 access=access,
1006 enclosing_attribute=self.__top_level_attribute_stack[-1],
1007 enclosing_string_annotation=self.__last_string_annotation,
1008 )
1009 )
1010
1011 def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
1012 self.scope.record_assignment(node.name.value, node)
1013 self.provider.set_metadata(node.name, self.scope)
1014
1015 with ExitStack() as stack:
1016 if node.type_parameters:
1017 stack.enter_context(self._new_scope(AnnotationScope, node, None))
1018 node.type_parameters.visit(self)
1019
1020 with self._new_scope(
1021 FunctionScope, node, get_full_name_for_node(node.name)
1022 ):
1023 node.params.visit(self)
1024 node.body.visit(self)
1025
1026 for decorator in node.decorators:
1027 decorator.visit(self)
1028 returns = node.returns
1029 if returns:
1030 returns.visit(self)
1031
1032 return False
1033
1034 def visit_Lambda(self, node: cst.Lambda) -> Optional[bool]:
1035 with self._new_scope(FunctionScope, node):
1036 node.params.visit(self)
1037 node.body.visit(self)
1038 return False
1039
1040 def visit_Param(self, node: cst.Param) -> Optional[bool]:
1041 self.scope.record_assignment(node.name.value, node)
1042 self.provider.set_metadata(node.name, self.scope)
1043 with self._switch_scope(self.scope.parent):
1044 for field in [node.default, node.annotation]:
1045 if field:
1046 field.visit(self)
1047
1048 return False
1049
1050 def visit_Arg(self, node: cst.Arg) -> bool:
1051 # The keyword of Arg is neither an Assignment nor an Access and we explicitly don't visit it.
1052 value = node.value
1053 if value:
1054 value.visit(self)
1055 return False
1056
1057 def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]:
1058 self.scope.record_assignment(node.name.value, node)
1059 self.provider.set_metadata(node.name, self.scope)
1060 for decorator in node.decorators:
1061 decorator.visit(self)
1062
1063 with ExitStack() as stack:
1064 if node.type_parameters:
1065 stack.enter_context(self._new_scope(AnnotationScope, node, None))
1066 node.type_parameters.visit(self)
1067
1068 for base in node.bases:
1069 base.visit(self)
1070 for keyword in node.keywords:
1071 keyword.visit(self)
1072
1073 with self._new_scope(ClassScope, node, get_full_name_for_node(node.name)):
1074 for statement in node.body.body:
1075 statement.visit(self)
1076 return False
1077
1078 def visit_ClassDef_bases(self, node: cst.ClassDef) -> None:
1079 self.__ignore_annotation += 1
1080
1081 def leave_ClassDef_bases(self, node: cst.ClassDef) -> None:
1082 self.__ignore_annotation -= 1
1083
1084 def visit_Global(self, node: cst.Global) -> Optional[bool]:
1085 for name_item in node.names:
1086 self.scope.record_global_overwrite(name_item.name.value)
1087 return False
1088
1089 def visit_Nonlocal(self, node: cst.Nonlocal) -> Optional[bool]:
1090 for name_item in node.names:
1091 self.scope.record_nonlocal_overwrite(name_item.name.value)
1092 return False
1093
1094 def visit_ListComp(self, node: cst.ListComp) -> Optional[bool]:
1095 return self._visit_comp_alike(node)
1096
1097 def visit_SetComp(self, node: cst.SetComp) -> Optional[bool]:
1098 return self._visit_comp_alike(node)
1099
1100 def visit_DictComp(self, node: cst.DictComp) -> Optional[bool]:
1101 return self._visit_comp_alike(node)
1102
1103 def visit_GeneratorExp(self, node: cst.GeneratorExp) -> Optional[bool]:
1104 return self._visit_comp_alike(node)
1105
1106 def _visit_comp_alike(
1107 self, node: Union[cst.ListComp, cst.SetComp, cst.DictComp, cst.GeneratorExp]
1108 ) -> bool:
1109 """
1110 Cheat sheet: `[elt for target in iter if ifs]`
1111
1112 Terminology:
1113 target: The variable or pattern we're storing each element of the iter in.
1114 iter: The thing we're iterating over.
1115 ifs: A list of conditions provided
1116 elt: The value that will be computed and "yielded" each time the loop
1117 iterates. For most comprehensions, this is just the `node.elt`, but
1118 DictComp has `key` and `value`, which behave like `node.elt` would.
1119
1120
1121 Nested Comprehension: ``[a for b in c for a in b]`` is a "nested" ListComp.
1122 The outer iterator is in ``node.for_in`` and the inner iterator is in
1123 ``node.for_in.inner_for_in``.
1124
1125
1126 The first comprehension object's iter in generators is evaluated
1127 outside of the ComprehensionScope. Every other comprehension's iter is
1128 evaluated inside the ComprehensionScope. Even though that doesn't seem very sane,
1129 but that appears to be how it works.
1130
1131 non_flat = [ [1,2,3], [4,5,6], [7,8]
1132 flat = [y for x in non_flat for y in x] # this works fine
1133
1134 # This will give a "NameError: name 'x' is not defined":
1135 flat = [y for x in x for y in x]
1136 # x isn't defined, because the first iter is evaluted outside the scope.
1137
1138 # This will give an UnboundLocalError, indicating that the second
1139 # comprehension's iter value is evaluated inside the scope as its elt.
1140 # UnboundLocalError: local variable 'y' referenced before assignment
1141 flat = [y for x in non_flat for y in y]
1142 """
1143 for_in = node.for_in
1144 for_in.iter.visit(self)
1145 self.provider.set_metadata(for_in, self.scope)
1146 with self._new_scope(ComprehensionScope, node):
1147 for_in.target.visit(self)
1148 # Things from here on can refer to the target.
1149 self.scope._assignment_count += 1
1150 for condition in for_in.ifs:
1151 condition.visit(self)
1152 inner_for_in = for_in.inner_for_in
1153 if inner_for_in:
1154 inner_for_in.visit(self)
1155 if isinstance(node, cst.DictComp):
1156 node.key.visit(self)
1157 node.value.visit(self)
1158 else:
1159 node.elt.visit(self)
1160 return False
1161
1162 def visit_For(self, node: cst.For) -> Optional[bool]:
1163 node.target.visit(self)
1164 self.scope._assignment_count += 1
1165 for child in [node.iter, node.body, node.orelse, node.asynchronous]:
1166 if child is not None:
1167 child.visit(self)
1168 return False
1169
1170 def infer_accesses(self) -> None:
1171 # Aggregate access with the same name and batch add with set union as an optimization.
1172 # In worst case, all accesses (m) and assignments (n) refer to the same name,
1173 # the time complexity is O(m x n), this optimizes it as O(m + n).
1174 scope_name_accesses = defaultdict(set)
1175 for def_access in self.__deferred_accesses:
1176 access, enclosing_attribute, enclosing_string_annotation = (
1177 def_access.access,
1178 def_access.enclosing_attribute,
1179 def_access.enclosing_string_annotation,
1180 )
1181 name = ensure_type(access.node, cst.Name).value
1182 if enclosing_attribute is not None:
1183 # if _gen_dotted_names doesn't generate any values, fall back to
1184 # the original name node above
1185 for attr_name, node in _gen_dotted_names(enclosing_attribute):
1186 if attr_name in access.scope:
1187 access.node = node
1188 name = attr_name
1189 break
1190
1191 if enclosing_string_annotation is not None:
1192 access.node = enclosing_string_annotation
1193
1194 scope_name_accesses[(access.scope, name)].add(access)
1195 access.record_assignments(name)
1196 access.scope.record_access(name, access)
1197
1198 for (scope, name), accesses in scope_name_accesses.items():
1199 for assignment in scope._resolve_scope_for_access(name, scope):
1200 assignment.record_accesses(accesses)
1201
1202 self.__deferred_accesses = []
1203
1204 def on_leave(self, original_node: cst.CSTNode) -> None:
1205 self.provider.set_metadata(original_node, self.scope)
1206 if isinstance(original_node, _ASSIGNMENT_LIKE_NODES):
1207 self.scope._assignment_count += 1
1208 super().on_leave(original_node)
1209
1210 def visit_TypeAlias(self, node: cst.TypeAlias) -> Optional[bool]:
1211 self.scope.record_assignment(node.name.value, node)
1212
1213 with self._new_scope(AnnotationScope, node, None):
1214 if node.type_parameters is not None:
1215 node.type_parameters.visit(self)
1216 node.value.visit(self)
1217
1218 return False
1219
1220 def visit_TypeVar(self, node: cst.TypeVar) -> Optional[bool]:
1221 self.scope.record_assignment(node.name.value, node)
1222
1223 if node.bound is not None:
1224 node.bound.visit(self)
1225
1226 return False
1227
1228 def visit_TypeVarTuple(self, node: cst.TypeVarTuple) -> Optional[bool]:
1229 self.scope.record_assignment(node.name.value, node)
1230 return False
1231
1232 def visit_ParamSpec(self, node: cst.ParamSpec) -> Optional[bool]:
1233 self.scope.record_assignment(node.name.value, node)
1234 return False
1235
1236
1237class ScopeProvider(BatchableMetadataProvider[Optional[Scope]]):
1238 """
1239 :class:`ScopeProvider` traverses the entire module and creates the scope inheritance
1240 structure. It provides the scope of name assignment and accesses. It is useful for
1241 more advanced static analysis. E.g. given a :class:`~libcst.FunctionDef`
1242 node, we can check the type of its Scope to figure out whether it is a class method
1243 (:class:`ClassScope`) or a regular function (:class:`GlobalScope`).
1244
1245 Scope metadata is available for most node types other than formatting information nodes
1246 (whitespace, parentheses, etc.).
1247 """
1248
1249 METADATA_DEPENDENCIES = (ExpressionContextProvider,)
1250
1251 def visit_Module(self, node: cst.Module) -> Optional[bool]:
1252 visitor = ScopeVisitor(self)
1253 node.visit(visitor)
1254 visitor.infer_accesses()