Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/libcst/helpers/_template.py: 19%

172 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-25 06:43 +0000

1# Copyright (c) Meta Platforms, Inc. and affiliates. 

2# 

3# This source code is licensed under the MIT license found in the 

4# LICENSE file in the root directory of this source tree. 

5# 

6 

7from typing import Dict, Mapping, Optional, Set, Union 

8 

9import libcst as cst 

10from libcst.helpers.common import ensure_type 

11 

12TEMPLATE_PREFIX: str = "__LIBCST_MANGLED_NAME_" 

13TEMPLATE_SUFFIX: str = "_EMAN_DELGNAM_TSCBIL__" 

14 

15 

16ValidReplacementType = Union[ 

17 cst.BaseExpression, 

18 cst.Annotation, 

19 cst.AssignTarget, 

20 cst.Param, 

21 cst.Parameters, 

22 cst.Arg, 

23 cst.BaseStatement, 

24 cst.BaseSmallStatement, 

25 cst.BaseSuite, 

26 cst.BaseSlice, 

27 cst.SubscriptElement, 

28 cst.Decorator, 

29] 

30 

31 

32def mangled_name(var: str) -> str: 

33 return f"{TEMPLATE_PREFIX}{var}{TEMPLATE_SUFFIX}" 

34 

35 

36def unmangled_name(var: str) -> Optional[str]: 

37 if TEMPLATE_PREFIX in var and TEMPLATE_SUFFIX in var: 

38 prefix, name_and_suffix = var.split(TEMPLATE_PREFIX, 1) 

39 name, suffix = name_and_suffix.split(TEMPLATE_SUFFIX, 1) 

40 if not prefix and not suffix: 

41 return name 

42 # This is not a valid mangled name 

43 return None 

44 

45 

46def mangle_template(template: str, template_vars: Set[str]) -> str: 

47 if TEMPLATE_PREFIX in template or TEMPLATE_SUFFIX in template: 

48 raise Exception("Cannot parse a template containing reserved strings") 

49 

50 for var in template_vars: 

51 original = f"{{{var}}}" 

52 if original not in template: 

53 raise Exception( 

54 f'Template string is missing a reference to "{var}" referred to in kwargs' 

55 ) 

56 template = template.replace(original, mangled_name(var)) 

57 return template 

58 

59 

60class TemplateTransformer(cst.CSTTransformer): 

61 def __init__( 

62 self, template_replacements: Mapping[str, ValidReplacementType] 

63 ) -> None: 

64 self.simple_replacements: Dict[str, cst.BaseExpression] = { 

65 name: value 

66 for name, value in template_replacements.items() 

67 if isinstance(value, cst.BaseExpression) 

68 } 

69 self.annotation_replacements: Dict[str, cst.Annotation] = { 

70 name: value 

71 for name, value in template_replacements.items() 

72 if isinstance(value, cst.Annotation) 

73 } 

74 self.assignment_replacements: Dict[str, cst.AssignTarget] = { 

75 name: value 

76 for name, value in template_replacements.items() 

77 if isinstance(value, cst.AssignTarget) 

78 } 

79 self.param_replacements: Dict[str, cst.Param] = { 

80 name: value 

81 for name, value in template_replacements.items() 

82 if isinstance(value, cst.Param) 

83 } 

84 self.parameters_replacements: Dict[str, cst.Parameters] = { 

85 name: value 

86 for name, value in template_replacements.items() 

87 if isinstance(value, cst.Parameters) 

88 } 

89 self.arg_replacements: Dict[str, cst.Arg] = { 

90 name: value 

91 for name, value in template_replacements.items() 

92 if isinstance(value, cst.Arg) 

93 } 

94 self.small_statement_replacements: Dict[str, cst.BaseSmallStatement] = { 

95 name: value 

96 for name, value in template_replacements.items() 

97 if isinstance(value, cst.BaseSmallStatement) 

98 } 

