Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/tpu_util.py: 27%

59 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utility functions for TPU.""" 

16 

17import contextlib 

18 

19from tensorflow.python.distribute import packed_distributed_variable as packed 

20from tensorflow.python.eager import context 

21from tensorflow.python.framework import ops 

22from tensorflow.python.tpu import tpu_replication 

23 

24 

25def enclosing_tpu_context(): 

26 """Returns the TPUReplicateContext, which exists inside a tpu.rewrite().""" 

27 return enclosing_tpu_context_and_graph()[0] 

28 

29 

30def enclosing_tpu_context_and_graph(): 

31 """Returns the TPUReplicateContext which exists inside a tpu.rewrite(), and its associated graph.""" 

32 graph = ops.get_default_graph() 

33 while graph is not None: 

34 ctx = graph._get_control_flow_context() # pylint: disable=protected-access 

35 while ctx is not None: 

36 if isinstance(ctx, tpu_replication.TPUReplicateContext): 

37 return ctx, graph 

38 ctx = ctx.outer_context 

39 # This may be a FuncGraph due to defuns or v2 control flow. We need to 

40 # find the original graph with the XLAControlFlowContext. 

41 graph = getattr(graph, "outer_graph", None) 

42 return None, None 

43 

44 

45@contextlib.contextmanager 

46def outside_or_skip_tpu_context(): 

47 """Returns a context manager that skips current enclosing context if there is any.""" 

48 ctx, graph = enclosing_tpu_context_and_graph() 

49 if ctx is None: 

50 yield 

51 else: 

52 saved_context = graph._get_control_flow_context() # pylint: disable=protected-access 

53 graph._set_control_flow_context(ctx.outer_context) # pylint: disable=protected-access 

54 yield 

55 graph._set_control_flow_context(saved_context) # pylint: disable=protected-access 

56 

57 

58@contextlib.contextmanager 

59def _maybe_enter_graph(tensor): 

60 # Note: might have an eager tensor but not be executing eagerly when 

61 # building functions. 

62 if (context.executing_eagerly() or isinstance(tensor, ops.EagerTensor) or 

63 ops.has_default_graph()): 

64 yield 

65 else: 

66 with tensor.graph.as_default(): 

67 yield 

68 

69 

70@contextlib.contextmanager 

71def _maybe_on_device(var): 

72 # Add a device scope for packed variables. 

73 if isinstance(var, packed.PackedVarAndDevice): 

74 with ops.device(var.device): 

75 yield 

76 else: 

77 yield 

78 

79 

80def make_raw_assign_fn(raw_assign_fn, use_handle=True): 

81 """Wrap `raw_assign_fn` with the proper graph context and device scope. 

82 

83 Args: 

84 raw_assign_fn: the function to be wrapped. 

85 use_handle: if True, the `raw_assign_fn` will be applied to the handle of a 

86 variable; otherwise it will be applied to the variable itself. 

87 

88 Returns: 

89 The wrapped function. 

90 """ 

91 

92 def assign_fn(var, value, use_locking=False, name=None, read_value=True): 

93 del use_locking # Unused. 

94 

95 handle = var.handle if use_handle else var 

96 with _maybe_enter_graph(handle), _maybe_on_device(var): 

97 op = raw_assign_fn( 

98 handle, ops.convert_to_tensor(value, dtype=var.dtype), name=name) 

99 with ops.control_dependencies([op]): 

100 if read_value: 

101 return var._read_variable_op() if use_handle else var.read_value() # pylint: disable=protected-access 

102 else: 

103 return op 

104 

105 return assign_fn 

106 

107 

108def make_raw_scatter_xxx_fn(raw_scatter_xxx_fn): 

109 """Wrap `raw_scatter_xxx_fn` so that it can be called w/ and w/o packed handle.""" 

110 

111 def scatter_xxx_fn(var, sparse_delta, use_locking=False, name=None): # pylint: disable=missing-docstring 

112 del use_locking # Unused. 

113 

114 handle = var.handle 

115 with _maybe_enter_graph(handle), _maybe_on_device(var): 

116 op = raw_scatter_xxx_fn( 

117 handle, 

118 sparse_delta.indices, 

119 ops.convert_to_tensor(sparse_delta.values, var.dtype), 

120 name=name) 

121 with ops.control_dependencies([op]): 

122 return var._read_variable_op() # pylint: disable=protected-access 

123 

124 return scatter_xxx_fn