Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/eager/polymorphic_function/function_context.py: 31%

59 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Context information for a tf.function.""" 

16 

17from typing import NamedTuple, Any 

18 

19from tensorflow.core.function.polymorphism import function_cache 

20from tensorflow.python.eager import context 

21from tensorflow.python.framework import device as pydev 

22from tensorflow.python.framework import func_graph as func_graph_module 

23from tensorflow.python.framework import ops 

24from tensorflow.python.ops import control_flow_ops 

25from tensorflow.python.saved_model import save_context 

26 

27 

28# EagerContext is used by tf.function to identify cases where tracing 

29# needs to occur due to a change in conditions other than the arguments. 

30class EagerContext(NamedTuple): 

31 parent_graph: Any 

32 device_functions: Any 

33 colocation_stack: Any 

34 in_cross_replica_context: Any 

35 variable_policy: Any 

36 xla_context_id: Any 

37 

38 

39def make_function_context() -> function_cache.FunctionContext: 

40 """Generates a FunctionContext based on current contextual info.""" 

41 ctx = context.context() 

42 

43 # Don't need to open an init_scope if the tf.function call is in eager mode 

44 # already. 

45 executing_eagerly = ctx.executing_eagerly() 

46 parent_graph = None 

47 xla_context_id = 0 

48 if not executing_eagerly: 

49 # We want to force function retracing for each different 

50 # XLAControlFlowContext, so add `xla_context_id` to the context. 

51 xla_context = _enclosing_xla_context() 

52 if xla_context is not None and xla_context.RequiresUniqueFunctionRetracing( 

53 ): 

54 xla_context_id = id(xla_context) 

55 

56 with ops.init_scope(): 

57 # The graph, or whether we're executing eagerly, should be a part of the 

58 # cache key so we don't improperly capture tensors such as variables. 

59 executing_eagerly = ctx.executing_eagerly() 

60 parent_graph = None if executing_eagerly else ops.get_default_graph() 

61 

62 # pylint: disable=protected-access 

63 default_graph = ops.get_default_graph() 

64 # TODO(b/117617952): The current distribution strategy will affect graph 

65 # building (e.g. accessing different variables from different devices) and 

66 # so requires retracing for each device. 

67 strategy_stack = default_graph._distribution_strategy_stack 

68 uses_distribution_strategy = ( 

69 strategy_stack and 

70 strategy_stack[-1].strategy.extended._retrace_functions_for_each_device) 

71 if executing_eagerly: 

72 colocation_stack = () 

73 if uses_distribution_strategy: 

74 device_functions = (pydev.merge_device(ctx.device_name),) 

75 else: 

76 device_functions = () 

77 else: 

78 colocation_stack = tuple(default_graph._colocation_stack.peek_objs()) 

79 if (uses_distribution_strategy or 

80 func_graph_module.device_stack_has_callable( 

81 default_graph._device_function_stack)): 

82 # Putting the device in the cache key ensures that call-site device 

83 # annotations are respected. 

84 device_functions = tuple(default_graph._device_functions_outer_to_inner) 

85 else: 

86 device_functions = () 

87 

88 in_cross_replica_context = False 

89 try: 

90 in_cross_replica_context = (strategy_stack[-1].replica_context is None) # pylint: disable=protected-access 

91 except (AttributeError, IndexError): 

92 pass 

93 

94 if save_context.in_save_context(): 

95 variable_policy = ( 

96 save_context.get_save_options().experimental_variable_policy) 

97 else: 

98 variable_policy = None 

99 

100 return function_cache.FunctionContext( 

101 EagerContext(parent_graph, device_functions, colocation_stack, 

102 in_cross_replica_context, variable_policy, xla_context_id)) 

103 

104 

105def _enclosing_xla_context(): 

106 """Returns the XLAControlFlowContext, which exists inside a tpu.rewrite().""" 

107 graph = ops.get_default_graph() 

108 while graph is not None: 

109 # pylint: disable=protected-access 

110 context_ = graph._get_control_flow_context() 

111 # pylint: enable=protected-access 

112 while context_ is not None: 

113 if isinstance(context_, control_flow_ops.XLAControlFlowContext): 

114 return context_ 

115 context_ = context_.outer_context 

116 # This may be a FuncGraph due to defuns or v2 control flow. We need to 

117 # find the original graph with the XLAControlFlowContext. 

118 graph = getattr(graph, "outer_graph", None) 

119 return None