99 self.statement_replacements: Dict[str, cst.BaseStatement] = { 

100 name: value 

101 for name, value in template_replacements.items() 

102 if isinstance(value, cst.BaseStatement) 

103 } 

104 self.suite_replacements: Dict[str, cst.BaseSuite] = { 

105 name: value 

106 for name, value in template_replacements.items() 

107 if isinstance(value, cst.BaseSuite) 

108 } 

109 self.subscript_element_replacements: Dict[str, cst.SubscriptElement] = { 

110 name: value 

111 for name, value in template_replacements.items() 

112 if isinstance(value, cst.SubscriptElement) 

113 } 

114 self.subscript_index_replacements: Dict[str, cst.BaseSlice] = { 

115 name: value 

116 for name, value in template_replacements.items() 

117 if isinstance(value, cst.BaseSlice) 

118 } 

119 self.decorator_replacements: Dict[str, cst.Decorator] = { 

120 name: value 

121 for name, value in template_replacements.items() 

122 if isinstance(value, cst.Decorator) 

123 } 

124 

125 # Figure out if there are any variables that we can't support 

126 # inserting into templates. 

127 supported_vars = { 

128 *[name for name in self.simple_replacements], 

129 *[name for name in self.annotation_replacements], 

130 *[name for name in self.assignment_replacements], 

131 *[name for name in self.param_replacements], 

132 *[name for name in self.parameters_replacements], 

133 *[name for name in self.arg_replacements], 

134 *[name for name in self.small_statement_replacements], 

135 *[name for name in self.statement_replacements], 

136 *[name for name in self.suite_replacements], 

137 *[name for name in self.subscript_element_replacements], 

138 *[name for name in self.subscript_index_replacements], 

139 *[name for name in self.decorator_replacements], 

140 } 

141 unsupported_vars = { 

142 name for name in template_replacements if name not in supported_vars 

143 } 

144 if unsupported_vars: 

145 raise Exception( 

146 f'Template replacement for "{next(iter(unsupported_vars))}" is unsupported' 

147 ) 

148 

149 def leave_Name( 

150 self, original_node: cst.Name, updated_node: cst.Name 

151 ) -> cst.BaseExpression: 

152 var_name = unmangled_name(updated_node.value) 

153 if var_name is None or var_name not in self.simple_replacements: 

154 # This is not a valid name, don't modify it 

155 return updated_node 

156 return self.simple_replacements[var_name].deep_clone() 

157 

158 def leave_Annotation( 

159 self, 

160 original_node: cst.Annotation, 

161 updated_node: cst.Annotation, 

162 ) -> cst.Annotation: 

163 # We can't use matchers here due to circular imports 

164 annotation = updated_node.annotation 

165 if isinstance(annotation, cst.Name): 

166 var_name = unmangled_name(annotation.value) 

167 if var_name in self.annotation_replacements: 

168 return self.annotation_replacements[var_name].deep_clone() 

169 return updated_node 

170 

171 def leave_AssignTarget( 

172 self, 

173 original_node: cst.AssignTarget, 

174 updated_node: cst.AssignTarget, 

175 ) -> cst.AssignTarget: 

176 # We can't use matchers here due to circular imports 

177 target = updated_node.target 

178 if isinstance(target, cst.Name): 

179 var_name = unmangled_name(target.value) 

180 if var_name in self.assignment_replacements: 

181 return self.assignment_replacements[var_name].deep_clone() 

182 return updated_node 

183 

184 def leave_Param( 

185 self, 

186 original_node: cst.Param, 

187 updated_node: cst.Param, 

188 ) -> cst.Param: 

189 var_name = unmangled_name(updated_node.name.value) 

190 if var_name in self.param_replacements: 

191 return self.param_replacements[var_name].deep_clone() 

192 return updated_node 

193 

194 def leave_Parameters( 

195 self, 

196 original_node: cst.Parameters, 

197 updated_node: cst.Parameters, 

198 ) -> cst.Parameters: 

199 # A very special case for when we use a template variable for all 

