1# Copyright (c) Meta Platforms, Inc. and affiliates.
2#
3# This source code is licensed under the MIT license found in the
4# LICENSE file in the root directory of this source tree.
5
6from typing import cast, Collection, List, Set, Union
7
8import libcst as cst
9import libcst.matchers as m
10from libcst.codemod._context import CodemodContext
11from libcst.codemod._visitor import ContextAwareVisitor
12from libcst.metadata import MetadataWrapper, QualifiedNameProvider
13
14FUNCS_CONSIDERED_AS_STRING_ANNOTATIONS = {"typing.TypeVar"}
15
16
17class GatherNamesFromStringAnnotationsVisitor(ContextAwareVisitor):
18 """
19 Collects all names from string literals used for typing purposes.
20 This includes annotations like ``foo: "SomeType"``, and parameters to
21 special functions related to typing (currently only `typing.TypeVar`).
22
23 After visiting, a set of all found names will be available on the ``names``
24 attribute of this visitor.
25 """
26
27 METADATA_DEPENDENCIES = (QualifiedNameProvider,)
28
29 def __init__(
30 self,
31 context: CodemodContext,
32 typing_functions: Collection[str] = FUNCS_CONSIDERED_AS_STRING_ANNOTATIONS,
33 ) -> None:
34 super().__init__(context)
35 self._typing_functions: Collection[str] = typing_functions
36 self._annotation_stack: List[cst.CSTNode] = []
37 #: The set of names collected from string literals.
38 self.names: Set[str] = set()
39
40 def visit_Annotation(self, node: cst.Annotation) -> bool:
41 self._annotation_stack.append(node)
42 return True
43
44 def leave_Annotation(self, original_node: cst.Annotation) -> None:
45 self._annotation_stack.pop()
46
47 def visit_Subscript(self, node: cst.Subscript) -> bool:
48 qnames = self.get_metadata(QualifiedNameProvider, node)
49 # A Literal["foo"] should not be interpreted as a use of the symbol "foo".
50 return not any(qn.name == "typing.Literal" for qn in qnames)
51
52 def visit_Call(self, node: cst.Call) -> bool:
53 qnames = self.get_metadata(QualifiedNameProvider, node)
54 if any(qn.name in self._typing_functions for qn in qnames):
55 self._annotation_stack.append(node)
56 return True
57 return False
58
59 def leave_Call(self, original_node: cst.Call) -> None:
60 if self._annotation_stack and self._annotation_stack[-1] == original_node:
61 self._annotation_stack.pop()
62
63 def visit_ConcatenatedString(self, node: cst.ConcatenatedString) -> bool:
64 if self._annotation_stack:
65 self.handle_any_string(node)
66 return False
67
68 def visit_SimpleString(self, node: cst.SimpleString) -> bool:
69 if self._annotation_stack:
70 self.handle_any_string(node)
71 return False
72
73 def handle_any_string(
74 self, node: Union[cst.SimpleString, cst.ConcatenatedString]
75 ) -> None:
76 value = node.evaluated_value
77 if value is None:
78 return
79 try:
80 mod = cst.parse_module(value)
81 except cst.ParserSyntaxError:
82 # Not all strings inside a type annotation are meant to be valid Python code.
83 return
84 extracted_nodes = m.extractall(
85 mod,
86 m.Name(
87 value=m.SaveMatchedNode(m.DoNotCare(), "name"),
88 metadata=m.MatchMetadataIfTrue(
89 cst.metadata.ParentNodeProvider,
90 lambda parent: not isinstance(parent, cst.Attribute),
91 ),
92 )
93 | m.SaveMatchedNode(m.Attribute(), "attribute"),
94 metadata_resolver=MetadataWrapper(mod, unsafe_skip_copy=True),
95 )
96 names = {
97 cast(str, values["name"]) for values in extracted_nodes if "name" in values
98 } | {
99 name
100 for values in extracted_nodes
101 if "attribute" in values
102 for name, _ in cst.metadata.scope_provider._gen_dotted_names(
103 cast(cst.Attribute, values["attribute"])
104 )
105 }
106 self.names.update(names)