1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Context information for a tf.function."""
16
17from typing import NamedTuple, Any
18
19from tensorflow.core.function.polymorphism import function_cache
20from tensorflow.python.eager import context
21from tensorflow.python.framework import device as pydev
22from tensorflow.python.framework import func_graph as func_graph_module
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import control_flow_ops
25from tensorflow.python.saved_model import save_context
26
27
28# EagerContext is used by tf.function to identify cases where tracing
29# needs to occur due to a change in conditions other than the arguments.
30class EagerContext(NamedTuple):
31 parent_graph: Any
32 device_functions: Any
33 colocation_stack: Any
34 in_cross_replica_context: Any
35 variable_policy: Any
36 xla_context_id: Any
37
38
39def make_function_context() -> function_cache.FunctionContext:
40 """Generates a FunctionContext based on current contextual info."""
41 ctx = context.context()
42
43 # Don't need to open an init_scope if the tf.function call is in eager mode
44 # already.
45 executing_eagerly = ctx.executing_eagerly()
46 parent_graph = None
47 xla_context_id = 0
48 if not executing_eagerly:
49 # We want to force function retracing for each different
50 # XLAControlFlowContext, so add `xla_context_id` to the context.
51 xla_context = _enclosing_xla_context()
52 if xla_context is not None and xla_context.RequiresUniqueFunctionRetracing(
53 ):
54 xla_context_id = id(xla_context)
55
56 with ops.init_scope():
57 # The graph, or whether we're executing eagerly, should be a part of the
58 # cache key so we don't improperly capture tensors such as variables.
59 executing_eagerly = ctx.executing_eagerly()
60 parent_graph = None if executing_eagerly else ops.get_default_graph()
61
62 # pylint: disable=protected-access
63 default_graph = ops.get_default_graph()
64 # TODO(b/117617952): The current distribution strategy will affect graph
65 # building (e.g. accessing different variables from different devices) and
66 # so requires retracing for each device.
67 strategy_stack = default_graph._distribution_strategy_stack
68 uses_distribution_strategy = (
69 strategy_stack and
70 strategy_stack[-1].strategy.extended._retrace_functions_for_each_device)
71 if executing_eagerly:
72 colocation_stack = ()
73 if uses_distribution_strategy:
74 device_functions = (pydev.merge_device(ctx.device_name),)
75 else:
76 device_functions = ()
77 else:
78 colocation_stack = tuple(default_graph._colocation_stack.peek_objs())
79 if (uses_distribution_strategy or
80 func_graph_module.device_stack_has_callable(
81 default_graph._device_function_stack)):
82 # Putting the device in the cache key ensures that call-site device
83 # annotations are respected.
84 device_functions = tuple(default_graph._device_functions_outer_to_inner)
85 else:
86 device_functions = ()
87
88 in_cross_replica_context = False
89 try:
90 in_cross_replica_context = (strategy_stack[-1].replica_context is None) # pylint: disable=protected-access
91 except (AttributeError, IndexError):
92 pass
93
94 if save_context.in_save_context():
95 variable_policy = (
96 save_context.get_save_options().experimental_variable_policy)
97 else:
98 variable_policy = None
99
100 return function_cache.FunctionContext(
101 EagerContext(parent_graph, device_functions, colocation_stack,
102 in_cross_replica_context, variable_policy, xla_context_id))
103
104
105def _enclosing_xla_context():
106 """Returns the XLAControlFlowContext, which exists inside a tpu.rewrite()."""
107 graph = ops.get_default_graph()
108 while graph is not None:
109 # pylint: disable=protected-access
110 context_ = graph._get_control_flow_context()
111 # pylint: enable=protected-access
112 while context_ is not None:
113 if isinstance(context_, control_flow_ops.XLAControlFlowContext):
114 return context_
115 context_ = context_.outer_context
116 # This may be a FuncGraph due to defuns or v2 control flow. We need to
117 # find the original graph with the XLAControlFlowContext.
118 graph = getattr(graph, "outer_graph", None)
119 return None