1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=unidiomatic-typecheck
16"""API for defining graph functions with some additional eager semantics.
17
18tf.function utilizes varying configurations of TracingCompiler to allow
19initializing `tf.Variable`s with subgraphs of the function. For example:
20
21```python
22class M(tf.Module):
23 def __init__(self):
24 self.v_opinit = None
25 self.v_arginit = None
26
27 @tf.function
28 def __call__(self, x):
29 # Variables are only created on the first call to the function. This is a
30 # common pattern in layer libraries.
31 if self.v_opinit is None:
32 # self.v_opinit will outlive the function call, but `tf.ones` is traced as
33 # part of the function body before the `tf.Variable` object is
34 # created. This subgraph is easy to lift out of the function.
35 self.v_opinit = tf.Variable(tf.ones([]))
36
37 # If arguments feed into variable initialization, it can be very tricky to
38 # disentangle from the rest of the function. We don't attempt it.
39 self.v_arginit = tf.Variable(tf.ones(tf.shape(x)) * tf.constant(2.))
40 return self.v_opinit + self.v_arginit + x
41```
42
43These patterns with using "TracingCompiler" directly throw an error asking
44the user to put the variable's initializer in a lambda. With tf.function they
45work with eager semantics either by lifting the subgraph out of the function and
46using it to initialize the variable, or by initializing variables on the first
47call to the function (if they weren't already initialized by something else,
48e.g. a checkpoint API). The latter requires tf.conds, and is not well supported
49by TF-XLA, so we only do it when necessary.
50
51Since these patterns are relatively common in layer libraries, we expose the
52wrapper in this file as `tf.function`. The defun concept in quarantine.py is a
53legacy internal API.
54
55In order to support these variable initialization patterns, tf.function defines
56a variable subtype (UnliftedInitializerVariable) which collects the input
57subgraph. This type of variable replaces the regular variable type on the first
58tf.function trace. To exclude initializers from the function body (the `tf.ones`
59ops above and associated assignment operations), tf.function traces a second
60time if it sees variables on the first call.
61"""
62
63import functools
64import os
65import threading
66import types as types_lib
67import weakref
68
69from google.protobuf import text_format as _text_format
70from google.protobuf.message import DecodeError
71from tensorflow.core.framework import attr_value_pb2
72from tensorflow.core.function import trace_type
73from tensorflow.python.distribute.parallel_device import parallel_device
74from tensorflow.python.eager import context
75from tensorflow.python.eager import lift_to_graph
76from tensorflow.python.eager import monitoring
77from tensorflow.python.eager.polymorphic_function import attributes as attributes_lib
78from tensorflow.python.eager.polymorphic_function import autograph_util
79from tensorflow.python.eager.polymorphic_function import compiler_ir
80from tensorflow.python.eager.polymorphic_function import eager_function_run
81from tensorflow.python.eager.polymorphic_function import function_spec as function_spec_lib
82from tensorflow.python.eager.polymorphic_function import tracing_compiler
83from tensorflow.python.framework import composite_tensor
84from tensorflow.python.framework import errors
85from tensorflow.python.framework import func_graph as func_graph_module
86from tensorflow.python.framework import ops
87from tensorflow.python.framework import tensor_spec
88from tensorflow.python.ops import array_ops_stack
89from tensorflow.python.ops import cond
90from tensorflow.python.ops import control_flow_ops
91from tensorflow.python.ops import control_flow_util
92from tensorflow.python.ops import math_ops
93from tensorflow.python.ops import resource_variable_ops
94from tensorflow.python.platform import tf_logging as logging
95from tensorflow.python.profiler import trace
96from tensorflow.python.trackable import base as trackable
97from tensorflow.python.types import core
98from tensorflow.python.util import deprecation
99from tensorflow.python.util import nest
100from tensorflow.python.util import object_identity
101from tensorflow.python.util import tf_decorator
102from tensorflow.python.util import traceback_utils
103from tensorflow.python.util.tf_export import tf_export
104
105FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY = 10
106FREQUENT_TRACING_WARNING_THRESHOLD = 5
107FREQUENT_TRACING_WARNING_MAX_WARNING_PER_DETECTOR = 2
108ALLOW_DYNAMIC_VARIABLE_CREATION = False
109
110
111def set_dynamic_variable_creation(is_allowed):
112 global ALLOW_DYNAMIC_VARIABLE_CREATION
113 ALLOW_DYNAMIC_VARIABLE_CREATION = is_allowed
114
115
116_tf_function_counter = monitoring.Counter(
117 "/tensorflow/core/tf_function_counter",
118 "Counter for the number of tf.functions created when Eager execution is "
119 "enabled.",
120 # jit_compile is "0" or "1".
121 "jit_compile")
122
123
124class _FrequentTracingDetector(object):
125 """Class keeping track of how many recent calls triggered tracing."""
126
127 __slots__ = ["_calls_per_tracings", "_call_count", "_total_warning_count"]
128
129 def __init__(self):
130 self._calls_per_tracings = []
131 self._total_warning_count = 0
132 self._call_count = 0
133
134 def called_with_tracing(self, function_name, omit_warning):
135 """Updates the list of most recent calls' tracing information.
136
137 Warns the user when recent calls caused retracing too often.
138
139 Args:
140 function_name: the python function being traced.
141 omit_warning: If 'True', this call will not warn the user even if
142 retracing happens too often.
143 """
144 self._call_count += 1
145 self._calls_per_tracings.append(1)
146
147 while self._calls_per_tracings:
148 if (self._call_count - self._calls_per_tracings[0] >
149 FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY):
150 self._call_count -= self._calls_per_tracings.pop(0)
151 else:
152 break
153
154 if (omit_warning or self._total_warning_count >=
155 FREQUENT_TRACING_WARNING_MAX_WARNING_PER_DETECTOR):
156 return
157 if len(self._calls_per_tracings) >= FREQUENT_TRACING_WARNING_THRESHOLD:
158 self._total_warning_count += 1
159 logging.warning(
160 "{} out of the last {} calls to {} triggered tf.function "
161 "retracing. Tracing is expensive and the excessive number of "
162 "tracings could be due to (1) creating @tf.function repeatedly in "
163 "a loop, (2) passing tensors with different shapes, (3) passing "
164 "Python objects instead of tensors. For (1), please define your "
165 "@tf.function outside of the loop. For (2), @tf.function has "
166 "reduce_retracing=True option that can avoid unnecessary "
167 "retracing. For (3), please refer to "
168 "https://www.tensorflow.org/guide/function#controlling_retracing"
169 " and https://www.tensorflow.org/api_docs/python/tf/function for "
170 " more details.".format(
171 len(self._calls_per_tracings), self._call_count, function_name))
172
173 def called_without_tracing(self):
174 # We don't count tracing when users load a concrete function directly or
175 # call get_concrete_function, so the first call can be not a tracing call.
176 if not self._calls_per_tracings:
177 self._calls_per_tracings = [0]
178 self._calls_per_tracings[-1] += 1
179 self._call_count += 1
180
181
182class _FrequentTracingDetectorManager(object):
183 """Class for the management of all _FrequentTracingDetector objects."""
184
185 __slots__ = ["_detectors", "_lock"]
186
187 def __init__(self):
188 self._detectors = weakref.WeakKeyDictionary() # GUARDED_BY(self._lock)
189 self._lock = threading.Lock()
190
191 def _get_detector(self, key):
192 if key not in self._detectors:
193 self._detectors[key] = _FrequentTracingDetector()
194 return self._detectors[key]
195
196 def called_without_tracing(self, key):
197 with self._lock:
198 detector = self._get_detector(key)
199 detector.called_without_tracing()
200
201 def called_with_tracing(self, key, function_name, omit_warning):
202 with self._lock:
203 detector = self._get_detector(key)
204 detector.called_with_tracing(function_name, omit_warning)
205
206
207_frequent_tracing_detector_manager = _FrequentTracingDetectorManager()
208
209
210class UnliftedInitializerVariable(resource_variable_ops.UninitializedVariable):
211 """Variable which does not lift its initializer out of function context.
212
213 Instances of this variable, when created, build a graph which runs their
214 initializer inside a tf.cond(is_initialized) block.
215
216 This can only be created inside a TracingCompiler called from
217 (eventually) eager mode. That is, non-function-building graphs are not
218 supported.
219 """
220
221 def __init__(
222 self,
223 initial_value=None,
224 trainable=None,
225 caching_device=None,
226 name=None,
227 dtype=None,
228 constraint=None,
229 add_initializers_to=None,
230 synchronization=None,
231 aggregation=None,
232 shape=None,
233 **unused_kwargs,
234 ):
235 """Creates a variable.
236
237 Args:
238 initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
239 which is the initial value for the Variable. The initial value must have
240 a shape specified unless `validate_shape` is set to False. Can also be a
241 callable with no argument that returns the initial value when called.
242 (Note that initializer functions from init_ops.py must first be bound to
243 a shape before being used here.)
244 trainable: If `True`, GradientTapes automatically watch uses of this
245 Variable.
246 caching_device: Optional device string or function describing where the
247 Variable should be cached for reading. Defaults to the Variable's
248 device. If not `None`, caches on another device. Typical use is to
249 cache on the device where the Ops using the Variable reside, to
250 deduplicate copying through `Switch` and other conditional statements.
251 name: Optional name for the variable. Defaults to `'Variable'` and gets
252 uniquified automatically.
253 dtype: If set, initial_value will be converted to the given type. If None,
254 either the datatype will be kept (if initial_value is a Tensor) or
255 float32 will be used (if it is a Python object convertible to a Tensor).
256 constraint: An optional projection function to be applied to the variable
257 after being updated by an `Optimizer` (e.g. used to implement norm
258 constraints or value constraints for layer weights). The function must
259 take as input the unprojected Tensor representing the value of the
260 variable and return the Tensor for the projected value (which must have
261 the same shape). Constraints are not safe to use when doing asynchronous
262 distributed training.
263 add_initializers_to: if not None and not in legacy graph mode, the
264 initializer tensor will be added to this map in addition to adding the
265 assignment to the function.
266 synchronization: Indicates when a distributed variable will be aggregated.
267 Accepted values are constants defined in the class
268 `tf.VariableSynchronization`. By default the synchronization is set to
269 `AUTO` and the current `DistributionStrategy` chooses when to
270 synchronize.
271 aggregation: Indicates how a distributed variable will be aggregated.
272 Accepted values are constants defined in the class
273 `tf.VariableAggregation`.
274 shape: (optional) The shape of this variable. If None, the shape of
275 `initial_value` will be used. When setting this argument to
276 `tf.TensorShape(None)` (representing an unspecified shape), the variable
277 can be assigned with values of different shapes.
278
279 Raises:
280 ValueError: If the initial value is not specified, or does not have a
281 shape and `validate_shape` is `True`.
282 RuntimeError: If called outside of a function definition.
283 """
284 with ops.init_scope():
285 self._in_graph_mode = not context.executing_eagerly()
286 if not ops.inside_function():
287 # If we've been init_scope()d out of the function definition nothing to do
288 # here; we can't really do the capturing or conditional logic.
289 resource_variable_ops.ResourceVariable.__init__(
290 self, initial_value=initial_value, trainable=trainable,
291 caching_device=caching_device, name=name, dtype=dtype,
292 constraint=constraint)
293 return
294 if initial_value is None:
295 raise ValueError("`initial_value` must be a Tensor or a Python "
296 "object convertible to a Tensor. Got None.")
297 init_from_fn = callable(initial_value)
298
299 if constraint is not None and not callable(constraint):
300 raise ValueError(f"`constraint` with type {type(constraint)} must be a "
301 "callable.")
302
303 with ops.name_scope(name, "Variable", []
304 if init_from_fn else [initial_value]) as scope_name:
305 with ops.name_scope("Initializer"):
306 if init_from_fn:
307 initial_value = initial_value()
308 if isinstance(initial_value, trackable.CheckpointInitialValue):
309 self._maybe_initialize_trackable()
310 self._update_uid = initial_value.checkpoint_position.restore_uid
311 initial_value = initial_value.wrapped_value
312
313 initial_value = ops.convert_to_tensor(initial_value,
314 name="initial_value", dtype=dtype)
315 assert initial_value is not None
316
317 # Don't use `shape or initial_value.shape` since TensorShape has
318 # overridden `__bool__`.
319 if shape is None:
320 shape = initial_value.shape
321
322 # Use the constructor for UninitializedVariable to start. Outside the name
323 # scope so we don't double up the prefix.
324 super().__init__(
325 trainable=trainable,
326 caching_device=caching_device,
327 name=name,
328 shape=shape,
329 dtype=initial_value.dtype,
330 constraint=constraint,
331 synchronization=synchronization,
332 aggregation=aggregation,
333 extra_handle_data=initial_value,
334 **unused_kwargs)
335
336 with ops.name_scope(scope_name):
337 if self._in_graph_mode:
338 with ops.init_scope():
339 outer_graph = ops.get_default_graph()
340 func_graph = ops.get_default_graph()
341 function_placeholders = (
342 func_graph.inputs + func_graph.internal_captures)
343 placeholder_ops = set(
344 [tensor.op for tensor in function_placeholders])
345 lifted_initializer = lift_to_graph.lift_to_graph(
346 [initial_value], outer_graph,
347 disallowed_placeholders=placeholder_ops)[initial_value]
348 with ops.init_scope():
349 self._initial_value = lifted_initializer
350 with ops.name_scope("IsInitialized"):
351 self._is_initialized_op = (
352 resource_variable_ops.var_is_initialized_op(self._handle))
353 if initial_value is not None:
354 with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
355 self._initializer_op = resource_variable_ops.assign_variable_op(
356 self._handle, lifted_initializer, name=n)
357 elif context.executing_eagerly():
358 # In this case, both current scope and init scope are eager.
359 # Assign_variable_op will be executed immediately. So we don't need to
360 # add it to "add_initializers_to" to lift it out.
361 with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
362 resource_variable_ops.assign_variable_op(
363 self._handle, initial_value, name=n)
364 else:
365 # Init scope is eager but current scope is graph. We will lift out this
366 # variable by addint it into "add_initializers_to".
367 if add_initializers_to is not None:
368 add_initializers_to.append((self, initial_value))
369
370 def assign_fn():
371 with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
372 resource_variable_ops.assign_variable_op(
373 self._handle,
374 initial_value,
375 name=n)
376 # Returning values to keep tf.cond happy.
377 return ops.convert_to_tensor(1)
378 def not_assign_fn():
379 return ops.convert_to_tensor(0)
380 # Note: this cond is always guaranteed to run because we're inside a
381 # TracingCompiler which will insert automatic control dependencies.
382 # It will only execute assign_fn if lifting failed.
383 graph = ops.get_default_graph()
384
385 # Capture the handle ahead of time in order to avoid querying the shape
386 # of the handle which helps async execution performance
387 graph.capture(self._handle, shape=())
388 cond.cond(
389 resource_variable_ops.var_is_initialized_op(self._handle),
390 not_assign_fn, assign_fn)
391
392
393JIT_COMPILE_FUNCTIONS = (
394 os.getenv("TF_FUNCTION_JIT_COMPILE_DEFAULT", "false").lower()
395 in ("true", "1"))
396
397
398def _evaluate_var_is_initialized(variables):
399 """Compute booleans indicating whether each variable is initialized."""
400 with ops.init_scope():
401 var_is_initialized = []
402 for v in variables:
403 var_is_initialized.append(
404 resource_variable_ops.var_is_initialized_op(v.handle))
405 try:
406 # Stack all the var_is_initialized values into one tensor and interpret
407 # the numpy value. This will reduce the number of RPCs between client and
408 # worker in the remote case.
409 return array_ops_stack.stack(var_is_initialized).numpy()
410 except errors.UnimplementedError:
411 # Some devices do not support implicit copy-off to host. Fall back to
412 # variable-by-variable processing.
413 for index, v in enumerate(variables):
414 try:
415 numpy_value = var_is_initialized[index].numpy()
416 except errors.UnimplementedError:
417 # This is a variable on a parallel device; we'll extract its value on
418 # each replica and assert that they're identical.
419 components = parallel_device.unpack(var_is_initialized[index])
420 with ops.device(None):
421 components = array_ops_stack.stack(components)
422 all_initialized = math_ops.reduce_all(components).numpy()
423 any_initialized = math_ops.reduce_any(components).numpy()
424 if all_initialized != any_initialized:
425 raise NotImplementedError(
426 f"Some but not all components of a parallel variable {v!r} "
427 "were initialized between their creation in a tf.function and "
428 "the function's trace having completed. This is not "
429 "supported; consider initializing either all or none of the "
430 "components, or moving initialization out of the function.")
431 numpy_value = all_initialized
432 var_is_initialized[index] = numpy_value
433 return var_is_initialized
434
435
436class OptionalXlaContext:
437 """Wrapper for XLA context optionally applied under a context manager."""
438
439 def __init__(self, is_compiled):
440 wrap = is_compiled and not control_flow_util.GraphOrParentsInXlaContext( \
441 ops.get_default_graph())
442 self.xla_context = control_flow_ops.XLAControlFlowContext() \
443 if wrap else None
444
445 def __enter__(self):
446 if self.xla_context:
447 self.xla_context.Enter()
448
449 def __exit__(self, t, value, traceback):
450 if self.xla_context:
451 self.xla_context.Exit()
452
453
454# TODO(mdan): Consider expose this type for instance type checking.
455@tf_export("__internal__.function.Function", v1=[])
456class Function(core.GenericFunction, trackable.Trackable):
457 """A `tf.types.experimental.GenericFunction` created by `tf.function`.
458
459 Currently, individual methods/attributes under this class are not guaranteed
460 by the TF API contract, and are subject to future changes.
461 """
462
463 def __init__(self,
464 python_function,
465 name,
466 input_signature=None,
467 autograph=True,
468 jit_compile=None,
469 reduce_retracing=False,
470 experimental_implements=None,
471 experimental_autograph_options=None,
472 experimental_attributes=None,):
473 """Initializes a `Function`.
474
475 Args:
476 python_function: the function to be wrapped.
477 name: the name given to it.
478 input_signature: See the documentation for `tf.function`.
479 autograph: See the documentation for `tf.function`.
480 jit_compile: See the documentation for `tf.function`.
481 reduce_retracing: See the documentation for `tf.function`.
482 experimental_implements: See the documentation for `tf.function`.
483 experimental_autograph_options: See the documentation for `tf.function`.
484 experimental_attributes: See the documentation for `tf.function`.
485
486 Raises:
487 ValueError: if `input_signature` is not None and the `python_function`'s
488 argspec has keyword arguments.
489 """
490 self._lock = threading.RLock()
491 self._python_function = python_function
492 self._function_spec = function_spec_lib.FunctionSpec.from_function_and_signature(
493 python_function,
494 input_signature,
495 jit_compile=jit_compile,
496 )
497
498 self._attributes = {}
499 if experimental_implements is not None:
500 self._attributes = self._create_implements_attribute(
501 experimental_implements
502 )
503
504 if experimental_attributes is not None:
505 self._attributes.update(experimental_attributes)
506
507 for attribute in self._attributes:
508 if attribute not in attributes_lib.POLYMORPHIC_FUNCTION_ALLOWLIST:
509 raise ValueError(
510 f"`{attribute} is not supported by tf.function as an attribute."
511 )
512
513 # If `True`, the function uses the rendezvous of the parent. This is only
514 # needed to support code where raw send/recv operations are inserted and
515 # when functions are run in graph mode where they may not be inlined.
516 self._shared_rendezvous = None
517 self._autograph = autograph
518 self._experimental_autograph_options = experimental_autograph_options
519 self._reduce_retracing = reduce_retracing
520 self._jit_compile = jit_compile
521 self._created_variables = None # GUARDED_BY(self._lock)
522 self._variable_creation_fn = None # GUARDED_BY(self._lock)
523 self._no_variable_creation_fn = None # GUARDED_BY(self._lock)
524 self._descriptor_cache = weakref.WeakKeyDictionary()
525 self._name = name
526 self._key_for_call_stats = self._get_key_for_call_stats()
527 self._omit_frequent_tracing_warning = False
528 ops._tf_function_api_gauge.get_cell().set(True) # pylint: disable=protected-access
529
530 @property
531 def name(self):
532 return self._name
533
534 def __getstate__(self):
535 """Custom pickling, to omit unpickleable objects."""
536 result = self.__dict__.copy()
537 del result["_lock"]
538 del result["_descriptor_cache"]
539 del result["_key_for_call_stats"]
540 return result
541
542 def __setstate__(self, state):
543 """Restore from pickled state."""
544 self.__dict__ = state
545 self._lock = threading.RLock()
546 self._descriptor_cache = weakref.WeakKeyDictionary()
547 self._key_for_call_stats = self._get_key_for_call_stats()
548
549 def _get_key_for_call_stats(self):
550 """Returns key instance to track call stats and retracings.
551
552 The key instance a best-effort to preserve global consistency.
553 """
554 target_function = self._python_function
555 # `__wrapped__` is a conventional Python attribute that a higher-order
556 # function keeps its original function's instance. We also directly use
557 # this attribute for dealing with a class method. See
558 # `bound_method_wrapper` in `function.py`. If we don't use `__wrapped__`,
559 # all class methods will return the same `bound_method_wrapper` instance
560 # from this function.
561 while hasattr(target_function, "__wrapped__"):
562 target_function = target_function.__wrapped__
563
564 if hasattr(target_function, "__func__"):
565 target_function = target_function.__func__
566
567 if hasattr(target_function, "__code__"):
568 return target_function.__code__
569
570 return self._python_function
571
572 def _compiler_with_scope(self, scope):
573 """Creates a TracingCompiler wrapped inside a variable creator scope."""
574
575 weak_wrapped_fn = None
576 compile_with_xla = self._jit_compile
577
578 def wrapped_fn(*args, **kwds):
579 """Wraps `self._python_function` in a variable creator scope."""
580 # We register a variable creator with reduced priority. If an outer
581 # variable creator is just modifying keyword arguments to the variable
582 # constructor, this will work harmoniously. Since the `scope` registered
583 # here actually creates the variable, it taking priority would otherwise
584 # ignore the outer creator.
585 #
586 # If an outer variable creator calls the variable constructor manually,
587 # for example creating a MirroredVariable, then they won't call our
588 # creator. This means we won't be able to trace the initialization graph,
589 # and so variable initializers can't depend on function arguments. This is
590 # better than the alternative, tracing the initialization graph but giving
591 # the user a variable type they didn't want.
592 default_graph = ops.get_default_graph()
593 with default_graph._variable_creator_scope(scope, priority=50): # pylint: disable=protected-access
594 # __wrapped__ allows AutoGraph to swap in a converted function. We give
595 # the function a weak reference to itself to avoid a reference cycle.
596 with OptionalXlaContext(compile_with_xla):
597 out = weak_wrapped_fn().__wrapped__(*args, **kwds)
598 return out
599
600 weak_wrapped_fn = weakref.ref(wrapped_fn)
601
602 return self._compiler(tf_decorator.make_decorator(
603 self._python_function,
604 wrapped_fn))
605
606 def _create_implements_attribute(self, implements_arg):
607 """Creates the attribute value corresponding to attribute_lib.IMPLEMENTS."""
608 attributes = {}
609 if isinstance(implements_arg, str):
610 # First check if the attribute_lib.IMPLEMENTS is specified as a
611 # NameAttrList. This is used when apart from the function name being
612 # implemented, a list of attributes is also being specified.
613 # The attributes are specified as key-value pairs in the NameAttrList
614 # of the corresponding AttrValue. The function name will be in the
615 # 'name' field of the NameAttrList. Else, it is just a string
616 # corresponding to the function name.
617 try:
618 attr_value = attr_value_pb2.AttrValue()
619 nameattrlist = attr_value_pb2.NameAttrList()
620 _text_format.Merge(implements_arg, nameattrlist)
621 attr_value.func.CopyFrom(nameattrlist)
622 attributes[attributes_lib.IMPLEMENTS] = attr_value
623 except (_text_format.ParseError, DecodeError):
624 attributes[attributes_lib.IMPLEMENTS] = implements_arg
625 return attributes
626
627 def _compiler(self, fn):
628 """Returns a TracingCompiler generated from the input function."""
629 attributes = self._attributes.copy()
630
631 share = self._shared_rendezvous
632 if share is not None:
633 attributes[attributes_lib.SHARED_RENDEZVOUS] = share
634
635 if self._jit_compile is not None:
636 attributes[attributes_lib.XLA_COMPILE] = bool(self._jit_compile)
637 if self._jit_compile:
638 attributes[attributes_lib.NO_INLINE] = True
639
640 try:
641 name = fn.__name__
642 except AttributeError:
643 name = "function"
644
645 if self._autograph:
646 fn = autograph_util.py_func_from_autograph(
647 fn, self._experimental_autograph_options)
648
649 return tracing_compiler.TracingCompiler(
650 fn,
651 name,
652 input_signature=self.input_signature,
653 attributes=attributes,
654 autograph=self._autograph,
655 jit_compile=self._jit_compile,
656 reduce_retracing=self._reduce_retracing,
657 autograph_options=self._experimental_autograph_options)
658
659 def _initialize(self, args, kwds, add_initializers_to=None):
660 """Initializes, on the first call.
661
662 Creates two `Function`s, one that will allow creation of variables
663 and one that won't.
664
665 Additionally runs a trace for the `Function` that allows creation
666 of variables.
667
668 Args:
669 args: Arguments to the underlying python callable.
670 kwds: Keyword arguments to the python callable.
671 add_initializers_to: Where to collect variable initializers, if not None.
672 """
673 created_variables = []
674
675 def variable_capturing_scope(next_creator, **kwds):
676 """Creates UnliftedInitializerVariables and saves references to them."""
677 enable_variable_lifting = kwds.get("experimental_enable_variable_lifting")
678 if enable_variable_lifting is None:
679 enable_variable_lifting = True
680 if not enable_variable_lifting:
681 return next_creator(**kwds)
682 v = UnliftedInitializerVariable(
683 add_initializers_to=add_initializers_to, **kwds
684 )
685 created_variables.append(weakref.ref(v))
686 return v
687
688 self._created_variables = created_variables
689 self._variable_creation_fn = self._compiler_with_scope(
690 variable_capturing_scope)
691 self._variable_creation_fn._name = self._name # pylint: disable=protected-access
692 # Force the definition of the function for these arguments
693 self._concrete_variable_creation_fn = (
694 self._variable_creation_fn # pylint: disable=protected-access
695 ._get_concrete_function_internal_garbage_collected(
696 *args, **kwds))
697
698 def invalid_creator_scope(*unused_args, **unused_kwds):
699 """Disables variable creation."""
700 raise ValueError(
701 "tf.function only supports singleton tf.Variables created on the "
702 "first call. Make sure the tf.Variable is only created once or "
703 "created outside tf.function. See "
704 "https://www.tensorflow.org/guide/function#creating_tfvariables "
705 "for more information.")
706
707 self._no_variable_creation_fn = self._compiler_with_scope(
708 invalid_creator_scope)
709 self._no_variable_creation_fn._name = self._name # pylint: disable=protected-access
710
711 def _clone(self, python_function):
712 """Clone the function with different python function."""
713 f = Function(
714 python_function=(self._python_function
715 if python_function is None else python_function),
716 name=self._name,
717 input_signature=self.input_signature,
718 autograph=self._autograph,
719 jit_compile=self._jit_compile,
720 reduce_retracing=self._reduce_retracing,
721 experimental_attributes=self._attributes,
722 experimental_autograph_options=self._experimental_autograph_options)
723
724 if self._shared_rendezvous:
725 f._shared_rendezvous = self._shared_rendezvous # pylint: disable=protected-access
726
727 return f
728
729 def _decorate(self, decorator):
730 """Allows the captured Python function to be decorated in place.
731
732 This method is only safe to call when the Function has not been called by a
733 user. It makes sense to use this method to push a decorator into the
734 function rather than wrapping the function in the decorator.
735
736 We use this in tf.Module to allow user annotated `tf.functions` to remain as
737 `Function` objects but still automatically enter the Module name_scope
738 when they are evaluated like all other methods.
739
740 Args:
741 decorator: A callable accepting a single argument which is the function
742 to decorate and returning a callable result.
743
744 Raises:
745 ValueError: If the function has been called a ValueError is raised.
746 """
747 if self._variable_creation_fn is not None or self._no_variable_creation_fn is not None:
748 raise ValueError(
749 "Functions cannot be decorated after they have been traced.")
750
751 self._python_function = decorator(self._python_function)
752 self._function_spec = function_spec_lib.FunctionSpec.from_function_and_signature(
753 self._python_function, self.input_signature)
754
755 # TODO: Remove this private method after updating all its uses
756 # A good moment to do this could be when the experimental label is removed
757 def _get_tracing_count(self):
758 return self.experimental_get_tracing_count()
759
760 def experimental_get_tracing_count(self):
761 """Returns the number of times the function has been traced.
762
763 For more information on when a function is traced and when it is
764 traced multiple times see https://www.tensorflow.org/guide/function.
765 Example:
766
767 >>> @tf.function
768 ... def double(a):
769 ... return a + a
770 >>> double(tf.constant(1))
771 >>> double(tf.constant(2))
772 >>> double.experimental_get_tracing_count()
773 1
774 >>> double(tf.constant("a"))
775 >>> double.experimental_get_tracing_count()
776 2
777
778
779 The first time experimental_get_tracing_count is called
780 it returns 1, as the function is traced the first
781 time it is called, and the second time the same graph is used
782 since we're calling it with a parameter of the same type.
783
784 The second time experimental_get_tracing_count is called
785 it returns 2, as we called double with a
786 different argument type, and so it was traced again.
787
788 """
789 result = self._no_variable_creation_fn.tracing_count if self._no_variable_creation_fn else 0
790 result += self._variable_creation_fn.tracing_count if self._variable_creation_fn else 0
791 return result
792
793 @property
794 def _run_functions_eagerly(self):
795 return eager_function_run.RUN_FUNCTIONS_EAGERLY
796
797 @traceback_utils.filter_traceback
798 def __call__(self, *args, **kwds):
799 # Implements GenericFunction.__call__.
800 if self._run_functions_eagerly:
801 with trace.Trace(self._name, tf_function_call="eager"):
802 return self._python_function(*args, **kwds)
803
804 # Only count the statistics the first time, before initialization took
805 # place.
806 if self._created_variables is None:
807 compiled = bool(self._jit_compile and
808 not control_flow_util.GraphOrParentsInXlaContext(
809 ops.get_default_graph()))
810 # For nested functions, increment the counter only when a function with
811 # jit_compile=True is called within a function with jit_compile=False. We
812 # count this special case to correctly record that both jit_compile=True
813 # and jit_compile=False is being used for parts of the outer function.
814 if ops.executing_eagerly_outside_functions() and (
815 context.executing_eagerly() or compiled):
816 # Labels must be strings in Python, so we convert 'compiled' to a string
817 _tf_function_counter.get_cell(str(int(compiled))).increase_by(1)
818
819 tracing_count = self.experimental_get_tracing_count()
820 with trace.Trace(self._name) as tm:
821 # TODO(cheshire): Do not duplicate the XLAControlFlowContext annotation.
822 compiler = "xla" if self._jit_compile else "nonXla"
823
824 with OptionalXlaContext(self._jit_compile):
825 result = self._call(*args, **kwds)
826
827 new_tracing_count = self.experimental_get_tracing_count()
828 without_tracing = (tracing_count == new_tracing_count)
829 execution_mode = "notTraced" if without_tracing else "traced"
830 tm.set_metadata(tf_function_call=execution_mode + "-" + compiler,
831 tracing_count=new_tracing_count)
832
833 if context.executing_eagerly():
834 if without_tracing:
835 _frequent_tracing_detector_manager.called_without_tracing(
836 self._key_for_call_stats)
837 else:
838 _frequent_tracing_detector_manager.called_with_tracing(
839 self._key_for_call_stats, self._python_function,
840 self._omit_frequent_tracing_warning)
841
842 return result
843
844 def _call(self, *args, **kwds):
845 """Calls the graph function."""
846 self._lock.acquire()
847 if ALLOW_DYNAMIC_VARIABLE_CREATION:
848 condition = self._created_variables and self._variable_creation_fn is None
849 else:
850 condition = self._created_variables
851 if condition:
852 # Release the lock early so that multiple threads can perform the call
853 # in parallel.
854 self._lock.release()
855 # In this case we have created variables on the first call, so we run the
856 # defunned version which is guaranteed to never create variables.
857 return self._no_variable_creation_fn(*args, **kwds) # pylint: disable=not-callable
858 elif self._variable_creation_fn is not None:
859 # Release the lock early so that multiple threads can perform the call
860 # in parallel.
861 self._lock.release()
862 # In this case we have not created variables on the first call. So we can
863 # run the first trace but we should fail if variables are created.
864 results = self._variable_creation_fn(*args, **kwds)
865 if self._created_variables and not ALLOW_DYNAMIC_VARIABLE_CREATION:
866 raise ValueError("Creating variables on a non-first call to a function"
867 " decorated with tf.function.")
868 return results
869
870 try:
871 # This is the first call of __call__, so we have to initialize.
872 initializers = []
873 self._initialize(args, kwds, add_initializers_to=initializers)
874 finally:
875 # At this point we know that the initialization is complete (or less
876 # interestingly an exception was raised) so we no longer need a lock.
877 self._lock.release()
878
879 if self._created_variables:
880 try:
881 # Attempt to initialize variables eagerly and without conds by lifting
882 # out initialization graphs. This is the only initialization strategy
883 # compatible with XLA at the moment.
884 self._initialize_uninitialized_variables(initializers)
885 except lift_to_graph.UnliftableError:
886 pass # Fall through to cond-based initialization.
887 else:
888 # Lifting succeeded, so variables are initialized and we can run the
889 # no_variable_creation function.
890 return self._no_variable_creation_fn(*args, **kwds)
891 else:
892 _, _, filtered_flat_args = (
893 self._variable_creation_fn._function_spec # pylint: disable=protected-access
894 .canonicalize_function_inputs(
895 args, kwds))
896 # If we did not create any variables the trace we have is good enough.
897 return self._concrete_variable_creation_fn._call_flat( # pylint: disable=protected-access
898 filtered_flat_args,
899 self._concrete_variable_creation_fn.captured_inputs)
900
901 def fn_with_cond(inner_args, inner_kwds, inner_filtered_flat_args):
902 """Conditionally runs initialization if it's needed."""
903 condition = True
904 for v, _ in initializers:
905 condition = math_ops.logical_and(
906 condition, resource_variable_ops.var_is_initialized_op(
907 v.handle))
908 # We want to call no_variable_creation if possible because it avoids
909 # recomputing potentially expensive initializers.
910 return cond.cond(
911 condition,
912 lambda: self._no_variable_creation_fn(*inner_args, **inner_kwds),
913 functools.partial(
914 self._concrete_variable_creation_fn._call_flat, # pylint: disable=protected-access
915 inner_filtered_flat_args,
916 captured_inputs=self._concrete_variable_creation_fn
917 .captured_inputs))
918
919 # We've created variables and are unable to lift the initialization graphs,
920 # so we fall back to initializing with conds while running the function.
921 # TODO(b/216870587) Note that this path is not currently supported for XLA.
922 if self._jit_compile:
923 raise errors.UnimplementedError(
924 None, None,
925 "We failed to lift variable creations out of this tf.function, "
926 "so this tf.function cannot be run on XLA. A possible workaround is "
927 "to move variable creation outside of the XLA compiled function.")
928 canon_args, canon_kwds, filtered_flat_args = (
929 self._variable_creation_fn._function_spec.canonicalize_function_inputs( # pylint: disable=protected-access
930 args, kwds))
931 return tracing_compiler.TracingCompiler(
932 fn_with_cond, "fn_with_cond")(canon_args, canon_kwds,
933 filtered_flat_args)
934
935 def experimental_get_compiler_ir(self, *args, **kwargs):
936 # Implements GenericFunction.experimental_get_compiler_ir
937 context.ensure_initialized()
938 if not self._jit_compile:
939 raise ValueError("Compiler IR can only be returned for functions marked "
940 "with 'jit_compile=True'")
941
942 is_tensor_spec = lambda x: isinstance(x, tensor_spec.TensorSpec)
943
944 def _check_inputs(args, kwargs):
945 all_inputs = list(args) + list(kwargs.values())
946 # Emtpy input is okay.
947 if not all_inputs:
948 return
949 if any(map(is_tensor_spec, all_inputs)) and any(
950 map(lambda x: not is_tensor_spec(x), all_inputs)
951 ):
952 raise ValueError(
953 "experimental_get_compiler_ir supports either "
954 "(1) all inputs are TensorSpec or "
955 "(2) all inputs are tf.Tensor/python variables"
956 )
957
958 _check_inputs(args, kwargs)
959 if (
960 len(args) + len(kwargs.values()) > 0
961 and all(map(is_tensor_spec, args))
962 and all(map(is_tensor_spec, kwargs.values()))
963 ):
964 # For the case inputs are not empty and input types are all tf.TensorSpec
965 concrete_fn = self.get_concrete_function(*args, **kwargs)
966 return compiler_ir.from_concrete_function(concrete_fn)
967
968 concrete_fn = self.get_concrete_function(*args, **kwargs)
969 fn_name = concrete_fn.name
970
971 # pylint: disable=protected-access
972 _, _, filtered_flat_args = (
973 concrete_fn._function_spec.canonicalize_function_inputs(args, kwargs))
974
975 def compiler_ir_generator(stage="hlo", device_name=None):
976 device_name = compiler_ir.maybe_get_device_name(device_name)
977 res_bytes = context.context().get_compiler_ir(
978 device_name=device_name,
979 function_name=fn_name,
980 flat_args=list(filtered_flat_args),
981 captured_inputs=concrete_fn.captured_inputs,
982 stage=stage,
983 )
984 if stage in ("hlo_serialized", "optimized_hlo_serialized",
985 "optimized_hlo_proto_serialized"):
986 return res_bytes
987 else:
988 return res_bytes.decode("utf-8")
989
990 return compiler_ir_generator
991
992 @property
993 def python_function(self):
994 """The python function wrapped in this tf.function."""
995 return self._python_function
996
997 @property
998 def input_signature(self):
999 return self._function_spec.input_signature
1000
1001 @property
1002 def function_spec(self):
1003 return self._function_spec
1004
1005 def pretty_printed_concrete_signatures(self, verbose=True):
1006 joiner = "\n\n" if verbose else "\n"
1007 return joiner.join([
1008 c.pretty_printed_signature(verbose=verbose)
1009 for c in self._list_all_concrete_functions()
1010 ])
1011
1012 def _initialize_uninitialized_variables(self, initializers):
1013 """Make and call a `ConcreteFunction` which initializes variables."""
1014
1015 if not initializers:
1016 return
1017
1018 var_is_initialized = _evaluate_var_is_initialized(
1019 [v for v, _ in initializers])
1020
1021 def initialize_variables():
1022 op_map = object_identity.ObjectIdentityDictionary()
1023
1024 inits = []
1025 for (v, init), is_initialized in zip(initializers, var_is_initialized):
1026 with ops.init_scope():
1027 if is_initialized:
1028 continue
1029 inits.append(init)
1030
1031 if inits:
1032 op_map = lift_to_graph.lift_to_graph(
1033 inits, ops.get_default_graph(), op_map=op_map)
1034 for (v, init), is_initialized in zip(initializers, var_is_initialized):
1035 with ops.init_scope():
1036 if is_initialized:
1037 continue
1038 v.assign(op_map[init], read_value=False)
1039
1040 with ops.init_scope():
1041 # Note: using TracingCompiler here avoids an infinite recursion.
1042 # Most of the code in this function runs eagerly with init_scope, where
1043 # autograph is not necessary.
1044 return tracing_compiler.TracingCompiler(
1045 initialize_variables, "initialize_variables",
1046 autograph=False).get_concrete_function()()
1047
1048 def get_initialization_function(self, *args, **kwargs):
1049 """Returns a `ConcreteFunction` which initializes this function's variables.
1050
1051 Requires that this function hasn't been accessed yet through either calling
1052 it or calling get_concrete_function. Fails if we cannot build an initializer
1053 function which does not depend on the concrete values of the inputs to this
1054 function.
1055
1056 Note that running this function will overwrite any values currently assigned
1057 to variables, for example restores from a checkpoint.
1058
1059 Args:
1060 *args: arguments to the underlying python callable.
1061 **kwargs: keyword arguments to the python callable.
1062
1063 Returns:
1064 A `ConcreteFunction` object which initializes the variables of this
1065 function.
1066
1067 Raises:
1068 RuntimeError: if called after the variables have been initialized.
1069 """
1070 with self._lock:
1071 if self._variable_creation_fn is not None:
1072 raise RuntimeError(
1073 "get_initialization_function cannot be called after the function "
1074 "has been used")
1075 # Here we trace the function, collect the initializers, and attempt to
1076 # extract them and run them eagerly. Fail only if we cannot do so.
1077 initializers = []
1078 self._initialize(args, kwargs, add_initializers_to=initializers)
1079
1080 def initialize_variables():
1081 for v, init in initializers:
1082 v.assign(
1083 lift_to_graph.lift_to_graph([init], ops.get_default_graph())[init],
1084 read_value=False)
1085
1086 # Note: using TracingCompiler here avoids an infinite recursion.
1087 return tracing_compiler.TracingCompiler(
1088 initialize_variables, "initialize_variables").get_concrete_function()
1089
1090 def _list_all_concrete_functions(self):
1091 """Returns all concrete functions."""
1092 if self.input_signature is not None:
1093 self.get_concrete_function()
1094 concrete_functions = []
1095 # pylint: disable=protected-access
1096 if self._variable_creation_fn:
1097 concrete_functions.extend(
1098 self._variable_creation_fn._list_all_concrete_functions())
1099 if self._no_variable_creation_fn:
1100 concrete_functions.extend(
1101 self._no_variable_creation_fn._list_all_concrete_functions())
1102 # pylint: enable=protected-access
1103 return concrete_functions
1104
1105 def _list_all_concrete_functions_for_serialization(self):
1106 """Returns all concrete functions for serialization.
1107
1108 Returns:
1109 A list of instances of `ConcreteFunction`.
1110 """
1111 seen_signatures = []
1112 if self.input_signature is not None:
1113 seen_signatures.append((self.input_signature, {}))
1114 else:
1115 concrete_functions = self._list_all_concrete_functions()
1116 for concrete_function in concrete_functions:
1117 signature = concrete_function.structured_input_signature
1118 flattened = nest.flatten(signature)
1119 if any(
1120 isinstance(arg, func_graph_module.UnknownArgument)
1121 for arg in flattened):
1122 logging.info("Unsupported signature for serialization: %s.",
1123 signature)
1124 continue
1125 equal_to_signature = functools.partial(
1126 function_spec_lib.is_same_structure, signature, check_values=True)
1127 if not any(equal_to_signature(s) for s in seen_signatures):
1128 seen_signatures.append(signature)
1129
1130 # Re-create concrete functions for these signatures. Re-creating ensures
1131 # that if the cache key has changed, the function will be traced again.
1132 concrete_functions = []
1133 for args, kwargs in seen_signatures:
1134 concrete_functions.append(self.get_concrete_function(*args, **kwargs))
1135 return concrete_functions
1136
1137 def _trackable_children(self, save_type="checkpoint", **kwargs):
1138 """For implementing `Trackable`."""
1139 if save_type == "checkpoint":
1140 return {}
1141 return {f"trace_{n}": fn for n, fn in
1142 enumerate(self._list_all_concrete_functions_for_serialization())}
1143
1144 def _deserialization_dependencies(self, children):
1145 """Returns concrete functions which must be loaded before this object."""
1146 return children
1147
1148 def _get_concrete_function_garbage_collected(self, *args, **kwargs):
1149 """Returns a `ConcreteFunction` specialized to inputs and execution context.
1150
1151 Unlike `get_concrete_function(...)`, the graph will be deleted when the
1152 returned function is deleted. It's useful to avoid creating a reference
1153 cycle when you know for sure that the graph will be no longer used without
1154 the returned function.
1155
1156 Args:
1157 *args: inputs to specialize on.
1158 **kwargs: inputs to specialize on.
1159
1160 Returns:
1161 A TensorFlow function which takes exactly one `tf.Tensor` per argument.
1162
1163 Raises:
1164 ValueError: if this object has not yet been called on concrete values.
1165 """
1166 with self._lock:
1167 if self._variable_creation_fn is None:
1168 initializers = []
1169 self._initialize(args, kwargs, add_initializers_to=initializers)
1170 self._initialize_uninitialized_variables(initializers)
1171
1172 if self._created_variables:
1173 # In this case we have created variables on the first call, so we run the
1174 # version which is guaranteed to never create variables.
1175 return self._no_variable_creation_fn._get_concrete_function_garbage_collected( # pylint: disable=protected-access
1176 *args, **kwargs)
1177 elif self._variable_creation_fn is not None:
1178 # In this case we have not created variables on the first call. So we can
1179 # run the first trace but we should fail if variables are created.
1180 concrete = self._variable_creation_fn._get_concrete_function_garbage_collected( # pylint: disable=protected-access
1181 *args, **kwargs)
1182 if self._created_variables:
1183 raise ValueError("Creating variables on a non-first call to a function"
1184 " decorated with tf.function.")
1185 return concrete
1186
1187 def get_concrete_function(self, *args, **kwargs):
1188 # Implements GenericFunction.get_concrete_function.
1189 concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
1190 concrete._garbage_collector.release() # pylint: disable=protected-access
1191 return concrete
1192
1193 def __tf_tracing_type__(self, _):
1194 return trace_type.Weakref(weakref.ref(self))
1195
1196 def __get__(self, instance, owner):
1197 """Makes it possible to decorate instance methods."""
1198 del owner
1199 # `instance` here is the instance that this `Function` was accessed through
1200 # e.g., for
1201 #
1202 # class Foo:
1203 #
1204 # @tf.function
1205 # def bar(self):
1206 # ...
1207 #
1208 # foo = Foo()
1209 # foo.bar() # `foo.bar` is a `Function` instance
1210 #
1211 # then `instance` will be `foo` (and `owner` will be `Foo`). For composite
1212 # tensors, we can just treat `instance` as a normal parameter. But for
1213 # other types, we create a new instance of `Function` here to allow
1214 # different instances each to create variables once, thereby allowing
1215 # methods to be decorated with tf.function. Keeps a cache to avoid retracing
1216 # the function every time the descriptor is accessed.
1217 # TODO(mdan): Identify types which can just be parameters more generically.
1218 #
1219 # The check for instance._type_spec=None is used because certain classes
1220 # (including subclasses of tf.linalg.LinearOperator) are subclasses of
1221 # CompositeTensor but do not actually implement the required APIs.
1222 # TODO(b/199278478): Fix those classes, then remove the check for
1223 # `instance._type_spec is not None`.
1224 if (isinstance(instance, composite_tensor.CompositeTensor) and
1225 instance._type_spec is not None): # pylint: disable=protected-access
1226 return types_lib.MethodType(self, instance)
1227 if instance not in self._descriptor_cache:
1228 if instance is None:
1229 return self
1230 # TODO(mdan): If the CompositeTensor path works, do the same here.
1231 # It's unclear whether we need the tf-decorator, or could just call
1232 # MethodType(self.clone(), instance)
1233 self._descriptor_cache[instance] = (
1234 tracing_compiler.class_method_to_instance_method(self, instance))
1235 return self._descriptor_cache[instance]
1236
1237
1238@tf_export("function")
1239@deprecation.deprecated_args(None,
1240 "experimental_compile is deprecated, use "
1241 "jit_compile instead", "experimental_compile")
1242@deprecation.deprecated_args(None,
1243 "experimental_relax_shapes is deprecated, use "
1244 "reduce_retracing instead",
1245 "experimental_relax_shapes")
1246@deprecation.deprecated_args(None,
1247 "experimental_follow_type_hints is deprecated",
1248 "experimental_follow_type_hints")
1249def function(
1250 func=None,
1251 input_signature=None,
1252 autograph=True,
1253 jit_compile=None,
1254 reduce_retracing=False,
1255 experimental_implements=None,
1256 experimental_autograph_options=None,
1257 experimental_attributes=None,
1258 experimental_relax_shapes=None,
1259 experimental_compile=None,
1260 experimental_follow_type_hints=None # pylint: disable=unused-argument
1261) -> core.GenericFunction:
1262 """Compiles a function into a callable TensorFlow graph.
1263
1264 `tf.function` constructs a `tf.types.experimental.GenericFunction` that
1265 executes a TensorFlow graph (`tf.Graph`) created by trace-compiling the
1266 TensorFlow operations in `func`. More information on the topic can be found
1267 in [Introduction to Graphs and tf.function]
1268 (https://www.tensorflow.org/guide/intro_to_graphs).
1269
1270 See [Better Performance with tf.function]
1271 (https://www.tensorflow.org/guide/function) for tips on performance and
1272 known limitations.
1273
1274 Example usage:
1275
1276 >>> @tf.function
1277 ... def f(x, y):
1278 ... return x ** 2 + y
1279 >>> x = tf.constant([2, 3])
1280 >>> y = tf.constant([3, -2])
1281 >>> f(x, y)
1282 <tf.Tensor: ... numpy=array([7, 7], ...)>
1283
1284 The trace-compilation allows non-TensorFlow operations to execute, but under
1285 special conditions. In general, only TensorFlow operations are guaranteed to
1286 run and create fresh results whenever the `GenericFunction` is called.
1287
1288 ## Features
1289
1290 `func` may use data-dependent Python control flow statements, including `if`,
1291 `for`, `while` `break`, `continue` and `return`:
1292
1293 >>> @tf.function
1294 ... def f(x):
1295 ... if tf.reduce_sum(x) > 0:
1296 ... return x * x
1297 ... else:
1298 ... return -x // 2
1299 >>> f(tf.constant(-2))
1300 <tf.Tensor: ... numpy=1>
1301
1302 `func`'s closure may include `tf.Tensor` and `tf.Variable` objects:
1303
1304 >>> @tf.function
1305 ... def f():
1306 ... return x ** 2 + y
1307 >>> x = tf.constant([-2, -3])
1308 >>> y = tf.Variable([3, -2])
1309 >>> f()
1310 <tf.Tensor: ... numpy=array([7, 7], ...)>
1311
1312 `func` may also use ops with side effects, such as `tf.print`, `tf.Variable`
1313 and others:
1314
1315 >>> v = tf.Variable(1)
1316 >>> @tf.function
1317 ... def f(x):
1318 ... for i in tf.range(x):
1319 ... v.assign_add(i)
1320 >>> f(3)
1321 >>> v
1322 <tf.Variable ... numpy=4>
1323
1324 Important: Any Python side-effects (appending to a list, printing with
1325 `print`, etc) will only happen once, when `func` is traced. To have
1326 side-effects executed into your `tf.function` they need to be written
1327 as TF ops:
1328
1329 >>> l = []
1330 >>> @tf.function
1331 ... def f(x):
1332 ... for i in x:
1333 ... l.append(i + 1) # Caution! Will only happen once when tracing
1334 >>> f(tf.constant([1, 2, 3]))
1335 >>> l
1336 [<tf.Tensor ...>]
1337
1338 Instead, use TensorFlow collections like `tf.TensorArray`:
1339
1340 >>> @tf.function
1341 ... def f(x):
1342 ... ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
1343 ... for i in range(len(x)):
1344 ... ta = ta.write(i, x[i] + 1)
1345 ... return ta.stack()
1346 >>> f(tf.constant([1, 2, 3]))
1347 <tf.Tensor: ..., numpy=array([2, 3, 4], ...)>
1348
1349 ## `tf.function` creates polymorphic callables
1350
1351 Internally, `tf.types.experimental.GenericFunction` may contain multiple
1352 `tf.types.experimental.ConcreteFunction`s, each specialized to arguments with
1353 different data types or shapes, since TensorFlow can perform more
1354 optimizations on graphs of specific shapes, dtypes and values of constant
1355 arguments. `tf.function` treats any pure Python values as opaque objects (best
1356 thought of as compile-time constants), and builds a separate `tf.Graph` for
1357 each set of Python arguments that it encounters.
1358 For more information, see the
1359 [tf.function guide](https://www.tensorflow.org/guide/function#rules_of_tracing)
1360
1361 Executing a `GenericFunction` will select and execute the appropriate
1362 `ConcreteFunction` based on the argument types and values.
1363
1364 To obtain an individual `ConcreteFunction`, use the
1365 `GenericFunction.get_concrete_function` method. It can be called with the
1366 same arguments as `func` and returns a
1367 `tf.types.experimental.ConcreteFunction`. `ConcreteFunction`s are backed by a
1368 single `tf.Graph`:
1369
1370 >>> @tf.function
1371 ... def f(x):
1372 ... return x + 1
1373 >>> isinstance(f.get_concrete_function(1).graph, tf.Graph)
1374 True
1375
1376 `ConcreteFunction`s can be executed just like `GenericFunction`s, but their
1377 input is resticted to the types to which they're specialized.
1378
1379 ## Retracing
1380
1381 `ConcreteFunctions` are built (traced) on the fly, as the `GenericFunction` is
1382 called with new TensorFlow types or shapes, or with new Python values as
1383 arguments. When `GenericFunction` builds a new trace, it is said that `func`
1384 is retraced. Retracing is a frequent performance concern for `tf.function` as
1385 it can be considerably slower than executing a graph that's already been
1386 traced. It is ideal to minimize the amount of retracing in your code.
1387
1388 Caution: Passing python scalars or lists as arguments to `tf.function` will
1389 usually retrace. To avoid this, pass numeric arguments as Tensors whenever
1390 possible:
1391
1392 >>> @tf.function
1393 ... def f(x):
1394 ... return tf.abs(x)
1395 >>> f1 = f.get_concrete_function(1)
1396 >>> f2 = f.get_concrete_function(2) # Slow - compiles new graph
1397 >>> f1 is f2
1398 False
1399 >>> f1 = f.get_concrete_function(tf.constant(1))
1400 >>> f2 = f.get_concrete_function(tf.constant(2)) # Fast - reuses f1
1401 >>> f1 is f2
1402 True
1403
1404 Python numerical arguments should only be used when they take few distinct
1405 values, such as hyperparameters like the number of layers in a neural network.
1406
1407 ## Input signatures
1408
1409 For Tensor arguments, `GenericFunction`creates a new `ConcreteFunction` for
1410 every unique set of input shapes and datatypes. The example below creates two
1411 separate `ConcreteFunction`s, each specialized to a different shape:
1412
1413 >>> @tf.function
1414 ... def f(x):
1415 ... return x + 1
1416 >>> vector = tf.constant([1.0, 1.0])
1417 >>> matrix = tf.constant([[3.0]])
1418 >>> f.get_concrete_function(vector) is f.get_concrete_function(matrix)
1419 False
1420
1421 An "input signature" can be optionally provided to `tf.function` to control
1422 this process. The input signature specifies the shape and type of each
1423 Tensor argument to the function using a `tf.TensorSpec` object. More general
1424 shapes can be used. This ensures only one `ConcreteFunction` is created, and
1425 restricts the `GenericFunction` to the specified shapes and types. It is
1426 an effective way to limit retracing when Tensors have dynamic shapes.
1427
1428 >>> @tf.function(
1429 ... input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
1430 ... def f(x):
1431 ... return x + 1
1432 >>> vector = tf.constant([1.0, 1.0])
1433 >>> matrix = tf.constant([[3.0]])
1434 >>> f.get_concrete_function(vector) is f.get_concrete_function(matrix)
1435 True
1436
1437 ## Variables may only be created once
1438
1439 `tf.function` only allows creating new `tf.Variable` objects when it is called
1440 for the first time:
1441
1442 >>> class MyModule(tf.Module):
1443 ... def __init__(self):
1444 ... self.v = None
1445 ...
1446 ... @tf.function
1447 ... def __call__(self, x):
1448 ... if self.v is None:
1449 ... self.v = tf.Variable(tf.ones_like(x))
1450 ... return self.v * x
1451
1452 In general, it is recommended to create `tf.Variable`s outside of
1453 `tf.function`.
1454 In simple cases, persisting state across `tf.function` boundaries may be
1455 implemented using a pure functional style in which state is represented by
1456 `tf.Tensor`s passed as arguments and returned as return values.
1457
1458 Contrast the two styles below:
1459
1460 >>> state = tf.Variable(1)
1461 >>> @tf.function
1462 ... def f(x):
1463 ... state.assign_add(x)
1464 >>> f(tf.constant(2)) # Non-pure functional style
1465 >>> state
1466 <tf.Variable ... numpy=3>
1467
1468 >>> state = tf.constant(1)
1469 >>> @tf.function
1470 ... def f(state, x):
1471 ... state += x
1472 ... return state
1473 >>> state = f(state, tf.constant(2)) # Pure functional style
1474 >>> state
1475 <tf.Tensor: ... numpy=3>
1476
1477 ## Python operations execute only once per trace
1478
1479 `func` may contain TensorFlow operations mixed with pure Python operations.
1480 However, when the function is executed, only the TensorFlow operations will
1481 run. The Python operations run only once, at trace time. If TensorFlow
1482 operations depend on results from Python operations, those results will be
1483 frozen into the graph.
1484
1485 >>> @tf.function
1486 ... def f(a, b):
1487 ... print('this runs at trace time; a is', a, 'and b is', b)
1488 ... return b
1489 >>> f(1, tf.constant(1))
1490 this runs at trace time; a is 1 and b is Tensor("...", shape=(), dtype=int32)
1491 <tf.Tensor: shape=(), dtype=int32, numpy=1>
1492
1493 >>> f(1, tf.constant(2))
1494 <tf.Tensor: shape=(), dtype=int32, numpy=2>
1495
1496 >>> f(2, tf.constant(1))
1497 this runs at trace time; a is 2 and b is Tensor("...", shape=(), dtype=int32)
1498 <tf.Tensor: shape=(), dtype=int32, numpy=1>
1499
1500 >>> f(2, tf.constant(2))
1501 <tf.Tensor: shape=(), dtype=int32, numpy=2>
1502
1503 Args:
1504 func: The function to be compiled. If `func` is None, `tf.function` returns
1505 a decorator that can be invoked with a single argument - `func`. In other
1506 words, `tf.function(input_signature=...)(func)` is equivalent to
1507 `tf.function(func, input_signature=...)`. The former can be used as
1508 decorator.
1509 input_signature: A possibly nested sequence of `tf.TensorSpec` objects
1510 specifying the shapes and dtypes of the Tensors that will be supplied to
1511 this function. If `None`, a separate function is instantiated for each
1512 inferred input signature. If input_signature is specified, every input to
1513 `func` must be a `Tensor`, and `func` cannot accept `**kwargs`.
1514 autograph: Whether autograph should be applied on `func` before tracing a
1515 graph. Data-dependent Python control flow statements require
1516 `autograph=True`. For more information, see the
1517 [tf.function and AutoGraph guide](
1518 https://www.tensorflow.org/guide/function#autograph_transformations).
1519 jit_compile: If `True`, compiles the function using
1520 [XLA](https://tensorflow.org/xla). XLA performs compiler optimizations,
1521 such as fusion, and attempts to emit more efficient code. This may
1522 drastically improve the performance. If set to `True`,
1523 the whole function needs to be compilable by XLA, or an
1524 `errors.InvalidArgumentError` is thrown.
1525 If `None` (default), compiles the function with XLA when running on TPU
1526 and goes through the regular function execution path when running on
1527 other devices.
1528 If `False`, executes the function without XLA compilation. Set this value
1529 to `False` when directly running a multi-device function on TPUs (e.g. two
1530 TPU cores, one TPU core and its host CPU).
1531 Not all functions are compilable, see a list of
1532 [sharp corners](https://tensorflow.org/xla/known_issues).
1533 reduce_retracing: When True, `tf.function` attempts to reduce the
1534 amount of retracing, for example by using more generic shapes. This
1535 can be controlled for user objects by customizing their associated
1536 `tf.types.experimental.TraceType`.
1537 experimental_implements: If provided, contains a name of a "known" function
1538 this implements. For example "mycompany.my_recurrent_cell".
1539 This is stored as an attribute in inference function,
1540 which can then be detected when processing serialized function.
1541 See [standardizing composite ops](https://github.com/tensorflow/community/blob/master/rfcs/20190610-standardizing-composite_ops.md) # pylint: disable=line-too-long
1542 for details. For an example of utilizing this attribute see this
1543 [example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc)
1544 The code above automatically detects and substitutes function that
1545 implements "embedded_matmul" and allows TFLite to substitute its own
1546 implementations. For instance, a tensorflow user can use this
1547 attribute to mark that their function also implements
1548 `embedded_matmul` (perhaps more efficiently!)
1549 by specifying it using this parameter:
1550 `@tf.function(experimental_implements="embedded_matmul")`
1551 This can either be specified as just the string name of the function or
1552 a NameAttrList corresponding to a list of key-value attributes associated
1553 with the function name. The name of the function will be in the 'name'
1554 field of the NameAttrList. To define a formal TF op for this function
1555 implements, try the experimental [composite TF](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tfr)
1556 project.
1557 experimental_autograph_options: Optional tuple of
1558 `tf.autograph.experimental.Feature` values.
1559 experimental_attributes: Optional dictionary of attributes to include in the
1560 generated FunctionDefs.
1561 experimental_relax_shapes: Deprecated. Use `reduce_retracing`
1562 instead.
1563 experimental_compile: Deprecated alias to 'jit_compile'.
1564 experimental_follow_type_hints: Deprecated. Please use input_signature or
1565 reduce_retracing instead.
1566
1567 Returns:
1568 If `func` is not None, returns a `tf.types.experimental.GenericFunction`.
1569 If `func` is None, returns a decorator that, when invoked with a single
1570 `func` argument, returns a `tf.types.experimental.GenericFunction`.
1571
1572 Raises:
1573 `ValueError` when attempting to use `jit_compile=True`, but XLA support is
1574 not available.
1575 """
1576 if jit_compile is None and JIT_COMPILE_FUNCTIONS:
1577 jit_compile = True
1578
1579 # TODO(b/224808187): Remove after renaming usages.
1580 if experimental_relax_shapes:
1581 reduce_retracing = True
1582
1583 def decorated(inner_function):
1584 try:
1585 name = inner_function.__name__
1586 except AttributeError:
1587 name = "function"
1588 return tf_decorator.make_decorator(
1589 inner_function,
1590 decorator_name="tf.function",
1591 decorator_func=Function(
1592 inner_function,
1593 name,
1594 input_signature=input_signature,
1595 autograph=autograph,
1596 experimental_autograph_options=experimental_autograph_options,
1597 reduce_retracing=reduce_retracing,
1598
1599 # TODO(b/171825496): Update once `experimental_compile` is removed
1600 # entirely in favor of 'jit_compile'.
1601 jit_compile=deprecation.deprecated_argument_lookup(
1602 "jit_compile",
1603 jit_compile,
1604 "experimental_compile",
1605 experimental_compile),
1606 experimental_implements=experimental_implements,
1607 experimental_attributes=experimental_attributes))
1608
1609 # This code path is for the `foo = tf.function(foo, ...)` use case
1610 if func is not None:
1611 return decorated(func)
1612
1613 # This code path is for the
1614 #
1615 # @tf.function(...)
1616 # def foo(...):
1617 # ...
1618 #
1619 # use case, which is equivalent to `foo = tf.function(...)(foo)`
1620 return decorated