Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/converters/directives.py: 25%
91 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Handles directives.
17This converter removes the directive functions from the code and moves the
18information they specify into AST annotations. It is a specialized form of
19static analysis, one that is specific to AutoGraph.
21Note that this requires that the actual directive functions are static - that
22is, they do not change at runtime. So if you do something like this:
24 tf.autograph.set_loop_options = <new function>
26Then the directive will may no longer be recognized. Furthermore, if the
27converted function is cached, such an action may be irreversible.
28"""
30import inspect
32import gast
34from tensorflow.python.autograph.core import converter
35from tensorflow.python.autograph.lang import directives
36from tensorflow.python.autograph.pyct import anno
37from tensorflow.python.util import tf_inspect
40STATIC_VALUE = 'static_value'
41"""Used for AST annotations, see visit_Name."""
44class _LoopScope(object):
46 def __init__(self):
47 self.ast_node = None
48 self.statements_visited = 0
51def _map_args(call_node, function):
52 """Maps AST call nodes to the actual function's arguments.
54 Args:
55 call_node: ast.Call
56 function: Callable[..., Any], the actual function matching call_node
57 Returns:
58 Dict[Text, ast.AST], mapping each of the function's argument names to
59 the respective AST node.
60 Raises:
61 ValueError: if the default arguments are not correctly set
62 """
63 args = call_node.args
64 kwds = {kwd.arg: kwd.value for kwd in call_node.keywords}
65 call_args = tf_inspect.getcallargs(function, *args, **kwds)
67 # Keyword arguments not specified in kwds will be mapped to their defaults,
68 # which are Python values. Since we don't currently have a way to transform
69 # those into AST references, we simply remove them. By convention, directives
70 # use UNSPECIFIED as default value for optional arguments. No other
71 # defaults should be present.
72 unexpected_defaults = []
73 for k in call_args:
74 if (k not in kwds
75 and call_args[k] not in args
76 and call_args[k] is not directives.UNSPECIFIED):
77 unexpected_defaults.append(k)
78 if unexpected_defaults:
79 raise ValueError('Unexpected keyword argument values, %s, for function %s'
80 % (zip(unexpected_defaults,
81 [call_args[k] for k in unexpected_defaults]),
82 function))
83 return {k: v for k, v in call_args.items() if v is not directives.UNSPECIFIED}
86class DirectivesTransformer(converter.Base):
87 """Parses compiler directives and converts them into AST annotations."""
89 def _process_symbol_directive(self, call_node, directive):
90 if len(call_node.args) < 1:
91 raise ValueError('"%s" requires a positional first argument'
92 ' as the target' % directive.__name__)
93 target = call_node.args[0]
94 defs = anno.getanno(target, anno.Static.ORIG_DEFINITIONS)
95 for def_ in defs:
96 def_.directives[directive] = _map_args(call_node, directive)
97 return call_node
99 def _process_statement_directive(self, call_node, directive):
100 if self.state[_LoopScope].statements_visited > 1:
101 raise ValueError(
102 '"%s" must be the first statement in the loop block' % (
103 directive.__name__))
104 if self.state[_LoopScope].level < 2:
105 raise ValueError(
106 '"%s" must be used inside a statement' % directive.__name__)
107 target = self.state[_LoopScope].ast_node
108 node_anno = anno.getanno(target, anno.Basic.DIRECTIVES, {})
109 node_anno[directive] = _map_args(call_node, directive)
110 anno.setanno(target, anno.Basic.DIRECTIVES, node_anno)
111 return call_node
113 def visit_Name(self, node):
114 node = self.generic_visit(node)
115 if isinstance(node.ctx, gast.Load):
116 defs = anno.getanno(node, anno.Static.DEFINITIONS, ())
117 is_defined = bool(defs)
118 if not is_defined and node.id in self.ctx.info.namespace:
119 anno.setanno(node, STATIC_VALUE, self.ctx.info.namespace[node.id])
120 return node
122 def visit_Attribute(self, node):
123 node = self.generic_visit(node)
124 parent_val = anno.getanno(node.value, STATIC_VALUE, default=None)
125 if parent_val is not None and inspect.ismodule(parent_val):
126 if hasattr(parent_val, node.attr):
127 anno.setanno(node, STATIC_VALUE, getattr(parent_val, node.attr))
128 return node
130 def visit_Assign(self, node):
131 self.state[_LoopScope].statements_visited += 1
132 return self.generic_visit(node)
134 def visit_AugAssign(self, node):
135 self.state[_LoopScope].statements_visited += 1
136 return self.generic_visit(node)
138 def visit_Expr(self, node):
139 self.state[_LoopScope].statements_visited += 1
140 node = self.generic_visit(node)
141 if isinstance(node.value, gast.Call):
142 call_node = node.value
143 static_val = anno.getanno(call_node.func, STATIC_VALUE, default=None)
144 if static_val is not None:
145 # Note: directive calls are not output in the generated code, hence
146 # the removal from the code by returning None.
148 if static_val is directives.set_element_type:
149 self._process_symbol_directive(call_node, static_val)
150 return None
151 elif static_val is directives.set_loop_options:
152 self._process_statement_directive(call_node, static_val)
153 return None
154 return node
156 # TODO(mdan): This will be insufficient for other control flow.
157 # That means that if we ever have a directive that affects things other than
158 # loops, we'll need support for parallel scopes, or have multiple converters.
159 def _track_and_visit_loop(self, node):
160 self.state[_LoopScope].enter()
161 self.state[_LoopScope].ast_node = node
162 node = self.generic_visit(node)
163 # Edge case: a loop with just one directive statement would become empty.
164 if not node.body:
165 node.body = [gast.Pass()]
166 self.state[_LoopScope].exit()
167 return node
169 def visit_While(self, node):
170 return self._track_and_visit_loop(node)
172 def visit_For(self, node):
173 return self._track_and_visit_loop(node)
176def transform(node, ctx):
177 return DirectivesTransformer(ctx).visit(node)