Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/libcst/helpers/_template.py: 19%
172 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:43 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:43 +0000
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2#
3# This source code is licensed under the MIT license found in the
4# LICENSE file in the root directory of this source tree.
5#
7from typing import Dict, Mapping, Optional, Set, Union
9import libcst as cst
10from libcst.helpers.common import ensure_type
12TEMPLATE_PREFIX: str = "__LIBCST_MANGLED_NAME_"
13TEMPLATE_SUFFIX: str = "_EMAN_DELGNAM_TSCBIL__"
16ValidReplacementType = Union[
17 cst.BaseExpression,
18 cst.Annotation,
19 cst.AssignTarget,
20 cst.Param,
21 cst.Parameters,
22 cst.Arg,
23 cst.BaseStatement,
24 cst.BaseSmallStatement,
25 cst.BaseSuite,
26 cst.BaseSlice,
27 cst.SubscriptElement,
28 cst.Decorator,
29]
32def mangled_name(var: str) -> str:
33 return f"{TEMPLATE_PREFIX}{var}{TEMPLATE_SUFFIX}"
36def unmangled_name(var: str) -> Optional[str]:
37 if TEMPLATE_PREFIX in var and TEMPLATE_SUFFIX in var:
38 prefix, name_and_suffix = var.split(TEMPLATE_PREFIX, 1)
39 name, suffix = name_and_suffix.split(TEMPLATE_SUFFIX, 1)
40 if not prefix and not suffix:
41 return name
42 # This is not a valid mangled name
43 return None
46def mangle_template(template: str, template_vars: Set[str]) -> str:
47 if TEMPLATE_PREFIX in template or TEMPLATE_SUFFIX in template:
48 raise Exception("Cannot parse a template containing reserved strings")
50 for var in template_vars:
51 original = f"{{{var}}}"
52 if original not in template:
53 raise Exception(
54 f'Template string is missing a reference to "{var}" referred to in kwargs'
55 )
56 template = template.replace(original, mangled_name(var))
57 return template
60class TemplateTransformer(cst.CSTTransformer):
61 def __init__(
62 self, template_replacements: Mapping[str, ValidReplacementType]
63 ) -> None:
64 self.simple_replacements: Dict[str, cst.BaseExpression] = {
65 name: value
66 for name, value in template_replacements.items()
67 if isinstance(value, cst.BaseExpression)
68 }
69 self.annotation_replacements: Dict[str, cst.Annotation] = {
70 name: value
71 for name, value in template_replacements.items()
72 if isinstance(value, cst.Annotation)
73 }
74 self.assignment_replacements: Dict[str, cst.AssignTarget] = {
75 name: value
76 for name, value in template_replacements.items()
77 if isinstance(value, cst.AssignTarget)
78 }
79 self.param_replacements: Dict[str, cst.Param] = {
80 name: value
81 for name, value in template_replacements.items()
82 if isinstance(value, cst.Param)
83 }
84 self.parameters_replacements: Dict[str, cst.Parameters] = {
85 name: value
86 for name, value in template_replacements.items()
87 if isinstance(value, cst.Parameters)
88 }
89 self.arg_replacements: Dict[str, cst.Arg] = {
90 name: value
91 for name, value in template_replacements.items()
92 if isinstance(value, cst.Arg)
93 }
94 self.small_statement_replacements: Dict[str, cst.BaseSmallStatement] = {
95 name: value
96 for name, value in template_replacements.items()
97 if isinstance(value, cst.BaseSmallStatement)
98 }
99 self.statement_replacements: Dict[str, cst.BaseStatement] = {
100 name: value
101 for name, value in template_replacements.items()
102 if isinstance(value, cst.BaseStatement)
103 }
104 self.suite_replacements: Dict[str, cst.BaseSuite] = {
105 name: value
106 for name, value in template_replacements.items()
107 if isinstance(value, cst.BaseSuite)
108 }
109 self.subscript_element_replacements: Dict[str, cst.SubscriptElement] = {
110 name: value
111 for name, value in template_replacements.items()
112 if isinstance(value, cst.SubscriptElement)
113 }
114 self.subscript_index_replacements: Dict[str, cst.BaseSlice] = {
115 name: value
116 for name, value in template_replacements.items()
117 if isinstance(value, cst.BaseSlice)
118 }
119 self.decorator_replacements: Dict[str, cst.Decorator] = {
120 name: value
121 for name, value in template_replacements.items()
122 if isinstance(value, cst.Decorator)
123 }
125 # Figure out if there are any variables that we can't support
126 # inserting into templates.
127 supported_vars = {
128 *[name for name in self.simple_replacements],
129 *[name for name in self.annotation_replacements],
130 *[name for name in self.assignment_replacements],
131 *[name for name in self.param_replacements],
132 *[name for name in self.parameters_replacements],
133 *[name for name in self.arg_replacements],
134 *[name for name in self.small_statement_replacements],
135 *[name for name in self.statement_replacements],
136 *[name for name in self.suite_replacements],
137 *[name for name in self.subscript_element_replacements],
138 *[name for name in self.subscript_index_replacements],
139 *[name for name in self.decorator_replacements],
140 }
141 unsupported_vars = {
142 name for name in template_replacements if name not in supported_vars
143 }
144 if unsupported_vars:
145 raise Exception(
146 f'Template replacement for "{next(iter(unsupported_vars))}" is unsupported'
147 )
149 def leave_Name(
150 self, original_node: cst.Name, updated_node: cst.Name
151 ) -> cst.BaseExpression:
152 var_name = unmangled_name(updated_node.value)
153 if var_name is None or var_name not in self.simple_replacements:
154 # This is not a valid name, don't modify it
155 return updated_node
156 return self.simple_replacements[var_name].deep_clone()
158 def leave_Annotation(
159 self,
160 original_node: cst.Annotation,
161 updated_node: cst.Annotation,
162 ) -> cst.Annotation:
163 # We can't use matchers here due to circular imports
164 annotation = updated_node.annotation
165 if isinstance(annotation, cst.Name):
166 var_name = unmangled_name(annotation.value)
167 if var_name in self.annotation_replacements:
168 return self.annotation_replacements[var_name].deep_clone()
169 return updated_node
171 def leave_AssignTarget(
172 self,
173 original_node: cst.AssignTarget,
174 updated_node: cst.AssignTarget,
175 ) -> cst.AssignTarget:
176 # We can't use matchers here due to circular imports
177 target = updated_node.target
178 if isinstance(target, cst.Name):
179 var_name = unmangled_name(target.value)
180 if var_name in self.assignment_replacements:
181 return self.assignment_replacements[var_name].deep_clone()
182 return updated_node
184 def leave_Param(
185 self,
186 original_node: cst.Param,
187 updated_node: cst.Param,
188 ) -> cst.Param:
189 var_name = unmangled_name(updated_node.name.value)
190 if var_name in self.param_replacements:
191 return self.param_replacements[var_name].deep_clone()
192 return updated_node
194 def leave_Parameters(
195 self,
196 original_node: cst.Parameters,
197 updated_node: cst.Parameters,
198 ) -> cst.Parameters:
199 # A very special case for when we use a template variable for all
200 # function parameters.
201 if (
202 len(updated_node.params) == 1
203 and updated_node.star_arg == cst.MaybeSentinel.DEFAULT
204 and len(updated_node.kwonly_params) == 0
205 and updated_node.star_kwarg is None
206 and len(updated_node.posonly_params) == 0
207 and updated_node.posonly_ind == cst.MaybeSentinel.DEFAULT
208 ):
209 # This parameters node has only one argument, which is possibly
210 # a replacement.
211 var_name = unmangled_name(updated_node.params[0].name.value)
212 if var_name in self.parameters_replacements:
213 return self.parameters_replacements[var_name].deep_clone()
214 return updated_node
216 def leave_Arg(self, original_node: cst.Arg, updated_node: cst.Arg) -> cst.Arg:
217 # We can't use matchers here due to circular imports
218 arg = updated_node.value
219 if isinstance(arg, cst.Name):
220 var_name = unmangled_name(arg.value)
221 if var_name in self.arg_replacements:
222 return self.arg_replacements[var_name].deep_clone()
223 return updated_node
225 def leave_SimpleStatementLine(
226 self,
227 original_node: cst.SimpleStatementLine,
228 updated_node: cst.SimpleStatementLine,
229 ) -> cst.BaseStatement:
230 # We can't use matchers here due to circular imports. We take advantage of
231 # the fact that a name on a single line will be parsed as an Expr node
232 # contained in a SimpleStatementLine, so we check for these and see if they
233 # should be expanded template-wise to a statement of some type.
234 if len(updated_node.body) == 1:
235 body_node = updated_node.body[0]
236 if isinstance(body_node, cst.Expr):
237 name_node = body_node.value
238 if isinstance(name_node, cst.Name):
239 var_name = unmangled_name(name_node.value)
240 if var_name in self.statement_replacements:
241 return self.statement_replacements[var_name].deep_clone()
242 return updated_node
244 def leave_Expr(
245 self,
246 original_node: cst.Expr,
247 updated_node: cst.Expr,
248 ) -> cst.BaseSmallStatement:
249 # We can't use matchers here due to circular imports. We do a similar trick
250 # to the above stanza handling SimpleStatementLine to support templates
251 # which are trying to substitute a BaseSmallStatement.
252 name_node = updated_node.value
253 if isinstance(name_node, cst.Name):
254 var_name = unmangled_name(name_node.value)
255 if var_name in self.small_statement_replacements:
256 return self.small_statement_replacements[var_name].deep_clone()
257 return updated_node
259 def leave_SimpleStatementSuite(
260 self,
261 original_node: cst.SimpleStatementSuite,
262 updated_node: cst.SimpleStatementSuite,
263 ) -> cst.BaseSuite:
264 # We can't use matchers here due to circular imports. We take advantage of
265 # the fact that a name in a simple suite will be parsed as an Expr node
266 # contained in a SimpleStatementSuite, so we check for these and see if they
267 # should be expanded template-wise to a base suite of some type.
268 if len(updated_node.body) == 1:
269 body_node = updated_node.body[0]
270 if isinstance(body_node, cst.Expr):
271 name_node = body_node.value
272 if isinstance(name_node, cst.Name):
273 var_name = unmangled_name(name_node.value)
274 if var_name in self.suite_replacements:
275 return self.suite_replacements[var_name].deep_clone()
276 return updated_node
278 def leave_IndentedBlock(
279 self,
280 original_node: cst.IndentedBlock,
281 updated_node: cst.IndentedBlock,
282 ) -> cst.BaseSuite:
283 # We can't use matchers here due to circular imports. We take advantage of
284 # the fact that a name in an indented block will be parsed as an Expr node
285 # contained in a SimpleStatementLine, so we check for these and see if they
286 # should be expanded template-wise to a base suite of some type.
287 if len(updated_node.body) == 1:
288 statement_node = updated_node.body[0]
289 if (
290 isinstance(statement_node, cst.SimpleStatementLine)
291 and len(statement_node.body) == 1
292 ):
293 body_node = statement_node.body[0]
294 if isinstance(body_node, cst.Expr):
295 name_node = body_node.value
296 if isinstance(name_node, cst.Name):
297 var_name = unmangled_name(name_node.value)
298 if var_name in self.suite_replacements:
299 return self.suite_replacements[var_name].deep_clone()
300 return updated_node
302 def leave_Index(
303 self,
304 original_node: cst.Index,
305 updated_node: cst.Index,
306 ) -> cst.BaseSlice:
307 # We can't use matchers here due to circular imports
308 expr = updated_node.value
309 if isinstance(expr, cst.Name):
310 var_name = unmangled_name(expr.value)
311 if var_name in self.subscript_index_replacements:
312 return self.subscript_index_replacements[var_name].deep_clone()
313 return updated_node
315 def leave_SubscriptElement(
316 self,
317 original_node: cst.SubscriptElement,
318 updated_node: cst.SubscriptElement,
319 ) -> cst.SubscriptElement:
320 # We can't use matchers here due to circular imports. We use the trick
321 # similar to above stanzas where a template replacement variable will
322 # always show up as a certain type (in this case an Index inside of a
323 # SubscriptElement) in order to successfully replace subscript elements
324 # in templates.
325 index = updated_node.slice
326 if isinstance(index, cst.Index):
327 expr = index.value
328 if isinstance(expr, cst.Name):
329 var_name = unmangled_name(expr.value)
330 if var_name in self.subscript_element_replacements:
331 return self.subscript_element_replacements[var_name].deep_clone()
332 return updated_node
334 def leave_Decorator(
335 self, original_node: cst.Decorator, updated_node: cst.Decorator
336 ) -> cst.Decorator:
337 # We can't use matchers here due to circular imports
338 decorator = updated_node.decorator
339 if isinstance(decorator, cst.Name):
340 var_name = unmangled_name(decorator.value)
341 if var_name in self.decorator_replacements:
342 return self.decorator_replacements[var_name].deep_clone()
343 return updated_node
346class TemplateChecker(cst.CSTVisitor):
347 def __init__(self, template_vars: Set[str]) -> None:
348 self.template_vars = template_vars
350 def visit_Name(self, node: cst.Name) -> None:
351 for var in self.template_vars:
352 if node.value == mangled_name(var):
353 raise Exception(f'Template variable "{var}" was not replaced properly')
356def unmangle_nodes(
357 tree: cst.CSTNode,
358 template_replacements: Mapping[str, ValidReplacementType],
359) -> cst.CSTNode:
360 unmangler = TemplateTransformer(template_replacements)
361 return ensure_type(tree.visit(unmangler), cst.CSTNode)
364_DEFAULT_PARTIAL_PARSER_CONFIG: cst.PartialParserConfig = cst.PartialParserConfig()
367def parse_template_module(
368 template: str,
369 config: cst.PartialParserConfig = _DEFAULT_PARTIAL_PARSER_CONFIG,
370 **template_replacements: ValidReplacementType,
371) -> cst.Module:
372 """
373 Accepts an entire python module template, including all leading and trailing
374 whitespace. Any :class:`~libcst.CSTNode` provided as a keyword argument to
375 this function will be inserted into the template at the appropriate location
376 similar to an f-string expansion. For example::
378 module = parse_template_module("from {mod} import Foo\\n", mod=Name("bar"))
380 The above code will parse to a module containing a single
381 :class:`~libcst.FromImport` statement, referencing module ``bar`` and importing
382 object ``Foo`` from it. Remember that if you are parsing a template as part
383 of a substitution inside a transform, its considered
384 :ref:`best practice <libcst-config_best_practice>` to pass in a ``config``
385 from the current module under transformation.
387 Note that unlike :func:`~libcst.parse_module`, this function does not support
388 bytes as an input. This is due to the fact that it is processed as a template
389 before parsing as a module.
390 """
392 source = mangle_template(template, {name for name in template_replacements})
393 module = cst.parse_module(source, config)
394 new_module = ensure_type(unmangle_nodes(module, template_replacements), cst.Module)
395 new_module.visit(TemplateChecker({name for name in template_replacements}))
396 return new_module
399def parse_template_statement(
400 template: str,
401 config: cst.PartialParserConfig = _DEFAULT_PARTIAL_PARSER_CONFIG,
402 **template_replacements: ValidReplacementType,
403) -> Union[cst.SimpleStatementLine, cst.BaseCompoundStatement]:
404 """
405 Accepts a statement template followed by a trailing newline. If a trailing
406 newline is not provided, one will be added. Any :class:`~libcst.CSTNode`
407 provided as a keyword argument to this function will be inserted into the
408 template at the appropriate location similar to an f-string expansion. For
409 example::
411 statement = parse_template_statement("assert x > 0, {msg}", msg=SimpleString('"Uh oh!"'))
413 The above code will parse to an assert statement checking that some variable
414 ``x`` is greater than zero, or providing the assert message ``"Uh oh!"``.
416 Remember that if you are parsing a template as part of a substitution inside
417 a transform, its considered :ref:`best practice <libcst-config_best_practice>`
418 to pass in a ``config`` from the current module under transformation.
419 """
421 source = mangle_template(template, {name for name in template_replacements})
422 statement = cst.parse_statement(source, config)
423 new_statement = unmangle_nodes(statement, template_replacements)
424 if not isinstance(
425 new_statement, (cst.SimpleStatementLine, cst.BaseCompoundStatement)
426 ):
427 raise Exception(
428 f"Expected a statement but got a {new_statement.__class__.__name__}!"
429 )
430 new_statement.visit(TemplateChecker({name for name in template_replacements}))
431 return new_statement
434def parse_template_expression(
435 template: str,
436 config: cst.PartialParserConfig = _DEFAULT_PARTIAL_PARSER_CONFIG,
437 **template_replacements: ValidReplacementType,
438) -> cst.BaseExpression:
439 """
440 Accepts an expression template on a single line. Leading and trailing whitespace
441 is not valid (there’s nowhere to store it on the expression node). Any
442 :class:`~libcst.CSTNode` provided as a keyword argument to this function will
443 be inserted into the template at the appropriate location similar to an
444 f-string expansion. For example::
446 expression = parse_template_expression("x + {foo}", foo=Name("y")))
448 The above code will parse to a :class:`~libcst.BinaryOperation` expression
449 adding two names (``x`` and ``y``) together.
451 Remember that if you are parsing a template as part of a substitution inside
452 a transform, its considered :ref:`best practice <libcst-config_best_practice>`
453 to pass in a ``config`` from the current module under transformation.
454 """
456 source = mangle_template(template, {name for name in template_replacements})
457 expression = cst.parse_expression(source, config)
458 new_expression = ensure_type(
459 unmangle_nodes(expression, template_replacements), cst.BaseExpression
460 )
461 new_expression.visit(TemplateChecker({name for name in template_replacements}))
462 return new_expression