Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/core/converter.py: 47%
95 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Converter construction support.
17This module contains a base class for all converters, as well as supporting
18structures. These structures are referred to as contexts.
20The class hierarchy is as follows:
22 <your converter>
23 [extends] converter.Base
24 [extends] transformer.Base
25 [extends] gast.nodeTransformer
26 [uses] transformer.SourceInfo
27 [uses] converter.EntityContext
28 [uses] converter.ProgramContext
29 [uses] transformer.SourceInfo
31converter.Base is a specialization of transformer.Base for AutoGraph. It's a
32very lightweight subclass that adds a `ctx` attribute holding the corresponding
33EntityContext object (see below). Note that converters are not reusable, and
34`visit` will raise an error if called more than once.
36converter.EntityContext contains mutable state associated with an entity that
37the converter processes.
39converter.ProgramContext contains mutable state across related entities. For
40example, when converting several functions that call one another, the
41ProgramContext should be shared across these entities.
43Below is the overall flow at conversion:
45 program_ctx = ProgramContext(<entities to convert>, <global settings>, ...)
46 while <program_ctx has more entities to convert>:
47 entity, source_info = <get next entity from program_ctx>
48 entity_ctx = EntityContext(program_ctx, source_info)
49 for <each ConverterClass>:
50 converter = ConverterClass(entity_ctx)
52 # May update entity_ctx and program_ctx
53 entity = converter.visit(entity)
55 <add entity's dependencies to program_ctx>
57Note that pyct contains a small number of transformers used for static analysis.
58These implement transformer.Base, rather than converter.Base, to avoid a
59dependency on AutoGraph.
60"""
62import enum
64from tensorflow.python.autograph.pyct import anno
65from tensorflow.python.autograph.pyct import ast_util
66from tensorflow.python.autograph.pyct import parser
67from tensorflow.python.autograph.pyct import templates
68from tensorflow.python.autograph.pyct import transformer
69from tensorflow.python.util.tf_export import tf_export
71# TODO(mdan): These contexts can be refactored into first class objects.
72# For example, we could define Program and Entity abstractions that hold on
73# to the actual entity and have conversion methods.
75# TODO(mdan): Add a test specific to this converter.
78@tf_export('autograph.experimental.Feature')
79class Feature(enum.Enum):
80 """This enumeration represents optional conversion options.
82 These conversion options are experimental. They are subject to change without
83 notice and offer no guarantees.
85 _Example Usage_
87 ```python
88 optionals= tf.autograph.experimental.Feature.EQUALITY_OPERATORS
89 @tf.function(experimental_autograph_options=optionals)
90 def f(i):
91 if i == 0: # EQUALITY_OPERATORS allows the use of == here.
92 tf.print('i is zero')
93 ```
95 Attributes:
96 ALL: Enable all features.
97 AUTO_CONTROL_DEPS: Insert of control dependencies in the generated code.
98 ASSERT_STATEMENTS: Convert Tensor-dependent assert statements to tf.Assert.
99 BUILTIN_FUNCTIONS: Convert builtin functions applied to Tensors to
100 their TF counterparts.
101 EQUALITY_OPERATORS: Whether to convert the equality operator ('==') to
102 tf.math.equal.
103 LISTS: Convert list idioms, like initializers, slices, append, etc.
104 NAME_SCOPES: Insert name scopes that name ops according to context, like the
105 function they were defined in.
106 """
108 ALL = 'ALL'
110 AUTO_CONTROL_DEPS = 'AUTO_CONTROL_DEPS'
111 ASSERT_STATEMENTS = 'ASSERT_STATEMENTS'
112 BUILTIN_FUNCTIONS = 'BUILTIN_FUNCTIONS'
113 EQUALITY_OPERATORS = 'EQUALITY_OPERATORS'
114 LISTS = 'LISTS'
115 NAME_SCOPES = 'NAME_SCOPES'
117 @classmethod
118 def all(cls):
119 """Returns a tuple that enables all options."""
120 return tuple(cls.__members__.values())
122 @classmethod
123 def all_but(cls, exclude):
124 """Returns a tuple that enables all but the excluded options."""
125 if not isinstance(exclude, (list, tuple, set)):
126 exclude = (exclude,)
127 return tuple(set(cls.all()) - set(exclude) - {cls.ALL})
130STANDARD_OPTIONS = None # Forward definition.
133class ConversionOptions(object):
134 """Immutable container for global conversion flags.
136 Attributes:
137 recursive: bool, whether to recursively convert any user functions or
138 classes that the converted function may use.
139 user_requested: bool, whether the conversion was explicitly requested by
140 the user, as opposed to being performed as a result of other logic. This
141 value always auto-resets to False in child conversions.
142 optional_features: Union[Feature, Set[Feature]], controls the use of
143 optional features in the conversion process. See Feature for available
144 options.
145 """
147 def __init__(self,
148 recursive=False,
149 user_requested=False,
150 internal_convert_user_code=True,
151 optional_features=Feature.ALL):
152 self.recursive = recursive
153 self.user_requested = user_requested
154 # TODO(mdan): Rename to conversion_recursion_depth?
155 self.internal_convert_user_code = internal_convert_user_code
157 if optional_features is None:
158 optional_features = ()
159 elif isinstance(optional_features, Feature):
160 optional_features = (optional_features,)
161 optional_features = frozenset(optional_features)
162 self.optional_features = optional_features
164 def as_tuple(self):
165 return (self.recursive, self.user_requested,
166 self.internal_convert_user_code, self.optional_features)
168 def __hash__(self):
169 return hash(self.as_tuple())
171 def __eq__(self, other):
172 assert isinstance(other, ConversionOptions)
173 return self.as_tuple() == other.as_tuple()
175 def __str__(self):
176 return 'ConversionOptions[{}]'
178 def uses(self, feature):
179 return (Feature.ALL in self.optional_features or
180 feature in self.optional_features)
182 def call_options(self):
183 """Returns the corresponding options to be used for recursive conversion."""
184 return ConversionOptions(
185 recursive=self.recursive,
186 user_requested=False,
187 internal_convert_user_code=self.recursive,
188 optional_features=self.optional_features)
190 def to_ast(self):
191 """Returns a representation of this object as an AST node.
193 The AST node encodes a constructor that would create an object with the
194 same contents.
196 Returns:
197 ast.Node
198 """
199 if self == STANDARD_OPTIONS:
200 return parser.parse_expression('ag__.STD')
202 template = """
203 ag__.ConversionOptions(
204 recursive=recursive_val,
205 user_requested=user_requested_val,
206 optional_features=optional_features_val,
207 internal_convert_user_code=internal_convert_user_code_val)
208 """
210 def list_of_features(values):
211 return parser.parse_expression('({})'.format(', '.join(
212 'ag__.{}'.format(str(v)) for v in values)))
214 expr_ast = templates.replace(
215 template,
216 recursive_val=parser.parse_expression(str(self.recursive)),
217 user_requested_val=parser.parse_expression(str(self.user_requested)),
218 internal_convert_user_code_val=parser.parse_expression(
219 str(self.internal_convert_user_code)),
220 optional_features_val=list_of_features(self.optional_features))
221 return expr_ast[0].value
224STANDARD_OPTIONS = ConversionOptions(
225 recursive=True,
226 user_requested=False,
227 internal_convert_user_code=True,
228 optional_features=None)
231class ProgramContext(object):
232 """ProgramContext keeps track of converting function hierarchies.
234 Attributes:
235 options: ConversionOptions
236 autograph_module: Deprecated. Do not use.
237 """
239 def __init__(self, options, autograph_module=None):
240 self.options = options
241 self.autograph_module = autograph_module
244class Base(transformer.Base):
245 """All converters should inherit from this class.
247 Attributes:
248 ctx: EntityContext
249 """
251 def __init__(self, ctx):
252 super(Base, self).__init__(ctx)
254 self._used = False
255 self._ast_depth = 0
257 def get_definition_directive(self, node, directive, arg, default):
258 """Returns the unique directive argument for a symbol.
260 See lang/directives.py for details on directives.
262 Example:
263 # Given a directive in the code:
264 ag.foo_directive(bar, baz=1)
266 # One can write for an AST node Name(id='bar'):
267 get_definition_directive(node, ag.foo_directive, 'baz')
269 Args:
270 node: ast.AST, the node representing the symbol for which the directive
271 argument is needed.
272 directive: Callable[..., Any], the directive to search.
273 arg: str, the directive argument to return.
274 default: Any
276 Raises:
277 ValueError: if conflicting annotations have been found
278 """
279 defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ())
280 if not defs:
281 return default
283 arg_values_found = []
284 for def_ in defs:
285 if (directive in def_.directives and arg in def_.directives[directive]):
286 arg_values_found.append(def_.directives[directive][arg])
288 if not arg_values_found:
289 return default
291 if len(arg_values_found) == 1:
292 return arg_values_found[0]
294 # If multiple annotations reach the symbol, they must all match. If they do,
295 # return any of them.
296 first_value = arg_values_found[0]
297 for other_value in arg_values_found[1:]:
298 if not ast_util.matches(first_value, other_value):
299 qn = anno.getanno(node, anno.Basic.QN)
300 raise ValueError(
301 '%s has ambiguous annotations for %s(%s): %s, %s' %
302 (qn, directive.__name__, arg, parser.unparse(other_value).strip(),
303 parser.unparse(first_value).strip()))
304 return first_value
306 def visit(self, node):
307 if not self._ast_depth:
308 if self._used:
309 raise ValueError('converter objects cannot be reused')
310 self._used = True
312 self._ast_depth += 1
313 try:
314 return super(Base, self).visit(node)
315 finally:
316 self._ast_depth -= 1