Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/parallel_device/parallel_device.py: 34%
86 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility for eagerly executing operations in parallel on multiple devices."""
17import threading
18import weakref
20from tensorflow.python import _pywrap_parallel_device
21from tensorflow.python.distribute import device_util
22from tensorflow.python.eager import context
23from tensorflow.python.framework import composite_tensor
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import ops
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import variables
28from tensorflow.python.tpu.ops import tpu_ops
29from tensorflow.python.util import nest
30from tensorflow.python.util import variable_utils
32_next_device_number = 0
33_next_device_number_lock = threading.Lock()
35_all_parallel_devices = weakref.WeakValueDictionary()
38def unpack(tensor):
39 """Finds `tensor`'s parallel device and unpacks its components."""
40 parallel_device = _all_parallel_devices.get(tensor.device, None)
41 if parallel_device is None:
42 raise ValueError("{} is not a parallel device".format(tensor.device))
43 return parallel_device.unpack(tensor)
46# TODO(allenl): Expand this docstring once things like getting components on and
47# off the device are stable.
48#
49# TODO(allenl): Make multi-client work; we need an offset for device IDs, and an
50# indication of how many other devices there are total for collectives which
51# don't have a number of participants hard-coded in their attributes.
52class ParallelDevice(object):
53 """A device which executes operations in parallel."""
55 def __init__(self, components):
56 """Creates a device which executes operations in parallel on `components`.
58 Args:
59 components: A list of device names. Each operation executed on the
60 returned device executes on these component devices.
62 Returns:
63 A string with the name of the newly created device.
64 """
65 global _next_device_number, _next_device_number_lock
66 self.components = tuple(device_util.canonicalize(d) for d in components)
67 if not self.components:
68 raise ValueError("ParallelDevice requires at least one component.")
69 ctx = context.context()
70 with _next_device_number_lock:
71 # TODO(allenl): Better names for parallel devices (right now "CUSTOM" is
72 # special-cased).
73 self._name = "{}/device:CUSTOM:{}".format(ctx.host_address_space(),
74 _next_device_number)
75 _next_device_number += 1
76 device, device_info = _pywrap_parallel_device.GetParallelDeviceCapsules(
77 self._name, self.components)
78 context.register_custom_device(device, self._name, device_info)
79 self._device_ids = None
80 self._device_scope = None
81 _all_parallel_devices[self._name] = self
83 def _pack_tensor(self, *tensors):
84 """Helper to pack plain-old-tensors, not structures or composites."""
85 for tensor in tensors:
86 if not isinstance(tensor, (ops.Tensor, composite_tensor.CompositeTensor,
87 variables.Variable)):
88 raise ValueError(
89 ("Every component to pack onto the ParallelDevice must already be "
90 "a tensor, got {}. Consider running `tf.constant` or "
91 "`tf.convert_to_tensor` first on literal values.")
92 .format(tensors))
93 with ops.device(self._name):
94 return tpu_ops.tpu_replicated_input(inputs=tensors)
96 def pack(self, tensors):
97 """Create a tensor on the parallel device from a sequence of tensors.
99 Args:
100 tensors: A list of tensors, one per device in `self.components`. The list
101 can contain composite tensors and nests (lists, dicts, etc. supported by
102 `tf.nest`) with the same structure for each device, but every component
103 of nests must already be a `tf.Tensor` or composite. Passing
104 `tf.Variable` objects reads their value, it does not share a mutable
105 reference between the packed and unpacked forms.
107 Returns:
108 A tensor placed on the ParallelDevice. For nested structures, returns a
109 single structure containing tensors placed on the ParallelDevice (same
110 structure as each component of `tensors`).
112 Raises:
113 ValueError: If the length of `tensors` does not match the number of
114 component devices, or if there are non-tensor inputs.
116 """
117 self._assert_eager()
118 if len(tensors) != len(self.components):
119 raise ValueError(
120 ("Creating a parallel tensor requires one tensor per component. "
121 "Got {} but was expecting {}.")
122 .format(len(tensors), len(self.components)))
123 with ops.device(None):
124 # Explicitly read variable values. This can not be done on the parallel
125 # device since the tensors are to be packed.
126 tensors = variable_utils.convert_variables_to_tensors(tensors)
127 return nest.map_structure(self._pack_tensor, *tensors,
128 expand_composites=True)
130 def _unpack_tensor(self, parallel_tensor):
131 """Helper to unpack a single tensor."""
132 if not isinstance(parallel_tensor, (
133 ops.Tensor, composite_tensor.CompositeTensor, variables.Variable)):
134 raise ValueError(
135 "Expected a tensor, got {}.".format(parallel_tensor))
136 with ops.device(self._name):
137 return tpu_ops.tpu_replicated_output(
138 parallel_tensor, num_replicas=len(self.components))
140 def unpack(self, parallel_tensor):
141 """Unpack a parallel tensor into its components.
143 Args:
144 parallel_tensor: A tensor, composite tensor, or `tf.nest` of such placed
145 on the ParallelDevice. Passing `tf.Variable` objects reads their value,
146 it does not share a mutable reference between the packed and unpacked
147 forms.
149 Returns:
150 A list with the same length as `self.components` each with the same
151 structure as `parallel_tensor`, containing component tensors.
153 """
154 self._assert_eager()
155 unpacked_components = [[] for _ in range(len(self.components))]
156 with ops.device(self._name):
157 parallel_tensor = variable_utils.convert_variables_to_tensors(
158 parallel_tensor)
159 for tensor in nest.flatten(parallel_tensor, expand_composites=True):
160 for accumulator, unpacked_tensor in zip(
161 unpacked_components, self._unpack_tensor(tensor)):
162 accumulator.append(unpacked_tensor)
163 return [nest.pack_sequence_as(parallel_tensor, unpacked,
164 expand_composites=True)
165 for unpacked in unpacked_components]
167 @property
168 def device_ids(self):
169 """A parallel tensor with scalar integers numbering component devices.
171 Each device ID is placed on its corresponding device, in the same order as
172 the `components` constructor argument.
174 Returns:
175 A parallel tensor containing 0 on the first device, 1 on the second, etc.
176 """
177 if self._device_ids is None:
178 # device_ids may be called from inside a tf.function, in which case the
179 # function captures the eager tensor. We can't pack tensors in a function
180 # at the moment, and even if we could we don't want to hold on to a
181 # symbolic tensor, so we need to init_scope out of the function
182 # temporarily.
183 with ops.init_scope():
184 # TODO(allenl): Functions which capture eager device ID tensors won't be
185 # saveable in SavedModels. Ideally we'd run a DeviceID op every time
186 # device IDs are required, with functions using the op in their bodies
187 # but not hard-coding a fixed number of devices (so they can be re-used
188 # with a different replica count).
189 device_ids_list = []
190 for index, device in enumerate(self.components):
191 with ops.device(device):
192 # The identity op ensures each device ID tensor is placed on its
193 # device.
194 device_ids_list.append(
195 array_ops.identity(constant_op.constant(index)))
196 self._device_ids = self.pack(device_ids_list)
198 return self._device_ids
200 def _assert_eager(self):
201 """Verifies that tracing is not active."""
202 if not context.executing_eagerly():
203 raise NotImplementedError(
204 "ParallelDevice is currently not supported inside `tf.function`. It "
205 "can however run calls to a `tf.function` in parallel:\n\n"
206 "with ParallelDevice() as p:\n f()")
208 def __enter__(self):
209 """Runs ops in parallel, makes variables which save independent buffers."""
210 if self._device_scope is not None:
211 raise AssertionError(
212 "Re-entered a ParallelDevice scope without first exiting it.")
213 self._assert_eager()
214 self._device_scope = ops.device(self._name)
215 self._device_scope.__enter__()
216 return self
218 def __exit__(self, typ, exc, tb):
219 self._device_scope.__exit__(typ, exc, tb)
220 self._device_scope = None