Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/libcst/codemod/visitors/_add_imports.py: 19%

110 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-25 06:43 +0000

1# Copyright (c) Meta Platforms, Inc. and affiliates. 

2# 

3# This source code is licensed under the MIT license found in the 

4# LICENSE file in the root directory of this source tree. 

5 

6from collections import defaultdict 

7from typing import Dict, List, Optional, Sequence, Set, Tuple, Union 

8 

9import libcst 

10from libcst import matchers as m, parse_statement 

11from libcst.codemod._context import CodemodContext 

12from libcst.codemod._visitor import ContextAwareTransformer 

13from libcst.codemod.visitors._gather_imports import GatherImportsVisitor 

14from libcst.codemod.visitors._imports import ImportItem 

15from libcst.helpers import get_absolute_module_from_package_for_import 

16 

17 

18class AddImportsVisitor(ContextAwareTransformer): 

19 """ 

20 Ensures that given imports exist in a module. Given a 

21 :class:`~libcst.codemod.CodemodContext` and a sequence of tuples specifying 

22 a module to import from as a string. Optionally an object to import from 

23 that module and any alias to assign that import, ensures that import exists. 

24 It will modify existing imports as necessary if the module in question is 

25 already being imported from. 

26 

27 This is one of the transforms that is available automatically to you when 

28 running a codemod. To use it in this manner, import 

29 :class:`~libcst.codemod.visitors.AddImportsVisitor` and then call the static 

30 :meth:`~libcst.codemod.visitors.AddImportsVisitor.add_needed_import` method, 

31 giving it the current context (found as ``self.context`` for all subclasses of 

32 :class:`~libcst.codemod.Codemod`), the module you wish to import from and 

33 optionally an object you wish to import from that module and any alias you 

34 would like to assign that import to. 

35 

36 For example:: 

37 

38 AddImportsVisitor.add_needed_import(self.context, "typing", "Optional") 

39 

40 This will produce the following code in a module, assuming there was no 

41 typing import already:: 

42 

43 from typing import Optional 

44 

45 As another example:: 

46 

47 AddImportsVisitor.add_needed_import(self.context, "typing") 

48 

49 This will produce the following code in a module, assuming there was no 

50 import already:: 

51 

52 import typing 

53 

54 Note that this is a subclass of :class:`~libcst.CSTTransformer` so it is 

55 possible to instantiate it and pass it to a :class:`~libcst.Module` 

56 :meth:`~libcst.CSTNode.visit` method. However, it is far easier to use 

57 the automatic transform feature of :class:`~libcst.codemod.CodemodCommand` 

58 and schedule an import to be added by calling 

59 :meth:`~libcst.codemod.visitors.AddImportsVisitor.add_needed_import` 

60 """ 

61 

62 CONTEXT_KEY = "AddImportsVisitor" 

63 

64 @staticmethod 

65 def _get_imports_from_context( 

66 context: CodemodContext, 

67 ) -> List[ImportItem]: 

68 imports = context.scratch.get(AddImportsVisitor.CONTEXT_KEY, []) 

69 if not isinstance(imports, list): 

70 raise Exception("Logic error!") 

71 return imports 

72 

73 @staticmethod 

74 def add_needed_import( 

75 context: CodemodContext, 

76 module: str, 

77 obj: Optional[str] = None, 

78 asname: Optional[str] = None, 

79 relative: int = 0, 

80 ) -> None: 

81 """ 

82 Schedule an import to be added in a future invocation of this class by 

83 updating the ``context`` to include the ``module`` and optionally ``obj`` 

84 to be imported as well as optionally ``alias`` to alias the imported 

85 ``module`` or ``obj`` to. When subclassing from 

86 :class:`~libcst.codemod.CodemodCommand`, this will be performed for you 

87 after your transform finishes executing. If you are subclassing from a 

88 :class:`~libcst.codemod.Codemod` instead, you will need to call the 

89 :meth:`~libcst.codemod.Codemod.transform_module` method on the module 

90 under modification with an instance of this class after performing your 

91 transform. Note that if the particular ``module`` or ``obj`` you are 

92 requesting to import already exists as an import on the current module 

93 at the time of executing :meth:`~libcst.codemod.Codemod.transform_module` 

94 on an instance of :class:`~libcst.codemod.visitors.AddImportsVisitor`, 

95 this will perform no action in order to avoid adding duplicate imports. 

96 """ 

