Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py: 26%
356 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""This module contains the user- and codegen-facing API for AutoGraph."""
17import functools
18import importlib
19import inspect
20import os
21import sys
22import textwrap
23import traceback
25from tensorflow.python.autograph import operators
26from tensorflow.python.autograph import utils
27from tensorflow.python.autograph.converters import asserts
28from tensorflow.python.autograph.converters import break_statements
29from tensorflow.python.autograph.converters import call_trees
30from tensorflow.python.autograph.converters import conditional_expressions
31from tensorflow.python.autograph.converters import continue_statements
32from tensorflow.python.autograph.converters import control_flow
33from tensorflow.python.autograph.converters import directives
34from tensorflow.python.autograph.converters import functions
35from tensorflow.python.autograph.converters import lists
36from tensorflow.python.autograph.converters import logical_expressions
37from tensorflow.python.autograph.converters import return_statements
38from tensorflow.python.autograph.converters import slices
39from tensorflow.python.autograph.converters import variables
40from tensorflow.python.autograph.core import ag_ctx
41from tensorflow.python.autograph.core import converter
42from tensorflow.python.autograph.core import function_wrappers
43from tensorflow.python.autograph.core import unsupported_features_checker
44from tensorflow.python.autograph.impl import conversion
45from tensorflow.python.autograph.lang import special_functions
46from tensorflow.python.autograph.operators import py_builtins
47from tensorflow.python.autograph.pyct import anno
48from tensorflow.python.autograph.pyct import cfg
49from tensorflow.python.autograph.pyct import error_utils
50from tensorflow.python.autograph.pyct import errors
51from tensorflow.python.autograph.pyct import inspect_utils
52from tensorflow.python.autograph.pyct import origin_info
53from tensorflow.python.autograph.pyct import qual_names
54from tensorflow.python.autograph.pyct import transpiler
55from tensorflow.python.autograph.pyct.static_analysis import activity
56from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
57from tensorflow.python.autograph.utils import ag_logging as logging
58from tensorflow.python.eager.polymorphic_function import tf_method_target
59from tensorflow.python.framework import errors_impl
60from tensorflow.python.util import tf_decorator
61from tensorflow.python.util import tf_inspect
62from tensorflow.python.util import tf_stack
63from tensorflow.python.util.tf_export import tf_export
66def is_autograph_strict_conversion_mode():
67 return int(os.environ.get('AUTOGRAPH_STRICT_CONVERSION', '0')) > 0
70#
71# Error handling
72#
75# TODO(mdan): Export this symbol.
76class AutoGraphError(errors.PyCTError):
77 """Base class for all AutoGraph exceptions."""
78 pass
81class ConversionError(AutoGraphError):
82 """Raised during the conversion process."""
83 pass
86class StagingError(AutoGraphError):
87 """Raised during the staging (i.e. Python execution) of converted code."""
88 pass
91class _ErrorMetadata(error_utils.ErrorMetadataBase):
92 """AutoGraph-specific error metadata. See base class."""
94 def create_exception(self, source_error):
95 preferred_type = type(source_error)
96 if issubclass(preferred_type, errors_impl.OpError):
97 # Best-effort unpacking of OpError exceptions.
98 # TODO(mdan): Use a mechanism that is more future-proof.
99 init_argspec = tf_inspect.getfullargspec(preferred_type.__init__)
100 message = self.get_message()
101 init_args = tuple(init_argspec.args)
102 # At the time of this writing, TF errors either take 3 or 4 arguments,
103 # the argument '*args' may or may not be used.
104 if init_args == ('self', 'node_def', 'op', 'message'):
105 return preferred_type(source_error.node_def, source_error.op, message,
106 source_error.experimental_payloads)
108 elif preferred_type in (errors.PyCTError, AutoGraphError, ConversionError,
109 StagingError, errors_impl.InaccessibleTensorError,
110 errors_impl.OperatorNotAllowedInGraphError):
111 return preferred_type(self.get_message())
113 exc = super(_ErrorMetadata, self).create_exception(source_error)
114 if exc is not None:
115 return exc
117 # Note: While changing an error's message property to change the message it
118 # displays will probably work a lot of times, there is no standard way in
119 # Python to do that. The safest way is therefore to create a new exception.
120 # For user defined exceptions, we could define an interface that allowed
121 # them to work under this mechanism.
122 return StagingError(self.get_message())
125def _attach_error_metadata(e, f):
126 """Augments an error with the metadata necessary for rewrite."""
127 if hasattr(e, 'ag_pass_through'):
128 return
130 metadata = getattr(e, 'ag_error_metadata', None)
131 source_map = f.ag_source_map
133 if metadata is None:
134 logging.log(1, 'Caught error in user callable %s', f, exc_info=True)
135 message = '{}: {}'.format(e.__class__.__name__, e)
136 else:
137 message = None
139 cause_tb = traceback.extract_tb(sys.exc_info()[2])[1:]
141 e.ag_error_metadata = _ErrorMetadata(cause_tb, metadata, message, source_map,
142 __file__)
145class StackTraceMapper(tf_stack.StackTraceMapper):
146 """Remaps generated code to code it originated from."""
148 def __init__(self, converted_fn):
149 super().__init__()
150 self._source_map = converted_fn.ag_source_map
151 # This may be called repeatedly: once on entry, by the superclass, then by
152 # each child context manager.
153 self._cached_map = None
155 def get_effective_source_map(self):
156 if self._cached_map is not None:
157 return self._cached_map
159 parent_map = self.parent.get_effective_source_map()
161 effective_source_map = {}
162 for loc, origin in self._source_map.items():
163 effective_source_map[(loc.filename, loc.lineno)] = (origin.loc.filename,
164 origin.loc.lineno,
165 origin.function_name)
167 for key, value in parent_map.items():
168 filename, lineno, _ = value
169 value_loc = origin_info.LineLocation(filename=filename, lineno=lineno)
170 if value_loc in self._source_map:
171 origin = self._source_map[value_loc]
172 effective_source_map[key] = (origin.loc.filename, origin.loc.lineno,
173 origin.function_name)
174 else:
175 effective_source_map[key] = value
177 self._cached_map = effective_source_map
178 return effective_source_map
181#
182# Actual source code transformation
183#
186class PyToTF(transpiler.PyToPy):
187 """The TensorFlow AutoGraph transformer."""
189 def __init__(self):
190 super(PyToTF, self).__init__()
191 self._extra_locals = None
193 def get_transformed_name(self, node):
194 return 'tf__' + super(PyToTF, self).get_transformed_name(node)
196 def get_extra_locals(self):
197 if self._extra_locals is None:
198 # TODO(mdan): Move into core or replace with an actual importable module.
199 # Craft a module that exposes the external API as well as certain
200 # internal modules.
201 module_spec = importlib.machinery.ModuleSpec('autograph', None)
202 ag_internal = importlib.util.module_from_spec(module_spec)
203 ag_internal.__dict__.update(inspect.getmodule(PyToTF).__dict__)
204 ag_internal.ConversionOptions = converter.ConversionOptions
205 ag_internal.STD = converter.STANDARD_OPTIONS
206 ag_internal.Feature = converter.Feature
207 ag_internal.utils = utils
208 ag_internal.FunctionScope = function_wrappers.FunctionScope
209 ag_internal.with_function_scope = function_wrappers.with_function_scope
210 # TODO(mdan): Add safeguards against name clashes.
211 # We don't want to create a submodule because we want the operators to be
212 # accessible as ag__.<operator>
213 ag_internal.__dict__.update(special_functions.__dict__)
214 ag_internal.__dict__.update(operators.__dict__)
216 self._extra_locals = {'ag__': ag_internal}
217 return self._extra_locals
219 def get_caching_key(self, ctx):
220 return ctx.options
222 def initial_analysis(self, node, ctx):
223 graphs = cfg.build(node)
224 node = qual_names.resolve(node)
225 node = activity.resolve(node, ctx, None)
226 node = reaching_definitions.resolve(node, ctx, graphs)
227 anno.dup(
228 node,
229 {
230 anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS,
231 },
232 )
233 return node
235 def transform_ast(self, node, ctx):
236 unsupported_features_checker.verify(node)
237 node = self.initial_analysis(node, ctx)
239 node = functions.transform(node, ctx)
240 node = directives.transform(node, ctx)
241 node = break_statements.transform(node, ctx)
242 if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS):
243 node = asserts.transform(node, ctx)
244 # Note: sequencing continue canonicalization before for loop one avoids
245 # dealing with the extra loop increment operation that the for
246 # canonicalization creates.
247 node = continue_statements.transform(node, ctx)
248 node = return_statements.transform(node, ctx)
249 if ctx.user.options.uses(converter.Feature.LISTS):
250 node = lists.transform(node, ctx)
251 node = slices.transform(node, ctx)
252 node = call_trees.transform(node, ctx)
253 node = control_flow.transform(node, ctx)
254 node = conditional_expressions.transform(node, ctx)
255 node = logical_expressions.transform(node, ctx)
256 node = variables.transform(node, ctx)
257 return node
260def _convert_actual(entity, program_ctx):
261 """Applies AutoGraph to entity."""
263 # TODO(mdan): Put these extra fields inside __autograph_info__.
264 if not hasattr(entity, '__code__'):
265 raise ValueError('Cannot apply autograph to a function that doesn\'t '
266 'expose a __code__ object. If this is a @tf.function,'
267 ' try passing f.python_function instead.')
269 transformed, module, source_map = _TRANSPILER.transform(entity, program_ctx)
271 assert not hasattr(transformed, 'ag_module')
272 assert not hasattr(transformed, 'ag_source_map')
273 transformed.ag_module = module
274 transformed.ag_source_map = source_map
275 return transformed
278#
279# Generated code support
280#
283def autograph_artifact(entity, extras=None):
284 if inspect.ismethod(entity):
285 setattr(entity.__func__, 'autograph_info__', extras)
286 else:
287 setattr(entity, 'autograph_info__', extras)
288 return entity
291def is_autograph_artifact(entity):
292 return hasattr(entity, 'autograph_info__')
295def converted_call(f, args, kwargs, caller_fn_scope=None, options=None):
296 """Converts a function call inline.
298 For internal use only.
300 Note: The argument list is optimized for readability of generated code, which
301 may look like this:
303 ag__.converted_call(f, (arg1, arg2), None, fscope)
304 ag__.converted_call(f, (), dict(arg1=val1, **kwargs), fscope)
305 ag__.converted_call(f, (arg1, arg2) + varargs, dict(**kwargs), lscope)
307 Args:
308 f: The function to convert.
309 args: Tuple, the original positional arguments of f
310 kwargs: Optional[Dict], the original keyword arguments of f
311 caller_fn_scope: Optional[function_wrappers.FunctionScope], the function
312 scope of the converted function in which this call was originally made.
313 options: Optional[converter.ConversionOptions], conversion options. If not
314 specified, the value of caller_fn_scope.callopts is used. Either options
315 or caller_fn_scope must be present.
317 Returns:
318 Any, the result of executing a possibly-converted `f` with the given
319 arguments.
320 """
321 logging.log(1, 'Converted call: %s\n args: %s\n kwargs: %s\n', f, args,
322 kwargs)
324 if options is None:
325 if caller_fn_scope is None:
326 raise ValueError('either caller_fn_scope or options must have a value')
327 options = caller_fn_scope.callopts
329 if conversion.is_in_allowlist_cache(f, options):
330 logging.log(2, 'Allowlisted %s: from cache', f)
331 return _call_unconverted(f, args, kwargs, options, False)
333 if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED:
334 logging.log(2, 'Allowlisted: %s: AutoGraph is disabled in context', f)
335 return _call_unconverted(f, args, kwargs, options, False)
337 if is_autograph_artifact(f):
338 logging.log(2, 'Permanently allowed: %s: AutoGraph artifact', f)
339 return _call_unconverted(f, args, kwargs, options)
341 # If this is a partial, unwrap it and redo all the checks.
342 if isinstance(f, functools.partial):
343 new_kwargs = {}
344 if f.keywords is not None:
345 # Use copy to avoid mutating the underlying keywords.
346 new_kwargs = f.keywords.copy()
347 if kwargs is not None:
348 new_kwargs.update(kwargs)
349 new_args = f.args + args
350 logging.log(3, 'Forwarding call of partial %s with\n%s\n%s\n', f, new_args,
351 new_kwargs)
352 return converted_call(
353 f.func,
354 new_args,
355 new_kwargs,
356 caller_fn_scope=caller_fn_scope,
357 options=options)
359 if inspect_utils.isbuiltin(f):
360 if f is eval:
361 return py_builtins.eval_in_original_context(f, args, caller_fn_scope)
362 if f is super:
363 return py_builtins.super_in_original_context(f, args, caller_fn_scope)
364 if f is globals:
365 return py_builtins.globals_in_original_context(caller_fn_scope)
366 if f is locals:
367 return py_builtins.locals_in_original_context(caller_fn_scope)
368 if kwargs:
369 return py_builtins.overload_of(f)(*args, **kwargs)
370 else:
371 return py_builtins.overload_of(f)(*args)
373 if conversion.is_unsupported(f):
374 return _call_unconverted(f, args, kwargs, options)
376 if not options.user_requested and conversion.is_allowlisted(f):
377 return _call_unconverted(f, args, kwargs, options)
379 # internal_convert_user_code is for example turned off when issuing a dynamic
380 # call conversion from generated code while in nonrecursive mode. In that
381 # case we evidently don't want to recurse, but we still have to convert
382 # things like builtins.
383 if not options.internal_convert_user_code:
384 return _call_unconverted(f, args, kwargs, options)
386 try:
387 if inspect.ismethod(f) or inspect.isfunction(f):
388 target_entity = f
389 effective_args = args
391 f_self = getattr(f, '__self__', None)
392 if f_self is not None:
393 if isinstance(f_self, tf_method_target.TfMethodTarget):
394 f_self = f_self.target
395 effective_args = (f_self,) + effective_args
397 elif hasattr(f, '__class__') and hasattr(f.__class__, '__call__'):
398 # Callable objects. Dunder methods have special lookup rules, see:
399 # https://docs.python.org/3/reference/datamodel.html#specialnames
400 # TODO(mdan): Recurse into converted_call to simplify other verifications.
401 # This should be handled in the same way as partials.
402 target_entity = f.__class__.__call__
403 effective_args = (f,) + args
405 else:
406 target_entity = f
407 raise NotImplementedError('unknown callable type "%s"' % type(f))
409 except Exception as e: # pylint:disable=broad-except
410 logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True)
411 if is_autograph_strict_conversion_mode():
412 raise
413 return _fall_back_unconverted(f, args, kwargs, options, e)
415 if not hasattr(target_entity, '__code__'):
416 logging.log(2, 'Permanently allowed: %s: native binding', target_entity)
417 return _call_unconverted(f, args, kwargs, options)
418 elif (hasattr(target_entity.__code__, 'co_filename') and
419 target_entity.__code__.co_filename == '<string>'):
420 # TODO(mdan): __globals__['txt'] might work in Py3.
421 logging.log(2, 'Permanently allowed: %s: dynamic code (exec?)',
422 target_entity)
423 return _call_unconverted(f, args, kwargs, options)
425 try:
426 program_ctx = converter.ProgramContext(options=options)
427 converted_f = _convert_actual(target_entity, program_ctx)
428 if logging.has_verbosity(2):
429 _log_callargs(converted_f, effective_args, kwargs)
430 except Exception as e: # pylint:disable=broad-except
431 logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True)
432 if is_autograph_strict_conversion_mode():
433 raise
434 return _fall_back_unconverted(f, args, kwargs, options, e)
436 with StackTraceMapper(converted_f), tf_stack.CurrentModuleFilter():
437 try:
438 if kwargs is not None:
439 result = converted_f(*effective_args, **kwargs)
440 else:
441 result = converted_f(*effective_args)
442 except Exception as e:
443 _attach_error_metadata(e, converted_f)
444 raise
446 return result
449def _call_unconverted(f, args, kwargs, options, update_cache=True):
450 """Calls the original function without converting with AutoGraph."""
451 if update_cache:
452 conversion.cache_allowlisted(f, options)
454 if (inspect.ismethod(f) and
455 isinstance(f.__self__, tf_method_target.TfMethodTarget)):
456 return f.__self__.call(args, kwargs)
458 if kwargs is not None:
459 return f(*args, **kwargs)
460 return f(*args)
463def _fall_back_unconverted(f, args, kwargs, options, exc):
464 """Falls back to calling the function unconverted, in case of error."""
465 # TODO(mdan): Consider adding an internal metric.
466 warning_template = (
467 'AutoGraph could not transform %s and will run it as-is.\n'
468 '%s'
469 'Cause: %s\n'
470 'To silence this warning, decorate the function with'
471 ' @tf.autograph.experimental.do_not_convert')
472 if isinstance(exc, errors.InaccessibleSourceCodeError):
473 if ag_ctx.INSPECT_SOURCE_SUPPORTED:
474 logging.warning(warning_template, f, '', exc)
475 elif isinstance(exc, errors.UnsupportedLanguageElementError):
476 if not conversion.is_in_allowlist_cache(f, options):
477 logging.warning(warning_template, f, '', exc)
478 else:
479 file_bug_message = (
480 'Please report this to the TensorFlow team. When filing the bug, set'
481 ' the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and'
482 ' attach the full output.\n')
483 logging.warning(warning_template, f, file_bug_message, exc)
485 return _call_unconverted(f, args, kwargs, options)
488#
489# TensorFlow integration
490#
493@tf_export('__internal__.autograph.tf_convert', v1=[])
494def tf_convert(f, ctx, convert_by_default=True, user_requested=False):
495 """Decorator that applies AutoGraph to a function.
497 Use in internal APIs.
499 This API is suitable for high order functions internal to the TensorFlow API,
500 and more generally any function to which AutoGraph is not applied.
502 Guidance: `convert` was a decorator meant for use directly by developers, but
503 most of today's uses go through `tf.function`. `tf_convert` is to be called
504 from high order functions internal to TF. By default, all the internal
505 TensorFlow functions are skipped when AutoGraph processes the code. This may
506 lead to user-supplied functions to be incorrectly skipped as well.
507 `tf_convert` helps avoid that. See the following example for more details.
509 ```
510 =====tf_internal_module.py=====
512 def unconverted(input_fn):
513 return input_fn()
515 def converted(input_fn):
516 return tf.__internal__.autograph.tf_convert(
517 input_fn, ctx=tf.__internal__.autograph.control_status_ctx())()
519 ======user_module.py======
521 @tf.function
522 def foo(input_fn)
523 return unconverted(input_fn)
525 @tf.function
526 def bar(input_fn)
527 return converted(input_fn)
529 @tf.function(autograph=False)
530 def baz(input_fn)
531 return converted(input_fn)
532 ```
534 The `foo` method above will execute the `input_fn` without autograph
535 conversion, while the `bar` method will run an autographed `input_fn`. The
536 `baz` method will run an unconverted `input_fn`, since `tf_convert` respect
537 the control status context.
539 Note that both methods in `tf_internal_module` are skipped by autograph when
540 tracing the `tf.function`. The configuration of whether a module/package
541 should be skipped by autograph is controlled in
542 tensorflow/python/autograph/core/config.py.
544 Args:
545 f: Callable.
546 ctx: ag_ctx.ControlStatusCtx, the Autograph context in which `f` is used.
547 convert_by_default: bool, whether to use AutoGraph when the context doesn't
548 specify.
549 user_requested: bool, whether to ignore the conversion allowlist. See
550 ConversionOptions.user_requested.
552 Returns:
553 Either `f or the converted version of `f`.
554 """
556 if is_autograph_artifact(f):
557 return f
558 f_wrapper = f
559 decorators, f = tf_decorator.unwrap(f)
561 # TODO(mdan): Grab features from context.
562 # Note: we pass the original context through to convert to properly handle the
563 # following scenario, which can be used inside TF implementations:
564 #
565 # ctx = ag_ctx.control_status_ctx()
566 # @function(autograph=False) # Low-level graph code
567 # def inner_fn():
568 # # The context is disabled here, but should be enabled in user user_fn
569 # tf_convert(user_fn, ctx=ctx)
570 if ctx.status == ag_ctx.Status.ENABLED:
571 wrapper_factory = convert(
572 recursive=True, user_requested=user_requested, conversion_ctx=ctx)
573 elif ctx.status == ag_ctx.Status.DISABLED:
574 wrapper_factory = do_not_convert
575 elif ctx.status == ag_ctx.Status.UNSPECIFIED:
576 if convert_by_default:
577 wrapper_factory = convert(
578 recursive=True, user_requested=user_requested, conversion_ctx=ctx)
579 else:
580 wrapper_factory = call_with_unspecified_conversion_status
581 else:
582 assert False, 'This switch contains all possible cases!'
583 wrapper = wrapper_factory(f)
585 if decorators:
586 wrapper = tf_decorator.rewrap(f_wrapper, f, wrapper)
588 return autograph_artifact(wrapper)
591def call_with_unspecified_conversion_status(func):
592 """Decorator that resets the conversion context to the unspecified status."""
594 def wrapper(*args, **kwargs):
595 with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED):
596 return func(*args, **kwargs)
598 if inspect.isfunction(func) or inspect.ismethod(func):
599 wrapper = functools.update_wrapper(wrapper, func)
601 return autograph_artifact(wrapper)
604def _log_callargs(f, args, kwargs):
605 """Logging helper."""
606 logging.log(2, 'Defaults of %s : %s', f, f.__defaults__)
607 logging.log(2, 'KW defaults of %s : %s', f, f.__kwdefaults__)
609 if kwargs is not None:
610 callargs = tf_inspect.getcallargs(f, *args, **kwargs)
611 else:
612 callargs = tf_inspect.getcallargs(f, *args)
614 formatted_callargs = '\n'.join(
615 ' {}: {}'.format(k, v) for k, v in callargs.items())
616 logging.log(2, 'Calling %s with\n%s\n', f, formatted_callargs)
619#
620# Public API
621#
624@tf_export('autograph.experimental.do_not_convert')
625def do_not_convert(func=None):
626 """Decorator that suppresses the conversion of a function.
628 Args:
629 func: function to decorate.
631 Returns:
632 If `func` is not None, returns a `Callable` which is equivalent to
633 `func`, but is not converted by AutoGraph.
634 If `func` is None, returns a decorator that, when invoked with a
635 single `func` argument, returns a `Callable` equivalent to the
636 above case.
637 """
638 if func is None:
639 return do_not_convert
641 def wrapper(*args, **kwargs):
642 with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED):
643 return func(*args, **kwargs)
645 if inspect.isfunction(func) or inspect.ismethod(func):
646 wrapper = functools.update_wrapper(wrapper, func)
648 return autograph_artifact(wrapper)
651# TODO(mdan): Make private.
652def convert(recursive=False,
653 optional_features=None,
654 user_requested=True,
655 conversion_ctx=ag_ctx.NullCtx()):
656 """Decorator that compiles a function to use TensorFlow ops.
658 The decorator is dynamic - it recompiles the target whenever the decorated
659 function is called. This means the parameter values are known at conversion.
660 It also means that repeated calls with different types of parameters will be
661 correctly processed.
663 Args:
664 recursive: bool, whether to recursively convert any functions or classes
665 that the converted function may use.
666 optional_features: converted.Feature, allows toggling optional or
667 experimental features. When set to None, only the core features are
668 enabled.
669 user_requested: bool, whether this is a function that the user explicitly
670 asked to be converted. See ConversionOptions.user_requested.
671 conversion_ctx: Optional ag_ctx.ControlStatusCtx, the Autograph context in
672 which `f` is used.
674 Returns:
675 Callable, a decorator that converts the given function into an equivalent
676 function that uses TensorFlow ops.
677 """
679 def decorator(f):
680 """Decorator implementation."""
682 def wrapper(*args, **kwargs):
683 """Wrapper that calls the converted version of f."""
684 options = converter.ConversionOptions(
685 recursive=recursive,
686 user_requested=user_requested,
687 optional_features=optional_features)
688 try:
689 with conversion_ctx:
690 return converted_call(f, args, kwargs, options=options)
691 except Exception as e: # pylint:disable=broad-except
692 if hasattr(e, 'ag_error_metadata'):
693 raise e.ag_error_metadata.to_exception(e)
694 else:
695 raise
697 if inspect.isfunction(f) or inspect.ismethod(f):
698 wrapper = functools.update_wrapper(wrapper, f)
700 decorated_wrapper = tf_decorator.make_decorator(f, wrapper)
701 return autograph_artifact(decorated_wrapper)
703 return decorator
706# pylint:disable=line-too-long
707@tf_export('autograph.to_graph', v1=[])
708def to_graph(entity, recursive=True, experimental_optional_features=None):
709 """Converts a Python entity into a TensorFlow graph.
711 Also see: `tf.autograph.to_code`, `tf.function`.
713 Unlike `tf.function`, `to_graph` is a low-level transpiler that converts
714 Python code to TensorFlow graph code. It does not implement any caching,
715 variable management or create any actual ops, and is best used where greater
716 control over the generated TensorFlow graph is desired. Another difference
717 from `tf.function` is that `to_graph` will not wrap the graph into a
718 TensorFlow function or a Python callable. Internally, `tf.function` uses
719 `to_graph`.
721 Example usage:
723 >>> def f(x):
724 ... if x > 0:
725 ... y = x * x
726 ... else:
727 ... y = -x
728 ... return y
729 ...
730 >>> converted_f = to_graph(f)
731 >>> x = tf.constant(2)
732 >>> converted_f(x) # converted_foo is like a TensorFlow Op.
733 <tf.Tensor: shape=(), dtype=int32, numpy=4>
735 Supported Python entities include:
736 * functions
737 * classes
738 * object methods
740 Functions are converted into new functions with converted code.
742 Classes are converted by generating a new class whose methods use converted
743 code.
745 Methods are converted into unbound function that have an additional first
746 argument called `self`.
748 For a tutorial, see the
749 [tf.function and AutoGraph guide](https://www.tensorflow.org/guide/function).
750 For more detailed information, see the
751 [AutoGraph reference documentation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/index.md).
753 Args:
754 entity: Python callable or class to convert.
755 recursive: Whether to recursively convert any functions that the converted
756 function may call.
757 experimental_optional_features: `None`, a tuple of, or a single
758 `tf.autograph.experimental.Feature` value.
760 Returns:
761 Same as `entity`, the converted Python function or class.
763 Raises:
764 ValueError: If the entity could not be converted.
765 """
766 try:
767 program_ctx = converter.ProgramContext(
768 options=converter.ConversionOptions(
769 recursive=recursive,
770 user_requested=True,
771 optional_features=experimental_optional_features))
772 return autograph_artifact(_convert_actual(entity, program_ctx))
773 except (ValueError, AttributeError, KeyError, NameError, AssertionError) as e:
774 logging.error(1, 'Error converting %s', entity, exc_info=True)
775 raise ConversionError('converting {}: {}: {}'.format(
776 entity, e.__class__.__name__, str(e)))
779@tf_export(v1=['autograph.to_graph'])
780def to_graph_v1(entity,
781 recursive=True,
782 arg_values=None,
783 arg_types=None,
784 experimental_optional_features=None):
785 """Converts a Python entity into a TensorFlow graph.
787 Also see: `tf.autograph.to_code`, `tf.function`.
789 Unlike `tf.function`, `to_graph` is a low-level transpiler that converts
790 Python code to TensorFlow graph code. It does not implement any caching,
791 variable management or create any actual ops, and is best used where greater
792 control over the generated TensorFlow graph is desired. Another difference
793 from `tf.function` is that `to_graph` will not wrap the graph into a
794 TensorFlow function or a Python callable. Internally, `tf.function` uses
795 `to_graph`.
797 _Example Usage_
799 ```python
800 def foo(x):
801 if x > 0:
802 y = x * x
803 else:
804 y = -x
805 return y
807 converted_foo = to_graph(foo)
809 x = tf.constant(1)
810 y = converted_foo(x) # converted_foo is a TensorFlow Op-like.
811 assert is_tensor(y)
812 ```
814 Supported Python entities include:
815 * functions
816 * classes
817 * object methods
819 Functions are converted into new functions with converted code.
821 Classes are converted by generating a new class whose methods use converted
822 code.
824 Methods are converted into unbound function that have an additional first
825 argument called `self`.
827 Args:
828 entity: Python callable or class to convert.
829 recursive: Whether to recursively convert any functions that the converted
830 function may call.
831 arg_values: Deprecated.
832 arg_types: Deprecated.
833 experimental_optional_features: `None`, a tuple of, or a single
834 `tf.autograph.experimental.Feature` value.
836 Returns:
837 Same as `entity`, the converted Python function or class.
839 Raises:
840 ValueError: If the entity could not be converted.
841 """
842 del arg_types
843 del arg_values
844 return to_graph(
845 entity,
846 recursive=recursive,
847 experimental_optional_features=experimental_optional_features)
850@tf_export(v1=['autograph.to_code'])
851def to_code_v1(entity,
852 recursive=True,
853 arg_values=None,
854 arg_types=None,
855 indentation=' ',
856 experimental_optional_features=None):
857 """Returns the source code generated by AutoGraph, as a string.
859 Example usage:
861 >>> def f(x):
862 ... if x < 0:
863 ... x = -x
864 ... return x
865 >>> tf.autograph.to_code(f)
866 "...def tf__f(x):..."
868 Also see: `tf.autograph.to_graph`.
870 Note: If a function has been decorated with `tf.function`, pass its
871 underlying Python function, rather than the callable that `tf.function
872 creates:
874 >>> @tf.function
875 ... def f(x):
876 ... if x < 0:
877 ... x = -x
878 ... return x
879 >>> tf.autograph.to_code(f.python_function)
880 "...def tf__f(x):..."
882 Args:
883 entity: Python callable or class.
884 recursive: Whether to recursively convert any functions that the converted
885 function may call.
886 arg_values: Deprecated.
887 arg_types: Deprecated.
888 indentation: Deprecated.
889 experimental_optional_features: `None`, a tuple of, or a single
890 `tf.autograph.experimental.Feature` value.
892 Returns:
893 The converted code as string.
894 """
895 del arg_values
896 del arg_types
897 del indentation
898 return to_code(
899 entity,
900 recursive=recursive,
901 experimental_optional_features=experimental_optional_features)
904@tf_export('autograph.to_code', v1=[])
905def to_code(entity, recursive=True, experimental_optional_features=None):
906 """Returns the source code generated by AutoGraph, as a string.
908 Example usage:
910 >>> def f(x):
911 ... if x < 0:
912 ... x = -x
913 ... return x
914 >>> tf.autograph.to_code(f)
915 "...def tf__f(x):..."
917 Also see: `tf.autograph.to_graph`.
919 Note: If a function has been decorated with `tf.function`, pass its
920 underlying Python function, rather than the callable that `tf.function
921 creates:
923 >>> @tf.function
924 ... def f(x):
925 ... if x < 0:
926 ... x = -x
927 ... return x
928 >>> tf.autograph.to_code(f.python_function)
929 "...def tf__f(x):..."
931 Args:
932 entity: Python callable or class to convert.
933 recursive: Whether to recursively convert any functions that the converted
934 function may call.
935 experimental_optional_features: `None`, a tuple of, or a single
936 `tf.autograph.experimental.Feature` value.
938 Returns:
939 The converted code as string.
940 """
941 source = tf_inspect.getsource(
942 to_graph(
943 entity,
944 recursive=recursive,
945 experimental_optional_features=experimental_optional_features))
946 return textwrap.dedent(source)
949_TRANSPILER = PyToTF()