Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/libcst/codemod/visitors/_add_imports.py: 20%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

136 statements  

1# Copyright (c) Meta Platforms, Inc. and affiliates. 

2# 

3# This source code is licensed under the MIT license found in the 

4# LICENSE file in the root directory of this source tree. 

5 

6from collections import defaultdict 

7from typing import Dict, List, Optional, Sequence, Set, Tuple, Union 

8 

9import libcst 

10from libcst import CSTLogicError, matchers as m, parse_statement 

11from libcst._nodes.statement import Import, ImportFrom, SimpleStatementLine 

12from libcst.codemod._context import CodemodContext 

13from libcst.codemod._visitor import ContextAwareTransformer 

14from libcst.codemod.visitors._gather_imports import _GatherImportsMixin 

15from libcst.codemod.visitors._imports import ImportItem 

16from libcst.helpers import get_absolute_module_from_package_for_import 

17from libcst.helpers.common import ensure_type 

18 

19 

20class _GatherTopImportsBeforeStatements(_GatherImportsMixin): 

21 """ 

22 Works similarly to GatherImportsVisitor, but only considers imports 

23 declared before any other statements of the module with the exception 

24 of docstrings and __strict__ flag. 

25 """ 

26 

27 def __init__(self, context: CodemodContext) -> None: 

28 super().__init__(context) 

29 # Track all of the imports found in this transform 

30 self.all_imports: List[Union[libcst.Import, libcst.ImportFrom]] = [] 

31 

32 def leave_Module(self, original_node: libcst.Module) -> None: 

33 start = 1 if _skip_first(original_node) else 0 

34 for stmt in original_node.body[start:]: 

35 if m.matches( 

36 stmt, 

37 m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()]), 

38 ): 

39 stmt = ensure_type(stmt, SimpleStatementLine) 

40 # Workaround for python 3.8 and 3.9, won't accept Union for isinstance 

41 if m.matches(stmt.body[0], m.ImportFrom()): 

42 imp = ensure_type(stmt.body[0], ImportFrom) 

43 self.all_imports.append(imp) 

44 if m.matches(stmt.body[0], m.Import()): 

45 imp = ensure_type(stmt.body[0], Import) 

46 self.all_imports.append(imp) 

47 else: 

48 break 

49 for imp in self.all_imports: 

50 if m.matches(imp, m.Import()): 

51 imp = ensure_type(imp, Import) 

52 self._handle_Import(imp) 

53 else: 

54 imp = ensure_type(imp, ImportFrom) 

55 self._handle_ImportFrom(imp) 

56 

57 

58class AddImportsVisitor(ContextAwareTransformer): 

59 """ 

60 Ensures that given imports exist in a module. Given a 

61 :class:`~libcst.codemod.CodemodContext` and a sequence of tuples specifying 

62 a module to import from as a string. Optionally an object to import from 

63 that module and any alias to assign that import, ensures that import exists. 

64 It will modify existing imports as necessary if the module in question is 

65 already being imported from. 

66 

67 This is one of the transforms that is available automatically to you when 

68 running a codemod. To use it in this manner, import 

69 :class:`~libcst.codemod.visitors.AddImportsVisitor` and then call the static 

70 :meth:`~libcst.codemod.visitors.AddImportsVisitor.add_needed_import` method, 

71 giving it the current context (found as ``self.context`` for all subclasses of 

72 :class:`~libcst.codemod.Codemod`), the module you wish to import from and 

73 optionally an object you wish to import from that module and any alias you 

74 would like to assign that import to. 

75 

76 For example:: 

77 

78 AddImportsVisitor.add_needed_import(self.context, "typing", "Optional") 

79 

80 This will produce the following code in a module, assuming there was no 

81 typing import already:: 

82 

83 from typing import Optional 

84 

85 As another example:: 

86 

87 AddImportsVisitor.add_needed_import(self.context, "typing") 

88 

89 This will produce the following code in a module, assuming there was no 

90 import already:: 

91 

92 import typing 

93 

94 Note that this is a subclass of :class:`~libcst.CSTTransformer` so it is 

95 possible to instantiate it and pass it to a :class:`~libcst.Module` 

96 :meth:`~libcst.CSTNode.visit` method. However, it is far easier to use 

97 the automatic transform feature of :class:`~libcst.codemod.CodemodCommand` 

98 and schedule an import to be added by calling 

99 :meth:`~libcst.codemod.visitors.AddImportsVisitor.add_needed_import` 

100 """ 

101 

102 CONTEXT_KEY = "AddImportsVisitor" 

103 

104 @staticmethod 

