Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/converters/control_flow.py: 19%
173 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Handles control flow statements: while, for, if."""
17import gast
19from tensorflow.python.autograph.core import converter
20from tensorflow.python.autograph.lang import directives
21from tensorflow.python.autograph.pyct import anno
22from tensorflow.python.autograph.pyct import cfg
23from tensorflow.python.autograph.pyct import origin_info
24from tensorflow.python.autograph.pyct import parser
25from tensorflow.python.autograph.pyct import qual_names
26from tensorflow.python.autograph.pyct import templates
27from tensorflow.python.autograph.pyct.static_analysis import activity
28from tensorflow.python.autograph.pyct.static_analysis import annos
29from tensorflow.python.autograph.pyct.static_analysis import liveness
30from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
31from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs
34class _Function(object):
36 scope = None
39class ControlFlowTransformer(converter.Base):
40 """Transforms control flow structures like loops an conditionals."""
42 def visit_Lambda(self, node):
43 with self.state[_Function] as fn:
44 fn.scope = anno.getanno(node, anno.Static.SCOPE)
45 return self.generic_visit(node)
47 def visit_FunctionDef(self, node):
48 with self.state[_Function] as fn:
49 fn.scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
50 return self.generic_visit(node)
52 def _create_nonlocal_declarations(self, vars_):
53 vars_ = set(vars_)
54 results = []
55 global_vars = self.state[_Function].scope.globals & vars_
57 if global_vars:
58 results.append(gast.Global([str(v) for v in global_vars]))
60 nonlocal_vars = [
61 v for v in vars_ if not v.is_composite() and v not in global_vars]
62 if nonlocal_vars:
63 results.append(gast.Nonlocal([str(v) for v in nonlocal_vars]))
65 return results
67 def _create_state_functions(
68 self, block_vars, nonlocal_declarations, getter_name, setter_name):
69 if not block_vars:
70 template = """
71 def getter_name():
72 return ()
73 def setter_name(block_vars):
74 pass
75 """
76 return templates.replace(
77 template, getter_name=getter_name, setter_name=setter_name)
79 guarded_block_vars = []
80 for v in block_vars:
81 if v.is_simple():
82 guarded_block_vars.append(v)
83 else:
84 guarded_block_vars.append(
85 templates.replace_as_expression(
86 'ag__.ldu(lambda: var_, name)',
87 var_=v,
88 name=gast.Constant(str(v), kind=None)))
90 template = """
91 def getter_name():
92 return guarded_state_vars,
93 def setter_name(vars_):
94 nonlocal_declarations
95 state_vars, = vars_
96 """
97 return templates.replace(
98 template,
99 nonlocal_declarations=nonlocal_declarations,
100 getter_name=getter_name,
101 guarded_state_vars=guarded_block_vars,
102 setter_name=setter_name,
103 state_vars=tuple(block_vars))
105 def _create_loop_options(self, node):
106 if not anno.hasanno(node, anno.Basic.DIRECTIVES):
107 return gast.Dict([], [])
109 loop_directives = anno.getanno(node, anno.Basic.DIRECTIVES)
110 if directives.set_loop_options not in loop_directives:
111 return gast.Dict([], [])
113 opts_dict = loop_directives[directives.set_loop_options]
114 str_keys, values = zip(*opts_dict.items())
115 keys = [gast.Constant(s, kind=None) for s in str_keys]
116 values = list(values) # ast and gast don't play well with tuples.
117 return gast.Dict(keys, values)
119 def _create_undefined_assigns(self, undefined_symbols):
120 assignments = []
121 for s in undefined_symbols:
122 template = '''
123 var = ag__.Undefined(symbol_name)
124 '''
125 assignments += templates.replace(
126 template,
127 var=s,
128 symbol_name=gast.Constant(s.ssf(), kind=None))
129 return assignments
131 def _get_block_basic_vars(self, modified, live_in, live_out):
132 nonlocals = self.state[_Function].scope.nonlocals
133 basic_scope_vars = []
134 for s in modified:
135 if s.is_composite():
136 # TODO(mdan): Raise an error when this happens for a TF scope.
137 continue
138 # Variables not live into or out of the scope are considered local to the
139 # scope.
140 if s in live_in or s in live_out or s in nonlocals:
141 basic_scope_vars.append(s)
142 continue
143 return frozenset(basic_scope_vars)
145 def _get_block_composite_vars(self, modified, live_in):
146 # The scope variables corresponding to composite symbols (e.g. `self.x`).
147 composite_scope_vars = []
148 for s in modified:
149 if not s.is_composite():
150 continue
151 # Mutations made to objects created inside the scope will appear as writes
152 # to composite symbols. Because these mutations appear as modifications
153 # made to composite symbols, we check whether the composite's parent is
154 # actually live into the scope.
155 # Example:
156 # while cond:
157 # x = Foo()
158 # x.foo = 2 * x.foo # x.foo is live into the scope, but x is not.
159 #
160 # Note that some parents might not be symbols - for example, in x['foo'],
161 # 'foo' is a parent, but it's a literal, not a symbol. We don't check the
162 # liveness of literals.
163 support_set_symbols = tuple(
164 sss for sss in s.support_set if sss.is_symbol())
165 if not all(sss in live_in for sss in support_set_symbols):
166 continue
167 composite_scope_vars.append(s)
168 return frozenset(composite_scope_vars)
170 def _get_block_vars(self, node, modified):
171 """Determines the variables affected inside a control flow statement."""
172 defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
173 live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN)
174 live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
175 fn_scope = self.state[_Function].scope
177 basic_scope_vars = self._get_block_basic_vars(
178 modified,
179 live_in,
180 live_out)
181 composite_scope_vars = self._get_block_composite_vars(modified, live_in)
182 scope_vars = tuple(basic_scope_vars | composite_scope_vars)
184 # Variables that are modified inside the scope, but not defined
185 # before entering it. Only simple variables must be defined. The
186 # composite ones will be implicitly checked at runtime.
187 possibly_undefined = (
188 modified - defined_in - fn_scope.globals - fn_scope.nonlocals)
189 undefined = tuple(v for v in possibly_undefined if not v.is_composite())
191 # Variables that are modified inside the scope, and depend on values outside
192 # it.
193 input_only = basic_scope_vars & live_in - live_out
195 # Place the outputs first, then sort lexicographically.
196 scope_vars = sorted(scope_vars, key=lambda v: (v in input_only, v))
197 nouts = len(scope_vars) - len(input_only)
199 return scope_vars, undefined, nouts
201 def visit_If(self, node):
202 node = self.generic_visit(node)
203 body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
204 orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE)
206 cond_vars, undefined, nouts = self._get_block_vars(
207 node, body_scope.bound | orelse_scope.bound)
209 undefined_assigns = self._create_undefined_assigns(undefined)
211 nonlocal_declarations = self._create_nonlocal_declarations(cond_vars)
213 reserved = body_scope.referenced | orelse_scope.referenced
214 state_getter_name = self.ctx.namer.new_symbol('get_state', reserved)
215 state_setter_name = self.ctx.namer.new_symbol('set_state', reserved)
216 state_functions = self._create_state_functions(
217 cond_vars, nonlocal_declarations, state_getter_name, state_setter_name)
219 orelse_body = node.orelse
220 if not orelse_body:
221 orelse_body = [gast.Pass()]
223 template = """
224 state_functions
225 def body_name():
226 nonlocal_declarations
227 body
228 def orelse_name():
229 nonlocal_declarations
230 orelse
231 undefined_assigns
232 ag__.if_stmt(
233 test,
234 body_name,
235 orelse_name,
236 state_getter_name,
237 state_setter_name,
238 (symbol_names,),
239 nouts)
240 """
241 new_nodes = templates.replace(
242 template,
243 body=node.body,
244 body_name=self.ctx.namer.new_symbol('if_body', reserved),
245 orelse=orelse_body,
246 orelse_name=self.ctx.namer.new_symbol('else_body', reserved),
247 nonlocal_declarations=nonlocal_declarations,
248 nouts=gast.Constant(nouts, kind=None),
249 state_functions=state_functions,
250 state_getter_name=state_getter_name,
251 state_setter_name=state_setter_name,
252 symbol_names=tuple(gast.Constant(str(s), kind=None) for s in cond_vars),
253 test=node.test,
254 undefined_assigns=undefined_assigns)
255 origin_info.copy_origin(node, new_nodes[-1])
256 return new_nodes
258 def visit_While(self, node):
259 node = self.generic_visit(node)
260 body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
262 loop_vars, undefined, _ = self._get_block_vars(node, body_scope.bound)
264 undefined_assigns = self._create_undefined_assigns(undefined)
266 nonlocal_declarations = self._create_nonlocal_declarations(loop_vars)
268 reserved = body_scope.referenced
269 state_getter_name = self.ctx.namer.new_symbol('get_state', reserved)
270 state_setter_name = self.ctx.namer.new_symbol('set_state', reserved)
271 state_functions = self._create_state_functions(
272 loop_vars, nonlocal_declarations, state_getter_name, state_setter_name)
274 opts = self._create_loop_options(node)
276 template = """
277 state_functions
278 def body_name():
279 nonlocal_declarations
280 body
281 def test_name():
282 return test
283 undefined_assigns
284 ag__.while_stmt(
285 test_name,
286 body_name,
287 state_getter_name,
288 state_setter_name,
289 (symbol_names,),
290 opts)
291 """
292 new_nodes = templates.replace(
293 template,
294 body=node.body,
295 body_name=self.ctx.namer.new_symbol('loop_body', reserved),
296 nonlocal_declarations=nonlocal_declarations,
297 opts=opts,
298 state_functions=state_functions,
299 state_getter_name=state_getter_name,
300 state_setter_name=state_setter_name,
301 symbol_names=tuple(gast.Constant(str(s), kind=None) for s in loop_vars),
302 test=node.test,
303 test_name=self.ctx.namer.new_symbol('loop_test', reserved),
304 undefined_assigns=undefined_assigns)
305 origin_info.copy_origin(node, new_nodes[-1])
306 return new_nodes
308 def visit_For(self, node):
309 node = self.generic_visit(node)
310 body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
311 iter_scope = anno.getanno(node, annos.NodeAnno.ITERATE_SCOPE)
313 loop_vars, undefined, _ = self._get_block_vars(
314 node, body_scope.bound | iter_scope.bound)
316 undefined_assigns = self._create_undefined_assigns(undefined)
318 nonlocal_declarations = self._create_nonlocal_declarations(loop_vars)
320 reserved = body_scope.referenced | iter_scope.referenced
321 state_getter_name = self.ctx.namer.new_symbol('get_state', reserved)
322 state_setter_name = self.ctx.namer.new_symbol('set_state', reserved)
323 state_functions = self._create_state_functions(
324 loop_vars, nonlocal_declarations, state_getter_name, state_setter_name)
326 opts = self._create_loop_options(node)
327 opts.keys.append(gast.Constant('iterate_names', kind=None))
328 opts.values.append(gast.Constant(
329 parser.unparse(node.target, include_encoding_marker=False), kind=None))
331 if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST):
332 extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST)
333 extra_test_name = self.ctx.namer.new_symbol(
334 'extra_test', reserved)
335 template = """
336 def extra_test_name():
337 nonlocal_declarations
338 return extra_test_expr
339 """
340 extra_test_function = templates.replace(
341 template,
342 extra_test_expr=extra_test,
343 extra_test_name=extra_test_name,
344 loop_vars=loop_vars,
345 nonlocal_declarations=nonlocal_declarations)
346 else:
347 extra_test_name = parser.parse_expression('None')
348 extra_test_function = []
350 # iterate_arg_name holds a single arg with the iterates, which may be a
351 # tuple.
352 iterate_arg_name = self.ctx.namer.new_symbol('itr', reserved)
353 template = """
354 iterates = iterate_arg_name
355 """
356 iterate_expansion = templates.replace(
357 template, iterate_arg_name=iterate_arg_name, iterates=node.target)
358 origin_info.copy_origin(node, iterate_expansion)
360 template = """
361 state_functions
362 def body_name(iterate_arg_name):
363 nonlocal_declarations
364 iterate_expansion
365 body
366 extra_test_function
367 undefined_assigns
368 ag__.for_stmt(
369 iterated,
370 extra_test_name,
371 body_name,
372 state_getter_name,
373 state_setter_name,
374 (symbol_names,),
375 opts)
376 """
377 new_nodes = templates.replace(
378 template,
379 body=node.body,
380 body_name=self.ctx.namer.new_symbol('loop_body', reserved),
381 extra_test_function=extra_test_function,
382 extra_test_name=extra_test_name,
383 iterate_arg_name=iterate_arg_name,
384 iterate_expansion=iterate_expansion,
385 iterated=node.iter,
386 nonlocal_declarations=nonlocal_declarations,
387 opts=opts,
388 symbol_names=tuple(gast.Constant(str(s), kind=None) for s in loop_vars),
389 state_functions=state_functions,
390 state_getter_name=state_getter_name,
391 state_setter_name=state_setter_name,
392 undefined_assigns=undefined_assigns)
393 origin_info.copy_origin(node, new_nodes[-1])
394 return new_nodes
397class AnnotatedDef(reaching_definitions.Definition):
399 def __init__(self):
400 super(AnnotatedDef, self).__init__()
401 self.directives = {}
404def transform(node, ctx):
405 graphs = cfg.build(node)
406 node = qual_names.resolve(node)
407 node = activity.resolve(node, ctx, None)
408 node = reaching_definitions.resolve(node, ctx, graphs)
409 node = reaching_fndefs.resolve(node, ctx, graphs)
410 node = liveness.resolve(node, ctx, graphs)
412 node = ControlFlowTransformer(ctx).visit(node)
413 return node