97 

98 if module == "__future__" and obj is None: 

99 raise Exception("Cannot import __future__ directly!") 

100 imports = AddImportsVisitor._get_imports_from_context(context) 

101 imports.append(ImportItem(module, obj, asname, relative)) 

102 context.scratch[AddImportsVisitor.CONTEXT_KEY] = imports 

103 

104 def __init__( 

105 self, 

106 context: CodemodContext, 

107 imports: Sequence[ImportItem] = (), 

108 ) -> None: 

109 # Allow for instantiation from either a context (used when multiple transforms 

110 # get chained) or from a direct instantiation. 

111 super().__init__(context) 

112 imps: List[ImportItem] = [ 

113 *AddImportsVisitor._get_imports_from_context(context), 

114 *imports, 

115 ] 

116 

117 # Verify that the imports are valid 

118 for imp in imps: 

119 if imp.module == "__future__" and imp.obj_name is None: 

120 raise Exception("Cannot import __future__ directly!") 

121 if imp.module == "__future__" and imp.alias is not None: 

122 raise Exception("Cannot import __future__ objects with aliases!") 

123 

124 # Resolve relative imports if we have a module name 

125 imps = [imp.resolve_relative(self.context.full_package_name) for imp in imps] 

126 

127 # List of modules we need to ensure are imported 

128 self.module_imports: Set[str] = { 

129 imp.module for imp in imps if imp.obj_name is None and imp.alias is None 

130 } 

131 

132 # List of modules we need to check for object imports on 

133 from_imports: Set[str] = { 

134 imp.module for imp in imps if imp.obj_name is not None and imp.alias is None 

135 } 

136 # Mapping of modules we're adding to the object they should import 

137 self.module_mapping: Dict[str, Set[str]] = { 

138 module: { 

139 imp.obj_name 

140 for imp in imps 

141 if imp.module == module 

142 and imp.obj_name is not None 

143 and imp.alias is None 

144 } 

145 for module in sorted(from_imports) 

146 } 

147 

148 # List of aliased modules we need to ensure are imported 

149 self.module_aliases: Dict[str, str] = { 

150 imp.module: imp.alias 

151 for imp in imps 

152 if imp.obj_name is None and imp.alias is not None 

153 } 

154 # List of modules we need to check for object imports on 

155 from_imports_aliases: Set[str] = { 

156 imp.module 

157 for imp in imps 

158 if imp.obj_name is not None and imp.alias is not None 

159 } 

160 # Mapping of modules we're adding to the object with alias they should import 

161 self.alias_mapping: Dict[str, List[Tuple[str, str]]] = { 

162 module: [ 

163 (imp.obj_name, imp.alias) 

164 for imp in imps 

165 if imp.module == module 

166 and imp.obj_name is not None 

167 and imp.alias is not None 

168 ] 

169 for module in sorted(from_imports_aliases) 

170 } 

171 

172 # Track the list of imports found in the file 

173 self.all_imports: List[Union[libcst.Import, libcst.ImportFrom]] = [] 

174 

175 def visit_Module(self, node: libcst.Module) -> None: 

176 # Do a preliminary pass to gather the imports we already have 

177 gatherer = GatherImportsVisitor(self.context) 

178 node.visit(gatherer) 

179 self.all_imports = gatherer.all_imports 

180 

181 self.module_imports = self.module_imports - gatherer.module_imports 

182 for module, alias in gatherer.module_aliases.items(): 

183 if module in self.module_aliases and self.module_aliases[module] == alias: 

184 del self.module_aliases[module] 

185 for module, aliases in gatherer.alias_mapping.items(): 

186 for obj, alias in aliases: 

187 if ( 

188 module in self.alias_mapping 

189 and (obj, alias) in self.alias_mapping[module] 

190 ): 

