Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/libcst/codemod/visitors/_gather_string_annotation_names.py: 39%
44 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:43 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:43 +0000
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2#
3# This source code is licensed under the MIT license found in the
4# LICENSE file in the root directory of this source tree.
6from typing import cast, Collection, List, Set, Union
8import libcst as cst
9import libcst.matchers as m
10from libcst.codemod._context import CodemodContext
11from libcst.codemod._visitor import ContextAwareVisitor
12from libcst.metadata import MetadataWrapper, QualifiedNameProvider
14FUNCS_CONSIDERED_AS_STRING_ANNOTATIONS = {"typing.TypeVar"}
17class GatherNamesFromStringAnnotationsVisitor(ContextAwareVisitor):
18 """
19 Collects all names from string literals used for typing purposes.
20 This includes annotations like ``foo: "SomeType"``, and parameters to
21 special functions related to typing (currently only `typing.TypeVar`).
23 After visiting, a set of all found names will be available on the ``names``
24 attribute of this visitor.
25 """
27 METADATA_DEPENDENCIES = (QualifiedNameProvider,)
29 def __init__(
30 self,
31 context: CodemodContext,
32 typing_functions: Collection[str] = FUNCS_CONSIDERED_AS_STRING_ANNOTATIONS,
33 ) -> None:
34 super().__init__(context)
35 self._typing_functions: Collection[str] = typing_functions
36 self._annotation_stack: List[cst.CSTNode] = []
37 #: The set of names collected from string literals.
38 self.names: Set[str] = set()
40 def visit_Annotation(self, node: cst.Annotation) -> bool:
41 self._annotation_stack.append(node)
42 return True
44 def leave_Annotation(self, original_node: cst.Annotation) -> None:
45 self._annotation_stack.pop()
47 def visit_Call(self, node: cst.Call) -> bool:
48 qnames = self.get_metadata(QualifiedNameProvider, node)
49 if any(qn.name in self._typing_functions for qn in qnames):
50 self._annotation_stack.append(node)
51 return True
52 return False
54 def leave_Call(self, original_node: cst.Call) -> None:
55 if self._annotation_stack and self._annotation_stack[-1] == original_node:
56 self._annotation_stack.pop()
58 def visit_ConcatenatedString(self, node: cst.ConcatenatedString) -> bool:
59 if self._annotation_stack:
60 self.handle_any_string(node)
61 return False
63 def visit_SimpleString(self, node: cst.SimpleString) -> bool:
64 if self._annotation_stack:
65 self.handle_any_string(node)
66 return False
68 def handle_any_string(
69 self, node: Union[cst.SimpleString, cst.ConcatenatedString]
70 ) -> None:
71 value = node.evaluated_value
72 if value is None:
73 return
74 mod = cst.parse_module(value)
75 extracted_nodes = m.extractall(
76 mod,
77 m.Name(
78 value=m.SaveMatchedNode(m.DoNotCare(), "name"),
79 metadata=m.MatchMetadataIfTrue(
80 cst.metadata.ParentNodeProvider,
81 lambda parent: not isinstance(parent, cst.Attribute),
82 ),
83 )
84 | m.SaveMatchedNode(m.Attribute(), "attribute"),
85 metadata_resolver=MetadataWrapper(mod, unsafe_skip_copy=True),
86 )
87 names = {
88 cast(str, values["name"]) for values in extracted_nodes if "name" in values
89 } | {
90 name
91 for values in extracted_nodes
92 if "attribute" in values
93 for name, _ in cst.metadata.scope_provider._gen_dotted_names(
94 cast(cst.Attribute, values["attribute"])
95 )
96 }
97 self.names.update(names)