1# Copyright (c) Meta Platforms, Inc. and affiliates.
2#
3# This source code is licensed under the MIT license found in the
4# LICENSE file in the root directory of this source tree.
5#
6from typing import Dict, List, Sequence, Set, Tuple, Union
7
8import libcst
9from libcst.codemod._context import CodemodContext
10from libcst.codemod._visitor import ContextAwareVisitor
11from libcst.codemod.visitors._imports import ImportItem
12from libcst.helpers import get_absolute_module_from_package_for_import
13
14
15class _GatherImportsMixin(ContextAwareVisitor):
16 """
17 A Mixin class for tracking visited imports.
18 """
19
20 def __init__(self, context: CodemodContext) -> None:
21 super().__init__(context)
22 # Track the available imports in this transform
23 self.module_imports: Set[str] = set()
24 self.object_mapping: Dict[str, Set[str]] = {}
25 # Track the aliased imports in this transform
26 self.module_aliases: Dict[str, str] = {}
27 self.alias_mapping: Dict[str, List[Tuple[str, str]]] = {}
28 # Track the import for every symbol introduced into the module
29 self.symbol_mapping: Dict[str, ImportItem] = {}
30
31 def _handle_Import(self, node: libcst.Import) -> None:
32 for name in node.names:
33 alias = name.evaluated_alias
34 imp = ImportItem(name.evaluated_name, alias=alias)
35 if alias is not None:
36 # Track this as an aliased module
37 self.module_aliases[name.evaluated_name] = alias
38 self.symbol_mapping[alias] = imp
39 else:
40 # Get the module we're importing as a string.
41 self.module_imports.add(name.evaluated_name)
42 self.symbol_mapping[name.evaluated_name] = imp
43
44 def _handle_ImportFrom(self, node: libcst.ImportFrom) -> None:
45 # Get the module we're importing as a string.
46 module = get_absolute_module_from_package_for_import(
47 self.context.full_package_name, node
48 )
49 if module is None:
50 # Can't get the absolute import from relative, so we can't
51 # support this.
52 return
53 nodenames = node.names
54 if isinstance(nodenames, libcst.ImportStar):
55 # We cover everything, no need to bother tracking other things
56 self.object_mapping[module] = set("*")
57 return
58 elif isinstance(nodenames, Sequence):
59 # Get the list of imports we're aliasing in this import
60 new_aliases = [
61 (ia.evaluated_name, ia.evaluated_alias)
62 for ia in nodenames
63 if ia.asname is not None
64 ]
65 if new_aliases:
66 if module not in self.alias_mapping:
67 self.alias_mapping[module] = []
68 # pyre-ignore We know that aliases are not None here.
69 self.alias_mapping[module].extend(new_aliases)
70
71 # Get the list of imports we're importing in this import
72 new_objects = {ia.evaluated_name for ia in nodenames if ia.asname is None}
73 if new_objects:
74 if module not in self.object_mapping:
75 self.object_mapping[module] = set()
76
77 # Make sure that we don't add to a '*' module
78 if "*" in self.object_mapping[module]:
79 self.object_mapping[module] = set("*")
80 return
81
82 self.object_mapping[module].update(new_objects)
83 for ia in nodenames:
84 imp = ImportItem(
85 module, obj_name=ia.evaluated_name, alias=ia.evaluated_alias
86 )
87 key = ia.evaluated_alias or ia.evaluated_name
88 self.symbol_mapping[key] = imp
89
90
91class GatherImportsVisitor(_GatherImportsMixin):
92 """
93 Gathers all imports in a module and stores them as attributes on the instance.
94 Intended to be instantiated and passed to a :class:`~libcst.Module`
95 :meth:`~libcst.CSTNode.visit` method in order to gather up information about
96 imports on a module. Note that this is not a substitute for scope analysis or
97 qualified name support. Please see :ref:`libcst-scope-tutorial` for a more
98 robust way of determining the qualified name and definition for an arbitrary
99 node.
100
101 After visiting a module the following attributes will be populated:
102
103 module_imports
104 A sequence of strings representing modules that were imported directly, such as
105 in the case of ``import typing``. Each module directly imported but not aliased
106 will be included here.
107 object_mapping
108 A mapping of strings to sequences of strings representing modules where we
109 imported objects from, such as in the case of ``from typing import Optional``.
110 Each from import that was not aliased will be included here, where the keys of
111 the mapping are the module we are importing from, and the value is a
112 sequence of objects we are importing from the module.
113 module_aliases
114 A mapping of strings representing modules that were imported and aliased,
115 such as in the case of ``import typing as t``. Each module imported this
116 way will be represented as a key in this mapping, and the value will be
117 the local alias of the module.
118 alias_mapping
119 A mapping of strings to sequences of tuples representing modules where we
120 imported objects from and aliased using ``as`` syntax, such as in the case
121 of ``from typing import Optional as opt``. Each from import that was aliased
122 will be included here, where the keys of the mapping are the module we are
123 importing from, and the value is a tuple representing the original object
124 name and the alias.
125 all_imports
126 A collection of all :class:`~libcst.Import` and :class:`~libcst.ImportFrom`
127 statements that were encountered in the module.
128 """
129
130 def __init__(self, context: CodemodContext) -> None:
131 super().__init__(context)
132 # Track all of the imports found in this transform
133 self.all_imports: List[Union[libcst.Import, libcst.ImportFrom]] = []
134
135 def visit_Import(self, node: libcst.Import) -> None:
136 # Track this import statement for later analysis.
137 self.all_imports.append(node)
138 self._handle_Import(node)
139
140 def visit_ImportFrom(self, node: libcst.ImportFrom) -> None:
141 # Track this import statement for later analysis.
142 self.all_imports.append(node)
143 self._handle_ImportFrom(node)