Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py: 22%
441 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""FuncGraph and related functionality."""
17import traceback
18from typing import Any, Callable, Hashable
19import weakref
21from tensorflow.core.function import trace_type
22from tensorflow.core.function.capture import capture_container
23from tensorflow.python.eager import context
24from tensorflow.python.eager import execute
25from tensorflow.python.eager.polymorphic_function import composite_tensor_utils
26from tensorflow.python.framework import auto_control_deps
27from tensorflow.python.framework import composite_tensor
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import errors
31from tensorflow.python.framework import indexed_slices
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import tensor_spec
34from tensorflow.python.framework import type_spec
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import resource_variable_ops
37from tensorflow.python.ops import tensor_array_ops
38from tensorflow.python.ops import variable_scope
39from tensorflow.python.saved_model import save_context
40from tensorflow.python.types import core
41from tensorflow.python.util import compat
42from tensorflow.python.util import nest
43from tensorflow.python.util import object_identity
44from tensorflow.python.util import tf_contextlib
45from tensorflow.python.util import tf_decorator
46from tensorflow.python.util import tf_inspect
47from tensorflow.python.util import variable_utils
48from tensorflow.python.util.tf_export import tf_export
51ALLOWLIST_COLLECTIONS = [
52 ops.GraphKeys.GLOBAL_VARIABLES,
53 ops.GraphKeys.LOCAL_VARIABLES,
54 ops.GraphKeys.TRAINABLE_VARIABLES,
55 variable_scope._VARSTORE_KEY, # pylint: disable=protected-access
56 variable_scope._VARSCOPESTORE_KEY # pylint: disable=protected-access
57]
60class UnknownArgument(object):
61 """Signifies an argument which is not currently handled."""
64def convert_structure_to_signature(structure, arg_names=None,
65 signature_context=None):
66 """Convert a potentially nested structure to a signature.
68 Args:
69 structure: Structure to convert, where top level collection is a list or a
70 tuple.
71 arg_names: Optional list of arguments that has equal number of elements as
72 `structure` and is used for naming corresponding TensorSpecs.
73 signature_context: TraceType InternalTracingContext to generate alias_ids
74 for mutable objects, like ResourceVariables.
76 Returns:
77 Identical structure that has TensorSpec objects instead of Tensors and
78 UnknownArgument instead of any unsupported types.
79 """
81 def encode_arg(arg, path):
82 """A representation for this argument, for converting into signatures."""
83 if isinstance(arg, ops.Tensor):
84 user_specified_name = None
85 try:
86 user_specified_name = compat.as_str(
87 arg.op.get_attr("_user_specified_name"))
88 except (ValueError, AttributeError):
89 pass
91 if path and user_specified_name and user_specified_name != path[0]:
92 # The user has explicitly named the argument differently than the name
93 # of the function argument.
94 name = user_specified_name
95 else:
96 name = tensor_spec.sanitize_spec_name("_".join(str(p) for p in path))
97 return tensor_spec.TensorSpec(arg.shape, arg.dtype, name)
98 if isinstance(arg, resource_variable_ops.ResourceVariable):
99 return trace_type.from_value(arg, signature_context)
100 if isinstance(arg, composite_tensor.CompositeTensor):
101 # TODO(b/133606651) Do we need to inject arg_name?
102 return arg._type_spec # pylint: disable=protected-access
103 if isinstance(arg, (
104 int,
105 float,
106 bool,
107 str,
108 type(None),
109 dtypes.DType,
110 tensor_spec.TensorSpec,
111 type_spec.TypeSpec,
112 )):
113 return arg
114 return UnknownArgument()
116 # We are using the flattened paths to name the TensorSpecs. We need an
117 # explicit name for them downstream.
118 flattened = nest.flatten_with_tuple_paths(structure)
119 if arg_names:
120 if len(arg_names) != len(structure):
121 raise ValueError(
122 "Passed in arg_names don't match actual signature (%s)." % arg_names)
123 # Replace all top-level names with their actual arg_names. If a path before
124 # was "(2,'a',1)", it will become "(arg_names[2],'a',1)".
125 flattened = [
126 ((arg_names[path[0]],) + path[1:], arg) for path, arg in flattened
127 ]
129 mapped = [encode_arg(arg, path) for path, arg in flattened]
130 return nest.pack_sequence_as(structure, mapped)
133@tf_export("__internal__.FuncGraph", v1=[])
134class FuncGraph(ops.Graph):
135 """Graph representing a function body.
137 Attributes:
138 name: The name of the function.
139 inputs: Placeholder tensors representing the inputs to this function. The
140 tensors are in this FuncGraph. This represents "regular" inputs as well as
141 captured inputs (i.e. the values of self.captures), with the regular
142 inputs coming first.
143 outputs: Tensors that will be returned by this function. The tensors are in
144 this FuncGraph.
145 control_outputs: Operations that must be executed before the function
146 represented by this graph can be said to have been executed.
147 structured_input_signature: A tuple of (args, kwargs), which are both
148 possibly-nested python objects that were received by this function. Note
149 that these structures might contain Python `None`s.
150 structured_outputs: A possibly-nested python object which will be returned
151 by this function. The Tensors in this structure are the same as those of
152 self.outputs. Note that this structure might contain Python `None`s.
153 variables: Variables that should be watched during function execution.
154 outer_graph: The graph this function is defined in. May be another FuncGraph
155 or the global default Graph.
156 captures: Maps external tensor -> internal tensor (i.e. input placeholder).
157 The entries are in the order they were captured.
158 seed: The graph-level random seed.
159 capture_by_value: If True, the func graph will capture Variables by value
160 instead of reference.
161 """
163 def __init__(self,
164 name,
165 collections=None,
166 capture_by_value=None,
167 structured_input_signature=None,
168 structured_outputs=None):
169 """Construct a new FuncGraph.
171 The graph will inherit its graph key, collections, seed, and distribution
172 strategy stack from the current context or graph.
174 Args:
175 name: the name of the function.
176 collections: a dictionary of collections this FuncGraph should start with.
177 If not specified (None), the FuncGraph will read (but not write to) the
178 outer graph's collections that are not allowlisted, and both read and
179 write to the outer graph's collections that are allowlisted. The current
180 allowlisted collections are the global variables, the local variables,
181 and the trainable variables. Defaults to None.
182 capture_by_value: An optional boolean. If True, the func graph will
183 capture Variables by value instead of reference. By default inherit from
184 outer graphs, and failing that will default to False.
185 structured_input_signature: Optional. The structured input signature to
186 use for initializing the FuncGraph. See the docstring for FuncGraph for
187 more information.
188 structured_outputs: Optional. The structured outputs to use for
189 initializing the FuncGraph. See the docstring for FuncGraph for more
190 information.
191 """
192 super().__init__()
193 self.name = name
194 # TODO(panzf): Separate captures from non-captures inputs in self.inputs
195 self.inputs = []
196 self.outputs = []
197 self.control_outputs = []
198 self.structured_input_signature = structured_input_signature
199 self.structured_outputs = structured_outputs
200 self._resource_tensor_inputs = object_identity.ObjectIdentitySet()
201 self._weak_variables = []
202 self._watched_variables = object_identity.ObjectIdentityWeakSet()
203 self.is_control_flow_graph = False
205 self._function_captures = capture_container.FunctionCaptures()
206 outer_graph = ops.get_default_graph()
207 self._weak_outer_graph = weakref.ref(outer_graph)
208 while outer_graph.building_function:
209 outer_graph = outer_graph.outer_graph
210 # If self._weak_outer_graph is deleted, we revert to the outermost Graph
211 # active when the FuncGraph was traced. This will not be a FuncGraph.
212 self._fallback_outer_graph = outer_graph
213 # If not None, records the names of output args of this function. Used to
214 # preserve the output names in the signature of a serialized+deserialized
215 # function. Private at the moment mostly because it's often out of date.
216 self._output_names = None
217 # Inherit capture-by-value from outer graph.
218 if capture_by_value is not None:
219 self.capture_by_value = capture_by_value
220 elif self.outer_graph is not None and isinstance(self.outer_graph,
221 FuncGraph):
222 self.capture_by_value = self.outer_graph.capture_by_value
223 else:
224 self.capture_by_value = False
226 self._building_function = True
228 graph = self.outer_graph
230 if context.executing_eagerly():
231 self.seed = context.global_seed()
232 # [for tf-data user migration from TF1.0 to 2.0] seed_used keep track of
233 # any None op_seed for random_op in the function, in which case we end up
234 # using function seed, which could be unintended behavior for the op.
235 self._seed_used = False
236 else:
237 self.seed = graph.seed
238 self._seed_used = False
239 # TODO(allenl): Figure out if we can remove colocation stack
240 # specialization (currently used in cond_v2), here and in the cache key.
241 self._colocation_stack = graph._colocation_stack.copy() # pylint: disable=protected-access
243 if collections is None:
244 for collection_name in graph.get_all_collection_keys():
245 if collection_name not in ALLOWLIST_COLLECTIONS:
246 self._collections[collection_name] = graph.get_collection(
247 collection_name)
248 for collection_name in ALLOWLIST_COLLECTIONS:
249 self._collections[collection_name] = graph.get_collection_ref(
250 collection_name)
251 else:
252 self._collections = collections
254 # Keep track of whether this FuncGraph is exportable to SavedModel. Use
255 # `graph.mark_as_unsaveable(reason)` to mark this FuncGraph and any
256 # dependent functions as unsaveable.
257 self._saveable = True
258 self._saving_errors = set()
260 # Keep track of callbacks to run when this graph exits default scope
261 self._scope_exit_callbacks = None
263 def __str__(self):
264 return "FuncGraph(name=%s, id=%s)" % (self.name, id(self))
266 def watch_variable(self, v):
267 """Marks the variable v as accessed while building this graph."""
268 # Don't watch `v` if it is one of ResourceVariable input arguments.
269 if (isinstance(v, resource_variable_ops.ResourceVariable) and
270 v.handle in self._resource_tensor_inputs):
271 return
273 while self is not None and isinstance(self, FuncGraph):
274 self._watched_variables.add(v)
275 self = self.outer_graph
277 def capture_call_time_value(self,
278 closure,
279 spec,
280 key=None,
281 default_value=None,
282 placeholder=None):
283 """Returns a placeholder which at call time has the value closure().
285 The `tf.function` supports the notion of captures, that is, it allows Python
286 functions to have closure variables, which bind over some value outside the
287 function. However, this name binding is "early binding" performed before the
288 program is run, i.e.,
289 ```
290 @tf.function
291 def f():
292 return x
294 x = tf.constant(1)
295 f() # returns 1
297 x = tf.constant(2)
298 f() # still returns 1!
299 ```
300 while in Python, name binding is performed as the program is running.
301 ```
302 def f():
303 return x
305 x = 1
306 f() # returns 1
308 x = 2
309 f() # returns 2
310 ```
311 `capture_call_time_value` allows tf.function to mimic late binding as a
312 Python function does, by passing in a `closure` callable argument to be
313 executed when the tf.function is invoked eagerly. E.g.
314 ```
315 @tf.function
316 def f():
317 return ops.get_default_graph.capture_call_time_value(lambda: x)
319 x = tf.constant(1)
320 f() # returns 1
322 x = tf.constant(2)
323 f() # returns 2
324 ```
325 Note that a `capture_call_time_value` function itself does not work well in
326 the saving process (since the tf.function in which it's called is not
327 invoked eagerly) unless passed a `default_value` argument. At saving time,
328 the `default_value` argument is returned instead.
330 Args:
331 closure: function which takes no arguments, to be evaluated at function
332 call time, returning a nest of tensors compatible with `spec`.
333 spec: nest of TypeSpec for the value to capture.
334 key: optional. If not None, multiple calls to lazy_capture with the same
335 key in the same graph will return the same placeholder, and the first
336 closure will be used at function call time.
337 default_value: optional value to return in environments that cannot safely
338 evaluate closure.
339 placeholder: optional. If not None, the graph will take the passed-in
340 `placeholder` as the internal capture instead of creating a new one.
341 This is useful when loading from a SavedModel.
343 Returns:
344 Nest of placeholders which, at function call time, will be fed with the
345 result of calling closure().
347 Raises:
348 ValueError: at function call time, if the return value of closure() is
349 not compatible with `spec`.
350 """
351 if key is None:
352 key = object()
353 if key not in self._function_captures.by_ref_internal:
354 trace_ctx = trace_type.InternalTracingContext(True)
355 spec = trace_type.from_value(spec, trace_ctx)
357 if placeholder is None:
358 placeholder_ctx = trace_type.InternalPlaceholderContext(self)
359 placeholder = spec.placeholder_value(placeholder_ctx)
361 def wrapped_closure():
363 # One major case requiring returning a `default_value` is when passing a
364 # concrete function to `save`, i.e.
365 # serving_fn = serve_fn.get_concrete_function(...)
366 # model.save(save_dir, signatures={"serving_default": serving_fn})
367 # `serving_fn` has deferred captures added through
368 # `capture_call_time_value`. It can't be saved correctly since
369 # `wrapped_closure` will end up executing under a default Graph instead
370 # of FuncGraph. The user of `capture_call_time_value` also cannot
371 # conditionally avoid this call since presence of `save_context` when
372 # executing `wrapped_closure` is not known at tracing time of
373 # `serving_fn`.
374 if save_context.in_save_context() and default_value is not None:
375 return default_value
376 # TODO(wxinyi): raise an error if in save context but no default value.
378 if not context.executing_eagerly():
379 graph = ops.get_default_graph()
380 assert isinstance(
381 graph,
382 FuncGraph), "This API should only be used in TF2 enviroment."
384 with graph.as_default():
385 ret_nest = graph.capture_call_time_value(
386 closure, spec, key=key, default_value=default_value)
387 else:
388 ret_nest = closure()
390 ret_nest = spec._cast(ret_nest, trace_type.InternalCastContext) # pylint: disable=protected-access
391 return spec._to_tensors(ret_nest) # pylint: disable=protected-access
393 wrapped_closure.output_spec = spec
394 self._function_captures.add_or_replace(
395 key=key,
396 external=wrapped_closure,
397 internal=placeholder,
398 tracetype=spec,
399 is_by_ref=True)
400 return self._function_captures.by_ref_internal[key]
402 def control_dependencies(self, control_inputs):
403 """Handles control dependencies.
405 FuncGraph wraps Graph's control_dependencies logic by first filtering out
406 any external tensors / operations and storing them in the graph's
407 control_captures member. Any consumers of this function graph must then
408 decide how to handle the control captures.
410 Args:
411 control_inputs: A list of `Operation` or `Tensor` objects which must be
412 executed or computed before running the operations defined in the
413 context. Can also be `None` to clear the control dependencies.
415 Returns:
416 A context manager that specifies control dependencies for all
417 operations constructed within the context.
419 Raises:
420 TypeError: If `control_inputs` is not a list of `Operation` or
421 `Tensor` objects.
422 """
423 if control_inputs is None:
424 return super().control_dependencies(control_inputs)
426 filtered_control_inputs = []
427 for c in control_inputs:
428 # Check for _UnreadVariable
429 if (isinstance(c, indexed_slices.IndexedSlices) or
430 (hasattr(c, "_handle") and hasattr(c, "op"))):
431 c = c.op
432 graph_element = ops._as_graph_element(c) # pylint: disable=protected-access
433 if graph_element is None:
434 graph_element = c
435 if graph_element is not None and getattr(graph_element, "graph",
436 None) is not self:
437 self._function_captures.control.add(graph_element)
438 else:
439 filtered_control_inputs.append(graph_element)
440 return super().control_dependencies(filtered_control_inputs)
442 def as_default(self):
443 outer_cm = super().as_default()
445 @tf_contextlib.contextmanager
446 def inner_cm():
447 """Context manager for copying distribute.Strategy scope information."""
448 # pylint: disable=protected-access
449 # TODO(b/112906995, nareshmodi): distribution strategy depends on
450 # inheriting this stack from the default graph even in eager mode. Maybe
451 # it should be part of the eager context? This would also allow us to
452 # remove a get_default_graph() call from the function cache lookup.
453 graph = ops.get_default_graph()
454 old_strategy_stack = self._distribution_strategy_stack
455 self._distribution_strategy_stack = list(
456 graph._distribution_strategy_stack)
458 # We ignore device placements from any outer scopes while tracing the
459 # function when possible, to avoid hard-coding them in the function
460 # graph. "Default" placements come from the PartitionedCallOp's placement,
461 # so that the same trace of the Python function may be placed on several
462 # different devices and saved functions may be placed on new devices when
463 # restored.
464 # However, we need to preserve the outer device stack in the following
465 # cases in non eager context:
466 # 1. device stack is callable
467 # 2. When using distribution strategy with legacy graph mode.
468 old_device_stack = self._device_function_stack
469 if (not context.executing_eagerly() and
470 (device_stack_has_callable(graph._device_function_stack) or
471 (self._distribution_strategy_stack and
472 not ops.executing_eagerly_outside_functions()))):
473 # Hard-code devices from device functions in the function body
474 self._device_function_stack = graph._device_function_stack.copy()
476 old_creator_stack = self._variable_creator_stack
477 self._variable_creator_stack = graph._variable_creator_stack
478 # Inherit the graph key, since this is used for matching variables in
479 # optimizers.
480 old_graph_key = self._graph_key
481 self._graph_key = graph._graph_key
482 # pylint: enable=protected-access
484 old_scope_exit_callbacks = self._scope_exit_callbacks
485 self._scope_exit_callbacks = []
487 with outer_cm as g:
488 try:
489 yield g
490 finally:
491 try:
492 for fn in self._scope_exit_callbacks:
493 fn()
494 finally:
495 self._scope_exit_callbacks = old_scope_exit_callbacks
496 self._distribution_strategy_stack = old_strategy_stack
497 self._device_function_stack = old_device_stack
498 self._variable_creator_stack = old_creator_stack
499 self._graph_key = old_graph_key
501 return inner_cm()
503 @property
504 def outer_graph(self):
505 """The Graph this FuncGraph is nested in.
507 Functions may capture Tensors from graphs they are nested in (transitive).
509 Returns:
510 A Graph object. Initially set to the current default graph when the
511 FuncGraph was created. If the previous `outer_graph` was deleted because
512 the function that owns it was deleted, `outer_graph` is reset to the
513 outermost default graph active when the FuncGraph was created. This
514 FuncGraph won't have captured anything from the new `outer_graph` (and
515 likely not from the previous setting, since that would have created a
516 strong reference), but it is returned so that FuncGraphs always have a
517 parent.
518 """
519 current = self._weak_outer_graph()
520 if current is None:
521 return self._fallback_outer_graph
522 return current
524 @outer_graph.setter
525 def outer_graph(self, new_outer_graph):
526 """Sets `outer_graph` to `new_outer_graph`."""
527 self._weak_outer_graph = weakref.ref(new_outer_graph)
529 @property
530 def output_types(self):
531 return [t.dtype for t in self.outputs]
533 @property
534 def output_shapes(self):
535 return [t.shape for t in self.outputs]
537 @property
538 def trainable_variables(self):
539 """A sequence of trainable variables accessed by this FuncGraph.
541 Note that functions keep only weak references to variables. Calling the
542 function after a variable it accesses has been deleted is an error.
544 Returns:
545 Sequence of trainable variables for this func graph.
546 """
547 return tuple(v for v in self.variables if v.trainable)
549 @property
550 def variables(self):
551 """A sequence of variables accessed by this FuncGraph.
553 Note that functions keep only weak references to variables. Calling the
554 function after a variable it accesses has been deleted is an error.
556 Returns:
557 Sequence of variables for this func graph.
558 """
560 def deref(weak_v):
561 v = weak_v()
562 if v is None:
563 raise AssertionError(
564 "Called a function referencing variables which have been deleted. "
565 "This likely means that function-local variables were created and "
566 "not referenced elsewhere in the program. This is generally a "
567 "mistake; consider storing variables in an object attribute on "
568 "first call.")
569 return v
571 return tuple(deref(v) for v in self._weak_variables)
573 @variables.setter
574 def variables(self, var_list):
575 self._weak_variables = [weakref.ref(v) for v in var_list]
577 def _capture_by_value(
578 self,
579 op_type,
580 inputs,
581 dtypes, # pylint: disable=redefined-outer-name
582 input_types=None,
583 name=None,
584 attrs=None,
585 op_def=None,
586 compute_device=True):
587 # When capturing by value, do the read outside
588 reverse_captures = dict((id(v), k) for k, v in self.captures)
589 uncaptured_inputs = [reverse_captures.get(id(t), t) for t in inputs]
590 with ops.init_scope():
591 if context.executing_eagerly():
592 attr_list = ("dtype", int(attrs["dtype"].type))
593 value, = execute.execute(
594 compat.as_bytes(op_type), 1, uncaptured_inputs, attr_list,
595 context.context())
596 else:
597 op = ops.get_default_graph()._create_op_internal( # pylint: disable=protected-access
598 op_type, uncaptured_inputs, dtypes, input_types, name, attrs,
599 op_def, compute_device)
600 value = op.outputs[0]
601 captured_value = self.capture(value)
602 return captured_value.op
604 def _create_op_internal(
605 self,
606 op_type,
607 inputs,
608 dtypes=None, # pylint: disable=redefined-outer-name
609 input_types=None,
610 name=None,
611 attrs=None,
612 op_def=None,
613 compute_device=True):
614 """Like Graph.create_op, except handles external input tensors.
616 This overload adds functionality to create_op to "capture" any external
617 input tensors, i.e. tensors from the eager context or outer function graphs
618 if this is a nested function. See `capture` for more information.
620 Args:
621 op_type: The `Operation` type to create. This corresponds to the
622 `OpDef.name` field for the proto that defines the operation.
623 inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
624 dtypes: (Optional) A list of `DType` objects that will be the types of the
625 tensors that the operation produces.
626 input_types: (Optional.) A list of `DType`s that will be the types of the
627 tensors that the operation consumes. By default, uses the base `DType`
628 of each input in `inputs`. Operations that expect reference-typed inputs
629 must specify `input_types` explicitly.
630 name: (Optional.) A string name for the operation. If not specified, a
631 name is generated based on `op_type`.
632 attrs: (Optional.) A dictionary where the key is the attribute name (a
633 string) and the value is the respective `attr` attribute of the
634 `NodeDef` proto that will represent the operation (an `AttrValue`
635 proto).
636 op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
637 the operation will have.
638 compute_device: (Optional.) If True, device functions will be executed to
639 compute the device property of the Operation.
641 Returns:
642 An `Operation` object.
643 """
644 if self.capture_by_value and op_type in [
645 "ReadVariableOp", "ResourceGather"
646 ]:
647 return self._capture_by_value(op_type, inputs, dtypes, input_types, name,
648 attrs, op_def, compute_device)
650 # This capturing logic interacts poorly with control flow contexts which
651 # want to replace inputs of ops far too late in the process. This can lead
652 # the context to get confused and try to create an Enter for an Enter. We
653 # can detect this here and skip the additional Enter which can confuse loop
654 # validation logic.
655 if op_type == "Enter" and inputs[0].op.type == "Enter":
656 if inputs[0].op.get_attr("frame_name") == attrs["frame_name"].s:
657 return inputs[0].op
658 # Calling AddValue on the control flow contexts to force creation of the
659 # backward accumulators in the original graph before we create placeholders
660 # to capture the inputs.
661 ctxt = ops.get_default_graph()._control_flow_context # pylint: disable=protected-access
662 # Use a different list to avoid modifying the original inputs list.
663 captured_inputs = []
664 for inp in inputs:
665 # TPU Estimator defines a control flow context with no AddValue method.
666 if ctxt is not None and hasattr(ctxt, "AddValue"):
667 inp = ctxt.AddValue(inp)
668 inp = self.capture(inp)
669 captured_inputs.append(inp)
670 return super()._create_op_internal( # pylint: disable=protected-access
671 op_type, captured_inputs, dtypes, input_types, name, attrs, op_def,
672 compute_device)
674 def capture(self, tensor, name=None, shape=None):
675 return self._function_captures.capture_by_value(self, tensor, name)
677 def _validate_in_scope(self, tensor):
678 inner_graph = tensor.graph
679 while inner_graph is not None and isinstance(inner_graph, FuncGraph):
680 if inner_graph is self:
681 try:
682 tb = tensor.op.traceback
683 except AttributeError:
684 tensor_traceback = "<unknown>"
685 else:
686 tensor_traceback_list = []
687 for frame in traceback.format_list(tb.get_user_frames()):
688 tensor_traceback_list.extend(
689 [f" {line}" for line in frame.split("\n") if line.strip()])
690 tensor_traceback = "\n".join(tensor_traceback_list)
691 # Keep in sync with tfe_wrapper.cc.
692 # TODO(b/200991648): Unify those two paths.
693 raise errors.InaccessibleTensorError(
694 f"{tensor!r} is out of scope and cannot be used here. Use return "
695 "values, explicit Python locals or TensorFlow collections to "
696 "access it.\n"
697 "Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values " # pylint: disable=line-too-long
698 "for more information.\n\n"
699 f"{tensor!r} was defined here:\n{tensor_traceback}\n\n"
700 f"The tensor {tensor!r} cannot be accessed from {self}, because "
701 f"it was defined in {tensor.graph}, which is out of scope.")
702 inner_graph = inner_graph.outer_graph
704 # TODO(panzf): Rename this method along with usages in cond/while graph.
705 def _capture_helper(self, tensor, name):
706 return self._function_captures._create_placeholder_helper( # pylint: disable=protected-access
707 self, tensor, name)
709 def _experimental_capture_side_input_by_ref(self, identifier: Hashable,
710 func: Callable[[], Any]) ->...:
711 """Implement capturing side input by reference for tf.function.
713 Note that this API will only register the capture in the func_graph where
714 it is called. In the case of nested graph, like nested tf.function or
715 tf.while, the outer graph is not aware of this capture in the inner graph.
716 Thus, the outer tf.function will not retrace when the by-ref capture
717 changes. It's the user's responsibility to call this API in the outer
718 func_graph as well if proper retracing is needed.
720 For example:
722 ```
723 x = 1
725 # Correct usage
726 @tf.function
727 def f_1():
728 graph = tf.compat.v1.get_default_graph()
729 # Capture the same x for the outer tf.function
730 graph._experimental_capture_side_input_by_ref("x", lambda: x)
732 @tf.function
733 def g():
734 graph = tf.compat.v1.get_default_graph()
735 cap_x = graph._experimental_capture_side_input_by_ref("x", lambda: x)
736 return cap_x + 1
738 return g()
740 # Incorrect usage
741 @tf.function
742 def f_2():
744 @tf.function
745 def g():
746 graph = tf.compat.v1.get_default_graph()
747 cap_x = graph._experimental_capture_side_input_by_ref("x", lambda: x)
748 return cap_x + 1
750 return g()
752 assert f_1() == 2
753 assert f_2() == 2
754 x = 2
755 assert f_1() == 3
756 assert f_2() == 2 # This is incorrect
757 ```
759 Args:
760 identifier: A hashable object as the key for the capture.
761 func: A Python function that takes no arguments and returns the value of
762 side input. The function is evaluated at function call time.
764 Returns:
765 A nested structure with the same structure as the side input. Tensors
766 are replaced with placehoders, and non-tensors remain the same.
768 """
769 if context.executing_eagerly():
770 return func()
772 def maybe_convert_to_tensor():
773 value = func()
774 if not (isinstance(value, core.Value) or isinstance(value, core.Symbol)):
775 value = constant_op.constant(value)
776 return value
778 placeholder = self._function_captures._capture_by_ref( # pylint: disable=protected-access
779 self, maybe_convert_to_tensor, identifier)
780 return placeholder
782 @property
783 def captures(self):
784 """Order list of tuples containing external and internal captures."""
785 return self._function_captures.by_val_capture_tuples
787 def add_capture(self, tensor, placeholder):
788 """Capture a specific tensor and utilize the provided placeholder.
790 Args:
791 tensor: Tensor to captures.
792 placeholder: Provided placeholder for the tensor.
793 """
794 self._function_captures.add_or_replace(
795 key=id(tensor),
796 external=tensor,
797 internal=placeholder,
798 is_by_ref=False)
799 self.inputs.append(placeholder)
801 def replace_capture(self, tensor, placeholder):
802 """Replace already existing capture."""
803 self._function_captures.add_or_replace(
804 key=id(tensor),
805 external=tensor,
806 internal=placeholder,
807 is_by_ref=False)
809 def replace_capture_with_deferred_capture(self,
810 tensor,
811 closure,
812 spec,
813 placeholder,
814 default_value=None):
815 """Replaces existing capture `tensor` with a deferred capture `closure`.
817 Caution: It is the caller's responsibility to make sure that, after calling
818 this function, the TypeSpec of the `inputs` (i.e. internal placeholders) and
819 the `_captured_inputs` (i.e. external captures) of a concrete function that
820 wraps this function graph are still compatible. Thus user should pairing
821 usage of this function with `ConcreteFunction.set_external_captures` to make
822 sure the order still matches. For example,
823 ```
824 # concrete_fn._captured_inputs == [tensor1, tensor2, tensor3]
825 # concrete_fn.inputs == [placeholder1, placeholder2, placeholder3]
826 # replace external capture `tensor2` with a deferred_capture, i.e., a
827 # closure, `closure2`
828 concrete_fn.graph.replace_capture_with_deferred_capture(tensor2,
829 closure2,
830 placeholder2,
831 some_spec,
832 some_default)
833 concrete_fn.set_external_captures([tensor1, closure2, tensor3])
834 ```
836 Args:
837 tensor: Tensor already captured.
838 closure: function which takes no arguments, to be evaluated at function
839 call time, returning a nest of tensors compatible with `spec`.
840 spec: nest of TypeSpec for the value to capture.
841 placeholder: the internal placeholder corresponding to the captured
842 `tensor`.
843 default_value: optional value to use in environments that cannot safely
844 evaluate closure.
845 """
846 self._function_captures.pop(id(tensor), is_by_ref=False)
847 self.capture_call_time_value(
848 closure,
849 spec,
850 key=id(tensor),
851 default_value=default_value,
852 placeholder=placeholder)
854 @property
855 def external_captures(self):
856 """External tensors captured by this function."""
857 return list(self._function_captures.by_val_external.values())
859 @property
860 def internal_captures(self):
861 """Placeholders in this function corresponding captured tensors."""
862 return list(self._function_captures.by_val_internal.values())
864 @property
865 def deferred_external_captures(self):
866 """Ordered nest of tensors whose placeholders will be fed at call time."""
867 return list(self._function_captures.by_ref_external.values())
869 @property
870 def deferred_internal_captures(self):
871 """List of nest of placeholders which at call time will be fed."""
872 return list(self._function_captures.by_ref_internal.values())
874 @property
875 def variable_captures(self):
876 """Map of python object ids of variables to variables which are captured."""
877 return self.variables
879 @property
880 def function_captures(self):
881 return self._function_captures
883 def mark_as_unsaveable(self, error_message):
884 """Marks this FuncGraph as unsaveable.
886 Any attempts to export this FuncGraph will raise an error with the specified
887 message.
889 Args:
890 error_message: List or string containing the error message to be raised
891 when saving this FuncGraph to SavedModel.
892 """
893 self._saveable = False
894 if isinstance(error_message, str):
895 error_message = [error_message]
896 self._saving_errors.update(error_message)
898 @property
899 def saveable(self):
900 """Returns whether this FuncGraph is saveable."""
901 return self._saveable
903 @property
904 def saving_errors(self):
905 """Returns set of errors preventing this FuncGraph from being saved."""
906 return self._saving_errors
908 def _add_scope_exit_callback(self, fn):
909 """Add a function to call when this graph exits the default scope."""
910 if not callable(fn):
911 raise TypeError("fn is not callable: {}".format(fn))
912 if self._scope_exit_callbacks is None:
913 raise RuntimeError(
914 "Attempting to add a scope exit callback, but the default graph is "
915 "not the context scope graph. Did you forget to call "
916 "'with graph.as_default(): ...'?")
917 self._scope_exit_callbacks.append(fn)
920def func_graph_from_py_func(name,
921 python_func,
922 args,
923 kwargs,
924 signature=None,
925 func_graph=None,
926 add_control_dependencies=True,
927 arg_names=None,
928 op_return_value=None,
929 collections=None,
930 capture_by_value=None,
931 create_placeholders=True):
932 """Returns a `FuncGraph` generated from `python_func`.
934 Args:
935 name: an identifier for the function.
936 python_func: the Python function to trace.
937 args: the positional args with which the Python function should be called;
938 ignored if a signature is provided.
939 kwargs: the keyword args with which the Python function should be called;
940 ignored if a signature is provided.
941 signature: a possibly nested sequence of `TensorSpecs` specifying the shapes
942 and dtypes of the arguments. When a signature is provided, `args` and
943 `kwargs` are ignored, and `python_func` is traced with Tensors conforming
944 to `signature`. If `None`, the shapes and dtypes are inferred from the
945 inputs.
946 func_graph: Optional. An instance of FuncGraph. If provided, we will use
947 this graph else a new one is built and returned.
948 add_control_dependencies: If True, automatically adds control dependencies
949 to ensure program order matches execution order and stateful ops always
950 execute.
951 arg_names: Optional list of argument names, used to give input placeholders
952 recognizable names.
953 op_return_value: Optional. A Tensor. If set and `python_func` returns
954 Operations, those return values will be replaced with this value. If not
955 set, returning an Operation triggers an error.
956 collections: a dictionary of collections this FuncGraph should start with.
957 If not specified (None), the FuncGraph will read (but not write to) the
958 outer graph's collections that are not allowlisted, and both read and
959 write to the outer graph's collections that are allowlisted. The current
960 allowlisted collections are the global variables, the local variables, and
961 the trainable variables. Defaults to None.
962 capture_by_value: An optional boolean. If True, the func graph will capture
963 Variables by value instead of reference. By default inherit from outer
964 graphs, and failing that will default to False.
965 create_placeholders: An optional boolean. If True, then func graph will
966 create placeholders for the inputs as graph ops. If False, the input args
967 and kwargs will be treated as the input placeholders.
969 Returns:
970 A FuncGraph.
972 Raises:
973 TypeError: If any of `python_func`'s return values is neither `None`, a
974 `Tensor` or a `tf.experimental.ExtensionType`.
975 """
976 if op_return_value is not None:
977 assert isinstance(op_return_value, ops.Tensor), op_return_value
978 if func_graph is None:
979 func_graph = FuncGraph(
980 name, collections=collections, capture_by_value=capture_by_value)
981 assert isinstance(func_graph, FuncGraph)
982 if add_control_dependencies:
983 deps_control_manager = auto_control_deps.AutomaticControlDependencies()
984 else:
985 deps_control_manager = ops.NullContextmanager()
987 with func_graph.as_default(), deps_control_manager as deps_ctx:
988 current_scope = variable_scope.get_variable_scope()
989 default_use_resource = current_scope.use_resource
990 current_scope.set_use_resource(True)
992 if signature is not None:
993 args = signature
994 kwargs = {}
996 if create_placeholders:
997 func_args, func_kwargs = _create_placeholders(args, kwargs, arg_names)
998 else:
999 func_args, func_kwargs = args, kwargs
1001 input_trace_types = trace_type.from_value([func_args, func_kwargs])
1002 func_graph.inputs = input_trace_types._to_tensors([func_args, func_kwargs]) # pylint: disable=protected-access
1003 for arg in func_graph.inputs:
1004 if arg.dtype == dtypes.resource:
1005 func_graph._resource_tensor_inputs.add(arg) # pylint:disable=protected-access
1007 signature_context = trace_type.InternalTracingContext()
1008 # Convert all Tensors into TensorSpecs before saving the structured inputs.
1009 # If storing pure concrete functions that are not called through polymorphic
1010 # functions, we don't have access to FunctionSpec, so we need to call the
1011 # TensorSpecs by their `arg_names` for later binding.
1012 func_graph.structured_input_signature = (
1013 convert_structure_to_signature(
1014 func_args, arg_names, signature_context=signature_context),
1015 convert_structure_to_signature(
1016 func_kwargs, signature_context=signature_context))
1018 # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
1019 # Variables to help check whether mutation happens in calling the function
1020 # Copy the recursive list, tuple and map structure, but not base objects
1021 func_args_before = nest.pack_sequence_as(
1022 func_args,
1023 nest.flatten(func_args, expand_composites=True),
1024 expand_composites=True)
1025 func_kwargs_before = nest.pack_sequence_as(
1026 func_kwargs,
1027 nest.flatten(func_kwargs, expand_composites=True),
1028 expand_composites=True)
1030 def convert(x):
1031 """Converts a function output to a Tensor."""
1032 if x is None:
1033 return None
1034 if op_return_value is not None and isinstance(x, ops.Operation):
1035 # TODO(b/79881896): we currently can't capture external control deps, so
1036 # this won't work if x needs to be captured (i.e. if python_func returns
1037 # captured Operations).
1038 with ops.control_dependencies([x]):
1039 x = array_ops.identity(op_return_value)
1040 elif not isinstance(x, tensor_array_ops.TensorArray):
1041 try:
1042 x = ops.convert_to_tensor_or_composite(x)
1043 except (ValueError, TypeError):
1044 raise TypeError(
1045 "To be compatible with tf.function, Python functions "
1046 "must return zero or more Tensors or ExtensionTypes or None "
1047 f"values; in compilation of {str(python_func)}, found return "
1048 f"value of type {type(x).__name__}, which is not a Tensor or "
1049 "ExtensionType.")
1050 if add_control_dependencies:
1051 x = deps_ctx.mark_as_return(x)
1052 return x
1054 _, original_func = tf_decorator.unwrap(python_func)
1055 func_outputs = python_func(*func_args, **func_kwargs)
1057 # invariant: `func_outputs` contains only Tensors, CompositeTensors,
1058 # TensorArrays and `None`s.
1059 func_outputs = variable_utils.convert_variables_to_tensors(func_outputs)
1060 func_outputs = nest.map_structure(
1061 convert, func_outputs, expand_composites=True)
1063 # flatten and unflatten func_args and func_kwargs to maintain parity
1064 # from flattening which sorts by key
1065 func_args = nest.pack_sequence_as(
1066 func_args,
1067 nest.flatten(func_args, expand_composites=True),
1068 expand_composites=True)
1069 func_kwargs = nest.pack_sequence_as(
1070 func_kwargs,
1071 nest.flatten(func_kwargs, expand_composites=True),
1072 expand_composites=True)
1073 check_func_mutation(func_args_before, func_kwargs_before, func_args,
1074 func_kwargs, original_func)
1075 current_scope.set_use_resource(default_use_resource)
1077 inputs = []
1078 for arg in composite_tensor_utils.flatten_with_variables([func_args,
1079 func_kwargs]):
1080 if isinstance(arg, resource_variable_ops.BaseResourceVariable):
1081 # Even if an argument variable was not used in the function, we've
1082 # already manually captured the resource Tensor when creating argument
1083 # placeholders.
1084 capture = func_graph._function_captures.pop(id(arg.handle), False) # pylint: disable=protected-access
1085 assert len(capture) >= 2
1086 resource_placeholder = capture[1]
1087 if resource_placeholder is None:
1088 continue
1089 inputs.append(resource_placeholder)
1090 elif isinstance(arg, ops.Tensor):
1091 inputs.append(arg)
1092 func_graph.inputs = (
1093 inputs + func_graph.internal_captures + nest.flatten(
1094 func_graph.deferred_internal_captures, expand_composites=True))
1095 func_graph.structured_outputs = func_outputs
1096 # Returning a closed-over tensor does not trigger convert_to_tensor.
1097 func_graph.outputs.extend(
1098 func_graph.capture(x)
1099 for x in flatten(func_graph.structured_outputs)
1100 if x is not None)
1102 func_graph.variables = func_graph._watched_variables # pylint: disable=protected-access
1104 if add_control_dependencies:
1105 func_graph.control_outputs.extend(deps_control_manager.ops_which_must_run)
1106 func_graph.collective_manager_ids_used = (
1107 deps_control_manager.collective_manager_ids_used)
1109 return func_graph
1112def maybe_captured(tensor):
1113 """If t is a captured value placeholder, returns the original captured value.
1115 Args:
1116 tensor: Tensor.
1118 Returns:
1119 A tensor, potentially from a different Graph/FuncGraph.
1120 """
1121 if (not isinstance(tensor, ops.EagerTensor) and
1122 tensor.op.graph.building_function and tensor.op.type == "Placeholder"):
1123 for input_t, placeholder_t in tensor.op.graph.captures:
1124 if tensor == placeholder_t:
1125 return maybe_captured(input_t)
1126 # pylint: enable=protected-access
1127 return tensor
1130def device_stack_has_callable(device_stack):
1131 """Checks whether a device stack contains a callable."""
1132 return any(
1133 callable(spec._device_name_or_function) # pylint: disable=protected-access
1134 for spec in device_stack.peek_objs())
1137def has_mutation(n1, n2):
1138 """Returns true if n1 and n2 are different (using `is` to compare leaves)."""
1139 try:
1140 nest.assert_same_structure(n1, n2, expand_composites=True)
1141 except ValueError:
1142 return True
1144 for arg1, arg2 in zip(
1145 nest.flatten(n1, expand_composites=True),
1146 nest.flatten(n2, expand_composites=True)):
1147 if arg1 is not arg2:
1148 return True
1150 return False
1153def check_func_mutation(old_args, old_kwargs, new_args, new_kwargs, func):
1154 """Checks that the arguments to a function are not modified."""
1155 if not has_mutation((old_args, old_kwargs), (new_args, new_kwargs)):
1156 return
1158 # Mutation detected; construct a useful error message.
1159 func_name = getattr(func, "__qualname__", getattr(func, "__name__", func))
1160 signature = tf_inspect.signature(func)
1161 try:
1162 old_bound = signature.bind(*old_args, **old_kwargs).arguments
1163 new_bound = signature.bind(*new_args, **new_kwargs).arguments
1164 except TypeError as e:
1165 # This occurs when the function is called with the (deprecated)
1166 # "flat signature". See ConcreteFunction._call_with_flat_signature. In
1167 # this case, we can't report which arguments were modified.
1168 raise ValueError(
1169 f"{func_name}{signature} should not modify its Python input "
1170 f"arguments. Check if it modifies any lists or dicts passed as "
1171 f"arguments. Modifying a copy is allowed.") from e
1173 assert set(old_bound) == set(new_bound)
1174 modified_args = [
1175 arg_name for arg_name in new_bound
1176 if has_mutation(old_bound[arg_name], new_bound[arg_name])
1177 ]
1178 changes = ", ".join(modified_args)
1179 raise ValueError(f"{func_name}{signature} should not modify its Python "
1180 f"input arguments. Modifying a copy is allowed. The "
1181 f"following parameter(s) were modified: {changes}")
1184# TODO(edloper): If TensorArray becomes a CompositeTensor, then delete this.
1185def flatten(sequence):
1186 """Like nest.flatten w/ expand_composites, but returns flow for TensorArrays.
1188 Args:
1189 sequence: A nested structure of Tensors, CompositeTensors, and TensorArrays.
1191 Returns:
1192 A list of tensors.
1193 """
1194 flat_sequence = nest.flatten(sequence, expand_composites=True)
1195 return [
1196 item.flow if isinstance(item, tensor_array_ops.TensorArray) else item
1197 for item in flat_sequence
1198 ]
1201# TODO(edloper): If TensorArray becomes a CompositeTensor, then delete this.
1202def pack_sequence_as(structure, flat_sequence):
1203 """Like `nest.pack_sequence_as` but also builds TensorArrays from flows.
1205 Args:
1206 structure: The structure to pack into. May contain Tensors,
1207 CompositeTensors, or TensorArrays.
1208 flat_sequence: An iterable containing tensors.
1210 Returns:
1211 A nested structure.
1213 Raises:
1214 AssertionError if `structure` and `flat_sequence` are not compatible.
1215 """
1216 flat_sequence = list(flat_sequence)
1217 flattened_structure = nest.flatten(structure, expand_composites=True)
1218 if len(flattened_structure) != len(flat_sequence):
1219 raise ValueError("Mismatch in element count")
1220 for i in range(len(flat_sequence)):
1221 if isinstance(flattened_structure[i], tensor_array_ops.TensorArray):
1222 flat_sequence[i] = tensor_array_ops.build_ta_with_new_flow(
1223 old_ta=flattened_structure[i], flow=flat_sequence[i])
1224 return nest.pack_sequence_as(structure, flat_sequence, expand_composites=True)
1227def _create_placeholders(args, kwargs, arg_names=None):
1228 """Create placeholders given positional args and keyword args."""
1229 signature_context = trace_type.InternalTracingContext(
1230 is_legacy_signature=True)
1231 arg_trace_types = trace_type.from_value(tuple(args), signature_context)
1232 kwarg_trace_types = trace_type.from_value(kwargs, signature_context)
1234 placeholder_mapping = signature_context.get_placeholder_mapping()
1235 placeholder_context = trace_type.InternalPlaceholderContext(
1236 ops.get_default_graph(), placeholder_mapping)
1238 if arg_names is None:
1239 arg_names = [None] * len(arg_trace_types.components)
1241 # Create placeholders for trace type args and trace type kwargs
1242 func_args = []
1243 for name, trace_type_arg in zip(arg_names, arg_trace_types.components):
1244 placeholder_context.update_naming_scope(name)
1245 placeholder = trace_type_arg.placeholder_value(placeholder_context)
1246 func_args.append(placeholder)
1248 func_kwargs = {}
1249 for name, trace_type_kwarg in zip(*sorted(kwarg_trace_types.mapping.items())):
1250 placeholder_context.update_naming_scope(name)
1251 placeholder = trace_type_kwarg.placeholder_value(placeholder_context)
1252 func_kwargs[name] = placeholder
1254 return tuple(func_args), func_kwargs
1257def dismantle_func_graph(func_graph):
1258 """Removes reference cycles in `func_graph` FuncGraph.
1260 Helpful for making sure the garbage collector doesn't need to run when
1261 the FuncGraph goes out of scope, e.g. in tests using defun with
1262 @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True).
1264 Args:
1265 func_graph: A `FuncGraph` object to destroy. `func_graph` is unusable after
1266 this function.
1267 """
1268 func_graph._function_captures.clear() # pylint: disable=protected-access
1269 ops.dismantle_graph(func_graph)
1272def override_func_graph_name_scope(func_graph, name_scope):
1273 func_graph._name_stack = name_scope # pylint: disable=protected-access