191 self.alias_mapping[module].remove((obj, alias)) 

192 if len(self.alias_mapping[module]) == 0: 

193 del self.alias_mapping[module] 

194 

195 for module, imports in gatherer.object_mapping.items(): 

196 if module not in self.module_mapping: 

197 # We don't care about this import at all 

198 continue 

199 elif "*" in imports: 

200 # We already implicitly are importing everything 

201 del self.module_mapping[module] 

202 else: 

203 # Lets figure out what's left to import 

204 self.module_mapping[module] = self.module_mapping[module] - imports 

205 if not self.module_mapping[module]: 

206 # There's nothing left, so lets delete this work item 

207 del self.module_mapping[module] 

208 

209 def leave_ImportFrom( 

210 self, original_node: libcst.ImportFrom, updated_node: libcst.ImportFrom 

211 ) -> libcst.ImportFrom: 

212 if isinstance(updated_node.names, libcst.ImportStar): 

213 # There's nothing to do here! 

214 return updated_node 

215 

216 # Get the module we're importing as a string, see if we have work to do. 

217 module = get_absolute_module_from_package_for_import( 

218 self.context.full_package_name, updated_node 

219 ) 

220 if ( 

221 module is None 

222 or module not in self.module_mapping 

223 and module not in self.alias_mapping 

224 ): 

225 return updated_node 

226 

227 # We have work to do, mark that we won't modify this again. 

228 imports_to_add = self.module_mapping.get(module, []) 

229 if module in self.module_mapping: 

230 del self.module_mapping[module] 

231 aliases_to_add = self.alias_mapping.get(module, []) 

232 if module in self.alias_mapping: 

233 del self.alias_mapping[module] 

234 

235 # Now, do the actual update. 

236 return updated_node.with_changes( 

237 names=[ 

238 *( 

239 libcst.ImportAlias(name=libcst.Name(imp)) 

240 for imp in sorted(imports_to_add) 

241 ), 

242 *( 

243 libcst.ImportAlias( 

244 name=libcst.Name(imp), 

245 asname=libcst.AsName(name=libcst.Name(alias)), 

246 ) 

247 for (imp, alias) in sorted(aliases_to_add) 

248 ), 

249 *updated_node.names, 

250 ] 

251 ) 

252 

253 def _split_module( 

254 self, orig_module: libcst.Module, updated_module: libcst.Module 

255 ) -> Tuple[ 

256 List[Union[libcst.SimpleStatementLine, libcst.BaseCompoundStatement]], 

257 List[Union[libcst.SimpleStatementLine, libcst.BaseCompoundStatement]], 

258 List[Union[libcst.SimpleStatementLine, libcst.BaseCompoundStatement]], 

259 ]: 

260 statement_before_import_location = 0 

261 import_add_location = 0 

262 

263 # never insert an import before initial __strict__ flag 

264 if m.matches( 

265 orig_module, 

266 m.Module( 

267 body=[ 

268 m.SimpleStatementLine( 

269 body=[ 

270 m.Assign( 

271 targets=[m.AssignTarget(target=m.Name("__strict__"))] 

272 ) 

273 ] 

274 ), 

275 m.ZeroOrMore(), 

276 ] 

277 ), 

278 ): 

279 statement_before_import_location = import_add_location = 1 

280 

281 # This works under the principle that while we might modify node contents, 

282 # we have yet to modify the number of statements. So we can match on the 

283 # original tree but break up the statements of the modified tree. If we 

284 # change this assumption in this visitor, we will have to change this code. 

285 for i, statement in enumerate(orig_module.body): 

286 if i == 0 and m.matches( 

287 statement, m.SimpleStatementLine(body=[m.Expr(value=m.SimpleString())]) 

288 ): 

289 statement_before_import_location = import_add_location = 1 

290 elif isinstance(statement, libcst.SimpleStatementLine): 

291 for possible_import in statement.body: 

292 for last_import in self.all_imports: 

293 if possible_import is last_import: 

294 import_add_location = i + 1 

295 break 

296 

