Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/libcst/codemod/visitors/_add_imports.py: 19%
110 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:43 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:43 +0000
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2#
3# This source code is licensed under the MIT license found in the
4# LICENSE file in the root directory of this source tree.
6from collections import defaultdict
7from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
9import libcst
10from libcst import matchers as m, parse_statement
11from libcst.codemod._context import CodemodContext
12from libcst.codemod._visitor import ContextAwareTransformer
13from libcst.codemod.visitors._gather_imports import GatherImportsVisitor
14from libcst.codemod.visitors._imports import ImportItem
15from libcst.helpers import get_absolute_module_from_package_for_import
18class AddImportsVisitor(ContextAwareTransformer):
19 """
20 Ensures that given imports exist in a module. Given a
21 :class:`~libcst.codemod.CodemodContext` and a sequence of tuples specifying
22 a module to import from as a string. Optionally an object to import from
23 that module and any alias to assign that import, ensures that import exists.
24 It will modify existing imports as necessary if the module in question is
25 already being imported from.
27 This is one of the transforms that is available automatically to you when
28 running a codemod. To use it in this manner, import
29 :class:`~libcst.codemod.visitors.AddImportsVisitor` and then call the static
30 :meth:`~libcst.codemod.visitors.AddImportsVisitor.add_needed_import` method,
31 giving it the current context (found as ``self.context`` for all subclasses of
32 :class:`~libcst.codemod.Codemod`), the module you wish to import from and
33 optionally an object you wish to import from that module and any alias you
34 would like to assign that import to.
36 For example::
38 AddImportsVisitor.add_needed_import(self.context, "typing", "Optional")
40 This will produce the following code in a module, assuming there was no
41 typing import already::
43 from typing import Optional
45 As another example::
47 AddImportsVisitor.add_needed_import(self.context, "typing")
49 This will produce the following code in a module, assuming there was no
50 import already::
52 import typing
54 Note that this is a subclass of :class:`~libcst.CSTTransformer` so it is
55 possible to instantiate it and pass it to a :class:`~libcst.Module`
56 :meth:`~libcst.CSTNode.visit` method. However, it is far easier to use
57 the automatic transform feature of :class:`~libcst.codemod.CodemodCommand`
58 and schedule an import to be added by calling
59 :meth:`~libcst.codemod.visitors.AddImportsVisitor.add_needed_import`
60 """
62 CONTEXT_KEY = "AddImportsVisitor"
64 @staticmethod
65 def _get_imports_from_context(
66 context: CodemodContext,
67 ) -> List[ImportItem]:
68 imports = context.scratch.get(AddImportsVisitor.CONTEXT_KEY, [])
69 if not isinstance(imports, list):
70 raise Exception("Logic error!")
71 return imports
73 @staticmethod
74 def add_needed_import(
75 context: CodemodContext,
76 module: str,
77 obj: Optional[str] = None,
78 asname: Optional[str] = None,
79 relative: int = 0,
80 ) -> None:
81 """
82 Schedule an import to be added in a future invocation of this class by
83 updating the ``context`` to include the ``module`` and optionally ``obj``
84 to be imported as well as optionally ``alias`` to alias the imported
85 ``module`` or ``obj`` to. When subclassing from
86 :class:`~libcst.codemod.CodemodCommand`, this will be performed for you
87 after your transform finishes executing. If you are subclassing from a
88 :class:`~libcst.codemod.Codemod` instead, you will need to call the
89 :meth:`~libcst.codemod.Codemod.transform_module` method on the module
90 under modification with an instance of this class after performing your
91 transform. Note that if the particular ``module`` or ``obj`` you are
92 requesting to import already exists as an import on the current module
93 at the time of executing :meth:`~libcst.codemod.Codemod.transform_module`
94 on an instance of :class:`~libcst.codemod.visitors.AddImportsVisitor`,
95 this will perform no action in order to avoid adding duplicate imports.
96 """
98 if module == "__future__" and obj is None:
99 raise Exception("Cannot import __future__ directly!")
100 imports = AddImportsVisitor._get_imports_from_context(context)
101 imports.append(ImportItem(module, obj, asname, relative))
102 context.scratch[AddImportsVisitor.CONTEXT_KEY] = imports
104 def __init__(
105 self,
106 context: CodemodContext,
107 imports: Sequence[ImportItem] = (),
108 ) -> None:
109 # Allow for instantiation from either a context (used when multiple transforms
110 # get chained) or from a direct instantiation.
111 super().__init__(context)
112 imps: List[ImportItem] = [
113 *AddImportsVisitor._get_imports_from_context(context),
114 *imports,
115 ]
117 # Verify that the imports are valid
118 for imp in imps:
119 if imp.module == "__future__" and imp.obj_name is None:
120 raise Exception("Cannot import __future__ directly!")
121 if imp.module == "__future__" and imp.alias is not None:
122 raise Exception("Cannot import __future__ objects with aliases!")
124 # Resolve relative imports if we have a module name
125 imps = [imp.resolve_relative(self.context.full_package_name) for imp in imps]
127 # List of modules we need to ensure are imported
128 self.module_imports: Set[str] = {
129 imp.module for imp in imps if imp.obj_name is None and imp.alias is None
130 }
132 # List of modules we need to check for object imports on
133 from_imports: Set[str] = {
134 imp.module for imp in imps if imp.obj_name is not None and imp.alias is None
135 }
136 # Mapping of modules we're adding to the object they should import
137 self.module_mapping: Dict[str, Set[str]] = {
138 module: {
139 imp.obj_name
140 for imp in imps
141 if imp.module == module
142 and imp.obj_name is not None
143 and imp.alias is None
144 }
145 for module in sorted(from_imports)
146 }
148 # List of aliased modules we need to ensure are imported
149 self.module_aliases: Dict[str, str] = {
150 imp.module: imp.alias
151 for imp in imps
152 if imp.obj_name is None and imp.alias is not None
153 }
154 # List of modules we need to check for object imports on
155 from_imports_aliases: Set[str] = {
156 imp.module
157 for imp in imps
158 if imp.obj_name is not None and imp.alias is not None
159 }
160 # Mapping of modules we're adding to the object with alias they should import
161 self.alias_mapping: Dict[str, List[Tuple[str, str]]] = {
162 module: [
163 (imp.obj_name, imp.alias)
164 for imp in imps
165 if imp.module == module
166 and imp.obj_name is not None
167 and imp.alias is not None
168 ]
169 for module in sorted(from_imports_aliases)
170 }
172 # Track the list of imports found in the file
173 self.all_imports: List[Union[libcst.Import, libcst.ImportFrom]] = []
175 def visit_Module(self, node: libcst.Module) -> None:
176 # Do a preliminary pass to gather the imports we already have
177 gatherer = GatherImportsVisitor(self.context)
178 node.visit(gatherer)
179 self.all_imports = gatherer.all_imports
181 self.module_imports = self.module_imports - gatherer.module_imports
182 for module, alias in gatherer.module_aliases.items():
183 if module in self.module_aliases and self.module_aliases[module] == alias:
184 del self.module_aliases[module]
185 for module, aliases in gatherer.alias_mapping.items():
186 for obj, alias in aliases:
187 if (
188 module in self.alias_mapping
189 and (obj, alias) in self.alias_mapping[module]
190 ):
191 self.alias_mapping[module].remove((obj, alias))
192 if len(self.alias_mapping[module]) == 0:
193 del self.alias_mapping[module]
195 for module, imports in gatherer.object_mapping.items():
196 if module not in self.module_mapping:
197 # We don't care about this import at all
198 continue
199 elif "*" in imports:
200 # We already implicitly are importing everything
201 del self.module_mapping[module]
202 else:
203 # Lets figure out what's left to import
204 self.module_mapping[module] = self.module_mapping[module] - imports
205 if not self.module_mapping[module]:
206 # There's nothing left, so lets delete this work item
207 del self.module_mapping[module]
209 def leave_ImportFrom(
210 self, original_node: libcst.ImportFrom, updated_node: libcst.ImportFrom
211 ) -> libcst.ImportFrom:
212 if isinstance(updated_node.names, libcst.ImportStar):
213 # There's nothing to do here!
214 return updated_node
216 # Get the module we're importing as a string, see if we have work to do.
217 module = get_absolute_module_from_package_for_import(
218 self.context.full_package_name, updated_node
219 )
220 if (
221 module is None
222 or module not in self.module_mapping
223 and module not in self.alias_mapping
224 ):
225 return updated_node
227 # We have work to do, mark that we won't modify this again.
228 imports_to_add = self.module_mapping.get(module, [])
229 if module in self.module_mapping:
230 del self.module_mapping[module]
231 aliases_to_add = self.alias_mapping.get(module, [])
232 if module in self.alias_mapping:
233 del self.alias_mapping[module]
235 # Now, do the actual update.
236 return updated_node.with_changes(
237 names=[
238 *(
239 libcst.ImportAlias(name=libcst.Name(imp))
240 for imp in sorted(imports_to_add)
241 ),
242 *(
243 libcst.ImportAlias(
244 name=libcst.Name(imp),
245 asname=libcst.AsName(name=libcst.Name(alias)),
246 )
247 for (imp, alias) in sorted(aliases_to_add)
248 ),
249 *updated_node.names,
250 ]
251 )
253 def _split_module(
254 self, orig_module: libcst.Module, updated_module: libcst.Module
255 ) -> Tuple[
256 List[Union[libcst.SimpleStatementLine, libcst.BaseCompoundStatement]],
257 List[Union[libcst.SimpleStatementLine, libcst.BaseCompoundStatement]],
258 List[Union[libcst.SimpleStatementLine, libcst.BaseCompoundStatement]],
259 ]:
260 statement_before_import_location = 0
261 import_add_location = 0
263 # never insert an import before initial __strict__ flag
264 if m.matches(
265 orig_module,
266 m.Module(
267 body=[
268 m.SimpleStatementLine(
269 body=[
270 m.Assign(
271 targets=[m.AssignTarget(target=m.Name("__strict__"))]
272 )
273 ]
274 ),
275 m.ZeroOrMore(),
276 ]
277 ),
278 ):
279 statement_before_import_location = import_add_location = 1
281 # This works under the principle that while we might modify node contents,
282 # we have yet to modify the number of statements. So we can match on the
283 # original tree but break up the statements of the modified tree. If we
284 # change this assumption in this visitor, we will have to change this code.
285 for i, statement in enumerate(orig_module.body):
286 if i == 0 and m.matches(
287 statement, m.SimpleStatementLine(body=[m.Expr(value=m.SimpleString())])
288 ):
289 statement_before_import_location = import_add_location = 1
290 elif isinstance(statement, libcst.SimpleStatementLine):
291 for possible_import in statement.body:
292 for last_import in self.all_imports:
293 if possible_import is last_import:
294 import_add_location = i + 1
295 break
297 return (
298 list(updated_module.body[:statement_before_import_location]),
299 list(
300 updated_module.body[
301 statement_before_import_location:import_add_location
302 ]
303 ),
304 list(updated_module.body[import_add_location:]),
305 )
307 def _insert_empty_line(
308 self,
309 statements: List[
310 Union[libcst.SimpleStatementLine, libcst.BaseCompoundStatement]
311 ],
312 ) -> List[Union[libcst.SimpleStatementLine, libcst.BaseCompoundStatement]]:
313 if len(statements) < 1:
314 # No statements, nothing to add to
315 return statements
316 if len(statements[0].leading_lines) == 0:
317 # Statement has no leading lines, add one!
318 return [
319 statements[0].with_changes(leading_lines=(libcst.EmptyLine(),)),
320 *statements[1:],
321 ]
322 if statements[0].leading_lines[0].comment is None:
323 # First line is empty, so its safe to leave as-is
324 return statements
325 # Statement has a comment first line, so lets add one more empty line
326 return [
327 statements[0].with_changes(
328 leading_lines=(libcst.EmptyLine(), *statements[0].leading_lines)
329 ),
330 *statements[1:],
331 ]
333 def leave_Module(
334 self, original_node: libcst.Module, updated_node: libcst.Module
335 ) -> libcst.Module:
336 # Don't try to modify if we have nothing to do
337 if (
338 not self.module_imports
339 and not self.module_mapping
340 and not self.module_aliases
341 and not self.alias_mapping
342 ):
343 return updated_node
345 # First, find the insertion point for imports
346 (
347 statements_before_imports,
348 statements_until_add_imports,
349 statements_after_imports,
350 ) = self._split_module(original_node, updated_node)
352 # Make sure there's at least one empty line before the first non-import
353 statements_after_imports = self._insert_empty_line(statements_after_imports)
355 # Mapping of modules we're adding to the object with and without alias they should import
356 module_and_alias_mapping = defaultdict(list)
357 for module, aliases in self.alias_mapping.items():
358 module_and_alias_mapping[module].extend(aliases)
359 for module, imports in self.module_mapping.items():
360 module_and_alias_mapping[module].extend(
361 [(object, None) for object in imports]
362 )
363 module_and_alias_mapping = {
364 module: sorted(aliases)
365 for module, aliases in module_and_alias_mapping.items()
366 }
367 # Now, add all of the imports we need!
368 return updated_node.with_changes(
369 # pyre-fixme[60]: Concatenation not yet support for multiple variadic tup...
370 body=(
371 *statements_before_imports,
372 *[
373 parse_statement(
374 f"from {module} import "
375 + ", ".join(
376 [
377 obj if alias is None else f"{obj} as {alias}"
378 for (obj, alias) in aliases
379 ]
380 ),
381 config=updated_node.config_for_parsing,
382 )
383 for module, aliases in module_and_alias_mapping.items()
384 if module == "__future__"
385 ],
386 *statements_until_add_imports,
387 *[
388 parse_statement(
389 f"import {module}", config=updated_node.config_for_parsing
390 )
391 for module in sorted(self.module_imports)
392 ],
393 *[
394 parse_statement(
395 f"import {module} as {asname}",
396 config=updated_node.config_for_parsing,
397 )
398 for (module, asname) in self.module_aliases.items()
399 ],
400 *[
401 parse_statement(
402 f"from {module} import "
403 + ", ".join(
404 [
405 obj if alias is None else f"{obj} as {alias}"
406 for (obj, alias) in aliases
407 ]
408 ),
409 config=updated_node.config_for_parsing,
410 )
411 for module, aliases in module_and_alias_mapping.items()
412 if module != "__future__"
413 ],
414 *statements_after_imports,
415 )
416 )