Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/pyct/templates.py: 22%
144 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""AST conversion templates.
17Adapted from Tangent.
18"""
20import ast
21import textwrap
23import gast
25from tensorflow.python.autograph.pyct import anno
26from tensorflow.python.autograph.pyct import ast_util
27from tensorflow.python.autograph.pyct import parser
28from tensorflow.python.autograph.pyct import qual_names
31class ContextAdjuster(gast.NodeTransformer):
32 """Adjusts the ctx field of nodes to ensure consistency.
34 This transformer can change the ctx fields of a variable, tuple and other
35 AST elements that allow one, based on whether the element is being read or
36 written.
37 """
39 def __init__(self, override_value):
40 self._ctx_override = override_value
42 def visit(self, node):
43 original_override = self._ctx_override
44 node = super(ContextAdjuster, self).visit(node)
45 if hasattr(node, 'ctx'):
46 assert node.ctx is not None, 'node {} has ctx unset'.format(node)
47 self._ctx_override = original_override
48 return node
50 def _apply_override(self, node):
51 if self._ctx_override is not None:
52 node.ctx = self._ctx_override()
54 def visit_Attribute(self, node):
55 self._apply_override(node)
56 self._ctx_override = gast.Load
57 node = self.generic_visit(node)
58 return node
60 def visit_Tuple(self, node):
61 self._apply_override(node)
62 return self.generic_visit(node)
64 def visit_List(self, node):
65 self._apply_override(node)
66 return self.generic_visit(node)
68 def visit_Name(self, node):
69 self._apply_override(node)
70 return self.generic_visit(node)
72 def visit_Call(self, node):
73 self._apply_override(node)
74 # We may be able to override these to Load(), but for now it's simpler
75 # to just assert that they're set.
76 self._ctx_override = None
77 return self.generic_visit(node)
79 def visit_Dict(self, node):
80 # We may be able to override these to Load(), but for now it's simpler
81 # to just assert that they're set.
82 self._ctx_override = None
83 return self.generic_visit(node)
85 def visit_Subscript(self, node):
86 self._apply_override(node)
87 self._ctx_override = gast.Load
88 node.value = self.visit(node.value)
89 return self.generic_visit(node)
91 def visit_comprehension(self, node):
92 # We may be able to override some of these, but for now it's simpler
93 # to just assert that they're set.
94 self._ctx_override = None
95 return self.generic_visit(node)
97 def visit_Lambda(self, node):
98 # We may be able to override some of these, but for now it's simpler
99 # to just assert that they're set.
100 self._ctx_override = None
101 return self.generic_visit(node)
104class ReplaceTransformer(gast.NodeTransformer):
105 """Replace AST nodes."""
107 def __init__(self, replacements):
108 """Create a new ReplaceTransformer.
110 Args:
111 replacements: A mapping from placeholder names to (lists of) AST nodes
112 that these placeholders will be replaced by.
113 """
114 self.replacements = replacements
115 self.in_replacements = False
116 self.preserved_annos = {
117 anno.Basic.DIRECTIVES,
118 anno.Basic.EXTRA_LOOP_TEST,
119 anno.Basic.ORIGIN,
120 anno.Basic.SKIP_PROCESSING,
121 anno.Static.ORIG_DEFINITIONS,
122 'function_context_name',
123 }
125 def _prepare_replacement(self, replaced, key):
126 """Prepares a replacement AST that's safe to swap in for a node.
128 Args:
129 replaced: ast.AST, the node being replaced
130 key: Hashable, the key of the replacement AST
131 Returns:
132 ast.AST, the replacement AST
133 """
134 repl = self.replacements[key]
136 new_nodes = ast_util.copy_clean(repl, preserve_annos=self.preserved_annos)
137 if isinstance(new_nodes, gast.AST):
138 new_nodes = [new_nodes]
140 return new_nodes
142 def visit_Expr(self, node):
143 # When replacing a placeholder with an entire statement, the replacement
144 # must stand on its own and not be wrapped in an Expr.
145 new_value = self.visit(node.value)
146 if new_value is node.value:
147 return node
148 return new_value
150 def visit_keyword(self, node):
151 if node.arg not in self.replacements:
152 return self.generic_visit(node)
154 repl = self._prepare_replacement(node, node.arg)
155 if isinstance(repl, gast.keyword):
156 return repl
157 elif (repl and isinstance(repl, (list, tuple)) and
158 all(isinstance(r, gast.keyword) for r in repl)):
159 return repl
160 # TODO(mdan): We may allow replacing with a string as well.
161 # For example, if one wanted to replace foo with bar in foo=baz, then
162 # we could allow changing just node arg, so that we end up with bar=baz.
163 raise ValueError(
164 'a keyword argument may only be replaced by another keyword or a '
165 'non-empty list of keywords. Found: {} for keyword {}'.format(
166 repl, node.arg))
168 def visit_FunctionDef(self, node):
169 node = self.generic_visit(node)
170 if node.name not in self.replacements:
171 return node
173 repl = self.replacements[node.name]
174 if not isinstance(repl, (gast.Name, ast.Name)):
175 raise ValueError(
176 'a function name can only be replaced by a Name node. Found: %s' %
177 repl)
178 node.name = repl.id
179 return node
181 def visit_Attribute(self, node):
182 node = self.generic_visit(node)
183 if node.attr not in self.replacements:
184 return node
186 repl = self.replacements[node.attr]
187 if not isinstance(repl, gast.Name):
188 raise ValueError(
189 'An attribute can only be replaced by a Name node. Found: %s' % repl)
190 node.attr = repl.id
191 return node
193 def visit_Name(self, node):
194 if node.id not in self.replacements:
195 return node
197 new_nodes = self._prepare_replacement(node, node.id)
199 if not new_nodes:
200 return new_nodes
202 # Preserve the target context.
203 adjuster = ContextAdjuster(type(node.ctx))
204 for n in new_nodes:
205 if hasattr(n, 'ctx'):
206 adjuster.visit(n)
208 if len(new_nodes) == 1:
209 new_nodes, = new_nodes
211 return new_nodes
214def _convert_to_ast(n):
215 """Converts from a known data type to AST."""
216 # Note: When generating AST nodes from strings/QNs in isolation, ctx is
217 # unknown. ctx must be filled in according to the template being used.
218 # See ReplaceTransformer.visit_Name.
219 if isinstance(n, str):
220 return gast.Name(id=n, ctx=None, annotation=None, type_comment=None)
221 if isinstance(n, qual_names.QN):
222 return n.ast()
223 if isinstance(n, list):
224 return [_convert_to_ast(e) for e in n]
225 if isinstance(n, tuple):
226 return tuple(_convert_to_ast(e) for e in n)
227 return n
230def replace(template, **replacements):
231 """Replaces placeholders in a Python template.
233 AST Name and Tuple nodes always receive the context that inferred from
234 the template. However, when replacing more complex nodes (that can potentially
235 contain Name children), then the caller is responsible for setting the
236 appropriate context.
238 Args:
239 template: A string representing Python code. Any symbol name can be used
240 that appears in the template code can be used as placeholder.
241 **replacements: A mapping from placeholder names to (lists of) AST nodes
242 that these placeholders will be replaced by. String values are also
243 supported as a shorthand for AST Name nodes with the respective ID.
245 Returns:
246 An AST node or list of AST nodes with the replacements made. If the
247 template was a function, a list will be returned. If the template was a
248 node, the same node will be returned. If the template was a string, an
249 AST node will be returned (a `Module` node in the case of a multi-line
250 string, an `Expr` node otherwise).
252 Raises:
253 ValueError: if the arguments are incorrect.
254 """
255 if not isinstance(template, str):
256 raise ValueError('Expected string template, got %s' % type(template))
257 for k in replacements:
258 replacements[k] = _convert_to_ast(replacements[k])
259 template_str = parser.STANDARD_PREAMBLE + textwrap.dedent(template)
260 nodes = parser.parse(
261 template_str,
262 preamble_len=parser.STANDARD_PREAMBLE_LEN,
263 single_node=False)
264 results = []
265 for node in nodes:
266 node = ReplaceTransformer(replacements).visit(node)
267 if isinstance(node, (list, tuple)):
268 results.extend(node)
269 else:
270 results.append(node)
271 results = [qual_names.resolve(r) for r in results]
272 return results
275def replace_as_expression(template, **replacements):
276 """Variant of replace that generates expressions, instead of code blocks."""
277 replacement = replace(template, **replacements)
278 if len(replacement) != 1:
279 raise ValueError(
280 'single expression expected; for more general templates use replace')
281 node, = replacement
283 if isinstance(node, gast.Expr):
284 return node.value
285 elif isinstance(node, gast.Name):
286 return node
288 raise ValueError(
289 'the template is expected to generate an expression or a name node;'
290 ' instead found %s' % node)