Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/pyct/transformer.py: 31%
160 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A node transformer that includes utilities for SCT."""
17import collections
18import enum
20import gast
22from tensorflow.python.autograph.pyct import anno
23from tensorflow.python.autograph.pyct import parser
24from tensorflow.python.autograph.pyct import pretty_printer
25from tensorflow.python.autograph.pyct import templates
28class AnalysisLevel(enum.IntEnum):
30 NONE = 0
31 ACTIVITY = 1
32 DEFINEDNESS = 2
33 LIVENESS = 3
36# TODO(znado): Use namedtuple.
37class Context(object):
38 """Contains information about a source code transformation.
40 This object is mutable, and is updated during conversion. Not thread safe.
42 Attributes:
43 info: EntityInfo, immutable.
44 namer: naming.Namer.
45 current_origin: origin_info.OriginInfo, holds the OriginInfo of the last
46 AST node to be processed successfully. Useful for error handling.
47 user: An user-supplied context object. The object is opaque to the
48 infrastructure, but will pe passed through to all custom transformations.
49 """
51 def __init__(self, info, namer, user_context):
52 self.info = info
53 self.namer = namer
54 self.current_origin = None
55 self.user = user_context
58# TODO(mdan): Move to a standalone file.
59class EntityInfo(
60 collections.namedtuple(
61 'EntityInfo',
62 ('name', 'source_code', 'source_file', 'future_features', 'namespace'))
63):
64 """Contains information about a Python entity.
66 Immutable.
68 Examples of entities include functions and classes.
70 Attributes:
71 name: The name that identifies this entity.
72 source_code: The entity's source code.
73 source_file: The entity's source file.
74 future_features: Tuple[Text], the future features that this entity was
75 compiled with. See
76 https://docs.python.org/2/reference/simple_stmts.html#future.
77 namespace: Dict[str, ], containing symbols visible to the entity (excluding
78 parameters).
79 """
80 pass
83class _StateStack(object):
84 """Templated context manager.
86 This class provides syntactic sugar for a stack of objects of known
87 type. It allows accessing attributes of the object at the top of the stack
88 directly against this object, which allows for very terse syntax.
90 For example, this code:
92 stack = _StateStack(Foo)
93 stack.enter()
94 stack.bar
96 Is equivalent to:
98 stack = []
99 stack.append(Foo())
100 foo = stack[-1]
101 foo.bar
103 See _State for more on how this is used.
105 Attributes:
106 type: Any, the type of objects that this stack holds
107 level: int, the current stack depth
108 stack: List[Any], the actual stack
109 value: Any, the instance of the object at the top of the stack
110 """
112 def __init__(self, type_):
113 # Because we override __setattr__, we need to attach these attributes using
114 # the superclass' setattr.
115 object.__setattr__(self, 'type', type_)
116 object.__setattr__(self, '_stack', [])
117 if not hasattr(type_, 'no_root'):
118 self.enter()
120 def __enter__(self):
121 self.enter()
122 return self
124 def __exit__(self, exc_type, exc_value, traceback):
125 self.exit()
127 def enter(self):
128 self._stack.append(self.type())
130 def exit(self):
131 self._stack.pop()
133 @property
134 def stack(self):
135 return self._stack
137 @property
138 def level(self):
139 return len(self._stack)
141 @property
142 def value(self):
143 return self._stack[-1]
145 def __iter__(self):
146 return iter(self._stack)
148 def __getattr__(self, key):
149 return getattr(self._stack[-1], key)
151 def __setattr__(self, key, value):
152 setattr(self._stack[-1], key, value)
155class _State(object):
156 """Syntactic sugar for accessing an instance of a StateStack context manager.
158 This structure offers syntactic sugar over a dict of stacks of objects
159 of known type. These structures are useful to keep state during AST walks.
160 Multiple different scopes can be tracked in parallel. For example:
162 s = _State()
164 s[foo].enter()
165 s[bar].enter() # this will not affect s[foo]
167 Element access has special semantics:
168 * keys are a data type
169 * element values are _StateStack(type=key) objects
170 * missing elements are automatically added, similarly to defaultdict
172 For example, the following block :
174 _State s
175 s[Foo]
177 Is equivalent to:
179 s = {}
180 if Foo not in s:
181 s[Foo] = Foo()
182 s[Foo]
184 See Base for how it's used.
185 """
187 def __init__(self):
188 self._value = {}
190 def __getitem__(self, key):
191 if key not in self._value:
192 self._value[key] = _StateStack(key)
193 return self._value[key]
196class NodeStateTracker(object):
197 """Base class for general-purpose Python code transformation.
199 This abstract class provides helpful functions, like state tracking within
200 the scope of arbitrary node, helpers for processing code blocks, debugging,
201 mapping of transformed code to original code, and others.
203 Scope-local state tracking: to keep state across nodes, at the level of
204 (possibly nested) scopes, use enter/exit_local_scope and set/get_local.
205 You must call enter/exit_local_scope manually, but the transformer detects
206 when they are not properly paired.
208 The transformer allows keeping state across calls that is local
209 to arbitrary nodes and their descendants, using the self.state attribute.
210 Multiple independent scopes are allowed and automatically constructed.
212 For example, to keep track of the `If` node that encloses any `Name` node,
213 one can write:
215 ```
216 class FooType(object):
218 def __init__(self):
219 self.foo_property = None
221 class DummyTransformer(NodeStateTracker, ast.NodeTransformer):
223 def visit_If(self, node):
224 self.state[FooType].enter()
225 self.state[FooType].foo_property = node
226 node = self.veneric_visit(node)
227 self.state[FooType].exit()
228 return node
230 def visit_Name(self, node):
231 self.state[FooType].foo_property # will hold the innermost enclosing if
232 ```
234 Alternatively, the `enter()`/`exit()` calls can be managed by a `with`
235 statement:
237 ```
238 def visit_If(self, node):
239 with self.state[FooType] as foo:
240 foo.foo_property = node
241 return self.generic_visit(node)
242 ```
243 """
245 # TODO(mdan): Document all extra features.
247 def __init__(self, ctx):
248 """Initialize the transformer.
250 Subclasses should call this.
252 Args:
253 ctx: A Context object.
254 """
255 self._lineno = 0
256 self._col_offset = 0
257 self.ctx = ctx
259 # Allows scoping of local variables to keep state across calls to visit_*
260 # methods. Multiple scope hierarchies may exist and are keyed by tag. A
261 # scope is valid at one or more nodes and all its children. Scopes created
262 # in child nodes supersede their parent. Scopes are isolated from one
263 # another.
264 self.state = _State()
266 def debug_print(self, node):
267 """Helper method useful for debugging. Prints the AST."""
268 if __debug__:
269 print(pretty_printer.fmt(node))
270 return node
272 def debug_print_src(self, node):
273 """Helper method useful for debugging. Prints the AST as code."""
274 if __debug__:
275 print(parser.unparse(node))
276 return node
278 def visit_block(self, nodes, before_visit=None, after_visit=None):
279 """A more powerful version of generic_visit for statement blocks.
281 An example of a block is the body of an if statement.
283 This function allows specifying a postprocessing callback (the
284 after_visit argument) argument which can be used to move nodes to a new
285 destination. This is done by after_visit by returning a non-null
286 second return value, e.g. return new_node, new_destination.
288 For example, a transformer could perform the following move:
290 foo()
291 bar()
292 baz()
294 foo()
295 if cond:
296 bar()
297 baz()
299 The above could be done with a postprocessor of this kind:
301 def after_visit(node):
302 if node_is_function_call(bar):
303 new_container_node = build_cond()
304 new_container_node.body.append(node)
305 return new_container_node, new_container_node.body
306 else:
307 # Once we set a new destination, all subsequent items will be
308 # moved to it, so we don't need to explicitly handle baz.
309 return node, None
311 Args:
312 nodes: enumerable of AST node objects. If None, the function returns None.
313 before_visit: optional callable that is called before visiting each item
314 in nodes
315 after_visit: optional callable that takes in an AST node and returns a
316 tuple (new_node, new_destination). It is called after visiting each item
317 in nodes. Is used in the same was as the
318 visit_* methods: new_node will replace the node; if not None,
319 new_destination must be a list, and subsequent nodes will be placed
320 in this list instead of the list returned by visit_block.
322 Returns:
323 A list of AST node objects containing the transformed items fron nodes,
324 except those nodes that have been relocated using after_visit.
325 """
326 if nodes is None:
327 return None
329 results = []
330 node_destination = results
331 for node in nodes:
332 if before_visit:
333 # TODO(mdan): We can modify node here too, if ever needed.
334 before_visit()
336 replacement = self.visit(node)
338 if after_visit and replacement:
339 replacement, new_destination = after_visit(replacement)
340 else:
341 new_destination = None
343 if replacement:
344 if isinstance(replacement, (list, tuple)):
345 node_destination.extend(replacement)
346 else:
347 node_destination.append(replacement)
349 # Allow the postprocessor to reroute the remaining nodes to a new list.
350 if new_destination is not None:
351 node_destination = new_destination
352 return results
355# TODO(mdan): Rename to PythonCodeTransformer.
356class Base(NodeStateTracker, gast.NodeTransformer):
357 """Base class for general-purpose Python-to-Python code transformation.
359 This is an extension of ast.NodeTransformer that provides the additional
360 functions offered by NodeStateTracker.
361 """
363 def create_assignment(self, target, expression):
364 template = """
365 target = expression
366 """
367 return templates.replace(template, target=target, expression=expression)
369 # TODO(mdan): Remove.
370 def apply_to_single_assignments(self, targets, values, apply_fn):
371 """Applies a function to each individual assignment.
373 This function can process a possibly-unpacked (e.g. a, b = c, d) assignment.
374 It tries to break down the unpacking if possible. In effect, it has the same
375 effect as passing the assigned values in SSA form to apply_fn.
377 Examples:
379 The following will result in apply_fn(a, c), apply_fn(b, d):
381 a, b = c, d
383 The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]):
385 a, b = c
387 The following will result in apply_fn(a, (b, c)):
389 a = b, c
391 It uses the visitor pattern to allow subclasses to process single
392 assignments individually.
394 Args:
395 targets: list, tuple of or individual AST node. Should be used with the
396 targets field of an ast.Assign node.
397 values: an AST node.
398 apply_fn: a function of a single argument, which will be called with the
399 respective nodes of each single assignment. The signature is
400 apply_fn(target, value), no return value.
401 """
402 if not isinstance(targets, (list, tuple)):
403 targets = (targets,)
404 for target in targets:
405 if isinstance(target, (gast.Tuple, gast.List)):
406 for i in range(len(target.elts)):
407 target_el = target.elts[i]
408 if isinstance(values, (gast.Tuple, gast.List)):
409 value_el = values.elts[i]
410 else:
411 value_el = gast.Subscript(values, i, ctx=gast.Store())
412 self.apply_to_single_assignments(target_el, value_el, apply_fn)
413 else:
414 # TODO(mdan): Look into allowing to rewrite the AST here.
415 apply_fn(target, values)
417 def visit(self, node):
418 if not isinstance(node, gast.AST):
419 # This is not that uncommon a mistake: various node bodies are lists, for
420 # example, posing a land mine for transformers that need to recursively
421 # call `visit`. The error needs to be raised before the exception handler
422 # below is installed, because said handler will mess up if `node` is not,
423 # in fact, a node.
424 msg = ('invalid value for "node": expected "ast.AST", got "{}"; to'
425 ' visit lists of nodes, use "visit_block" instead').format(
426 type(node))
427 raise ValueError(msg)
429 if anno.hasanno(node, anno.Basic.SKIP_PROCESSING):
430 return node
432 parent_origin = self.ctx.current_origin
433 if anno.hasanno(node, anno.Basic.ORIGIN):
434 self.ctx.current_origin = anno.getanno(node, anno.Basic.ORIGIN)
436 try:
437 processing_expr_node = isinstance(node, gast.Expr)
438 if processing_expr_node:
439 entry_expr_value = node.value
441 result = super(Base, self).visit(node)
443 # Adjust for consistency: replacing the value of an Expr with
444 # an Assign node removes the need for the Expr node.
445 if (processing_expr_node and isinstance(result, gast.Expr) and
446 (result.value is not entry_expr_value)):
447 # When the replacement is a list, it is assumed that the list came
448 # from a template that contained a number of statements, which
449 # themselves are standalone and don't require an enclosing Expr.
450 if isinstance(result.value,
451 (list, tuple, gast.Assign, gast.AugAssign)):
452 result = result.value
454 # By default, all replacements receive the origin info of the replaced
455 # node.
456 if result is not node and result is not None:
457 inherited_origin = anno.getanno(
458 node, anno.Basic.ORIGIN, default=parent_origin)
459 if inherited_origin is not None:
460 nodes_to_adjust = result
461 if isinstance(result, (list, tuple)):
462 nodes_to_adjust = result
463 else:
464 nodes_to_adjust = (result,)
465 for n in nodes_to_adjust:
466 if not anno.hasanno(n, anno.Basic.ORIGIN):
467 anno.setanno(n, anno.Basic.ORIGIN, inherited_origin)
468 finally:
469 self.ctx.current_origin = parent_origin
471 return result
474class CodeGenerator(NodeStateTracker, gast.NodeVisitor):
475 """Base class for general-purpose Python-to-string code transformation.
477 Similar to Base, but outputs arbitrary strings instead of a Python AST.
479 This uses the same visitor mechanism that the standard NodeVisitor uses,
480 meaning that subclasses write handlers for the different kinds of nodes.
481 New code is generated using the emit method, which appends to a code buffer
482 that can be afterwards obtained from code_buffer.
484 Example:
486 class SimpleCodeGen(CodeGenerator):
488 def visitIf(self, node):
489 self.emit('if ')
490 self.visit(node.test)
491 self.emit(' { ')
492 self.visit(node.body)
493 self.emit(' } else { ')
494 self.visit(node.orelse)
495 self.emit(' } ')
497 node = ast.parse(...)
498 gen = SimpleCodeGen()
499 gen.visit(node)
500 # gen.code_buffer contains the resulting code
501 """
503 def __init__(self, ctx):
504 super(CodeGenerator, self).__init__(ctx)
506 self._output_code = ''
507 self.source_map = {}
509 def emit(self, code):
510 self._output_code += code
512 @property
513 def code_buffer(self):
514 return self._output_code
516 def visit(self, node):
517 if anno.hasanno(node, anno.Basic.SKIP_PROCESSING):
518 return
520 parent_origin = self.ctx.current_origin
521 eof_before = len(self._output_code)
522 if anno.hasanno(node, anno.Basic.ORIGIN):
523 self.ctx.current_origin = anno.getanno(node, anno.Basic.ORIGIN)
525 try:
526 ret = super(CodeGenerator, self).visit(node)
528 # By default, all replacements receive the origin info of the replaced
529 # node.
530 eof_after = len(self._output_code)
531 if eof_before - eof_after:
532 inherited_origin = anno.getanno(
533 node, anno.Basic.ORIGIN, default=parent_origin)
534 if inherited_origin is not None:
535 self.source_map[(eof_before, eof_after)] = inherited_origin
536 return ret
537 finally:
538 self.ctx.current_origin = parent_origin