105 def _get_imports_from_context( 

106 context: CodemodContext, 

107 ) -> List[ImportItem]: 

108 imports = context.scratch.get(AddImportsVisitor.CONTEXT_KEY, []) 

109 if not isinstance(imports, list): 

110 raise CSTLogicError("Logic error!") 

111 return imports 

112 

113 @staticmethod 

114 def add_needed_import( 

115 context: CodemodContext, 

116 module: str, 

117 obj: Optional[str] = None, 

118 asname: Optional[str] = None, 

119 relative: int = 0, 

120 ) -> None: 

121 """ 

122 Schedule an import to be added in a future invocation of this class by 

123 updating the ``context`` to include the ``module`` and optionally ``obj`` 

124 to be imported as well as optionally ``alias`` to alias the imported 

125 ``module`` or ``obj`` to. When subclassing from 

126 :class:`~libcst.codemod.CodemodCommand`, this will be performed for you 

127 after your transform finishes executing. If you are subclassing from a 

128 :class:`~libcst.codemod.Codemod` instead, you will need to call the 

129 :meth:`~libcst.codemod.Codemod.transform_module` method on the module 

130 under modification with an instance of this class after performing your 

131 transform. Note that if the particular ``module`` or ``obj`` you are 

132 requesting to import already exists as an import on the current module 

133 at the time of executing :meth:`~libcst.codemod.Codemod.transform_module` 

134 on an instance of :class:`~libcst.codemod.visitors.AddImportsVisitor`, 

135 this will perform no action in order to avoid adding duplicate imports. 

136 """ 

137 

138 if module == "__future__" and obj is None: 

139 raise ValueError("Cannot import __future__ directly!") 

140 imports = AddImportsVisitor._get_imports_from_context(context) 

141 imports.append(ImportItem(module, obj, asname, relative)) 

142 context.scratch[AddImportsVisitor.CONTEXT_KEY] = imports 

143 

144 def __init__( 

145 self, 

146 context: CodemodContext, 

147 imports: Sequence[ImportItem] = (), 

148 ) -> None: 

149 # Allow for instantiation from either a context (used when multiple transforms 

150 # get chained) or from a direct instantiation. 

151 super().__init__(context) 

152 imps: List[ImportItem] = [ 

153 *AddImportsVisitor._get_imports_from_context(context), 

154 *imports, 

155 ] 

156 

157 # Verify that the imports are valid 

158 for imp in imps: 

159 if imp.module == "__future__" and imp.obj_name is None: 

160 raise ValueError("Cannot import __future__ directly!") 

161 if imp.module == "__future__" and imp.alias is not None: 

162 raise ValueError("Cannot import __future__ objects with aliases!") 

163 

164 # Resolve relative imports if we have a module name 

165 imps = [imp.resolve_relative(self.context.full_package_name) for imp in imps] 

166 

167 # List of modules we need to ensure are imported 

168 self.module_imports: Set[str] = { 

169 imp.module for imp in imps if imp.obj_name is None and imp.alias is None 

170 } 

171 

172 # List of modules we need to check for object imports on 

173 from_imports: Set[str] = { 

174 imp.module for imp in imps if imp.obj_name is not None and imp.alias is None 

175 } 

176 # Mapping of modules we're adding to the object they should import 

177 self.module_mapping: Dict[str, Set[str]] = { 

178 module: { 

179 imp.obj_name 

180 for imp in imps 

181 if imp.module == module 

182 and imp.obj_name is not None 

183 and imp.alias is None 

184 } 

185 for module in sorted(from_imports) 

186 } 

187 

188 # List of aliased modules we need to ensure are imported 

189 self.module_aliases: Dict[str, str] = { 

190 imp.module: imp.alias 

191 for imp in imps 

192 if imp.obj_name is None and imp.alias is not None 

193 } 

194 # List of modules we need to check for object imports on 

195 from_imports_aliases: Set[str] = { 

196 imp.module 

197 for imp in imps 

198 if imp.obj_name is not None and imp.alias is not None 

199 } 

200 # Mapping of modules we're adding to the object with alias they should import 

201 self.alias_mapping: Dict[str, List[Tuple[str, str]]] = { 

202 module: [ 

203 (imp.obj_name, imp.alias) 

204 for imp in imps 

205 if imp.module == module 

206 and imp.obj_name is not None 

207 and imp.alias is not None 

208 ] 

209 for module in sorted(from_imports_aliases) 

210 } 

211 

212 # Track the list of imports found at the top of the file 

213 self.all_imports: List[Union[libcst.Import, libcst.ImportFrom]] = [] 

214 

215 def visit_Module(self, node: libcst.Module) -> None: 