297 return ( 

298 list(updated_module.body[:statement_before_import_location]), 

299 list( 

300 updated_module.body[ 

301 statement_before_import_location:import_add_location 

302 ] 

303 ), 

304 list(updated_module.body[import_add_location:]), 

305 ) 

306 

307 def _insert_empty_line( 

308 self, 

309 statements: List[ 

310 Union[libcst.SimpleStatementLine, libcst.BaseCompoundStatement] 

311 ], 

312 ) -> List[Union[libcst.SimpleStatementLine, libcst.BaseCompoundStatement]]: 

313 if len(statements) < 1: 

314 # No statements, nothing to add to 

315 return statements 

316 if len(statements[0].leading_lines) == 0: 

317 # Statement has no leading lines, add one! 

318 return [ 

319 statements[0].with_changes(leading_lines=(libcst.EmptyLine(),)), 

320 *statements[1:], 

321 ] 

322 if statements[0].leading_lines[0].comment is None: 

323 # First line is empty, so its safe to leave as-is 

324 return statements 

325 # Statement has a comment first line, so lets add one more empty line 

326 return [ 

327 statements[0].with_changes( 

328 leading_lines=(libcst.EmptyLine(), *statements[0].leading_lines) 

329 ), 

330 *statements[1:], 

331 ] 

332 

333 def leave_Module( 

334 self, original_node: libcst.Module, updated_node: libcst.Module 

335 ) -> libcst.Module: 

336 # Don't try to modify if we have nothing to do 

337 if ( 

338 not self.module_imports 

339 and not self.module_mapping 

340 and not self.module_aliases 

341 and not self.alias_mapping 

342 ): 

343 return updated_node 

344 

345 # First, find the insertion point for imports 

346 ( 

347 statements_before_imports, 

348 statements_until_add_imports, 

349 statements_after_imports, 

350 ) = self._split_module(original_node, updated_node) 

351 

352 # Make sure there's at least one empty line before the first non-import 

353 statements_after_imports = self._insert_empty_line(statements_after_imports) 

354 

355 # Mapping of modules we're adding to the object with and without alias they should import 

356 module_and_alias_mapping = defaultdict(list) 

357 for module, aliases in self.alias_mapping.items(): 

358 module_and_alias_mapping[module].extend(aliases) 

359 for module, imports in self.module_mapping.items(): 

360 module_and_alias_mapping[module].extend( 

361 [(object, None) for object in imports] 

362 ) 

363 module_and_alias_mapping = { 

364 module: sorted(aliases) 

365 for module, aliases in module_and_alias_mapping.items() 

366 } 

367 # Now, add all of the imports we need! 

368 return updated_node.with_changes( 

369 # pyre-fixme[60]: Concatenation not yet support for multiple variadic tup... 

370 body=( 

371 *statements_before_imports, 

372 *[ 

373 parse_statement( 

374 f"from {module} import " 

375 + ", ".join( 

376 [ 

377 obj if alias is None else f"{obj} as {alias}" 

378 for (obj, alias) in aliases 

379 ] 

380 ), 

381 config=updated_node.config_for_parsing, 

382 ) 

383 for module, aliases in module_and_alias_mapping.items() 

384 if module == "__future__" 

385 ], 

386 *statements_until_add_imports, 

387 *[ 

388 parse_statement( 

389 f"import {module}", config=updated_node.config_for_parsing 

390 ) 

391 for module in sorted(self.module_imports) 

392 ], 

393 *[ 

394 parse_statement( 

395 f"import {module} as {asname}", 

396 config=updated_node.config_for_parsing, 

397 ) 

398 for (module, asname) in self.module_aliases.items() 

399 ], 

400 *[ 

401 parse_statement( 

402 f"from {module} import " 

403 + ", ".join( 

404 [ 

405 obj if alias is None else f"{obj} as {alias}" 

406 for (obj, alias) in aliases 

407 ] 

408 ), 

409 config=updated_node.config_for_parsing, 

410 ) 

411 for module, aliases in module_and_alias_mapping.items() 

412 if module != "__future__" 

413 ], 

414 *statements_after_imports, 

415 ) 

416 )