Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/converters/call_trees.py: 28%
95 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Handles function calls, by generating compiled function names and calls.
17Note: this transformer does not rename the top level object being converted;
18that is the caller's responsibility.
20Requires function_scopes.
21"""
23import gast
25from tensorflow.python.autograph.core import converter
26from tensorflow.python.autograph.pyct import anno
27from tensorflow.python.autograph.pyct import parser
28from tensorflow.python.autograph.pyct import qual_names
29from tensorflow.python.autograph.pyct import templates
30from tensorflow.python.autograph.utils import ag_logging
33# TODO(mdan): Rename to FunctionCallsTransformer.
36class _Function(object):
38 no_root = True
40 def __init__(self):
41 self.context_name = None
44set_trace_warned = False
47class _ArgTemplateBuilder(object):
48 """Constructs a tuple representing the positional arguments in a call.
50 Example (yes, it's legal Python 3):
52 f(*args1, b, *args2, c, d) -> args1 + (b,) + args2 + (c, d)
53 """
55 def __init__(self):
56 self._arg_accumulator = []
57 self._argspec = []
58 self._finalized = False
60 def _consume_args(self):
61 if self._arg_accumulator:
62 self._argspec.append(
63 gast.Tuple(elts=self._arg_accumulator, ctx=gast.Load()))
64 self._arg_accumulator = []
66 def add_arg(self, a):
67 self._arg_accumulator.append(a)
69 def add_stararg(self, a):
70 self._consume_args()
71 self._argspec.append(
72 gast.Call(
73 gast.Name(
74 'tuple', ctx=gast.Load(), annotation=None, type_comment=None),
75 args=[a],
76 keywords=()))
78 def finalize(self):
79 self._consume_args()
80 self._finalized = True
82 def to_ast(self):
83 assert self._finalized
84 if self._argspec:
85 result = self._argspec[0]
86 for i in range(1, len(self._argspec)):
87 result = gast.BinOp(result, gast.Add(), self._argspec[i])
88 return result
89 return gast.Tuple([], gast.Load())
92class CallTreeTransformer(converter.Base):
93 """Transforms the call tree by renaming transformed symbols."""
95 def visit_Lambda(self, node):
96 if not anno.hasanno(node, 'function_context_name'):
97 # Lambda functions created during the conversion process have no
98 # context manager.
99 return self.generic_visit(node)
100 with self.state[_Function] as fn_scope:
101 fn_scope.context_name = anno.getanno(node, 'function_context_name')
102 return self.generic_visit(node)
104 def visit_FunctionDef(self, node):
105 # Decorators and arg defaults are part of the outer scope.
106 node.decorator_list = self.visit_block(node.decorator_list)
107 node.args.defaults = self.visit_block(node.args.defaults)
108 for i, d in enumerate(node.args.kw_defaults):
109 if d is not None:
110 node.args.kw_defaults[i] = self.visit(d)
111 with self.state[_Function] as fn_scope:
112 # Note: if the conversion process ever creates helper functions, this
113 # assumption will no longer hold.
114 assert anno.hasanno(node, 'function_context_name'), (
115 'The function_scopes converter always creates a scope for functions.')
116 fn_scope.context_name = anno.getanno(node, 'function_context_name')
117 node.body = self.visit_block(node.body)
118 if node.returns:
119 node.returns = self.visit(node.returns)
120 return node
122 def visit_With(self, node):
123 # Context manager calls (in node.items) are not converted.
124 node.body = self.visit_block(node.body)
125 return node
127 def _args_to_tuple(self, node):
128 """Ties together all positional and *arg arguments in a single tuple."""
129 # TODO(mdan): We could rewrite this to just a call to tuple(). Maybe better?
130 # For example for
131 # f(a, b, *args)
132 # instead of writing:
133 # (a, b) + args
134 # just write this?
135 # tuple(a, b, *args)
136 builder = _ArgTemplateBuilder()
137 for a in node.args:
138 if isinstance(a, gast.Starred):
139 builder.add_stararg(a.value)
140 else:
141 builder.add_arg(a)
142 builder.finalize()
143 return builder.to_ast()
145 def _kwargs_to_dict(self, node):
146 """Ties together all keyword and **kwarg arguments in a single dict."""
147 if node.keywords:
148 return gast.Call(
149 gast.Name(
150 'dict', ctx=gast.Load(), annotation=None, type_comment=None),
151 args=(),
152 keywords=node.keywords)
153 else:
154 return parser.parse_expression('None')
156 def visit_Call(self, node):
157 full_name = str(anno.getanno(node.func, anno.Basic.QN, default=''))
158 function_context_name = self.state[_Function].context_name
159 node = self.generic_visit(node)
161 # TODO(mdan): Refactor converted_call as a 'Call' operator.
163 # Calls to the internal 'ag__' module are never converted (though their
164 # arguments might be).
165 if full_name.startswith('ag__.'):
166 return node
168 # Calls to the function context manager (inserted by function_scopes) are
169 # also safe.
170 if full_name.startswith(function_context_name + '.'):
171 return node
173 # Calls to pdb.set_trace or ipdb.set_trace are never converted. We don't use
174 # the normal mechanisms to bypass these literals because they are sensitive
175 # to the frame they are being called from.
176 # TODO(mdan): Generalize this to a "static allowlist" config.
177 if full_name in ('pdb.set_trace', 'ipdb.set_trace', 'breakpoint'):
178 global set_trace_warned
179 if not set_trace_warned:
180 # TODO(mdan): Update and shorten once available on tensorflow.org.
181 ag_logging.warning(
182 'Detected `pdb.set_trace()` in user code. The code'
183 ' generated by AutoGraph is not optimized for step-by-step'
184 ' debugging. See https://github.com/tensorflow/tensorflow/'
185 'blob/master/tensorflow/python/autograph/g3doc/reference/'
186 'debugging.md.')
187 set_trace_warned = True
188 return node
190 if (full_name == 'print' and
191 not self.ctx.user.options.uses(converter.Feature.BUILTIN_FUNCTIONS)):
192 return node
194 template = """
195 ag__.converted_call(func, args, kwargs, function_ctx)
196 """
197 new_call = templates.replace_as_expression(
198 template,
199 func=node.func,
200 args=self._args_to_tuple(node),
201 kwargs=self._kwargs_to_dict(node),
202 function_ctx=function_context_name)
204 return new_call
207def transform(node, ctx):
208 """Transform function call to the compiled counterparts.
210 Args:
211 node: AST
212 ctx: EntityContext
213 Returns:
214 A tuple (node, new_names):
215 node: The transformed AST
216 new_names: set(string), containing any newly-generated names
217 """
218 node = qual_names.resolve(node)
220 node = CallTreeTransformer(ctx).visit(node)
221 return node