200 # function parameters. 

201 if ( 

202 len(updated_node.params) == 1 

203 and updated_node.star_arg == cst.MaybeSentinel.DEFAULT 

204 and len(updated_node.kwonly_params) == 0 

205 and updated_node.star_kwarg is None 

206 and len(updated_node.posonly_params) == 0 

207 and updated_node.posonly_ind == cst.MaybeSentinel.DEFAULT 

208 ): 

209 # This parameters node has only one argument, which is possibly 

210 # a replacement. 

211 var_name = unmangled_name(updated_node.params[0].name.value) 

212 if var_name in self.parameters_replacements: 

213 return self.parameters_replacements[var_name].deep_clone() 

214 return updated_node 

215 

216 def leave_Arg(self, original_node: cst.Arg, updated_node: cst.Arg) -> cst.Arg: 

217 # We can't use matchers here due to circular imports 

218 arg = updated_node.value 

219 if isinstance(arg, cst.Name): 

220 var_name = unmangled_name(arg.value) 

221 if var_name in self.arg_replacements: 

222 return self.arg_replacements[var_name].deep_clone() 

223 return updated_node 

224 

225 def leave_SimpleStatementLine( 

226 self, 

227 original_node: cst.SimpleStatementLine, 

228 updated_node: cst.SimpleStatementLine, 

229 ) -> cst.BaseStatement: 

230 # We can't use matchers here due to circular imports. We take advantage of 

231 # the fact that a name on a single line will be parsed as an Expr node 

232 # contained in a SimpleStatementLine, so we check for these and see if they 

233 # should be expanded template-wise to a statement of some type. 

234 if len(updated_node.body) == 1: 

235 body_node = updated_node.body[0] 

236 if isinstance(body_node, cst.Expr): 

237 name_node = body_node.value 

238 if isinstance(name_node, cst.Name): 

239 var_name = unmangled_name(name_node.value) 

240 if var_name in self.statement_replacements: 

241 return self.statement_replacements[var_name].deep_clone() 

242 return updated_node 

243 

244 def leave_Expr( 

245 self, 

246 original_node: cst.Expr, 

247 updated_node: cst.Expr, 

248 ) -> cst.BaseSmallStatement: 

249 # We can't use matchers here due to circular imports. We do a similar trick 

250 # to the above stanza handling SimpleStatementLine to support templates 

251 # which are trying to substitute a BaseSmallStatement. 

252 name_node = updated_node.value 

253 if isinstance(name_node, cst.Name): 

254 var_name = unmangled_name(name_node.value) 

255 if var_name in self.small_statement_replacements: 

256 return self.small_statement_replacements[var_name].deep_clone() 

257 return updated_node 

258 

259 def leave_SimpleStatementSuite( 

260 self, 

261 original_node: cst.SimpleStatementSuite, 

262 updated_node: cst.SimpleStatementSuite, 

263 ) -> cst.BaseSuite: 

264 # We can't use matchers here due to circular imports. We take advantage of 

265 # the fact that a name in a simple suite will be parsed as an Expr node 

266 # contained in a SimpleStatementSuite, so we check for these and see if they 

267 # should be expanded template-wise to a base suite of some type. 

268 if len(updated_node.body) == 1: 

269 body_node = updated_node.body[0] 

270 if isinstance(body_node, cst.Expr): 

271 name_node = body_node.value 

272 if isinstance(name_node, cst.Name): 

273 var_name = unmangled_name(name_node.value) 

274 if var_name in self.suite_replacements: 

275 return self.suite_replacements[var_name].deep_clone() 

276 return updated_node 

277 

278 def leave_IndentedBlock( 

279 self, 

280 original_node: cst.IndentedBlock, 

281 updated_node: cst.IndentedBlock, 

282 ) -> cst.BaseSuite: 

283 # We can't use matchers here due to circular imports. We take advantage of 

284 # the fact that a name in an indented block will be parsed as an Expr node 

285 # contained in a SimpleStatementLine, so we check for these and see if they 