216 # Do a preliminary pass to gather the imports we already have at the top 

217 gatherer = _GatherTopImportsBeforeStatements(self.context) 

218 node.visit(gatherer) 

219 self.all_imports = gatherer.all_imports 

220 

221 self.module_imports = self.module_imports - gatherer.module_imports 

222 for module, alias in gatherer.module_aliases.items(): 

223 if module in self.module_aliases and self.module_aliases[module] == alias: 

224 del self.module_aliases[module] 

225 for module, aliases in gatherer.alias_mapping.items(): 

226 for obj, alias in aliases: 

227 if ( 

228 module in self.alias_mapping 

229 and (obj, alias) in self.alias_mapping[module] 

230 ): 

231 self.alias_mapping[module].remove((obj, alias)) 

232 if len(self.alias_mapping[module]) == 0: 

233 del self.alias_mapping[module] 

234 

235 for module, imports in gatherer.object_mapping.items(): 

236 if module not in self.module_mapping: 

237 # We don't care about this import at all 

238 continue 

239 elif "*" in imports: 

240 # We already implicitly are importing everything 

241 del self.module_mapping[module] 

242 else: 

243 # Lets figure out what's left to import 

244 self.module_mapping[module] = self.module_mapping[module] - imports 

245 if not self.module_mapping[module]: 

246 # There's nothing left, so lets delete this work item 

247 del self.module_mapping[module] 

248 

249 def leave_ImportFrom( 

250 self, original_node: libcst.ImportFrom, updated_node: libcst.ImportFrom 

251 ) -> libcst.ImportFrom: 

252 if isinstance(updated_node.names, libcst.ImportStar): 

253 # There's nothing to do here! 

254 return updated_node 

255 

256 # Ensure this is one of the imports at the top 

257 if original_node not in self.all_imports: 

258 return updated_node 

259 

260 # Get the module we're importing as a string, see if we have work to do. 

261 module = get_absolute_module_from_package_for_import( 

262 self.context.full_package_name, updated_node 

263 ) 

264 if ( 

265 module is None 

266 or module not in self.module_mapping 

267 and module not in self.alias_mapping 

268 ): 

269 return updated_node 

270 

271 # We have work to do, mark that we won't modify this again. 

272 imports_to_add = self.module_mapping.get(module, []) 

273 if module in self.module_mapping: 

274 del self.module_mapping[module] 

275 aliases_to_add = self.alias_mapping.get(module, []) 

276 if module in self.alias_mapping: 

277 del self.alias_mapping[module] 

278 

279 # Now, do the actual update. 

280 return updated_node.with_changes( 

281 names=[ 

282 *( 

283 libcst.ImportAlias(name=libcst.Name(imp)) 

284 for imp in sorted(imports_to_add) 

285 ), 

286 *( 

287 libcst.ImportAlias( 

288 name=libcst.Name(imp), 

289 asname=libcst.AsName(name=libcst.Name(alias)), 

290 ) 

291 for (imp, alias) in sorted(aliases_to_add) 

292 ), 

293 *updated_node.names, 

294 ] 

295 ) 

296 

297 def _split_module( 

298 self, orig_module: libcst.Module, updated_module: libcst.Module 

299 ) -> Tuple[ 

300 List[Union[libcst.SimpleStatementLine, libcst.BaseCompoundStatement]], 

301 List[Union[libcst.SimpleStatementLine, libcst.BaseCompoundStatement]], 

302 List[Union[libcst.SimpleStatementLine, libcst.BaseCompoundStatement]], 

303 ]: 

304 statement_before_import_location = 0 

305 import_add_location = 0 

306 

307 # This works under the principle that while we might modify node contents, 

308 # we have yet to modify the number of statements. So we can match on the 

309 # original tree but break up the statements of the modified tree. If we 

310 # change this assumption in this visitor, we will have to change this code. 

311 

312 # Finds the location to add imports. It is the end of the first import block that occurs before any other statement (save for docstrings) 

313 

314 # Never insert an import before initial __strict__ flag or docstring 

315 if _skip_first(orig_module): 

316 statement_before_import_location = import_add_location = 1 

317 

318 for i, statement in enumerate( 

319 orig_module.body[statement_before_import_location:] 

320 ): 

321 if m.matches( 

322 statement, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()]) 

323 ): 

324 import_add_location = i + statement_before_import_location + 1 

325 else: 

326 break 

327 

328 return ( 

329 list(updated_module.body[:statement_before_import_location]), 

330 list( 

331 updated_module.body[ 

332 statement_before_import_location:import_add_location 

333 ] 

334 ), 

335 list(updated_module.body[import_add_location:]), 

336 ) 

