Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/converters/functions.py: 29%
55 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Converts function definitions and lambdas by adding necessary boilerplate."""
17import gast
19from tensorflow.python.autograph.core import converter
20from tensorflow.python.autograph.pyct import anno
21from tensorflow.python.autograph.pyct import parser
22from tensorflow.python.autograph.pyct import qual_names
23from tensorflow.python.autograph.pyct import templates
24from tensorflow.python.autograph.pyct.static_analysis import activity
25from tensorflow.python.autograph.pyct.static_analysis import annos
28class _Function(object):
30 def __init__(self):
31 self.context_name = None
34class FunctionTransformer(converter.Base):
35 """Wraps function bodies around autograph-specific boilerplate."""
37 def _function_scope_options(self, fn_scope):
38 """Returns the options with which to create function scopes."""
39 # Top-level function receive the options that were directly requested.
40 # All others receive the options corresponding to a recursive conversion.
41 # Note: this mainly controls the user_requested flag, which is important
42 # primarily because the FunctionScope context also creates a
43 # ControlStatusCtx(autograph=ENABLED) when user_requested is True. See
44 # function_wrappers.py.
45 if fn_scope.level == 2:
46 return self.ctx.user.options
47 return self.ctx.user.options.call_options()
49 def visit_Lambda(self, node):
50 with self.state[_Function] as fn_scope:
51 node = self.generic_visit(node)
53 # TODO(mdan): Fix the tests so that we can always add this decorator.
54 if fn_scope.level > 2:
55 return templates.replace_as_expression(
56 'ag__.autograph_artifact(l)', l=node)
58 scope = anno.getanno(node, anno.Static.SCOPE)
59 function_context_name = self.ctx.namer.new_symbol('lscope',
60 scope.referenced)
61 fn_scope.context_name = function_context_name
62 anno.setanno(node, 'function_context_name', function_context_name)
64 template = """
65 ag__.with_function_scope(
66 lambda function_context: body, function_context_name, options)
67 """
68 node.body = templates.replace_as_expression(
69 template,
70 options=self._function_scope_options(fn_scope).to_ast(),
71 function_context=function_context_name,
72 function_context_name=gast.Constant(function_context_name, kind=None),
73 body=node.body)
75 return node
77 def visit_FunctionDef(self, node):
78 with self.state[_Function] as fn_scope:
79 scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
81 function_context_name = self.ctx.namer.new_symbol('fscope',
82 scope.referenced)
83 fn_scope.context_name = function_context_name
84 anno.setanno(node, 'function_context_name', function_context_name)
86 node = self.generic_visit(node)
88 if fn_scope.level <= 2:
89 # Top-level functions lose their decorator because the conversion is
90 # always just-in-time and by the time it happens the decorators are
91 # already set to be applied.
92 node.decorator_list = []
93 else:
94 # TODO(mdan): Fix the tests so that we can always add this decorator.
95 # Inner functions are converted already, so we insert a decorator to
96 # prevent double conversion. Double conversion would work too, but this
97 # saves the overhead.
98 node.decorator_list.append(
99 parser.parse_expression('ag__.autograph_artifact'))
101 docstring_node = None
102 if node.body:
103 first_statement = node.body[0]
104 if (isinstance(first_statement, gast.Expr) and
105 isinstance(first_statement.value, gast.Constant)):
106 docstring_node = first_statement
107 node.body = node.body[1:]
109 template = """
110 with ag__.FunctionScope(
111 function_name, context_name, options) as function_context:
112 body
113 """
114 wrapped_body = templates.replace(
115 template,
116 function_name=gast.Constant(node.name, kind=None),
117 context_name=gast.Constant(function_context_name, kind=None),
118 options=self._function_scope_options(fn_scope).to_ast(),
119 function_context=function_context_name,
120 body=node.body)
122 if docstring_node is not None:
123 wrapped_body = [docstring_node] + wrapped_body
125 node.body = wrapped_body
127 return node
130def transform(node, ctx):
131 node = qual_names.resolve(node)
132 node = activity.resolve(node, ctx, None)
134 return FunctionTransformer(ctx).visit(node)