286 # should be expanded template-wise to a base suite of some type. 

287 if len(updated_node.body) == 1: 

288 statement_node = updated_node.body[0] 

289 if ( 

290 isinstance(statement_node, cst.SimpleStatementLine) 

291 and len(statement_node.body) == 1 

292 ): 

293 body_node = statement_node.body[0] 

294 if isinstance(body_node, cst.Expr): 

295 name_node = body_node.value 

296 if isinstance(name_node, cst.Name): 

297 var_name = unmangled_name(name_node.value) 

298 if var_name in self.suite_replacements: 

299 return self.suite_replacements[var_name].deep_clone() 

300 return updated_node 

301 

302 def leave_Index( 

303 self, 

304 original_node: cst.Index, 

305 updated_node: cst.Index, 

306 ) -> cst.BaseSlice: 

307 # We can't use matchers here due to circular imports 

308 expr = updated_node.value 

309 if isinstance(expr, cst.Name): 

310 var_name = unmangled_name(expr.value) 

311 if var_name in self.subscript_index_replacements: 

312 return self.subscript_index_replacements[var_name].deep_clone() 

313 return updated_node 

314 

315 def leave_SubscriptElement( 

316 self, 

317 original_node: cst.SubscriptElement, 

318 updated_node: cst.SubscriptElement, 

319 ) -> cst.SubscriptElement: 

320 # We can't use matchers here due to circular imports. We use the trick 

321 # similar to above stanzas where a template replacement variable will 

322 # always show up as a certain type (in this case an Index inside of a 

323 # SubscriptElement) in order to successfully replace subscript elements 

324 # in templates. 

325 index = updated_node.slice 

326 if isinstance(index, cst.Index): 

327 expr = index.value 

328 if isinstance(expr, cst.Name): 

329 var_name = unmangled_name(expr.value) 

330 if var_name in self.subscript_element_replacements: 

331 return self.subscript_element_replacements[var_name].deep_clone() 

332 return updated_node 

333 

334 def leave_Decorator( 

335 self, original_node: cst.Decorator, updated_node: cst.Decorator 

336 ) -> cst.Decorator: 

337 # We can't use matchers here due to circular imports 

338 decorator = updated_node.decorator 

339 if isinstance(decorator, cst.Name): 

340 var_name = unmangled_name(decorator.value) 

341 if var_name in self.decorator_replacements: 

342 return self.decorator_replacements[var_name].deep_clone() 

343 return updated_node 

344 

345 

346class TemplateChecker(cst.CSTVisitor): 

347 def __init__(self, template_vars: Set[str]) -> None: 

348 self.template_vars = template_vars 

349 

350 def visit_Name(self, node: cst.Name) -> None: 

351 for var in self.template_vars: 

352 if node.value == mangled_name(var): 

353 raise Exception(f'Template variable "{var}" was not replaced properly') 

354 

355 

356def unmangle_nodes( 

357 tree: cst.CSTNode, 

358 template_replacements: Mapping[str, ValidReplacementType], 

359) -> cst.CSTNode: 

360 unmangler = TemplateTransformer(template_replacements) 

361 return ensure_type(tree.visit(unmangler), cst.CSTNode) 

362 

363 

364_DEFAULT_PARTIAL_PARSER_CONFIG: cst.PartialParserConfig = cst.PartialParserConfig() 

365 

366 

367def parse_template_module( 

368 template: str, 

369 config: cst.PartialParserConfig = _DEFAULT_PARTIAL_PARSER_CONFIG, 

370 **template_replacements: ValidReplacementType, 

371) -> cst.Module: 

