1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=unidiomatic-typecheck
16"""Implementation for Monomorphic Functions (including Differentiable ones)."""
17
18import collections
19import pprint
20
21from tensorflow.core.framework import attr_value_pb2
22from tensorflow.core.function.polymorphism import function_type as function_type_lib
23from tensorflow.python import pywrap_tfe
24from tensorflow.python.eager import backprop_util
25from tensorflow.python.eager import context
26from tensorflow.python.eager import forwardprop_util
27from tensorflow.python.eager import record
28from tensorflow.python.eager.graph_only_ops import graph_placeholder
29from tensorflow.python.eager.polymorphic_function import atomic_function
30from tensorflow.python.eager.polymorphic_function import attributes as attributes_lib
31from tensorflow.python.eager.polymorphic_function import function_spec
32from tensorflow.python.eager.polymorphic_function import saved_model_exported_concrete
33from tensorflow.python.framework import composite_tensor
34from tensorflow.python.framework import dtypes
35from tensorflow.python.framework import errors
36from tensorflow.python.framework import func_graph as func_graph_module
37from tensorflow.python.framework import indexed_slices
38from tensorflow.python.framework import ops
39from tensorflow.python.framework import tensor_shape
40from tensorflow.python.framework import tensor_spec
41from tensorflow.python.framework import type_spec
42from tensorflow.python.ops import array_ops
43from tensorflow.python.ops import default_gradient
44from tensorflow.python.ops import gradients_util
45from tensorflow.python.ops import handle_data_util
46from tensorflow.python.ops import resource_variable_ops
47from tensorflow.python.platform import tf_logging as logging
48from tensorflow.python.profiler import trace
49from tensorflow.python.trackable import base as trackable
50from tensorflow.python.types import core
51from tensorflow.python.util import _pywrap_utils
52from tensorflow.python.util import compat
53from tensorflow.python.util import nest
54from tensorflow.python.util import object_identity
55
56
57def _is_type_subset(a, b):
58 """Returns true if `b` is a subset of type `a` (or if a is not a TypeSpec.)"""
59 if isinstance(a, type_spec.TypeSpec):
60 return a.most_specific_compatible_type(b) == a
61 return True
62
63
64def _parse_func_attrs(attributes):
65 """Convert the keyword arguments into function_def attributes.
66
67 Currently only support primitive types: bool, int, float and string.
68
69 Args:
70 attributes: the dictionary of attributes.
71 Returns:
72 A dict of attributes where the key is the name of attribute and the value
73 is the AttrValue proto.
74 Raises:
75 ValueError: If the kwargs contains unallowlisted name or unsupported value
76 types.
77 """
78 attrs = {}
79 for key, value in attributes.items():
80 if key not in attributes_lib.MONOMORPHIC_FUNCTION_ALLOWLIST:
81 raise ValueError(
82 f"ConcreteFunction does not support `{key}` as an attribute.")
83 if isinstance(value, attr_value_pb2.AttrValue):
84 attrs[key] = value
85 # bool type check has to happen before int since bool is a subclass of int.
86 elif isinstance(value, bool):
87 attrs[key] = attr_value_pb2.AttrValue(b=value)
88 elif isinstance(value, int):
89 attrs[key] = attr_value_pb2.AttrValue(i=value)
90 elif isinstance(value, float):
91 attrs[key] = attr_value_pb2.AttrValue(f=value)
92 elif isinstance(value, (str, bytes)):
93 attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(value))
94 else:
95 raise ValueError(f"Attribute {key} must be bool, int, float, string, or "
96 f"AttrValue. Got {type(value)}.")
97 return attrs
98
99_FORWARD_PREFIX = "__forward_"
100_BACKWARD_PREFIX = "__backward_"
101_INFERENCE_PREFIX = "__inference_"
102
103
104def _forward_name(n):
105 """The name of a generated forward defun named n."""
106 return "%s%s_%s" % (_FORWARD_PREFIX, n, ops.uid())
107
108
109def _backward_name(n):
110 """The name of a generated backward defun named n."""
111 return "%s%s_%s" % (_BACKWARD_PREFIX, n, ops.uid())
112
113
114def _inference_name(n):
115 """The name of a forward-but-no-gradient defun named n."""
116 return "%s%s_%s" % (_INFERENCE_PREFIX, n, ops.uid())
117
118
119def _create_forward_backward_with_graph(attrs, forward_graph, backwards_graph):
120 """Creates forward and backward functions from the function graphs."""
121 forward_function_name = _forward_name(forward_graph.name)
122 common_attributes = dict(attrs)
123 # NB: forward and backward function need to drop "_implements".
124 # attribute, because their signature contains all the intermediate tensors
125 # that they compute. Thus they don't have a stable signature which can
126 # be directly optimized downstream.
127 # See for more details:
128 # https://github.com/tensorflow/community/blob/master/rfcs/20190610-standardizing-composite_ops.md#appendix-future-support-for-optimizing-gradient-functions
129 common_attributes.pop(attributes_lib.IMPLEMENTS, None)
130 backward_function_attr = _parse_func_attrs(
131 {attributes_lib.FORWARD_FUNCTION: forward_function_name})
132 backward_function_attr.update(common_attributes)
133 backward_function = ConcreteFunction(
134 backwards_graph, attrs=backward_function_attr)
135 forward_function_attr = _parse_func_attrs({
136 attributes_lib.BACKWARD_FUNCTION:
137 backward_function.name})
138 forward_function_attr.update(common_attributes)
139 forward_function = atomic_function.from_func_graph(
140 forward_function_name, forward_graph, forward_graph.inputs,
141 forward_graph.outputs, forward_function_attr)
142 return forward_function, backward_function
143
144
145class _DelayedRewriteGradientFunctions(object):
146 """Caches forward/backward functions with a delayed forward rewrite."""
147
148 def __init__(self, func_graph, attrs, func_graph_deleter):
149 """Construct an inference function and initialize caches."""
150 # A map from the number of forward function outputs with accepted gradients
151 # to forward and backward functions, used to cache non-tape backward
152 # function generation.
153 self._cached_function_pairs = {}
154 self._func_graph = func_graph
155 self._inference_function = atomic_function.from_func_graph(
156 _inference_name(self._func_graph.name), self._func_graph,
157 self._func_graph.inputs, self._func_graph.outputs, attrs)
158 self._attrs = attrs
159 self._gradient_name = None
160 # Note that the FuncGraph is mutated later, so we need to inspect it now to
161 # figure out the user-specified outputs of the inference function.
162 self._num_inference_outputs = len(self._func_graph.outputs)
163 self._func_graph_deleter = func_graph_deleter
164
165 def forward_backward(self, num_doutputs=None):
166 """A possibly-cached pair of forward and backward functions."""
167 if num_doutputs is None:
168 num_doutputs = self._num_inference_outputs
169 forward_backward = self._cached_function_pairs.get(num_doutputs)
170 if forward_backward is not None:
171 return forward_backward
172 forward, backward = self._construct_forward_backward(num_doutputs)
173 self._cached_function_pairs[num_doutputs] = (forward, backward)
174 return forward, backward
175
176 def _construct_forward_backward(self, num_doutputs):
177 """Constructs a pair of forward and backward functions.
178
179 Args:
180 num_doutputs: The constructed backprop function will take output gradients
181 for the first `num_doutputs` outputs of the forward function. Defaults
182 to the number of outputs for the inference function, but when
183 higher-order gradients are computed this will increase to include side
184 outputs.
185
186 Returns:
187 A pair of (forward_function, backward_function):
188 forward_function: A re-generated inference function (an
189 AtomicFunction) to account for new side outputs, if any extra
190 were required when building the backward pass.
191 backward_function: A ConcreteFunction that Takes `num_doutputs`
192 arguments and returns gradients with respect to inputs of the forward
193 function.
194 """
195 trainable_outputs = [
196 output for output in self._func_graph.outputs[:num_doutputs]
197 if backprop_util.IsTrainable(output)]
198
199 signature = []
200 for t in trainable_outputs:
201 signature.append(
202 tensor_spec.TensorSpec(*default_gradient.shape_and_dtype(t)))
203
204 def _backprop_function(*grad_ys):
205 with ops.device(None):
206 return gradients_util._GradientsHelper( # pylint: disable=protected-access
207 trainable_outputs,
208 self._func_graph.inputs,
209 grad_ys=grad_ys,
210 src_graph=self._func_graph)
211
212 with self._func_graph.as_default():
213 backwards_graph = func_graph_module.FuncGraph(
214 _backward_name(self._func_graph.name))
215 func_graph_module.func_graph_from_py_func(
216 name=backwards_graph.name,
217 python_func=_backprop_function,
218 args=[], kwargs={},
219 signature=signature,
220 func_graph=backwards_graph)
221 backwards_graph_captures = backwards_graph.external_captures
222 captures_from_forward = [
223 c for c in backwards_graph_captures if
224 not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph]
225
226 existing_outputs = object_identity.ObjectIdentitySet(
227 self._func_graph.outputs)
228 for capture in captures_from_forward:
229 if capture not in existing_outputs:
230 existing_outputs.add(capture)
231 self._func_graph.outputs.append(capture)
232
233 forward_function, backward_function = _create_forward_backward_with_graph(
234 self._attrs, self._func_graph, backwards_graph)
235 return forward_function, backward_function
236
237 def _rewrite_forward_and_call_backward(self, op, *doutputs):
238 """Add outputs to the forward call and feed them to the grad function."""
239 forward_function, backwards_function = self.forward_backward(len(doutputs))
240 if not backwards_function.outputs:
241 return backwards_function.structured_outputs
242
243 op.graph._add_function_recursive(forward_function) # pylint: disable=protected-access
244
245 # pylint: disable=protected-access
246 # Rewrite an inference call op to be a forward call op
247 op._set_func_attr("f", forward_function.name)
248 op._set_type_list_attr(
249 "Tout",
250 [
251 o.dtype.as_datatype_enum
252 for o in forward_function.function_type.flat_outputs
253 ],
254 )
255 truncated_outputs = forward_function.function_type.flat_outputs[
256 len(op.outputs) :
257 ]
258 op._add_outputs(
259 [o.dtype.as_datatype_enum for o in truncated_outputs],
260 [o.shape for o in truncated_outputs],
261 )
262 for i in range(len(op.outputs)):
263 output_type = forward_function.function_type.flat_outputs[i]
264 handle_data = output_type.dtype._handle_data
265 if handle_data:
266 handle_data_util.set_handle_data(op.outputs[i], handle_data)
267 # pylint: enable=protected-access
268
269 capture_mapping = dict(
270 zip((ops.tensor_id(t) for t in self._func_graph.outputs), op.outputs))
271 remapped_captures = [
272 capture_mapping.get(ops.tensor_id(capture), capture)
273 for capture in backwards_function.captured_inputs
274 ]
275
276 # Replace Nones with zeros since we're calling a graph function which
277 # expects numeric inputs.
278 cleaned_doutputs = []
279 for doutput, placeholder in zip(doutputs, self._func_graph.outputs):
280 if backprop_util.IsTrainable(placeholder):
281 if isinstance(doutput, indexed_slices.IndexedSlices):
282 # Gradient passed to a backward ConcreteFunction must be tf.Tensor,
283 # so we convert tf.IndexedSlices to tf.Tensor.
284 cleaned_doutputs.append(ops.convert_to_tensor(doutput))
285 elif doutput is not None:
286 cleaned_doutputs.append(doutput)
287 else:
288 cleaned_doutputs.append(default_gradient.zeros_like(placeholder))
289
290 # Compute the gradients using the side outputs
291 return backwards_function._call_flat( # pylint: disable=protected-access
292 cleaned_doutputs, remapped_captures)
293
294 def get_gradient_function(self):
295 """Returns gradient function.
296
297 The gradient rewrites an inference call op to a forward call op, but does
298 not modify a pre-existing forward call op. It then computes the gradient
299 from the output's gradients and the side outputs of the forward op.
300 """
301 return self._rewrite_forward_and_call_backward
302
303 def forward(self, inference_args=None, input_tangents=None):
304 """A forward function with only user-specified outputs.
305
306 The call operation for the returned inference function can be rewritten into
307 a forward function. This only happens if the backward function (from the
308 `backward` method) ends up being used to compute gradients.
309
310 This approach avoids constructing unnecessary graphs, but it only works if
311 we are calling this function when not executing eagerly.
312
313 Args:
314 inference_args: A flat list of Tensors, arguments to the inference
315 function. Unused, but taken for compatibility with
316 _TapeGradientFunctions.
317 input_tangents: A flat list of Tensors, jvps associated with
318 `inference_args`. Unused; if required, tape functions must be used
319 instead.
320
321 Returns:
322 An atomic_function.AtomicFunction.
323 """
324 del inference_args # unused
325 if input_tangents:
326 # This class does not support special-cased forwardprop. The arguments are
327 # here for compatibility with _TapeGradientFunctions.
328 raise errors.InternalError("unexpectedly got forwardprop information in "
329 "a class that does not support forwardprop.")
330 return self._inference_function
331
332 def _backward(self, outputs):
333 """Fetch a backward function for `outputs` from the forward function."""
334 def _backward_function(*args):
335 call_op = outputs[0].op
336 return self._rewrite_forward_and_call_backward(call_op, *args)
337 return _backward_function, outputs
338
339 def record(self, flat_outputs, inference_args, input_tangents):
340 """Record the function call operation.
341
342 _DelayedRewriteGradientFunctions supports only first-order backprop tape
343 gradients (and then only when graph building). It does not work with
344 higher-order tape gradients or forward autodiff, but does work with
345 higher-order symbolic gradients (tf.gradients).
346
347 Args:
348 flat_outputs: The result of running `forward`.
349 inference_args: A flat list of Tensors with inference inputs to the
350 operation.
351 input_tangents: A flat list of Tensors with input tangents consumed by the
352 operation.
353 """
354 backward_function, to_record = self._backward(flat_outputs)
355 record.record_operation(
356 self._inference_function.cached_definition.signature.name,
357 to_record,
358 inference_args + input_tangents,
359 backward_function,
360 )
361
362
363# Contains information about a forward function wrapped to compute jvps.
364_ForwardWrapper = collections.namedtuple(
365 "_ForwardWrapper", (
366 # The wrapper Graph.
367 "graph",
368 # A flat list of non-tangent Tensor outputs from the wrapped forward
369 # function.
370 "outputs",
371 # Indices for output tangents, same format as
372 # forwardprop_util.pack_tangents.
373 "output_indices",
374 # A flat list of tangents for `outputs`.
375 "output_tangents"))
376
377
378class _TapeGradientFunctions(object):
379 """Caches forward and backward functions compatible with eager gradients.
380
381 In contrast to the delayed-rewrite approach in
382 `_DelayedRewriteGradientFunctions` which only works with delayed execution,
383 the forward function generated by this class has a fixed set of outputs which
384 may be preserved by a tape in order to compute gradients later.
385
386 This class is abstract; its child classes differ in how many side outputs of
387 the forward function their backward function accepts gradients for, which
388 determines whether higher-order tape gradients are possible.
389 """
390
391 def __init__(self, func_graph, attrs, func_graph_deleter,
392 forwardprop_input_indices, delayed_rewrite_functions,
393 need_gradients_for_jvps):
394 self._func_graph = func_graph
395 self._forward_graph = None
396 self._attrs = attrs
397 self._forward = None
398 self._backward = None
399 self._num_outputs = len(func_graph.outputs)
400 self._func_graph_deleter = func_graph_deleter
401 self._forwardprop_input_indices = forwardprop_input_indices
402 self._forwardprop_output_indices = None
403 self._num_forwardprop_outputs = 0
404 self._num_inference_outputs = len(func_graph.outputs)
405 self._num_trainable_inference_outputs = len(
406 [t for t in func_graph.outputs if backprop_util.IsTrainable(t)])
407 self._delayed_rewrite_functions = delayed_rewrite_functions
408 self._need_gradients_for_jvps = need_gradients_for_jvps
409
410 def _build_functions_for_outputs(
411 self, outputs, inference_args, input_tangents):
412 """Forward+backward functions where the backward function sees `outputs`."""
413 # First figure out which of `outputs` are trainable. We'll accept gradients
414 # for each of these in the backward function.
415 trainable_outputs = []
416 trainable_indices = []
417 for index, output in enumerate(outputs):
418
419 if backprop_util.IsTrainable(output):
420 trainable_outputs.append(output)
421 trainable_indices.append(index)
422
423 backwards_graph = func_graph_module.FuncGraph(
424 _backward_name(self._func_graph.name))
425 with backwards_graph.as_default():
426 gradients_wrt_outputs = []
427 for output in trainable_outputs:
428 gradient_shape, gradient_dtype = default_gradient.shape_and_dtype(
429 output)
430 gradient_placeholder = graph_placeholder(gradient_dtype, gradient_shape)
431 handle_data_util.copy_handle_data(output, gradient_placeholder)
432 gradients_wrt_outputs.append(gradient_placeholder)
433 with ops.device(None):
434 gradients_wrt_inputs = gradients_util._GradientsHelper( # pylint: disable=protected-access
435 trainable_outputs,
436 self._func_graph.inputs,
437 grad_ys=gradients_wrt_outputs,
438 src_graph=self._func_graph)
439
440 if input_tangents:
441 # Convert IndexedSlices to dense tensors (as we do elsewhere for
442 # function gradients). Our C++ bindings don't know how to handle them
443 # currently.
444 gradients_wrt_inputs = nest.map_structure(
445 lambda x: ops.convert_to_tensor(x) if x is not None else None,
446 gradients_wrt_inputs)
447 captures_from_forward = [
448 c for c in backwards_graph.external_captures
449 if not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph
450 ]
451 existing_outputs = object_identity.ObjectIdentitySet(
452 self._func_graph.outputs)
453 for capture in captures_from_forward:
454 if capture not in existing_outputs:
455 existing_outputs.add(capture)
456 self._func_graph.outputs.append(capture)
457
458 # The ordering of `backwards_graph.inputs` is important: inputs of
459 # `backward_function` correspond to outputs (including
460 # side outputs) of `self._tape_forward_function`.
461 backwards_graph.inputs = (
462 gradients_wrt_outputs + backwards_graph.internal_captures)
463 backwards_graph.outputs.extend(
464 grad
465 for grad in nest.flatten(gradients_wrt_inputs, expand_composites=True)
466 if grad is not None)
467 backwards_graph.structured_outputs = gradients_wrt_inputs
468
469 forward_function, backward_function = _create_forward_backward_with_graph(
470 self._attrs, self._func_graph, backwards_graph)
471
472 if not input_tangents:
473 # There is no need to special-case forwardprop, so we can return the
474 # forward+backward pair we've created without further wrapping.
475 return (forward_function, self._func_graph, backward_function,
476 # No forwardprop outputs.
477 None, 0)
478 forward_wrapper = self._wrap_forward_function_with_jvps(
479 forward_function, backward_function, inference_args, input_tangents)
480 (wrapped_backwards_graph,
481 forward_wrapper) = self._wrap_backward_function_with_jvp_backprop(
482 backward_function, gradients_wrt_outputs, forward_wrapper)
483 # Now that we've added new captures, we need to make sure forward outputs
484 # are in the same order the backward function expects them to be in:
485 # [inference outputs] + [jvps] + [side outputs] + [captures].
486 forward_wrapper = self._shuffle_forward_outputs(forward_wrapper)
487 (wrapped_forward_function,
488 wrapped_backward_function) = _create_forward_backward_with_graph(
489 self._attrs, forward_wrapper.graph, wrapped_backwards_graph)
490 if (len(inference_args) + len(input_tangents)
491 != len(forward_wrapper.graph.inputs)):
492 raise errors.InternalError(
493 f"The forward graph had {len(forward_wrapper.graph.inputs)} inputs, "
494 f"but we expected {len(inference_args) + len(input_tangents)} "
495 f"({len(inference_args)} inference inputs and "
496 f"{len(input_tangents)} input tangents).")
497 return (wrapped_forward_function, forward_wrapper.graph,
498 wrapped_backward_function, forward_wrapper.output_indices,
499 len(forward_wrapper.output_tangents))
500
501 def _wrap_forward_function_with_jvps(
502 self, forward_function, backward_function,
503 inference_args, input_tangents):
504 """Adds inline JVP computation to a forward function."""
505 forward_wrapper_graph = func_graph_module.FuncGraph(
506 _forward_name(self._func_graph.name))
507 with forward_wrapper_graph.as_default():
508 # Tell forward accumulators to free up space for new JVP computations,
509 # since one may be in the process of computing a JVP (if that computation
510 # triggered this function building).
511 #
512 # We'll make symbolic versions of input JVPs, run the forward function
513 # under forward accumulators to get symbolic output JVPs, then set those
514 # as outputs of the new wrapped forward function.
515 with forwardprop_util.push_forwardprop_state():
516 forward_captures = {
517 ops.tensor_id(internal): external
518 for external, internal in self._func_graph.captures}
519 for input_index, real_input in enumerate(self._func_graph.inputs):
520 # This loop is more or less equivalent to running tf.identity on each
521 # of self._func_graph.inputs. However, doing that also captures jvps
522 # for resource handles, which confuses the jvp capturing code below
523 # (since primal inputs are interwoven with jvp inputs).
524 input_placeholder = array_ops.placeholder(
525 dtype=real_input.dtype,
526 shape=real_input.shape)
527 capture = forward_captures.get(ops.tensor_id(real_input))
528 if capture is not None:
529 forward_wrapper_graph.add_capture(capture, input_placeholder)
530 if capture.dtype == dtypes.resource:
531 handle_data_util.copy_handle_data(capture, input_placeholder)
532 else:
533 forward_wrapper_graph.inputs.append(input_placeholder)
534 for inp, arg in zip(forward_wrapper_graph.inputs, inference_args):
535 record.record_operation(
536 "captured_value", [inp], [arg],
537 backward_function=lambda x: [x],
538 forward_function=lambda x: [x])
539 num_inference_inputs = len(inference_args)
540 for tape_indices in self._forwardprop_input_indices:
541 for input_index, jvp_index in tape_indices:
542 input_placeholder = forward_wrapper_graph.inputs[input_index]
543 if len(forward_wrapper_graph.inputs) != jvp_index:
544 raise errors.InternalError(
545 f"Expected {jvp_index} forward graph inputs, "
546 f"got {len(forward_wrapper_graph.inputs)}.")
547 gradient_shape, gradient_dtype = default_gradient.shape_and_dtype(
548 input_placeholder)
549 jvp_placeholder = graph_placeholder(gradient_dtype, gradient_shape)
550 external_jvp = input_tangents[jvp_index - num_inference_inputs]
551 forward_wrapper_graph.add_capture(external_jvp, jvp_placeholder)
552 tensor_shape.TensorShape(
553 external_jvp.shape).assert_is_compatible_with(
554 jvp_placeholder.shape)
555 record.record_operation(
556 "captured_value",
557 [jvp_placeholder],
558 [external_jvp],
559 backward_function=lambda x: [x],
560 forward_function=lambda x: [x])
561 forward_inputs = forward_wrapper_graph.inputs[:num_inference_inputs]
562 gradient_function = (
563 self._delayed_rewrite_functions._rewrite_forward_and_call_backward) # pylint: disable=protected-access
564 with ops.get_default_graph()._override_gradient_function( # pylint: disable=protected-access
565 {"PartitionedCall": gradient_function,
566 "StatefulPartitionedCall": gradient_function}):
567 forward_outputs = forward_function(*forward_inputs)
568 if isinstance(forward_outputs, ops.Operation):
569 # _wrapped_backward_function expects a list, but if the function has
570 # no outputs its call() returns an Operation. We need to undo that
571 # so we don't cause problems later.
572 forward_outputs = []
573 py_backward, _ = self._wrap_backward_function(
574 self._func_graph, backward_function, forward_outputs)
575 # We will never request backward tape gradients for this operation
576 # directly since we're wrapping the call; forwardprop will call the
577 # backward function (and nested forward accumulators may build
578 # higher-order gradients), but any watching GradientTapes should ignore
579 # it.
580 #
581 # TODO(allenl): It might be better to explicitly stop backward recording
582 # so we don't use the second-order tape cases unnecessarily.
583 record.record_operation_forwardprop_only(
584 forward_function.cached_definition.signature.name,
585 forward_outputs, forward_inputs, py_backward, None)
586 output_indices, output_tangents = (
587 pywrap_tfe.TFE_Py_PackJVPs(forward_outputs))
588 output_tangents = [forward_wrapper_graph.capture(t)
589 for t in output_tangents]
590 return _ForwardWrapper(
591 graph=forward_wrapper_graph, outputs=forward_outputs,
592 output_indices=output_indices, output_tangents=output_tangents)
593
594 def _wrap_backward_function_with_jvp_backprop(
595 self, backward_function, gradients_wrt_outputs, forward_wrapper):
596 """Wraps `backward_function` to include gradients for JVPs."""
597 wrapped_backwards_graph = func_graph_module.FuncGraph(
598 _backward_name(self._func_graph.name))
599 with wrapped_backwards_graph.as_default():
600 py_backward, recorded_outputs = self._wrap_backward_function(
601 self._func_graph, backward_function, forward_wrapper.outputs)
602 trainable_index = 0
603 forward_doutputs = []
604 doutput_args = []
605 for output in recorded_outputs:
606 if backprop_util.IsTrainable(output):
607 doutput = gradients_wrt_outputs[trainable_index]
608 doutput_placeholder = graph_placeholder(doutput.dtype, doutput.shape)
609 doutput_args.append(doutput_placeholder)
610 forward_doutputs.append(doutput_placeholder)
611 trainable_index += 1
612 else:
613 doutput_args.append(None)
614
615 dinputs = py_backward(*doutput_args)
616 existing_outputs = object_identity.ObjectIdentitySet(
617 forward_wrapper.outputs + forward_wrapper.output_tangents)
618 num_processed_output_tangents = 0
619 gradients_wrt_output_tangents = []
620 tangent_doutputs = []
621 output_tangents = forward_wrapper.output_tangents
622 output_indices = forward_wrapper.output_indices
623 if self._need_gradients_for_jvps:
624 # TODO(allenl): Consider using a throwaway graph to avoid extra gradient
625 # evaluations; gradients for jvps may have common subgraphs.
626 while num_processed_output_tangents != len(output_tangents):
627 for output in output_tangents[num_processed_output_tangents:]:
628 gradient_shape, gradient_dtype = default_gradient.shape_and_dtype(
629 output)
630 placeholder = graph_placeholder(gradient_dtype, gradient_shape)
631 gradients_wrt_output_tangents.append(placeholder)
632 tangent_doutputs.append(placeholder)
633 num_processed_output_tangents = len(output_tangents)
634 with ops.device(None):
635 gradients_wrt_inputs = gradients_util._GradientsHelper( # pylint: disable=protected-access
636 output_tangents,
637 forward_wrapper.graph.inputs,
638 grad_ys=gradients_wrt_output_tangents,
639 src_graph=forward_wrapper.graph)
640 dinputs = [
641 backprop_util.AggregateIndexedSlicesGradients((existing, new))
642 for existing, new in zip(dinputs, gradients_wrt_inputs)
643 if existing is not None or new is not None]
644 dinputs.extend(gradients_wrt_inputs[len(dinputs):])
645 captures_from_forward = [
646 c for c in wrapped_backwards_graph.external_captures
647 if (not isinstance(c, ops.EagerTensor)
648 and c.graph is forward_wrapper.graph)]
649 for capture in captures_from_forward:
650 if capture not in existing_outputs:
651 existing_outputs.add(capture)
652 forward_wrapper.outputs.append(capture)
653 output_indices, output_tangents = (
654 forwardprop_util.pack_tangents(forward_wrapper.outputs))
655 output_tangents = [forward_wrapper.graph.capture(t)
656 for t in output_tangents]
657 for t in output_tangents:
658 existing_outputs.add(t)
659 wrapped_backwards_graph.inputs = (
660 forward_doutputs[:self._num_trainable_inference_outputs]
661 + tangent_doutputs
662 + forward_doutputs[self._num_trainable_inference_outputs:]
663 + wrapped_backwards_graph.internal_captures)
664 wrapped_backwards_graph.structured_outputs = dinputs
665 wrapped_backwards_graph.outputs = [t for t in dinputs if t is not None]
666 return (wrapped_backwards_graph,
667 forward_wrapper._replace(output_indices=output_indices,
668 output_tangents=output_tangents))
669
670 def _shuffle_forward_outputs(self, forward_wrapper):
671 """Reorders function outputs so captures are last."""
672 def _index_map(original):
673 if original < self._num_inference_outputs:
674 return original
675 if original >= len(forward_wrapper.outputs):
676 return (original - len(forward_wrapper.outputs)
677 + self._num_inference_outputs)
678 return original + len(forward_wrapper.output_tangents)
679 output_indices = nest.map_structure(
680 _index_map, forward_wrapper.output_indices)
681 forward_wrapper.graph.outputs = (
682 forward_wrapper.outputs[:self._num_inference_outputs]
683 + forward_wrapper.output_tangents
684 + forward_wrapper.outputs[self._num_inference_outputs:])
685 return forward_wrapper._replace(output_indices=output_indices)
686
687 def forward(self, inference_args, input_tangents):
688 """Construct or fetch a forward function with side-outputs.
689
690 When graph building without a tape active, symbolic gradients rely on
691 regenerating the backward function for higher-order gradients (to account
692 for new side outputs of the rewritten forward function call). Thus there is
693 no fixed backward function for this case. However, when a tape is active
694 (eager or graph building), we generate fixed backward and forward functions
695 at forward function call time.
696
697 This difference between the tape and non-tape cases is to avoid building
698 unneeded backward functions while graph building (where we may or may not
699 eventually need gradients).
700
701 Args:
702 inference_args: A flat list of Tensors, arguments to the inference
703 function.
704 input_tangents: A flat list of Tensors, jvps associated with
705 `inference_args`.
706
707 Returns:
708 A forward atomic_function.AtomicFunction.
709 """
710 if self._forward is None:
711 (self._forward, self._forward_graph, self._backward,
712 self._forwardprop_output_indices, self._num_forwardprop_outputs) = (
713 self._forward_and_backward_functions(inference_args, input_tangents))
714 return self._forward
715
716 def _wrap_backward_function(self, forward_graph, backward, outputs):
717 """Create a backward function given `outputs` from the forward function."""
718 capture_mapping = dict(
719 zip((ops.tensor_id(t) for t in forward_graph.outputs), outputs))
720 captured_inputs = backward.captured_inputs
721 remapped_captures = [
722 capture_mapping.get(ops.tensor_id(capture), capture)
723 for capture in captured_inputs
724 ]
725 if any(t.graph is forward_graph for t in remapped_captures
726 if not isinstance(t, ops.EagerTensor)):
727 incorrect_mapping = [t for t in remapped_captures
728 if (not isinstance(t, ops.EagerTensor) and
729 t.graph is not forward_graph)]
730 raise errors.InternalError("Failed to map all backward graph captures to "
731 "the forward graph. Incorrectly mapped: "
732 f"{incorrect_mapping}.")
733 # We may need to use zeros_like to get a zero for variant Tensors with
734 # unconnected gradients. We do that in advance so we don't have to hold on
735 # to the outputs themselves, which may not be needed otherwise.
736 variant_zeros_like = {}
737 backward_function_inputs = (len(backward.inputs) - len(captured_inputs))
738 recorded_outputs = []
739 trainable_recorded_outputs = 0
740 skip_positions = []
741 if self._num_forwardprop_outputs and not self._need_gradients_for_jvps:
742 relevant_outputs = (
743 outputs[:self._num_inference_outputs]
744 + outputs[self._num_inference_outputs
745 + self._num_forwardprop_outputs:])
746 else:
747 relevant_outputs = outputs
748 for output_index, output in enumerate(relevant_outputs):
749 if trainable_recorded_outputs < backward_function_inputs:
750 recorded_outputs.append(output)
751 if backprop_util.IsTrainable(output):
752 trainable_recorded_outputs += 1
753 else:
754 skip_positions.append(output_index)
755 if output.dtype == dtypes.variant:
756 variant_zeros_like[output_index] = default_gradient.zeros_like(output)
757
758 def _backward_function_wrapper(*args):
759 """Process output gradients and call the backward function."""
760 if not backward.outputs:
761 return backward.structured_outputs
762
763 processed_args = []
764 input_index = 0
765 for output_index, arg in enumerate(args):
766 # Convert IndexedSlices to dense tensors. The IndexedSlices optimization
767 # is only really effective when doing tf.gather(variable) as the
768 # adjoint functions for most operations are unlikely to preserve the
769 # sparsity in IndexedSlices.
770 if isinstance(arg, indexed_slices.IndexedSlices):
771 arg = ops.convert_to_tensor(arg)
772 if output_index in skip_positions:
773 continue
774 if arg is None:
775 # We're calling a (non-polymorphic) ConcreteFunction, so we need to
776 # have a Tensor value for each Tensor we thought would be trainable
777 # based on its dtype, even if it ended up being unconnected.
778 input_placeholder = backward.inputs[
779 input_index]
780 if input_placeholder.dtype == dtypes.variant:
781 arg = variant_zeros_like[output_index]
782 else:
783 arg = array_ops.zeros(
784 *default_gradient.shape_and_dtype(input_placeholder))
785 processed_args.append(arg)
786 input_index += 1
787 if input_index >= backward_function_inputs:
788 break
789 return backward._call_flat( # pylint: disable=protected-access
790 processed_args, remapped_captures)
791
792 return _backward_function_wrapper, recorded_outputs
793
794 def record(self, flat_outputs, inference_args, input_tangents):
795 """Record the function call operation.
796
797 For backprop, indicates the backward function to use and which new Tensors
798 must be watched. For forwardprop from eager, the function call itself will
799 have produced tangents which need to be recorded.
800
801 Args:
802 flat_outputs: The result of running `forward`.
803 inference_args: A flat list of Tensors with inference inputs to the
804 operation.
805 input_tangents: A flat list of Tensors with input tangents consumed by the
806 operation.
807 """
808 backward_function, to_record = self._wrap_backward_function(
809 self._forward_graph, self._backward, flat_outputs)
810 if self._forwardprop_output_indices:
811 record.record_operation_backprop_only(
812 self._forward.cached_definition.signature.name,
813 to_record, inference_args,
814 backward_function)
815 record.record_operation_forwardprop_only(
816 self._forward.cached_definition.signature.name,
817 flat_outputs, inference_args + input_tangents,
818 backward_function,
819 self._forwardprop_output_indices)
820 else:
821 record.record_operation(self._forward.cached_definition.signature.name,
822 to_record, inference_args + input_tangents,
823 backward_function)
824
825
826class _FirstOrderTapeGradientFunctions(_TapeGradientFunctions):
827 """Caches tape-friendly functions for first-order gradients."""
828
829 def __init__(self, func_graph, attrs, func_graph_deleter,
830 forwardprop_input_indices, delayed_rewrite_functions,
831 need_gradients_for_jvps):
832 super().__init__(func_graph, attrs, func_graph_deleter,
833 forwardprop_input_indices, delayed_rewrite_functions,
834 need_gradients_for_jvps)
835 self._func_graph_deleter = func_graph_deleter
836 self._forwardprop_input_indices = forwardprop_input_indices
837
838 def _forward_and_backward_functions(self, inference_args, input_tangents):
839 """Shortcut for when only first-order gradients are required.
840
841 The returned backward function does not accept gradients with respect to
842 side output of forward_function. This is fine as long as the user can't
843 possibly request second order tape gradients, as when they've used a single
844 non-persistent GradientTape. Since we don't need the backward function to
845 take gradients with respect to side outputs, we can skip some potentially
846 slow graph building.
847
848 Args:
849 inference_args: A flat list of Tensors, arguments to the inference
850 function.
851 input_tangents: A flat list of Tensors, jvps associated with
852 `inference_args`.
853
854 Returns:
855 A tuple of (forward_function, backward_function):
856 forward_function: Takes the same inputs as the inference function, but
857 returns side outputs used by backward_function in addition to the
858 inference function's outputs.
859 backward_function: Takes side outputs from forward_function and
860 gradients with respect to the "real" outputs of forward_function and
861 returns gradients with respect to the inputs.
862 """
863 outputs = self._func_graph.outputs[:self._num_inference_outputs]
864 return self._build_functions_for_outputs(
865 outputs, inference_args, input_tangents)
866
867
868class _HigherOrderTapeGradientFunctions(_TapeGradientFunctions):
869 """Caches tape-friendly functions for higher-order gradients."""
870
871 # TODO(b/136189779): Cond/while under a tape may need similar logic. Consider
872 # generalizing if so.
873 def _forward_and_backward_functions(self, inference_args, input_tangents):
874 """Forward and backward functions suitable for higher-order gradients.
875
876 Unlike in `_FirstOrderTapeGradientFunctions`, the backward function built by
877 this method accepts gradients for all of the outputs of the returned forward
878 function, including side outputs.
879
880 Args:
881 inference_args: A flat list of Tensors, arguments to the inference
882 function.
883 input_tangents: A flat list of Tensors, jvps associated with
884 `inference_args`.
885
886 Returns:
887 A tuple of (forward_function, backward_function):
888 forward_function: Takes the same inputs as the inference function, but
889 returns side outputs used by backward_function in addition to the
890 inference function's outputs.
891 backward_function: Takes side outputs from forward_function and
892 gradients with respect to all of its outputs, real and side. Returns
893 gradients with respect to the inputs.
894 """
895 outputs = []
896 iteration_count = 0
897 # First we need to figure out how many side outputs from the forward pass
898 # will be required. We do this in a temporary graph to avoid actually
899 # running multiple copies of the backward pass (one per _GradientsHelper
900 # call).
901 #
902 # While computing gradients, the backward function captures Tensors from
903 # the forward function. We add these as side outputs of the original
904 # function. However, we then need to accept output gradients with respect
905 # to these side outputs for higher order gradients to work. Thus we loop
906 # until the number of outputs of the function stabilizes. Note that this
907 # is only required for tape gradients, where we need to declare in advance
908 # all of the forward op's outputs: symbolic gradients with tf.gradients
909 # instead rely on regenerating backward functions when higher-order
910 # gradients are requested.
911 while (len(outputs) < len(self._func_graph.outputs)
912 # It's possible for gradient generation to add new ops to the forward
913 # pass. If all of the new outputs are non-trainable, there's no
914 # reason to continue.
915 and any(backprop_util.IsTrainable(output)
916 for output in self._func_graph.outputs[len(outputs):])):
917 iteration_count += 1
918 if iteration_count >= 20 and iteration_count % 5 == 0:
919 new_op_with_trainable_output = None
920 num_new_trainable_outputs = 0
921 for output in self._func_graph.outputs[len(outputs):]:
922 if backprop_util.IsTrainable(output):
923 num_new_trainable_outputs += 1
924 new_op_with_trainable_output = output.op
925 logging.warning(
926 ("Determining side outputs for the function '{}' is taking longer "
927 "than expected ({} iterations, typically this converges in 5 or "
928 "so). This could indicate that a gradient registration is adding "
929 "new ops to the forward pass every time gradients are generated. "
930 "{} new trainable output(s) were added this iteration, one from "
931 "the following op:\n {}\nThis may indicate a TensorFlow bug, or "
932 "an issue in a tf.custom_gradient.")
933 .format(
934 self._func_graph.name, iteration_count,
935 num_new_trainable_outputs, new_op_with_trainable_output))
936 outputs = list(self._func_graph.outputs)
937 self._build_functions_for_outputs(
938 outputs, inference_args, input_tangents)
939
940 (forward_function, forward_graph,
941 backward_function, output_indices, num_output_tangents) = (
942 self._build_functions_for_outputs(
943 outputs, inference_args, input_tangents))
944 if (len(self._func_graph.outputs) > len(outputs)
945 and any(backprop_util.IsTrainable(output)
946 for output in self._func_graph.outputs[len(outputs):])):
947 raise errors.InternalError(
948 "Unexpectedly added new outputs to the forward function when "
949 "building the backward function: "
950 f"{self._func_graph.outputs[len(outputs):]}.")
951 return (forward_function, forward_graph, backward_function, output_indices,
952 num_output_tangents)
953
954
955class _ForwardBackwardCall(object):
956 """Holds the state of a function call between execution and recording."""
957
958 __slots__ = [
959 "_functions", "_inference_args", "_input_tangents", "_tape_watching"
960 ]
961
962 def __init__(self, functions, inference_args, input_tangents, tape_watching):
963 """Collects information about the function call.
964
965 Args:
966 functions: An object which produces forward and backward functions, either
967 a _DelayedRewriteGradientFunctions or a _TapeGradientFunctions object.
968 inference_args: A flat list of Tensors, arguments to the inference
969 function.
970 input_tangents: A flat list of Tensors, jvps associated with
971 `inference_args`.
972 tape_watching: Boolean, with True indicating that recording is necessary.
973 """
974 self._functions = functions
975 self._inference_args = inference_args
976 self._input_tangents = input_tangents
977 self._tape_watching = tape_watching
978
979 def forward(self):
980 """Builds or retrieves a forward function for this call."""
981 forward_function = self._functions.forward(
982 self._inference_args, self._input_tangents)
983 return forward_function, self._inference_args + self._input_tangents
984
985 def record(self, flat_outputs):
986 """Given outputs from the execution of `forward`, records the operation."""
987 if (self._tape_watching
988 and not isinstance(flat_outputs, ops.Operation)
989 and flat_outputs is not None):
990 # We only record function calls which have outputs, and then only when a
991 # tape is watching.
992 self._functions.record(
993 flat_outputs, self._inference_args, self._input_tangents)
994
995
996class ConcreteFunction(core.ConcreteFunction, trackable.Trackable):
997 """A `tf.types.experimental.ConcreteFunction` created from `tf.function`."""
998
999 def __init__(self, func_graph, attrs=None, shared_func_graph=True, spec=None):
1000 """Initialize a `ConcreteFunction`.
1001
1002 Args:
1003 func_graph: An instance of FuncGraph: the function body to wrap.
1004 attrs: (optional) dict mapping names of attributes to their AttrValue
1005 values. Attributes in `attrs` will be included in this function's
1006 definition.
1007 shared_func_graph: If False, the ConcreteFunction takes ownership of
1008 `func_graph` and will break reference cycles when it is deleted. This
1009 makes the FuncGraph inoperable.
1010 spec: FunctionSpec for the original function. If not specified, then this
1011 ConcreteFunction may only be called using the flat signature.
1012
1013 Raises:
1014 ValueError: If number of input_placeholders is not equal to the number
1015 of function inputs.
1016 """
1017 # _arg_keywords and _num_positional_args define the flat signature. They
1018 # are assigned after construction.
1019 self._arg_keywords = None
1020 self._num_positional_args = None
1021
1022 self._func_graph = func_graph
1023 self._captured_inputs = self._func_graph.external_captures + self._func_graph.deferred_external_captures
1024
1025 # spec defines the structured signature.
1026 self._set_function_spec(spec)
1027
1028 if attrs and attributes_lib.IMPLEMENTS in attrs:
1029 # The alternative is to silently drop "implements" tag
1030 # but it seems likely it would lead to hard to catch bugs.
1031 # Another alternative is to make func_body to preserve the order
1032 # of arguments if variables are present. Yet another option
1033 # is to automatically replace variables as arguments to functions
1034 # to v.read_value() whenever "implements" tag is present
1035 # Anytime we annotate existing function we probably want to wrap
1036 # it with safe read_value for backward compatibility.
1037 has_resource_vars = any(
1038 inp.dtype == dtypes.resource for inp in self.inputs)
1039
1040 assert not any((has_resource_vars, self._captured_inputs)), (
1041 'Function {name} has "{attr}={value}" attribute and thus can not '
1042 "depend on any tensors outside of its signature or modify variables. "
1043 "\n\nNote: variables are always captured and cause function "
1044 "re-tracing for every variable called.\n"
1045 " inputs: {inputs}\n captures: {captured}\n\n"
1046 "To pass a variable to such function use "
1047 "use variable.read_value().".format(
1048 name=func_graph.name,
1049 attr=attributes_lib.IMPLEMENTS,
1050 value=attrs[attributes_lib.IMPLEMENTS],
1051 inputs=self.inputs,
1052 captured=self._captured_inputs))
1053 self._output_shapes = tuple(
1054 output.shape for output in self._func_graph.outputs)
1055 self._attrs = _parse_func_attrs(attrs or {})
1056
1057 if shared_func_graph:
1058 self._garbage_collector = None
1059 else:
1060 self._garbage_collector = ConcreteFunctionGarbageCollector(func_graph)
1061
1062 # Pairs of forward and backward functions used for computing gradients.
1063 #
1064 # These each get a reference to the FuncGraph deleter since they use the
1065 # FuncGraph directly.
1066 self._delayed_rewrite_functions = _DelayedRewriteGradientFunctions(
1067 func_graph, self._attrs, self._garbage_collector)
1068 self._first_order_tape_functions = {}
1069 self._higher_order_tape_functions = {}
1070 # Cache the inference function to avoid a (Python) function call when not
1071 # building gradients.
1072 self._inference_function = self._delayed_rewrite_functions.forward()
1073
1074 def _set_function_spec(self, spec):
1075 """Enables the structured signature by supplying a spec."""
1076 self._function_spec = None
1077 self._pre_initialized_function_spec = spec
1078 self._initialize_function_spec()
1079
1080 def _initialize_function_spec(self):
1081 """Updates `self._function_spec` to include varargs and bound variables.
1082
1083 Adds new positional arguments for any varargs (i.e., for args that are
1084 in `structured_input_signature`, but not in the original fullargspec.args).
1085
1086 Replaces `defaults` and `kwonlydefaults` with the `BOUND_VALUE`, for
1087 all args and kwargs in `structured_input_signature`.
1088
1089 Sets `varkw` and `varargs` to None.
1090 """
1091 if self._pre_initialized_function_spec is None:
1092 return # e.g., SavedBareConcreteFunction doesn't have function_spec yet.
1093 assert not self._function_spec, "already initialized"
1094 spec = self._pre_initialized_function_spec
1095 unconstrainted_poly_type = function_type_lib.FunctionType(
1096 [
1097 function_type_lib.Parameter(p.name, p.kind, p.optional, None)
1098 for p in spec.function_type.parameters.values()
1099 ]
1100 )
1101 arg_specs, kwarg_specs = self.structured_input_signature
1102
1103 _, func_type, _ = function_type_lib.canonicalize_to_monomorphic(
1104 arg_specs,
1105 {
1106 function_type_lib.sanitize_arg_name(k): v
1107 for k, v in kwarg_specs.items()
1108 },
1109 self._pre_initialized_function_spec.default_values,
1110 {},
1111 unconstrainted_poly_type,
1112 )
1113
1114 self._function_spec = function_spec.FunctionSpec(
1115 func_type,
1116 {d: function_spec.BOUND_VALUE for d in spec.default_values},
1117 spec.is_pure,
1118 name=self._func_graph.name,
1119 )
1120
1121 @property
1122 def variables(self):
1123 """Sequence of variables for this function."""
1124 return tuple(self._func_graph.variables)
1125
1126 def set_variables(self, variables):
1127 self._func_graph.variables = variables
1128
1129 @property
1130 def trainable_variables(self):
1131 """Sequence of trainable variables for this function."""
1132 return tuple(self._func_graph.trainable_variables)
1133
1134 def __call__(self, *args, **kwargs):
1135 """Executes the wrapped function.
1136
1137 ConcreteFunctions have two signatures:
1138
1139 * The signature of the original function wrapped by this ConcreteFunction.
1140 * A flat signature, where each argument accepts a single Tensor.
1141
1142 The original function signature is generally preferred, but the flat input
1143 signature is supported for backward compatibility.
1144
1145 ### Original Function Signature
1146
1147 When calling a ConcreteFunction with the signature of the original function,
1148 each argument must match the type or value that was used when the
1149 ConcreteFunction's graph was traced. In particular:
1150
1151 * Tensor arguments (including CompositeTensors, such as RaggedTensor) must
1152 have matching `TypeSpec`s.
1153 * Non-Tensor arguments (such as booleans or ints) must have equal values.
1154 * Nested arguments (such as lists, tuples, or dictionaries) must have the
1155 same nesting structure; and each nested value must have a matching type
1156 or value.
1157
1158 The default value for any arguments that were traced with non-Tensor values
1159 is the value that was used in the trace. Arguments that were traced with
1160 tensor arguments do not have a default value (even if the original function
1161 had a default value for that argument).
1162
1163 ### Flat Signature
1164
1165 When calling a ConcreteFunction with the flat signature, the arguments
1166 correspond to the flattened component tensors of the arguments that were
1167 used to construct the ConcreteFunction. Parameter names are assigned based
1168 on `TensorSpec.name` (when specified) or the original argument names (with
1169 suffixes automatically added for nested arguments or composite tensors with
1170 multiple components).
1171
1172 Args:
1173 *args: Positional arguments to the concrete function.
1174 **kwargs: Keyword arguments to the concrete function.
1175
1176 Returns:
1177 The result of applying the TF function on the given Tensors.
1178
1179 Raises:
1180 AssertionError: If this `ConcreteFunction` was not created through
1181 `get_concrete_function`.
1182 TypeError: If the arguments do not match the function's signature.
1183 """
1184 return self._call_impl(args, kwargs)
1185
1186 def _call_impl(self, args, kwargs):
1187 """See `__call__` for details."""
1188 with trace.Trace(self._func_graph.name, tf_function_call="concrete"):
1189 # Construct the list of input tensors: check if the structured signature
1190 # applies first; and if not, then use the flat signature.
1191 if self._function_spec is not None:
1192 try:
1193 return self._call_with_structured_signature(args, kwargs)
1194 except TypeError as structured_err:
1195 try:
1196 return self._call_with_flat_signature(args, kwargs)
1197 except TypeError:
1198 raise structured_err
1199
1200 return self._call_with_flat_signature(args, kwargs)
1201
1202 def _call_with_flat_signature(self, args, kwargs):
1203 """Executes the wrapped function with the flat signature.
1204
1205 Args:
1206 args: Positional arguments to the concrete function.
1207 kwargs: Keyword arguments to the concrete function.
1208
1209 Returns:
1210 The result of applying the function on the Tensors/Variables contained in
1211 `args` and `kwargs`.
1212 Raises:
1213 TypeError: if `args` and `kwargs` do not match the flat signature of this
1214 `ConcreteFunction`.
1215 """
1216 if len(args) > self._num_positional_args:
1217 raise TypeError(
1218 f"{self._flat_signature_summary()} takes {self._num_positional_args} "
1219 f"positional arguments, got {len(args)}.")
1220 args = list(args)
1221 kwargs = dict(kwargs)
1222 kwargs = {
1223 function_type_lib.sanitize_arg_name(k): v for k, v in kwargs.items()
1224 }
1225 for keyword in self._arg_keywords[len(args):]:
1226 try:
1227 args.append(
1228 kwargs.pop(
1229 function_type_lib.sanitize_arg_name(compat.as_str(keyword))))
1230 except KeyError:
1231 specified_keywords = (
1232 list(self._arg_keywords[:len(args)]) + list(kwargs.keys()))
1233 missing_required_args = sorted(
1234 set(self._arg_keywords) - set(specified_keywords))
1235 raise TypeError(f"{self._flat_signature_summary()} missing required "
1236 f"arguments: {', '.join(missing_required_args)}.")
1237 if kwargs:
1238 positional_arg_keywords = set(self._arg_keywords[:len(args)])
1239 for unused_key in kwargs:
1240 if unused_key in positional_arg_keywords:
1241 raise TypeError(f"{self._flat_signature_summary()} got two values "
1242 f"for '{unused_key}'.")
1243 raise TypeError(f"{self._flat_signature_summary()} got unexpected "
1244 f"keyword arguments: {', '.join(sorted(kwargs))}.")
1245
1246 for i, arg in enumerate(args):
1247 if not isinstance(
1248 arg, (ops.Tensor, resource_variable_ops.BaseResourceVariable)):
1249 raise TypeError(f"{self._flat_signature_summary()}: expected argument "
1250 f"#{i}(zero-based) to be a Tensor; "
1251 f"got {type(arg).__name__} ({arg}).")
1252 return self._call_flat(args, self.captured_inputs)
1253
1254 def _call_with_structured_signature(self, args, kwargs):
1255 """Executes the wrapped function with the structured signature.
1256
1257 Args:
1258 args: Positional arguments to the concrete function.
1259 kwargs: Keyword arguments to the concrete function.
1260
1261 Returns:
1262 The result of applying the function on the Tensors/Variables contained in
1263 `args` and `kwargs`.
1264 Raises:
1265 TypeError: if `args` and `kwargs` do not match the structured signature
1266 of this `ConcreteFunction`.
1267 """
1268 args, kwargs, filtered_flat_args = (
1269 self._function_spec.canonicalize_function_inputs(args, kwargs))
1270 return self._call_flat(
1271 filtered_flat_args,
1272 captured_inputs=self.captured_inputs)
1273
1274 def _call_flat(self, args, captured_inputs):
1275 """Executes the wrapped function.
1276
1277 Args:
1278 args: a list of Tensors or Variables. Arguments from the Python function
1279 should be filtered before calling this method: objects aside from
1280 Tensors, CompositeTensors, and Variables are ignored. Any
1281 CompositeTensors other than ResourceVariables should be expanded before
1282 calling this method.
1283 captured_inputs: the captured inputs that are also part of the input args
1284 to the actual execution. By default, it should be self._captured_inputs.
1285 Returns:
1286 The result of applying the TF function to `args`.
1287
1288 Raises:
1289 ValueError: If `args` contains anything other than Tensors or Variables.
1290 """
1291 ctx = context.context()
1292 executing_eagerly = ctx.executing_eagerly()
1293
1294 # Copy saveable status of function's graph to current FuncGraph.
1295 default_graph = ops.get_default_graph()
1296 if default_graph.building_function and not self._func_graph.saveable:
1297 default_graph.mark_as_unsaveable(self._func_graph.saving_errors)
1298
1299 if (record.could_possibly_record() or
1300 hasattr(default_graph, "watch_variable")):
1301 for v in self._func_graph.variables:
1302 resource_variable_ops.variable_accessed(v)
1303
1304 tensor_inputs = []
1305 variables_used = set([])
1306 for i, arg in enumerate(args):
1307 if isinstance(arg, resource_variable_ops.BaseResourceVariable):
1308 # We can pass a variable more than once, and in this case we need to
1309 # pass its handle only once.
1310 if id(arg.handle) in variables_used:
1311 continue
1312 resource_variable_ops.variable_accessed(arg)
1313 tensor_inputs.append(arg.handle)
1314 variables_used.add(id(arg.handle))
1315 elif isinstance(arg, ops.Tensor):
1316 tensor_inputs.append(arg)
1317 else:
1318 raise ValueError(f"{i:d}-th input {arg} must be a Tensor, got "
1319 f"{type(arg)} when calling {self._func_graph.name}.")
1320
1321 if not executing_eagerly:
1322 for i, tensor_input in enumerate(tensor_inputs):
1323 # Can not compare shapes in these cases
1324 # TODO(b/216506654): Consider moving this check elsewhere and making it
1325 # work for all types (e.g. by including shape for Variables).
1326 if (tensor_input.dtype == dtypes.resource or
1327 tensor_input.dtype == dtypes.variant):
1328 continue
1329
1330 # If we're graph building, shape inference is on. We check for input
1331 # compatibility up front to avoid hard to debug incompatibilities
1332 # later.
1333 graph_input_shape = tensor_shape.TensorShape(
1334 self._func_graph.inputs[i].shape)
1335 if not graph_input_shape.is_compatible_with(tensor_input.shape):
1336 raise ValueError(
1337 f"Tensor {tensor_input} is not compatible with the shape this "
1338 f"function was traced with. Expected shape "
1339 f"{self._func_graph.inputs[i].shape}, but got shape "
1340 f"{tensor_input.shape}.\n\nIf you called get_concrete_function, "
1341 f"you may need to pass a tf.TensorSpec(..., shape=...) with a "
1342 f"less specific shape, having None on axes which can vary.")
1343
1344 args = tensor_inputs + captured_inputs
1345 possible_gradient_type = gradients_util.PossibleTapeGradientTypes(args)
1346 if (possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_NONE
1347 and executing_eagerly):
1348 # No tape is watching; skip to running the function.
1349 return self._build_call_outputs(self._inference_function(*args))
1350 forward_backward = self._select_forward_and_backward_functions(
1351 args,
1352 possible_gradient_type,
1353 executing_eagerly)
1354 forward_function, args_with_tangents = forward_backward.forward()
1355 if executing_eagerly:
1356 flat_outputs = forward_function(*args_with_tangents)
1357 else:
1358 with default_graph._override_gradient_function( # pylint: disable=protected-access
1359 {"PartitionedCall": self._get_gradient_function(),
1360 "StatefulPartitionedCall": self._get_gradient_function()}):
1361 flat_outputs = forward_function(*args_with_tangents)
1362 forward_backward.record(flat_outputs)
1363 return self._build_call_outputs(flat_outputs)
1364
1365 @property
1366 def name(self):
1367 """`ConcreteFunction` name."""
1368 return self._delayed_rewrite_functions.forward().name
1369
1370 @property
1371 def graph(self):
1372 """Returns the graph from which this function was constructed."""
1373 return self._func_graph
1374
1375 @property
1376 def inputs(self):
1377 """Returns tensors in `self.graph` corresponding to arguments."""
1378 return self._func_graph.inputs
1379
1380 @property
1381 def structured_input_signature(self):
1382 """Returns structured signature for this concrete function.
1383
1384 Returns:
1385 A tuple `(args, kwargs)`, where:
1386
1387 * `args` is a tuple that specifies the expected type or value each for
1388 positional argument.
1389 * `kwargs` is a dictionary that specifies the expected type or value
1390 for each keyword-only argument.
1391
1392 The type or value for each argument is specified using one of the
1393 following:
1394
1395 * A `tf.TypeSpec`, indicating that a Tensor or other TensorFlow-native
1396 value is expected.
1397 * A Python value, such as an integer, indicating that an equal value
1398 is expected.
1399 * A nested structure of `tf.TypeSpec`s and Python values, indicating
1400 that a corresponding nested structure is expected.
1401 """
1402 return self._func_graph.structured_input_signature
1403
1404 @property
1405 def outputs(self):
1406 """Returns tensors in `self.graph` corresponding to returned tensors."""
1407 return self._func_graph.outputs
1408
1409 @property
1410 def structured_outputs(self):
1411 """Returns outputs in `self.graph` as returned by the original function."""
1412 return self._func_graph.structured_outputs
1413
1414 def set_external_captures(self, captures):
1415 """Updates the function capture values.
1416
1417 The new values must have tensor types and shapes consistent with the
1418 original captures of the concrete function, but it is allowed to change a
1419 value captured with a deferred one and vice-versa.
1420
1421 Args:
1422 captures: A list of tensors or closures. Tensors are value captures, and
1423 closures are call-time (deferred captures).
1424 """
1425 # TODO(wxinyi): 1. verify that the new captures' type spec is compatible
1426 # with the original's. However, doing so requires MirroredVariable captures
1427 # initialized. 2. replace the original/new captures/deferred
1428 # captures in the wrapped graph. Doing such for a capture-to-deferred
1429 # capture replacement requires more arguments than the deferred capture
1430 # itself, e.g. default value, spec.
1431 self._captured_inputs = captures
1432
1433 def replace_capture_with_deferred_capture(self,
1434 tensor,
1435 closure,
1436 spec,
1437 placeholder=None,
1438 default_value=None):
1439 """Replaces existing capture `tensor` with a deferred capture `closure`.
1440
1441 This API replaces the capture `tensor` from the concrete function's captured
1442 inputs list, and places the deferred capture `closure` in
1443 its spot so the order of captured inputs is preserved. This is important
1444 because the old `tensor` and the new `closure` will have the same internal
1445 placeholder, which can be passed through the `placeholder` argument, or
1446 skipped, in which case we find the placeholder from internal inputs by
1447 indexing `tensor` in the external captured inputs list. Thus, it is
1448 important that the new deferred capture has output spec (specified by the
1449 `spec` argument) compatible with the internal placeholder (`placeholder`)
1450 and the original capture (`tensor`).
1451
1452 For example,
1453
1454 ```python
1455 bool_captured_tensor = tf.constant(True)
1456 float_captured_tensor = tf.constant([3.], dtype=tf.float32)
1457 value = tf.constant([2.], dtype=tf.float32)
1458
1459 @tf.function
1460 def fn():
1461 deferred_tensor = ops.get_default_graph().capture_call_time_value(
1462 lambda: value,
1463 tf.TensorSpec(shape=(1,), dtype=tf.float32))
1464 if bool_captured_tensor:
1465 return deferred_tensor
1466 else:
1467 return deferred_tensor + float_captured_tensor
1468
1469 concrete_fn = fn.get_concrete_function()
1470 print(concrete_fn()) # tf.Tensor([2.], shape=(1,), dtype=float32)
1471
1472 new_bool_captured_tensor = constant_op.constant(False)
1473 def bool_closure():
1474 return new_bool_captured_tensor
1475
1476 concrete_fn.replace_capture_with_deferred_capture(
1477 bool_captured_tensor,
1478 bool_closure,
1479 spec=tensor_spec.TensorSpec(shape=(), dtype=dtypes.bool))
1480
1481 print(concrete_fn()) # tf.Tensor([5.], shape=(1,), dtype=float32)
1482 ```
1483
1484 Args:
1485 tensor: Tensor already captured. This `tensor` should be listed in
1486 concrete_function.captured_inputs except when it's empty such as when
1487 the concrete function is restored from SavedModel.
1488 closure: function which takes no arguments, to be evaluated at function
1489 call time, returning a nest of tensors compatible with `spec`.
1490 spec: nest of TypeSpec for the value to capture.
1491 placeholder: optional. The internal placeholder corresponding to the
1492 captured `tensor` and the new `closure`.
1493 default_value: optional value to use in environments that cannot safely
1494 evaluate closure.
1495 """
1496 capture_index = None
1497 for i, capture in enumerate(self._captured_inputs):
1498 if id(tensor) == id(capture):
1499 capture_index = i
1500 break
1501
1502 if placeholder is None:
1503 if capture_index is None:
1504 raise ValueError(
1505 f"Did not find `tensor` argument {tensor} in the ConcreteFunction's"
1506 " captured inputs list, and did not receive a placeholder argument."
1507 " Thus we're unable to infer the internal placeholder. ")
1508
1509 placeholder = self.inputs[-len(self._captured_inputs) + capture_index]
1510
1511 if not (spec.is_compatible_with(tensor) or
1512 spec.is_compatible_with(placeholder)):
1513 raise ValueError(
1514 f"Attempting to substitute closure with spec {spec} that's "
1515 f"incompatible with the original capture {tensor} or the internal "
1516 f"placeholder {placeholder}.")
1517
1518 self._func_graph.replace_capture_with_deferred_capture(
1519 tensor=tensor,
1520 closure=closure,
1521 spec=spec,
1522 placeholder=placeholder,
1523 default_value=default_value)
1524
1525 if capture_index is not None:
1526 self._captured_inputs[capture_index] = closure
1527
1528 @property
1529 def captured_inputs(self):
1530 """Returns external Tensors captured by this function.
1531
1532 self.__call__(*args) passes `args + self.captured_inputs` to the function.
1533 """
1534 return nest.flatten(
1535 [x() if callable(x) else x for x in self._captured_inputs],
1536 expand_composites=True)
1537
1538 @property
1539 def function_def(self):
1540 """Returns a `FunctionDef` object representing this function."""
1541 return self._delayed_rewrite_functions.forward().cached_definition
1542
1543 @property
1544 def output_shapes(self):
1545 """The function's output shapes."""
1546 return nest.map_structure(
1547 lambda x: getattr(x, "shape", tensor_shape.TensorShape(None)),
1548 composite_tensor.replace_composites_with_components(
1549 self._func_graph.structured_outputs),
1550 expand_composites=False)
1551
1552 @property
1553 def output_dtypes(self):
1554 # TODO(akshayka): Consider removing this.
1555 return nest.map_structure(
1556 lambda x: x.dtype if x is not None else None,
1557 composite_tensor.replace_composites_with_components(
1558 self._func_graph.structured_outputs),
1559 expand_composites=False)
1560
1561 def add_to_graph(self, g=None, overwrite=False):
1562 """Registers the function, adds it to the graph g or default graph.
1563
1564 Args:
1565 g: If specified, registers the function with this graph. Defaults to the
1566 current context (either the default graph or the eager context).
1567 overwrite: A bool. If True, its forward function will overwrite
1568 any existing function of the same signature name in the graph `g`.
1569 """
1570 # If we are not executing eagerly, adds the function to default graph if no
1571 # graph is specified.
1572 # In case of eager execution, function definition gets added to context
1573 # during construction itself.
1574
1575 if not context.executing_eagerly() and not g:
1576 g = ops.get_default_graph()
1577
1578 if g is not None:
1579 g._add_function_recursive(self._delayed_rewrite_functions.forward()) # pylint: disable=protected-access
1580
1581 def add_gradient_functions_to_graph(self, g=None):
1582 """Add forward/backward functions to graph `g` or the current context."""
1583 if not context.executing_eagerly() and not g:
1584 g = ops.get_default_graph()
1585 g._add_function_recursive(self._delayed_rewrite_functions.forward()) # pylint: disable=protected-access
1586 forward_function, backward_function = (
1587 self._delayed_rewrite_functions.forward_backward())
1588 g._add_function_recursive(forward_function) # pylint: disable=protected-access
1589 backward_function.add_to_graph(g)
1590
1591 def _get_gradient_function(self):
1592 """Returns gradient function. It will be lazily created at first call."""
1593 return self._delayed_rewrite_functions._rewrite_forward_and_call_backward # pylint: disable=protected-access
1594
1595 def _select_forward_and_backward_functions(
1596 self, args, possible_gradient_type, executing_eagerly):
1597 """Selects forward and backward functions based on the calling context.
1598
1599 The forward function computes the "real" function outputs, `self._outputs`,
1600 and any extra values needed by the corresponding backward function.
1601
1602 Args:
1603 args: A flat list of Tensors with all of the inputs to the forward
1604 function (including user-specified and captured inputs).
1605 possible_gradient_type: One of gradients_util.POSSIBLE_GRADIENT_TYPES_*.
1606 executing_eagerly: Boolean, the value of context.executing_eagerly().
1607
1608 Returns:
1609 An object with a `forward` method returning a tuple of (forward_function :
1610 AtomicFunction, augmented_arguments : List), and a corresponding
1611 `record` method which takes outputs from the forward function and records
1612 the operation. forward_function should be called with augmented_arguments.
1613 """
1614 if executing_eagerly:
1615 input_tangents = forwardprop_util.pack_tangents(args)
1616 else:
1617 input_tangents = forwardprop_util.TangentInfo()
1618 need_gradients_for_jvps = record.should_record_backprop(
1619 input_tangents.tangents)
1620 # Allows re-use of forward and backward function pairs depending on the
1621 # tapes and forward accumulators watching its inputs.
1622 cache_key = (need_gradients_for_jvps, input_tangents.indices)
1623 if (possible_gradient_type
1624 == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER):
1625 if input_tangents.indices or executing_eagerly:
1626 # There is a single non-persistent tape active, so the user can only
1627 # request first-order gradients from a tape. We can spend less time
1628 # graph building since we know this.
1629 #
1630 # We may still end up computing higher-order gradients, but that'd be
1631 # through `tf.gradients`, which can re-write the forward pass and so
1632 # needs no preparation here.
1633 functions = self._first_order_tape_functions.get(cache_key, None)
1634 if functions is None:
1635 functions = _FirstOrderTapeGradientFunctions(
1636 self._func_graph, self._attrs, self._garbage_collector,
1637 forwardprop_input_indices=input_tangents.indices,
1638 delayed_rewrite_functions=self._delayed_rewrite_functions,
1639 need_gradients_for_jvps=need_gradients_for_jvps)
1640 self._first_order_tape_functions[cache_key] = functions
1641 return _ForwardBackwardCall(
1642 functions, args, input_tangents.tangents, tape_watching=True)
1643 else:
1644 # We can avoid computing second-order gradients in some cases by doing a
1645 # delayed rewrite when graph building. Since we know we'll only compute
1646 # first-order tape gradients, the delayed rewrite is safe: we won't need
1647 # to tell the tape about side outputs.
1648 #
1649 # TODO(allenl): This case is really dirty. It would be better if we
1650 # could temporarily pop all of the current tapes to avoid
1651 # accidentally taking second-order gradients.
1652 return _ForwardBackwardCall(
1653 self._delayed_rewrite_functions, args, input_tangents.tangents,
1654 tape_watching=True)
1655 elif (possible_gradient_type
1656 == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER):
1657 # Either there's a persistent tape watching, or there are multiple nested
1658 # tapes. Either way, the user may request higher-order gradients. We'll
1659 # spend a bit more time and make sure higher-order gradients are correct.
1660 functions = self._higher_order_tape_functions.get(
1661 cache_key, None)
1662 if functions is None:
1663 functions = _HigherOrderTapeGradientFunctions(
1664 self._func_graph, self._attrs, self._garbage_collector,
1665 forwardprop_input_indices=input_tangents.indices,
1666 delayed_rewrite_functions=self._delayed_rewrite_functions,
1667 need_gradients_for_jvps=need_gradients_for_jvps)
1668 self._higher_order_tape_functions[cache_key] = functions
1669 return _ForwardBackwardCall(functions, args, input_tangents.tangents,
1670 tape_watching=True)
1671 # else possible_gradient_type == POSSIBLE_GRADIENT_TYPES_NONE, meaning no
1672 # tape is recording.
1673 return _ForwardBackwardCall(
1674 self._delayed_rewrite_functions, args, input_tangents.tangents,
1675 tape_watching=False)
1676
1677 def _build_call_outputs(self, result):
1678 """Maps the fdef output list to actual output structure.
1679
1680 Args:
1681 result: Output lists defined by FunctionDef.
1682 Returns:
1683 The actual call output.
1684 """
1685 # TODO(jlchu): call C++ version in function.cc when speed is improved
1686 if self._func_graph.structured_outputs is None:
1687 return result
1688
1689 # Replace outputs with results, skipping over any 'None' values.
1690 outputs_list = nest.flatten(
1691 self._func_graph.structured_outputs, expand_composites=True)
1692 j = 0
1693 for i, o in enumerate(outputs_list):
1694 if o is not None:
1695 handle_data_util.copy_handle_data(self.outputs[j], result[j])
1696 outputs_list[i] = result[j]
1697 j += 1
1698 ret = nest.pack_sequence_as(self._func_graph.structured_outputs,
1699 outputs_list, expand_composites=True)
1700 return ret
1701
1702 @property
1703 def _as_name_attr_list(self):
1704 """Returns a `NameAttrList` representing this function."""
1705 ret = attr_value_pb2.NameAttrList(name=self.name)
1706 for name, value in self._attrs.items():
1707 ret.attr[name].CopyFrom(value)
1708 return ret
1709
1710 def _structured_signature_summary(self, default_values=False):
1711 """Returns a string summarizing this function's structured signature.
1712
1713 Args:
1714 default_values: If true, then include default values in the signature.
1715
1716 Returns:
1717 A `string`.
1718 """
1719 # Note: we can't just use self._funcion_spec.signature_summary(), because
1720 # that would show "BOUND_VALUE" as the default value for all arguments.
1721 assert self._function_spec is not None
1722 arg_specs, kwarg_specs = self.structured_input_signature
1723 arg_names = list(self._function_spec.arg_names)
1724
1725 # If an explicit input_signature is provided to @tf.function, then any
1726 # arguments with defaults that are not covered by that explicit signature
1727 # are simply dropped from the signature.
1728 # TODO(b/159639913) Look into whether dropping arguments with default values
1729 # from the signature is the right thing to do.
1730 arg_names = arg_names[:len(arg_specs)]
1731
1732 if default_values:
1733 for i in range(len(arg_names)):
1734 if not _contains_type_spec(arg_specs[i]):
1735 arg_names[i] += "={}".format(arg_specs[i])
1736 if kwarg_specs:
1737 arg_names.append("*")
1738 for name, spec in kwarg_specs.items():
1739 arg_names.append(name)
1740 if default_values and not _contains_type_spec(spec):
1741 arg_names[-1] += "={}".format(spec)
1742 signature = f"{self._func_graph.name}({', '.join(arg_names)})"
1743
1744 return signature
1745
1746 def _flat_signature_summary(self):
1747 """Returns a string summarizing this function's flat signature."""
1748 assert self._arg_keywords is not None
1749 assert self._num_positional_args is not None
1750 arg_names = self._arg_keywords
1751 if self._num_positional_args > len(arg_names):
1752 arg_names.extend(
1753 "<arg{}>".format(i + 1)
1754 for i in range(len(arg_names), self._num_positional_args))
1755 return f"{self._func_graph.name}({', '.join(arg_names)})"
1756
1757 def pretty_printed_signature(self, verbose=True):
1758 """Returns a string summarizing the signature of this concrete function."""
1759 if not verbose:
1760 return self._structured_signature_summary(default_values=True)
1761
1762 def pretty_print_spec(spec):
1763 """Returns a string describing the spec for a single argument."""
1764 if isinstance(spec, tensor_spec.TensorSpec):
1765 return "{} Tensor, shape={}".format(spec.dtype.name, spec.shape)
1766 elif nest.is_nested(spec):
1767 pieces = nest.flatten(spec, expand_composites=False)
1768 markers = [_Marker("<{}>".format(i + 1)) for i in range(len(pieces))]
1769 structure = nest.pack_sequence_as(spec, markers)
1770 # Ensure dictionaries are sorted by key (for determinism)
1771 result = pprint.pformat(structure, width=10000)
1772 for (marker, piece) in zip(markers, pieces):
1773 result += "\n {}: {}".format(marker, pretty_print_spec(piece))
1774 return result
1775 else:
1776 return repr(spec)
1777
1778 lines = [self._structured_signature_summary(default_values=True)]
1779 arg_specs, kwarg_specs = self.structured_input_signature
1780 names = list(self._function_spec.arg_names)
1781
1782 # If an explicit input_signature is provided to @tf.function, then any
1783 # arguments with defaults that are not covered by that explicit signature
1784 # are simply dropped from the signature.
1785 # TODO(b/159639913) Look into whether dropping arguments with default values
1786 # from the signature is the right thing to do.
1787
1788 # Note: we can skip bound args, since we already displayed their bound
1789 # value in the signature summary.
1790 arg_details = []
1791 for (name, spec) in zip(names[:len(arg_specs)], list(arg_specs)):
1792 if _contains_type_spec(spec):
1793 arg_details.append(" {}: {}".format(name, pretty_print_spec(spec)))
1794
1795 if kwarg_specs:
1796 for kwarg in sorted(kwarg_specs):
1797 spec = kwarg_specs[kwarg]
1798 if _contains_type_spec(spec):
1799 arg_details.append(" {}: {}".format(
1800 kwarg, pretty_print_spec(spec)))
1801
1802 if arg_details:
1803 lines.append(" Args:")
1804 lines.extend(arg_details)
1805 lines.append(" Returns:")
1806
1807 def spec_from_value(value):
1808 # For loaded function, structured_outputs are already specs.
1809 if isinstance(value, type_spec.TypeSpec):
1810 return value
1811 return type_spec.type_spec_from_value(value)
1812
1813 lines.append(" {}".format(
1814 pretty_print_spec(
1815 nest.map_structure(spec_from_value, self.structured_outputs))))
1816
1817 return "\n".join(lines)
1818
1819 def __repr__(self):
1820 if self._function_spec is not None:
1821 return "<ConcreteFunction {} at 0x{:X}>".format(
1822 self.pretty_printed_signature(verbose=False), id(self))
1823 elif not (self._num_positional_args is None or self._arg_keywords is None):
1824 return "<ConcreteFunction {} at 0x{:X}>".format(
1825 self._flat_signature_summary(), id(self))
1826 else:
1827 return object.__repr__(self)
1828
1829 def __str__(self):
1830 if self._function_spec is not None:
1831 return "ConcreteFunction {}".format(self.pretty_printed_signature())
1832 else:
1833 return self.__repr__()
1834
1835 def _trackable_children(self, save_type="checkpoint", **kwargs):
1836 """Implements `Trackable`."""
1837 if save_type == "checkpoint":
1838 # Checkpoint dependencies do not include functions at all. Users
1839 # expect the checkpointed variables to be saved using the model
1840 # architecture, e.g. `model.layers[1].kernel` or `model.variables`.
1841 return {}
1842
1843 captured_trackables = {}
1844 for n, (capture, _) in enumerate(self.graph.captures):
1845 if (capture.dtype not in (dtypes.variant, dtypes.resource) and
1846 not resource_variable_ops.is_resource_variable(capture)):
1847 # Variant/resource type tensors are skipped since we have no way of
1848 # getting the `Trackable` wrapper for these tensors. The wrappers are
1849 # expected to be elsewhere in the saved object graph.
1850 # TODO(b/223866972): Directly encode/decode tensor captures.
1851
1852 # Resource variable captures are also skipped at this time, to maintain
1853 # existing behavior.
1854 # TODO(b/217979389): Return the non-constant captures as children.
1855
1856 captured_trackables[f"capture_{n}"] = capture
1857
1858 return captured_trackables
1859
1860 def _deserialization_dependencies(self, children):
1861 return children
1862
1863 def _export_to_saved_model_graph(self, object_map, tensor_map,
1864 **unused_kwargs):
1865 if not self.graph.saveable:
1866 raise ValueError(
1867 (f"Unable to save function {self.name} for the following reason(s):\n"
1868 + "\n".join(self.graph.saving_errors)))
1869 self.add_to_graph()
1870 object_map[self] = saved_model_exported_concrete.ExportedConcreteFunction(
1871 self, tensor_map)
1872 return []
1873
1874
1875_pywrap_utils.RegisterType("Tensor", ops.Tensor)
1876_pywrap_utils.RegisterType("EagerTensor", ops.EagerTensor)
1877_pywrap_utils.RegisterType("IndexedSlices", indexed_slices.IndexedSlices)
1878
1879
1880class ConcreteFunctionGarbageCollector:
1881 """Cleans up reference cycles when a `ConcreteFunction` goes out of scope."""
1882
1883 __slots__ = ["_func_graph"]
1884
1885 def __init__(self, func_graph):
1886 self._func_graph = func_graph
1887
1888 def release(self):
1889 """Call off the FuncGraph deletion."""
1890 self._func_graph = None
1891
1892 def __del__(self):
1893 if func_graph_module is None or self._func_graph is None:
1894 return
1895 try:
1896 func_graph_module.dismantle_func_graph(self._func_graph)
1897 except: # pylint: disable=bare-except
1898 pass
1899
1900
1901class _Marker(object):
1902 """Markers used to pretty-print nested args in function signatures."""
1903
1904 __slots__ = ["_s"]
1905
1906 def __init__(self, s):
1907 self._s = s
1908
1909 def __repr__(self):
1910 return str(self._s)
1911
1912
1913def _contains_type_spec(value):
1914 return any(isinstance(x, type_spec.TypeSpec) for x in nest.flatten(value))