1# Copyright (c) Meta Platforms, Inc. and affiliates.
2#
3# This source code is licensed under the MIT license found in the
4# LICENSE file in the root directory of this source tree.
5
6from collections import defaultdict
7from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
8
9import libcst
10from libcst import CSTLogicError, matchers as m, parse_statement
11from libcst._nodes.statement import Import, ImportFrom, SimpleStatementLine
12from libcst.codemod._context import CodemodContext
13from libcst.codemod._visitor import ContextAwareTransformer
14from libcst.codemod.visitors._gather_imports import _GatherImportsMixin
15from libcst.codemod.visitors._imports import ImportItem
16from libcst.helpers import get_absolute_module_from_package_for_import
17from libcst.helpers.common import ensure_type
18
19
20class _GatherTopImportsBeforeStatements(_GatherImportsMixin):
21 """
22 Works similarly to GatherImportsVisitor, but only considers imports
23 declared before any other statements of the module with the exception
24 of docstrings and __strict__ flag.
25 """
26
27 def __init__(self, context: CodemodContext) -> None:
28 super().__init__(context)
29 # Track all of the imports found in this transform
30 self.all_imports: List[Union[libcst.Import, libcst.ImportFrom]] = []
31
32 def leave_Module(self, original_node: libcst.Module) -> None:
33 start = 1 if _skip_first(original_node) else 0
34 for stmt in original_node.body[start:]:
35 if m.matches(
36 stmt,
37 m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()]),
38 ):
39 stmt = ensure_type(stmt, SimpleStatementLine)
40 # Workaround for python 3.8 and 3.9, won't accept Union for isinstance
41 if m.matches(stmt.body[0], m.ImportFrom()):
42 imp = ensure_type(stmt.body[0], ImportFrom)
43 self.all_imports.append(imp)
44 if m.matches(stmt.body[0], m.Import()):
45 imp = ensure_type(stmt.body[0], Import)
46 self.all_imports.append(imp)
47 else:
48 break
49 for imp in self.all_imports:
50 if m.matches(imp, m.Import()):
51 imp = ensure_type(imp, Import)
52 self._handle_Import(imp)
53 else:
54 imp = ensure_type(imp, ImportFrom)
55 self._handle_ImportFrom(imp)
56
57
58class AddImportsVisitor(ContextAwareTransformer):
59 """
60 Ensures that given imports exist in a module. Given a
61 :class:`~libcst.codemod.CodemodContext` and a sequence of tuples specifying
62 a module to import from as a string. Optionally an object to import from
63 that module and any alias to assign that import, ensures that import exists.
64 It will modify existing imports as necessary if the module in question is
65 already being imported from.
66
67 This is one of the transforms that is available automatically to you when
68 running a codemod. To use it in this manner, import
69 :class:`~libcst.codemod.visitors.AddImportsVisitor` and then call the static
70 :meth:`~libcst.codemod.visitors.AddImportsVisitor.add_needed_import` method,
71 giving it the current context (found as ``self.context`` for all subclasses of
72 :class:`~libcst.codemod.Codemod`), the module you wish to import from and
73 optionally an object you wish to import from that module and any alias you
74 would like to assign that import to.
75
76 For example::
77
78 AddImportsVisitor.add_needed_import(self.context, "typing", "Optional")
79
80 This will produce the following code in a module, assuming there was no
81 typing import already::
82
83 from typing import Optional
84
85 As another example::
86
87 AddImportsVisitor.add_needed_import(self.context, "typing")
88
89 This will produce the following code in a module, assuming there was no
90 import already::
91
92 import typing
93
94 Note that this is a subclass of :class:`~libcst.CSTTransformer` so it is
95 possible to instantiate it and pass it to a :class:`~libcst.Module`
96 :meth:`~libcst.CSTNode.visit` method. However, it is far easier to use
97 the automatic transform feature of :class:`~libcst.codemod.CodemodCommand`
98 and schedule an import to be added by calling
99 :meth:`~libcst.codemod.visitors.AddImportsVisitor.add_needed_import`
100 """
101
102 CONTEXT_KEY = "AddImportsVisitor"
103
104 @staticmethod
105 def _get_imports_from_context(
106 context: CodemodContext,
107 ) -> List[ImportItem]:
108 imports = context.scratch.get(AddImportsVisitor.CONTEXT_KEY, [])
109 if not isinstance(imports, list):
110 raise CSTLogicError("Logic error!")
111 return imports
112
113 @staticmethod
114 def add_needed_import(
115 context: CodemodContext,
116 module: str,
117 obj: Optional[str] = None,
118 asname: Optional[str] = None,
119 relative: int = 0,
120 ) -> None:
121 """
122 Schedule an import to be added in a future invocation of this class by
123 updating the ``context`` to include the ``module`` and optionally ``obj``
124 to be imported as well as optionally ``alias`` to alias the imported
125 ``module`` or ``obj`` to. When subclassing from
126 :class:`~libcst.codemod.CodemodCommand`, this will be performed for you
127 after your transform finishes executing. If you are subclassing from a
128 :class:`~libcst.codemod.Codemod` instead, you will need to call the
129 :meth:`~libcst.codemod.Codemod.transform_module` method on the module
130 under modification with an instance of this class after performing your
131 transform. Note that if the particular ``module`` or ``obj`` you are
132 requesting to import already exists as an import on the current module
133 at the time of executing :meth:`~libcst.codemod.Codemod.transform_module`
134 on an instance of :class:`~libcst.codemod.visitors.AddImportsVisitor`,
135 this will perform no action in order to avoid adding duplicate imports.
136 """
137
138 if module == "__future__" and obj is None:
139 raise ValueError("Cannot import __future__ directly!")
140 imports = AddImportsVisitor._get_imports_from_context(context)
141 imports.append(ImportItem(module, obj, asname, relative))
142 context.scratch[AddImportsVisitor.CONTEXT_KEY] = imports
143
144 def __init__(
145 self,
146 context: CodemodContext,
147 imports: Sequence[ImportItem] = (),
148 ) -> None:
149 # Allow for instantiation from either a context (used when multiple transforms
150 # get chained) or from a direct instantiation.
151 super().__init__(context)
152 imps: List[ImportItem] = [
153 *AddImportsVisitor._get_imports_from_context(context),
154 *imports,
155 ]
156
157 # Verify that the imports are valid
158 for imp in imps:
159 if imp.module == "__future__" and imp.obj_name is None:
160 raise ValueError("Cannot import __future__ directly!")
161 if imp.module == "__future__" and imp.alias is not None:
162 raise ValueError("Cannot import __future__ objects with aliases!")
163
164 # Resolve relative imports if we have a module name
165 imps = [imp.resolve_relative(self.context.full_package_name) for imp in imps]
166
167 # List of modules we need to ensure are imported
168 self.module_imports: Set[str] = {
169 imp.module for imp in imps if imp.obj_name is None and imp.alias is None
170 }
171
172 # List of modules we need to check for object imports on
173 from_imports: Set[str] = {
174 imp.module for imp in imps if imp.obj_name is not None and imp.alias is None
175 }
176 # Mapping of modules we're adding to the object they should import
177 self.module_mapping: Dict[str, Set[str]] = {
178 module: {
179 imp.obj_name
180 for imp in imps
181 if imp.module == module
182 and imp.obj_name is not None
183 and imp.alias is None
184 }
185 for module in sorted(from_imports)
186 }
187
188 # List of aliased modules we need to ensure are imported
189 self.module_aliases: Dict[str, str] = {
190 imp.module: imp.alias
191 for imp in imps
192 if imp.obj_name is None and imp.alias is not None
193 }
194 # List of modules we need to check for object imports on
195 from_imports_aliases: Set[str] = {
196 imp.module
197 for imp in imps
198 if imp.obj_name is not None and imp.alias is not None
199 }
200 # Mapping of modules we're adding to the object with alias they should import
201 self.alias_mapping: Dict[str, List[Tuple[str, str]]] = {
202 module: [
203 (imp.obj_name, imp.alias)
204 for imp in imps
205 if imp.module == module
206 and imp.obj_name is not None
207 and imp.alias is not None
208 ]
209 for module in sorted(from_imports_aliases)
210 }
211
212 # Track the list of imports found at the top of the file
213 self.all_imports: List[Union[libcst.Import, libcst.ImportFrom]] = []
214
215 def visit_Module(self, node: libcst.Module) -> None:
216 # Do a preliminary pass to gather the imports we already have at the top
217 gatherer = _GatherTopImportsBeforeStatements(self.context)
218 node.visit(gatherer)
219 self.all_imports = gatherer.all_imports
220
221 self.module_imports = self.module_imports - gatherer.module_imports
222 for module, alias in gatherer.module_aliases.items():
223 if module in self.module_aliases and self.module_aliases[module] == alias:
224 del self.module_aliases[module]
225 for module, aliases in gatherer.alias_mapping.items():
226 for obj, alias in aliases:
227 if (
228 module in self.alias_mapping
229 and (obj, alias) in self.alias_mapping[module]
230 ):
231 self.alias_mapping[module].remove((obj, alias))
232 if len(self.alias_mapping[module]) == 0:
233 del self.alias_mapping[module]
234
235 for module, imports in gatherer.object_mapping.items():
236 if module not in self.module_mapping:
237 # We don't care about this import at all
238 continue
239 elif "*" in imports:
240 # We already implicitly are importing everything
241 del self.module_mapping[module]
242 else:
243 # Lets figure out what's left to import
244 self.module_mapping[module] = self.module_mapping[module] - imports
245 if not self.module_mapping[module]:
246 # There's nothing left, so lets delete this work item
247 del self.module_mapping[module]
248
249 def leave_ImportFrom(
250 self, original_node: libcst.ImportFrom, updated_node: libcst.ImportFrom
251 ) -> libcst.ImportFrom:
252 if isinstance(updated_node.names, libcst.ImportStar):
253 # There's nothing to do here!
254 return updated_node
255
256 # Ensure this is one of the imports at the top
257 if original_node not in self.all_imports:
258 return updated_node
259
260 # Get the module we're importing as a string, see if we have work to do.
261 module = get_absolute_module_from_package_for_import(
262 self.context.full_package_name, updated_node
263 )
264 if (
265 module is None
266 or module not in self.module_mapping
267 and module not in self.alias_mapping
268 ):
269 return updated_node
270
271 # We have work to do, mark that we won't modify this again.
272 imports_to_add = self.module_mapping.get(module, [])
273 if module in self.module_mapping:
274 del self.module_mapping[module]
275 aliases_to_add = self.alias_mapping.get(module, [])
276 if module in self.alias_mapping:
277 del self.alias_mapping[module]
278
279 # Now, do the actual update.
280 return updated_node.with_changes(
281 names=[
282 *(
283 libcst.ImportAlias(name=libcst.Name(imp))
284 for imp in sorted(imports_to_add)
285 ),
286 *(
287 libcst.ImportAlias(
288 name=libcst.Name(imp),
289 asname=libcst.AsName(name=libcst.Name(alias)),
290 )
291 for (imp, alias) in sorted(aliases_to_add)
292 ),
293 *updated_node.names,
294 ]
295 )
296
297 def _split_module(
298 self, orig_module: libcst.Module, updated_module: libcst.Module
299 ) -> Tuple[
300 List[Union[libcst.SimpleStatementLine, libcst.BaseCompoundStatement]],
301 List[Union[libcst.SimpleStatementLine, libcst.BaseCompoundStatement]],
302 List[Union[libcst.SimpleStatementLine, libcst.BaseCompoundStatement]],
303 ]:
304 statement_before_import_location = 0
305 import_add_location = 0
306
307 # This works under the principle that while we might modify node contents,
308 # we have yet to modify the number of statements. So we can match on the
309 # original tree but break up the statements of the modified tree. If we
310 # change this assumption in this visitor, we will have to change this code.
311
312 # Finds the location to add imports. It is the end of the first import block that occurs before any other statement (save for docstrings)
313
314 # Never insert an import before initial __strict__ flag or docstring
315 if _skip_first(orig_module):
316 statement_before_import_location = import_add_location = 1
317
318 for i, statement in enumerate(
319 orig_module.body[statement_before_import_location:]
320 ):
321 if m.matches(
322 statement, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])
323 ):
324 import_add_location = i + statement_before_import_location + 1
325 else:
326 break
327
328 return (
329 list(updated_module.body[:statement_before_import_location]),
330 list(
331 updated_module.body[
332 statement_before_import_location:import_add_location
333 ]
334 ),
335 list(updated_module.body[import_add_location:]),
336 )
337
338 def _insert_empty_line(
339 self,
340 statements: List[
341 Union[libcst.SimpleStatementLine, libcst.BaseCompoundStatement]
342 ],
343 ) -> List[Union[libcst.SimpleStatementLine, libcst.BaseCompoundStatement]]:
344 if len(statements) < 1:
345 # No statements, nothing to add to
346 return statements
347 if len(statements[0].leading_lines) == 0:
348 # Statement has no leading lines, add one!
349 return [
350 statements[0].with_changes(leading_lines=(libcst.EmptyLine(),)),
351 *statements[1:],
352 ]
353 if statements[0].leading_lines[0].comment is None:
354 # First line is empty, so its safe to leave as-is
355 return statements
356 # Statement has a comment first line, so lets add one more empty line
357 return [
358 statements[0].with_changes(
359 leading_lines=(libcst.EmptyLine(), *statements[0].leading_lines)
360 ),
361 *statements[1:],
362 ]
363
364 def leave_Module(
365 self, original_node: libcst.Module, updated_node: libcst.Module
366 ) -> libcst.Module:
367 # Don't try to modify if we have nothing to do
368 if (
369 not self.module_imports
370 and not self.module_mapping
371 and not self.module_aliases
372 and not self.alias_mapping
373 ):
374 return updated_node
375
376 # First, find the insertion point for imports
377 (
378 statements_before_imports,
379 statements_until_add_imports,
380 statements_after_imports,
381 ) = self._split_module(original_node, updated_node)
382
383 # Make sure there's at least one empty line before the first non-import
384 statements_after_imports = self._insert_empty_line(statements_after_imports)
385
386 # Mapping of modules we're adding to the object with and without alias they should import
387 module_and_alias_mapping = defaultdict(list)
388 for module, aliases in self.alias_mapping.items():
389 module_and_alias_mapping[module].extend(aliases)
390 for module, imports in self.module_mapping.items():
391 module_and_alias_mapping[module].extend(
392 [(object, None) for object in imports]
393 )
394 module_and_alias_mapping = {
395 module: sorted(aliases)
396 for module, aliases in module_and_alias_mapping.items()
397 }
398 # Now, add all of the imports we need!
399 return updated_node.with_changes(
400 # pyre-fixme[60]: Concatenation not yet support for multiple variadic tup...
401 body=(
402 *statements_before_imports,
403 *[
404 parse_statement(
405 f"from {module} import "
406 + ", ".join(
407 [
408 obj if alias is None else f"{obj} as {alias}"
409 for (obj, alias) in aliases
410 ]
411 ),
412 config=updated_node.config_for_parsing,
413 )
414 for module, aliases in module_and_alias_mapping.items()
415 if module == "__future__"
416 ],
417 *statements_until_add_imports,
418 *[
419 parse_statement(
420 f"import {module}", config=updated_node.config_for_parsing
421 )
422 for module in sorted(self.module_imports)
423 ],
424 *[
425 parse_statement(
426 f"import {module} as {asname}",
427 config=updated_node.config_for_parsing,
428 )
429 for (module, asname) in self.module_aliases.items()
430 ],
431 *[
432 parse_statement(
433 f"from {module} import "
434 + ", ".join(
435 [
436 obj if alias is None else f"{obj} as {alias}"
437 for (obj, alias) in aliases
438 ]
439 ),
440 config=updated_node.config_for_parsing,
441 )
442 for module, aliases in module_and_alias_mapping.items()
443 if module != "__future__"
444 ],
445 *statements_after_imports,
446 )
447 )
448
449
450def _skip_first(orig_module: libcst.Module) -> bool:
451 # Is there a __strict__ flag or docstring at the top?
452 if m.matches(
453 orig_module,
454 m.Module(
455 body=[
456 m.SimpleStatementLine(
457 body=[
458 m.Assign(targets=[m.AssignTarget(target=m.Name("__strict__"))])
459 ]
460 ),
461 m.ZeroOrMore(),
462 ]
463 )
464 | m.Module(
465 body=[
466 m.SimpleStatementLine(body=[m.Expr(value=m.SimpleString())]),
467 m.ZeroOrMore(),
468 ]
469 ),
470 ):
471 return True
472 return False