372 """ 

373 Accepts an entire python module template, including all leading and trailing 

374 whitespace. Any :class:`~libcst.CSTNode` provided as a keyword argument to 

375 this function will be inserted into the template at the appropriate location 

376 similar to an f-string expansion. For example:: 

377 

378 module = parse_template_module("from {mod} import Foo\\n", mod=Name("bar")) 

379 

380 The above code will parse to a module containing a single 

381 :class:`~libcst.FromImport` statement, referencing module ``bar`` and importing 

382 object ``Foo`` from it. Remember that if you are parsing a template as part 

383 of a substitution inside a transform, its considered 

384 :ref:`best practice <libcst-config_best_practice>` to pass in a ``config`` 

385 from the current module under transformation. 

386 

387 Note that unlike :func:`~libcst.parse_module`, this function does not support 

388 bytes as an input. This is due to the fact that it is processed as a template 

389 before parsing as a module. 

390 """ 

391 

392 source = mangle_template(template, {name for name in template_replacements}) 

393 module = cst.parse_module(source, config) 

394 new_module = ensure_type(unmangle_nodes(module, template_replacements), cst.Module) 

395 new_module.visit(TemplateChecker({name for name in template_replacements})) 

396 return new_module 

397 

398 

399def parse_template_statement( 

400 template: str, 

401 config: cst.PartialParserConfig = _DEFAULT_PARTIAL_PARSER_CONFIG, 

402 **template_replacements: ValidReplacementType, 

403) -> Union[cst.SimpleStatementLine, cst.BaseCompoundStatement]: 

404 """ 

405 Accepts a statement template followed by a trailing newline. If a trailing 

406 newline is not provided, one will be added. Any :class:`~libcst.CSTNode` 

407 provided as a keyword argument to this function will be inserted into the 

408 template at the appropriate location similar to an f-string expansion. For 

409 example:: 

410 

411 statement = parse_template_statement("assert x > 0, {msg}", msg=SimpleString('"Uh oh!"')) 

412 

413 The above code will parse to an assert statement checking that some variable 

414 ``x`` is greater than zero, or providing the assert message ``"Uh oh!"``. 

415 

416 Remember that if you are parsing a template as part of a substitution inside 

417 a transform, its considered :ref:`best practice <libcst-config_best_practice>` 

418 to pass in a ``config`` from the current module under transformation. 

419 """ 

420 

421 source = mangle_template(template, {name for name in template_replacements}) 

422 statement = cst.parse_statement(source, config) 

423 new_statement = unmangle_nodes(statement, template_replacements) 

424 if not isinstance( 

425 new_statement, (cst.SimpleStatementLine, cst.BaseCompoundStatement) 

426 ): 

427 raise Exception( 

428 f"Expected a statement but got a {new_statement.__class__.__name__}!" 

429 ) 

430 new_statement.visit(TemplateChecker({name for name in template_replacements})) 

431 return new_statement 

432 

433 

434def parse_template_expression( 

435 template: str, 

436 config: cst.PartialParserConfig = _DEFAULT_PARTIAL_PARSER_CONFIG, 

437 **template_replacements: ValidReplacementType, 

438) -> cst.BaseExpression: 

439 """ 

440 Accepts an expression template on a single line. Leading and trailing whitespace 

441 is not valid (there’s nowhere to store it on the expression node). Any 

442 :class:`~libcst.CSTNode` provided as a keyword argument to this function will 

443 be inserted into the template at the appropriate location similar to an 

444 f-string expansion. For example:: 

445 

446 expression = parse_template_expression("x + {foo}", foo=Name("y"))) 

447 

448 The above code will parse to a :class:`~libcst.BinaryOperation` expression 

449 adding two names (``x`` and ``y``) together. 

450 

451 Remember that if you are parsing a template as part of a substitution inside 

452 a transform, its considered :ref:`best practice <libcst-config_best_practice>` 

453 to pass in a ``config`` from the current module under transformation. 

454 """ 

455 

456 source = mangle_template(template, {name for name in template_replacements}) 

457 expression = cst.parse_expression(source, config) 

458 new_expression = ensure_type( 

459 unmangle_nodes(expression, template_replacements), cst.BaseExpression 

460 ) 

461 new_expression.visit(TemplateChecker({name for name in template_replacements})) 

462 return new_expression