337 

338 def _insert_empty_line( 

339 self, 

340 statements: List[ 

341 Union[libcst.SimpleStatementLine, libcst.BaseCompoundStatement] 

342 ], 

343 ) -> List[Union[libcst.SimpleStatementLine, libcst.BaseCompoundStatement]]: 

344 if len(statements) < 1: 

345 # No statements, nothing to add to 

346 return statements 

347 if len(statements[0].leading_lines) == 0: 

348 # Statement has no leading lines, add one! 

349 return [ 

350 statements[0].with_changes(leading_lines=(libcst.EmptyLine(),)), 

351 *statements[1:], 

352 ] 

353 if statements[0].leading_lines[0].comment is None: 

354 # First line is empty, so its safe to leave as-is 

355 return statements 

356 # Statement has a comment first line, so lets add one more empty line 

357 return [ 

358 statements[0].with_changes( 

359 leading_lines=(libcst.EmptyLine(), *statements[0].leading_lines) 

360 ), 

361 *statements[1:], 

362 ] 

363 

364 def leave_Module( 

365 self, original_node: libcst.Module, updated_node: libcst.Module 

366 ) -> libcst.Module: 

367 # Don't try to modify if we have nothing to do 

368 if ( 

369 not self.module_imports 

370 and not self.module_mapping 

371 and not self.module_aliases 

372 and not self.alias_mapping 

373 ): 

374 return updated_node 

375 

376 # First, find the insertion point for imports 

377 ( 

378 statements_before_imports, 

379 statements_until_add_imports, 

380 statements_after_imports, 

381 ) = self._split_module(original_node, updated_node) 

382 

383 # Make sure there's at least one empty line before the first non-import 

384 statements_after_imports = self._insert_empty_line(statements_after_imports) 

385 

386 # Mapping of modules we're adding to the object with and without alias they should import 

387 module_and_alias_mapping = defaultdict(list) 

388 for module, aliases in self.alias_mapping.items(): 

389 module_and_alias_mapping[module].extend(aliases) 

390 for module, imports in self.module_mapping.items(): 

391 module_and_alias_mapping[module].extend( 

392 [(object, None) for object in imports] 

393 ) 

394 module_and_alias_mapping = { 

395 module: sorted(aliases) 

396 for module, aliases in module_and_alias_mapping.items() 

397 } 

398 # Now, add all of the imports we need! 

399 return updated_node.with_changes( 

400 # pyre-fixme[60]: Concatenation not yet support for multiple variadic tup... 

401 body=( 

402 *statements_before_imports, 

403 *[ 

404 parse_statement( 

405 f"from {module} import " 

406 + ", ".join( 

407 [ 

408 obj if alias is None else f"{obj} as {alias}" 

409 for (obj, alias) in aliases 

410 ] 

411 ), 

412 config=updated_node.config_for_parsing, 

413 ) 

414 for module, aliases in module_and_alias_mapping.items() 

415 if module == "__future__" 

416 ], 

417 *statements_until_add_imports, 

418 *[ 

419 parse_statement( 

420 f"import {module}", config=updated_node.config_for_parsing 

421 ) 

422 for module in sorted(self.module_imports) 

423 ], 

424 *[ 

425 parse_statement( 

426 f"import {module} as {asname}", 

427 config=updated_node.config_for_parsing, 

428 ) 

429 for (module, asname) in self.module_aliases.items() 

430 ], 

431 *[ 

432 parse_statement( 

433 f"from {module} import " 

434 + ", ".join( 

435 [ 

436 obj if alias is None else f"{obj} as {alias}" 

437 for (obj, alias) in aliases 

438 ] 

439 ), 

440 config=updated_node.config_for_parsing, 

441 ) 

442 for module, aliases in module_and_alias_mapping.items() 

443 if module != "__future__" 

444 ], 

445 *statements_after_imports, 

446 ) 

447 ) 

448 

449 

450def _skip_first(orig_module: libcst.Module) -> bool: 

451 # Is there a __strict__ flag or docstring at the top? 

452 if m.matches( 

453 orig_module, 

454 m.Module( 

455 body=[ 

456 m.SimpleStatementLine( 

457 body=[ 

458 m.Assign(targets=[m.AssignTarget(target=m.Name("__strict__"))]) 

459 ] 

460 ), 

461 m.ZeroOrMore(), 

462 ] 

463 ) 

464 | m.Module( 

465 body=[ 

466 m.SimpleStatementLine(body=[m.Expr(value=m.SimpleString())]), 

467 m.ZeroOrMore(), 

468 ] 

469 ), 

470 ): 

471 return True 

472 return False