Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/dtensor/python/dtensor_device.py: 24%
185 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Propagates information about tensor layouts across operations."""
17import contextlib
18import logging
19import threading
20from typing import Any, List, Sequence, Set
22import numpy as np
24from tensorflow.core.framework import attr_value_pb2
25from tensorflow.dtensor.python import config
26from tensorflow.dtensor.python import gen_dtensor_ops
27from tensorflow.dtensor.python import layout as layout_lib
28from tensorflow.python import _pywrap_dtensor_device
29from tensorflow.python.eager import context
30from tensorflow.python.eager import core
31from tensorflow.python.framework import device as tf_device
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import ops
34from tensorflow.python.framework import sparse_tensor
35from tensorflow.python.framework import tensor_util
36from tensorflow.python.ops import resource_variable_ops
37from tensorflow.python.ops import variables
40# TODO(allenl): Allow something other than "CUSTOM" so we don't need device
41# numbering hacks to avoid collisions between parallel devices and dtensor
42# devices.
43_next_device_number = 0
44_next_device_number_lock = threading.Lock()
47class DTensorDevice(object):
48 """Wraps a custom device which attempts to propagate tensor layouts."""
50 def __init__(self,
51 meshes: List[layout_lib.Mesh],
52 is_async=True,
53 in_flight_nodes_limit=8):
54 """Create a new DTensorDevice which executes ops on `underlying_device`.
56 Args:
57 meshes: A list of `Mesh` objects indicating groups of devices to execute
58 on. These may also be registered lazily.
59 is_async: Indicates whether DTensor operations on this client will return
60 immediately (with "non-ready" handles) or block until executed. This is
61 on by default and is exposed as an option for ease of debugging.
62 in_flight_nodes_limit: Indicates the limit of in-flight nodes before
63 enqueueing of async operations to DTensorDevice is blocked. This limit
64 is per mesh. 0 for no limits from DTensor. Default is 8.
65 """
66 if any(not isinstance(mesh, layout_lib.Mesh) for mesh in meshes):
67 raise TypeError(
68 "Expected a flat list of Mesh objects, got {}".format(meshes))
69 global _next_device_number
70 ctx = context.context()
71 with _next_device_number_lock:
72 self.name = "{}/device:CUSTOM:{}".format(ctx.host_address_space(),
73 _next_device_number)
74 _next_device_number += 1
75 device, device_info = _pywrap_dtensor_device.Allocate(
76 self.name, is_async, in_flight_nodes_limit
77 )
78 context.register_custom_device(device, self.name, device_info)
80 self._device_info = device_info
81 self._current_output_layout = None
82 self._current_default_mesh = None
83 self._meshes = set()
84 self._mesh_lock = threading.Lock()
85 for mesh in meshes:
86 self._register_mesh(mesh)
88 def _create_host_array(self, shape, host_id):
89 """Returns ID and device lists that can be used to create a host mesh."""
90 num_global_devices = np.prod(shape)
91 global_device_ids = np.arange(num_global_devices).reshape(shape)
92 local_device_list = [
93 tf_device.DeviceSpec(
94 job=config.full_job_name(), device_type="CPU", device_index=0)
95 ]
96 num_local_devices = len(local_device_list)
97 local_device_ids = [
98 x + host_id * num_local_devices for x in range(num_local_devices)
99 ]
100 return global_device_ids, local_device_ids, local_device_list
102 def _create_embedding_host_mesh(self, tpu_mesh: layout_lib.Mesh):
103 """Returns Embedding host mesh for each client."""
104 if tpu_mesh.device_type().upper() != "TPU":
105 raise ValueError("Must pass input of a tpu mesh.")
107 # Global device ids are global host ids, while local device ids contains
108 # local host id.
110 ts_local_device_ids = []
111 ts_local_devices = []
112 for local_device_str in tpu_mesh.local_devices():
113 # We only need to keep TPU:0 for each client.
114 if not local_device_str.endswith("TPU:0"):
115 continue
117 device_spec = tf_device.DeviceSpec.from_string(local_device_str)
118 ts_local_device_ids.append(device_spec.task)
119 ts_local_devices.append(device_spec.replace(device_type="CPU"))
121 if not ts_local_device_ids or not ts_local_device_ids:
122 logging.info(
123 "Cannot create tpu system mesh as %s has no `TPU:0` local device "
124 "found", tpu_mesh.to_string())
125 return None
127 ts_global_device_ids = np.arange(config.num_clients())
128 # TODO(zhonglinhan): parse global device specs as input when not None.
129 return layout_lib.Mesh(
130 dim_names=[tpu_mesh.dim_names[0]], # 1D mesh.
131 global_device_ids=ts_global_device_ids,
132 local_device_ids=ts_local_device_ids,
133 local_devices=ts_local_devices)
135 def _register_mesh(self, mesh: layout_lib.Mesh):
136 """Idempotently register `mesh` with the dtensor device."""
137 with self._mesh_lock:
138 if mesh not in self._meshes:
139 _pywrap_dtensor_device.AddMesh(
140 self._device_info, mesh.to_string(), False
141 )
142 self._meshes.add(mesh)
143 if mesh.device_type().upper() == "TPU":
144 logging.info(
145 "Registering virtual 1:1 mapped host mesh %s for mesh %s",
146 mesh.host_mesh().to_string(), mesh.to_string())
147 _pywrap_dtensor_device.AddMesh(
148 self._device_info, mesh.host_mesh().to_string(), True
149 )
150 self._meshes.add(mesh.host_mesh())
151 embedding_host_mesh = self._create_embedding_host_mesh(mesh)
152 if embedding_host_mesh:
153 logging.info(
154 "Registering embedding host mesh %s on each client for mesh %s",
155 embedding_host_mesh.to_string(), mesh.to_string())
156 _pywrap_dtensor_device.AddMesh(
157 self._device_info, embedding_host_mesh.to_string(), False
158 )
159 self._meshes.add(embedding_host_mesh)
161 @property
162 def meshes(self) -> Set[layout_lib.Mesh]:
163 return self._meshes
165 def copy_to_mesh(self, tensor, new_layout) -> ops.Tensor:
166 """Copy `tensor` to `device` with the given layout."""
167 self._register_mesh(new_layout.mesh)
168 with ops.device(self.name):
169 return gen_dtensor_ops.copy_to_mesh(tensor, layout=new_layout.to_string())
171 def pack(self, tensors: Sequence[Any], layout: layout_lib.Layout) -> Any:
172 """Packs tensors into a DTensor handle on this DTensor device.
174 Packing and unpacking are inverse operations:
176 ```
177 * unpack(pack(tensors)) == tensors
178 * pack(unpack(dtensor)) == dtensor
179 ```
181 Refer to `dtensor.pack` for more information.
183 Args:
184 tensors: The list of tensors to pack into a DTensor.
185 layout: The layout of the DTensor to be created.
187 Returns:
188 A DTensor created from the individual component tensors.
190 Raises:
191 RuntimeError: When not called eagerly.
192 """
193 if not context.executing_eagerly():
194 raise RuntimeError("`pack` must be called eagerly.")
195 if any(
196 issubclass(type(t), resource_variable_ops.BaseResourceVariable)
197 for t in tensors):
198 raise TypeError(
199 "Received Variable input to Pack, Variable is not supported.")
200 self._register_mesh(layout.mesh)
201 with ops.device(self.name):
202 if all(isinstance(t, sparse_tensor.SparseTensor) for t in tensors):
203 if not all(t.shape == tensors[0].shape for t in tensors):
204 raise TypeError("All input SparseTensors to Pack must be same shape.")
205 is_sparse = True
206 tensors = [t.indices for t in tensors] + [t.values for t in tensors] + [
207 ops.convert_to_tensor(t.shape, dtype=dtypes.int64) for t in tensors
208 ]
209 elif any(isinstance(t, sparse_tensor.SparseTensor) for t in tensors):
210 raise TypeError("Cannot Pack SparseTensors with Tensors.")
211 else:
212 is_sparse = False
213 try:
214 return _pywrap_dtensor_device.Pack(
215 context.context()._handle, # pylint: disable=protected-access
216 tensors,
217 layout.to_string(),
218 self._device_info,
219 is_sparse)
220 except core._NotOkStatusException as e: # pylint: disable=protected-access
221 raise core._status_to_exception(e) from None # pylint: disable=protected-access
223 def unpack(self, dtensor: Any) -> Sequence[Any]:
224 """Unpacks a DTensor handle on this DTensor device.
226 Packing and unpacking are inverse operations:
228 ```
229 * unpack(pack(tensors)) == tensors
230 * pack(unpack(dtensor)) == dtensor
231 ```
233 Refer to `dtensor.unpack` for more information.
235 Args:
236 dtensor: The DTensor to unpack.
238 Returns:
239 The raw underlying tensor components of the DTensor.
241 Raises:
242 RuntimeError: When not called eagerly.
243 """
244 if not context.executing_eagerly():
245 raise RuntimeError("`unpack` must be called eagerly.")
246 if issubclass(type(dtensor), resource_variable_ops.BaseResourceVariable):
247 raise TypeError(
248 "Received Variable input to unpack, Variable is not supported.")
249 try:
250 tensors = _pywrap_dtensor_device.Unpack(
251 context.context()._handle, # pylint: disable=protected-access
252 dtensor,
253 self._device_info)
254 except core._NotOkStatusException as e: # pylint: disable=protected-access
255 raise core._status_to_exception(e) from None # pylint: disable=protected-access
257 is_sparse = _pywrap_dtensor_device.IsSparseDTensor(
258 context.context()._handle, # pylint: disable=protected-access.
259 dtensor,
260 self._device_info)
261 if is_sparse:
262 result = []
263 for i in range(len(tensors) // 3):
264 result.append(
265 sparse_tensor.SparseTensor(tensors[i],
266 tensors[i + len(tensors) // 3],
267 tensors[i + 2 * len(tensors) // 3]))
268 return result
269 else:
270 return tensors
272 def fetch_layout(self, dtensor: Any) -> layout_lib.Layout:
273 """Fetches the layout of the DTensor.
275 Args:
276 dtensor: The DTensor whose layout is to be fetched.
278 Returns:
279 The `Layout` of this DTensor.
281 Raises:
282 RuntimeError: When not called eagerly.
283 """
284 if not context.executing_eagerly():
285 raise RuntimeError("`fetch_layout` must be called eagerly.")
286 if issubclass(type(dtensor), resource_variable_ops.BaseResourceVariable):
287 dtensor = dtensor.read_value()
288 try:
289 layout_string = _pywrap_dtensor_device.FetchLayout(
290 context.context()._handle, # pylint: disable=protected-access
291 dtensor,
292 self._device_info)
293 except core._NotOkStatusException as e: # pylint: disable=protected-access
294 raise core._status_to_exception(e) from None # pylint: disable=protected-access
296 if layout_string is None:
297 return None
298 return layout_lib.Layout.from_string(layout_string)
300 def is_dtensor(self, tensor: Any) -> bool:
301 """Check whether the input tensor is a DTensor.
303 In Python, a DTensor has the same type as a `tf.Tensor`. This method will
304 let you check and handle the tensor differently if a tf.Tensor is a DTensor.
306 Args:
307 tensor: an object to be checked.
309 Returns:
310 bool, True if the given tensor is a DTensor.
312 Raises:
313 RuntimeError: When not called eagerly.
314 """
315 if not context.executing_eagerly():
316 raise RuntimeError("`is_dtensor` must be called eagerly.")
317 if not tensor_util.is_tensor(tensor):
318 return False
319 if isinstance(tensor, variables.Variable):
320 # Get the resource handle for tf.Variable
321 tensor = tensor._handle # pylint: disable=protected-access
322 return _pywrap_dtensor_device.IsDTensor(
323 context.context()._handle, # pylint: disable=protected-access
324 tensor,
325 self._device_info,
326 )
328 def set_tpu_core_ids(self, mesh_name, tpu_core_ids):
329 """Sets the singleton global device ID-to-physical core ID map.
331 Args:
332 mesh_name: The name of a mesh. If empty, set the default mapping.
333 tpu_core_ids: TPU core IDs sorted by TF task/device ordinal.
334 """
335 _pywrap_dtensor_device.SetTPUCoreIDs(self._device_info, mesh_name,
336 tpu_core_ids)
338 def clear_tpu_core_ids(self):
339 _pywrap_dtensor_device.ClearTPUCoreIDs(self._device_info)
341 def tpu_core_ids_to_locations(self, tpu_core_ids):
342 """Translates TPU core IDs to TPU core locations.
344 Args:
345 tpu_core_ids: A list of TPU core IDs. Each one is an unsigned integer.
347 Returns:
348 A list of corresponding TPU core locations.
349 """
350 return _pywrap_dtensor_device.TPUCoreIDsToLocations(
351 context.context()._handle, # pylint: disable=protected-access
352 self._device_info,
353 tpu_core_ids)
355 def tpu_core_locations_to_ids(self, tpu_core_locations):
356 """Translates TPU core locations to TPU core IDs.
358 Args:
359 tpu_core_locations: A list of TPU core locations. Each one is a list of
360 four unsigned integers, [x, y, z, core].
362 Returns:
363 A list of corresponding TPU core IDs.
364 """
365 return _pywrap_dtensor_device.TPUCoreLocationsToIDs(
366 context.context()._handle, # pylint: disable=protected-access
367 self._device_info,
368 tpu_core_locations)
370 def _get_function_cache_stats(self):
371 """Returns the number of cache hit and miss for function compilation.
373 Returns:
374 A dictionary.
375 'miss': number of cache misses;
376 'hit': number of cache hits; and
377 'size': size of cache;
378 miss count.
379 """
380 return _pywrap_dtensor_device.GetFunctionCacheStats(
381 context.context()._handle, # pylint: disable=protected-access,
382 self._device_info,
383 )
385 def set_iterator_element_layouts(self, iterator_resource_dtensor,
386 layouts: List[layout_lib.Layout]):
387 """Sets the element layouts on an iterator resource tensor.
389 Args:
390 iterator_resource_dtensor: a DTensor created by packing the individiual
391 iterator resource tensors.
392 layouts: the flattened list of layouts to be applied to the elements
393 emitted by the iterator resource DTensor.
394 """
395 _pywrap_dtensor_device.SetIteratorElementLayouts(
396 context.context()._handle, # pylint: disable=protected-access
397 iterator_resource_dtensor,
398 [layout.to_string() for layout in layouts],
399 self._device_info)
401 @contextlib.contextmanager
402 def _experimental_default_mesh(self, mesh: layout_lib.Mesh):
403 """Sets a default mesh for all ops in the scope.
405 Note: This is an internal helper method, which is not user facing api.
407 Useful for requesting a specific mesh for ops which would have no inferred
408 layout, e.g. tf.zeros.
410 Args:
411 mesh: A Mesh to be used for ops without Mesh.
413 Yields:
414 Nothing.
415 """
416 previous_default = self._current_default_mesh
417 self._register_mesh(mesh)
418 _pywrap_dtensor_device.ExperimentalSetDefaultMesh(
419 self._device_info,
420 mesh.to_string().encode("utf-8"))
421 self._current_default_mesh = mesh
422 yield
423 _pywrap_dtensor_device.ExperimentalClearDefaultMesh(self._device_info)
424 if previous_default:
425 _pywrap_dtensor_device.ExperimentalSetDefaultMesh(
426 self._device_info,
427 previous_default.to_string().encode("utf-8"))
428 self._current_default_mesh = previous_default
430 @contextlib.contextmanager
431 def _default_layout(self, layout: layout_lib.Layout):
432 """Sets a default output layout for all ops in the scope.
434 Note: This is an internal helper method, which is not user facing api.
436 Useful for requesting a specific layout for ops which would have no inferred
437 layout, e.g. tf.zeros.
439 Caveats:
441 - Currently only affects the first output of an op. For Op with multiple
442 outputs, this does not support yet.
444 - All Ops in the scope will be attached with the same layout. This might not
445 be valid as the rank is different. The current suggestion is: Try to wrap
446 the raw op wheneven possible.
448 Args:
449 layout: A Layout for the outputs of all operations in this scope.
451 Yields:
452 Nothing.
453 """
454 previous_default = None
455 previous_graph_size = None
456 graph = None
458 self._register_mesh(layout.mesh)
459 try:
460 previous_default = self._current_output_layout
461 self._current_output_layout = layout.to_string().encode("utf-8")
462 _pywrap_dtensor_device.ExperimentalSetDefaultLayout(
463 self._device_info, self._current_output_layout)
464 if context.executing_eagerly():
465 with ops.device(self.name):
466 yield
467 else:
468 # Custom devices currently don't affect graph building, so we need a
469 # separate way to indicate layouts.
470 #
471 # TODO(allenl): Remove this case once the DTensor device is active
472 # during tracing.
473 graph = ops.get_default_graph()
474 previous_graph_size = len(graph.get_operations())
475 yield
476 finally:
477 if graph is not None:
478 # Tag operations added under this scope
479 for operation in graph.get_operations()[previous_graph_size:]:
480 # Set layout directly on the Op itself.
481 operation._set_attr( # pylint: disable=protected-access
482 "_layout",
483 attr_value_pb2.AttrValue(
484 list=attr_value_pb2.AttrValue.ListValue(
485 s=[self._current_output_layout])))
486 operation._set_attr( # pylint: disable=protected-access
487 "_mesh",
488 attr_value_pb2.AttrValue(
489 s=layout.mesh.to_string().encode("utf-8")))
491 self._current_output_layout = previous_default
492 if self._current_output_layout is None:
493 _pywrap_dtensor_device.ExperimentalClearDefaultLayout(self._device_info)
494 else:
495 _pywrap_dtensor_device.ExperimentalSetDefaultLayout(
496 self._device_info, self._current_output_layout.decode("utf-8"))