Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/tpu_util.py: 27%
59 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility functions for TPU."""
17import contextlib
19from tensorflow.python.distribute import packed_distributed_variable as packed
20from tensorflow.python.eager import context
21from tensorflow.python.framework import ops
22from tensorflow.python.tpu import tpu_replication
25def enclosing_tpu_context():
26 """Returns the TPUReplicateContext, which exists inside a tpu.rewrite()."""
27 return enclosing_tpu_context_and_graph()[0]
30def enclosing_tpu_context_and_graph():
31 """Returns the TPUReplicateContext which exists inside a tpu.rewrite(), and its associated graph."""
32 graph = ops.get_default_graph()
33 while graph is not None:
34 ctx = graph._get_control_flow_context() # pylint: disable=protected-access
35 while ctx is not None:
36 if isinstance(ctx, tpu_replication.TPUReplicateContext):
37 return ctx, graph
38 ctx = ctx.outer_context
39 # This may be a FuncGraph due to defuns or v2 control flow. We need to
40 # find the original graph with the XLAControlFlowContext.
41 graph = getattr(graph, "outer_graph", None)
42 return None, None
45@contextlib.contextmanager
46def outside_or_skip_tpu_context():
47 """Returns a context manager that skips current enclosing context if there is any."""
48 ctx, graph = enclosing_tpu_context_and_graph()
49 if ctx is None:
50 yield
51 else:
52 saved_context = graph._get_control_flow_context() # pylint: disable=protected-access
53 graph._set_control_flow_context(ctx.outer_context) # pylint: disable=protected-access
54 yield
55 graph._set_control_flow_context(saved_context) # pylint: disable=protected-access
58@contextlib.contextmanager
59def _maybe_enter_graph(tensor):
60 # Note: might have an eager tensor but not be executing eagerly when
61 # building functions.
62 if (context.executing_eagerly() or isinstance(tensor, ops.EagerTensor) or
63 ops.has_default_graph()):
64 yield
65 else:
66 with tensor.graph.as_default():
67 yield
70@contextlib.contextmanager
71def _maybe_on_device(var):
72 # Add a device scope for packed variables.
73 if isinstance(var, packed.PackedVarAndDevice):
74 with ops.device(var.device):
75 yield
76 else:
77 yield
80def make_raw_assign_fn(raw_assign_fn, use_handle=True):
81 """Wrap `raw_assign_fn` with the proper graph context and device scope.
83 Args:
84 raw_assign_fn: the function to be wrapped.
85 use_handle: if True, the `raw_assign_fn` will be applied to the handle of a
86 variable; otherwise it will be applied to the variable itself.
88 Returns:
89 The wrapped function.
90 """
92 def assign_fn(var, value, use_locking=False, name=None, read_value=True):
93 del use_locking # Unused.
95 handle = var.handle if use_handle else var
96 with _maybe_enter_graph(handle), _maybe_on_device(var):
97 op = raw_assign_fn(
98 handle, ops.convert_to_tensor(value, dtype=var.dtype), name=name)
99 with ops.control_dependencies([op]):
100 if read_value:
101 return var._read_variable_op() if use_handle else var.read_value() # pylint: disable=protected-access
102 else:
103 return op
105 return assign_fn
108def make_raw_scatter_xxx_fn(raw_scatter_xxx_fn):
109 """Wrap `raw_scatter_xxx_fn` so that it can be called w/ and w/o packed handle."""
111 def scatter_xxx_fn(var, sparse_delta, use_locking=False, name=None): # pylint: disable=missing-docstring
112 del use_locking # Unused.
114 handle = var.handle
115 with _maybe_enter_graph(handle), _maybe_on_device(var):
116 op = raw_scatter_xxx_fn(
117 handle,
118 sparse_delta.indices,
119 ops.convert_to_tensor(sparse_delta.values, var.dtype),
120 name=name)
121 with ops.control_dependencies([op]):
122 return var._read_variable_op() # pylint: disable=protected-access
124 return scatter_xxx_fn