Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/dtensor/python/api.py: 53%
90 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Core DTensor Python API."""
17import contextlib
18import threading
19from typing import Any, Callable, Optional, Sequence
21from tensorflow.dtensor.python import dtensor_device
22from tensorflow.dtensor.python import gen_dtensor_ops
23from tensorflow.dtensor.python import layout as layout_lib
24from tensorflow.python.framework import ops
25from tensorflow.python.util import deprecation
26from tensorflow.python.util.tf_export import tf_export
28_dtensor_singleton = None
29_dtensor_singleton_lock = threading.Lock()
31# -----------------------------------------------------------------------------
32# Main methods to launch DTensor computations.
35@tf_export("experimental.dtensor.call_with_layout", v1=[])
36def call_with_layout(fn: Callable[...,
37 Any], layout: Optional[layout_lib.Layout],
38 *args, **kwargs) -> Any:
39 """Calls a function in the DTensor device scope if `layout` is not None.
41 If `layout` is not None, `fn` consumes DTensor(s) as input and produces a
42 DTensor as output; a DTensor is a tf.Tensor with layout-related attributes.
44 If `layout` is None, `fn` consumes and produces regular tf.Tensors.
46 Args:
47 fn: A supported TF API function such as tf.zeros.
48 layout: Optional, the layout of the output DTensor.
49 *args: Arguments given to `fn`.
50 **kwargs: Keyword arguments given to `fn`.
52 Returns:
53 The return value of `fn` transformed to a DTensor if requested.
54 """
55 if layout is not None:
56 with default_mesh(layout.mesh):
57 with _dtensor_device()._default_layout(layout): # pylint: disable=protected-access
58 return fn(*args, **kwargs)
59 return fn(*args, **kwargs)
62@tf_export("experimental.dtensor.run_on", v1=[])
63@deprecation.deprecated(None, "Use `dtensor.default_mesh` scope instead.")
64@contextlib.contextmanager
65def run_on(mesh: layout_lib.Mesh):
66 """Runs enclosed functions in the DTensor device scope.
68 This function returns a scope. All the ops and tf.functions in this scope will
69 run on the DTensor device using the mesh provided.
70 This is useful for wrapping any tf.function that doesn't take a DTensor as
71 input but would like to produce DTensor as result. The scope will also make
72 sure all small constants be replicated as DTensor.
74 Args:
75 mesh: A Mesh instance to extract a default mesh from.
77 Yields:
78 A context in which all ops and tf.functions will run on the DTensor device.
79 """
80 with default_mesh(mesh):
81 yield
84@tf_export("experimental.dtensor.default_mesh", v1=[])
85@contextlib.contextmanager
86def default_mesh(mesh: layout_lib.Mesh):
87 """Sets the default DTensor device mesh to use for enclosed functions.
89 This function returns a scope. All the ops and tf.functions in this scope will
90 default to this DTensor mesh if a mesh cannot be inferred from any of the
91 inputs
92 This is useful for wrapping any tf.function that doesn't take a DTensor as
93 input but would like to produce DTensor as result. The scope will also make
94 sure all small constants are replicated as DTensors.
96 Args:
97 mesh: A Mesh instance to extract a default mesh from.
99 Yields:
100 A context in which all ops and tf.functions will run on the given mesh.
101 """
102 if not isinstance(mesh, layout_lib.Mesh):
103 raise ValueError(f"Expect `mesh` to be `Mesh`, got {type(mesh)}")
105 with _dtensor_device()._experimental_default_mesh(mesh): # pylint: disable=protected-access
106 with ops.device(device_name()):
107 yield
110@tf_export("experimental.dtensor.device_name", v1=[])
111def device_name() -> str:
112 """Returns the singleton DTensor device's name.
114 This function can be used in the following way:
116 ```python
117 import tensorflow as tf
119 with tf.device(dtensor.device_name()):
120 # ...
121 ```
122 """
123 return _dtensor_device().name
126@tf_export("experimental.dtensor.is_dtensor", v1=[])
127def is_dtensor(tensor) -> bool:
128 """Check whether the input tensor is a DTensor.
130 In Python, a DTensor has the same type as a `tf.Tensor`. This method will
131 let you check and handle the tensor differently if a tf.Tensor is a DTensor.
133 Args:
134 tensor: an object to be checked.
136 Returns:
137 bool, True if the given tensor is a DTensor.
138 """
139 return _dtensor_device().is_dtensor(tensor)
142# -----------------------------------------------------------------------------
143# Data transfer methods.
146@tf_export("experimental.dtensor.copy_to_mesh", v1=[])
147def copy_to_mesh(
148 tensor: Any,
149 layout: layout_lib.Layout,
150 source_layout: Optional[layout_lib.Layout] = None) -> ops.Tensor:
151 """Copies a tf.Tensor onto the DTensor device with the given layout.
153 Copies a regular tf.Tensor onto the DTensor device. Use the mesh attached to
154 `layout` as target mesh. This method currently only supports replicated
155 layouts, or one-to-one copies for sharded layouts.
157 Args:
158 tensor: A regular tf.Tensor to be copied as a DTensor.
159 layout: Target layout (and mesh) for the result DTensor.
160 source_layout: Source layout of the tensor before copy. This argument
161 is deprecated.
163 Returns:
164 A DTensor on the DTensor device with the given layout.
165 """
166 del source_layout
167 with default_mesh(layout.mesh):
168 return gen_dtensor_ops.copy_to_mesh(tensor, layout.to_string())
171@tf_export("experimental.dtensor.pack", v1=[])
172def pack(tensors: Sequence[Any], layout: layout_lib.Layout) -> Any:
173 """Packs `tf.Tensor` components into a DTensor.
175 Packing and unpacking are inverse operations:
177 ```
178 * unpack(pack(tensors)) == tensors
179 * pack(unpack(dtensor)) == dtensor
180 ```
182 1. For any DTensor on the mesh, `unpack` returns the raw components placed on
183 each underlying device.
184 2. Packing these raw components in the same order using `pack` returns a
185 DTensor which should be identical to the original DTensor--both the content
186 value and the layout.
188 **Shape, Rank, and Scalars**: The rank of the DTensor is the same as the
189 rank of its raw components, i.e., rank is preserved. This leads to a
190 consistent interpretation for packing scalar values into a DTensor. The only
191 valid layout for a scalar value is fully replicated, and the individual
192 components must be identical scalars.
194 Each input `tensors[i]` will be copied to `layout.mesh.local_device[i]`
195 if not already on the local device. Non-local components should not be passed
196 to `pack`; use `copy_to_mesh` and `relayout` to place tensors on all global
197 devices on a mesh.
199 It is the caller's responsibility to ensure that the underlying values
200 for `pack` adhere to the specified layout, and that only as many values are
201 specified as there are local devices. Pack does not move data between clients.
202 See examples below for more detail about layouts.
204 For example, assume we have a mesh `[X(2), Y(3)]`, which has in total 6
205 underlying devices. Futuremore, assume that the device location mapping is
206 the following:
208 ```
209 device_ID | location X, Y
210 0 0, 0
211 1 0, 1
212 2 0, 2
213 3 1, 0
214 4 1, 1
215 5 1, 2
216 ```
218 1. For 1-D vector DTensor with shape `[128]` with layout `[mesh.X]` and value
219 as `range(128)`, the raw components will have shape `[64]` each, and the
220 raw components will be:
222 ```
223 device_ID | raw component
224 0 range(0, 64)
225 1 range(0, 64)
226 2 range(0, 64)
227 3 range(64, 128)
228 4 range(64, 128)
229 5 range(64, 128)
230 ```
232 This also means for a 1-D DTensor with shape `[2]` and layout `[mesh.X]`,
233 the raw components have shape `[1]` rather than the shape for scalar values
234 `[]`.
236 2. For 2-D vector DTensor with shape `[2, 3]` with layout `[mesh.X, mesh.Y]`
237 and value as `range(6)`, this is basically a fully-sharded DTensor.
239 From global view, the content looks like
240 ```
241 [
242 [0.0, 1.0, 2.0],
243 [3.0, 4.0, 5.0],
244 ]
245 ```
247 The raw components will have shape `[1, 1]` each, and have the following
248 content:
250 ```
251 device_ID | raw component
252 0 [[0.0]]
253 1 [[1.0]]
254 2 [[2.0]]
255 3 [[3.0]]
256 4 [[4.0]]
257 5 [[5.0]]
258 ```
260 3. For a scalar value `123.0` DTensor, it can only have one legitimate layout
261 `[]` (no dimension, but fully replicated).
263 The raw components will have shape `[]` each, and have the following
264 content:
266 ```
267 device_ID | raw component
268 0 123.0
269 1 123.0
270 2 123.0
271 3 123.0
272 4 123.0
273 5 123.0
274 ```
276 Again, caller of `pack` is expected to provide 6 identical value raw
277 components with scalar shapes.
279 4. For 3-D vector DTensor with shape `[2, 2, 3]` with layout
280 `[X, unsharded, unsharded]` and value as `range(12)`,
282 From global view, the content looks like:
283 ```
284 [
285 [
286 [0.0, 1.0, 2.0],
287 [3.0, 4.0, 5.0],
288 ],
289 [
290 [6.0, 7.0, 8.0],
291 [9.0, 10., 11.],
292 ],
293 ]
294 ```
296 The raw components will have shape `[1, 2, 3]` each, and have the following
297 content:
299 ```
300 device_ID | raw component
301 0 range(6).reshape([1, 2, 3])
302 1 range(6).reshape([1, 2, 3])
303 2 range(6).reshape([1, 2, 3])
304 3 range(6, 12).reshape([1, 2, 3])
305 4 range(6, 12).reshape([1, 2, 3])
306 5 range(6, 12).reshape([1, 2, 3])
307 ```
309 Args:
310 tensors: The list of local tensor components to pack into a DTensor.
311 layout: The layout of the DTensor to be created.
313 Returns:
314 A DTensor created from the individual component tensors.
316 Raises:
317 RuntimeError: When `pack` is not called eagerly.
318 """
319 return _dtensor_device().pack(tensors, layout)
322@tf_export("experimental.dtensor.unpack", v1=[])
323def unpack(tensor: Any) -> Sequence[Any]:
324 """Unpacks a DTensor into `tf.Tensor` components.
326 Packing and unpacking are inverse operations:
328 ```
329 * unpack(pack(tensors)) == tensors
330 * pack(unpack(dtensor)) == dtensor
331 ```
333 1. For any DTensor on the mesh, `unpack` returns the raw components placed on
334 each underlying device.
335 2. Packing these raw components in the same order using `pack` returns a
336 DTensor which should be identical to the original DTensor--both the content
337 value and the layout.
339 See the documentation for `pack` for more information about how packing and
340 unpacking works.
342 Args:
343 tensor: The DTensor to unpack.
345 Returns:
346 The individual component tensors of the DTensor. This will include only the
347 client-local components, i.e. the components placed on the local devices.
349 Raises:
350 RuntimeError: When `unpack` is not called eagerly.
351 """
352 return _dtensor_device().unpack(tensor)
355# -----------------------------------------------------------------------------
356# Layout-related methods.
359@tf_export("experimental.dtensor.fetch_layout", v1=[])
360def fetch_layout(tensor: ops.Tensor) -> layout_lib.Layout:
361 """Fetches the layout of a DTensor.
363 Args:
364 tensor: The DTensor whose layout is to be fetched.
366 Returns:
367 The `Layout` of this DTensor.
369 Raises:
370 RuntimeError: When not called eagerly.
371 """
372 return _dtensor_device().fetch_layout(tensor)
375@tf_export("experimental.dtensor.check_layout", v1=[])
376def check_layout(tensor: ops.Tensor, layout: layout_lib.Layout) -> None:
377 """Asserts that the layout of the DTensor is `layout`.
379 Args:
380 tensor: A DTensor whose layout is to be checked.
381 layout: The `Layout` to compare against.
383 Raises:
384 ValueError: If the layout of `tensor` does not match the supplied `layout`.
385 """
386 if fetch_layout(tensor) != layout:
387 raise ValueError("Layout of tensor: " + str(fetch_layout(tensor)) +
388 ", did not match expected layout: " + str(layout))
391@tf_export("experimental.dtensor.relayout", v1=[])
392def relayout(tensor: ops.Tensor, layout: layout_lib.Layout) -> ops.Tensor:
393 """Changes the layout of `tensor`.
395 Changes the layout of `tensor` to `layout`. This is used to fine-tune the
396 behavior of ops following/connected to `tensor`, such as choosing one SPMD
397 expansion pattern over another. This works by forward propagating `layout`
398 to connected TensorFlow computation graphs during layout propagation.
400 Currently, only converting layouts from replicated to sharded or sharded to
401 replicated per mesh dimension is supported. That is, "x, y" -> "unsharded, y"
402 is supported, while "x, y" -> "z, y" is not supported.
404 We also support a special "match" sharding spec, which instructs the relayout
405 to act as an identity operation with respect to any sharding on these
406 mesh dimensions.
408 Relayout is internally lowered to a set of Split and/or AllToAll ops. When
409 tensor layouts are converted from replicated to sharded, the cost is
410 comparatively low because we only insert Split ops and no cross-device
411 communication is needed. However, when tensor layouts are converted from
412 sharded to replicated, cross-device communication may occur, causing potential
413 performance impact.
415 Args:
416 tensor: A DTensor to specify a new layout for.
417 layout: A Layout object specifying a new sharding spec.
419 Returns:
420 A DTensor output from the Relayout op.
421 """
422 layout_str = layout.to_string()
423 with default_mesh(layout.mesh):
424 return gen_dtensor_ops.relayout(tensor, layout_str)
427def _set_dtensor_device(device: dtensor_device.DTensorDevice) -> None:
428 global _dtensor_singleton
429 _dtensor_singleton = device
432def _dtensor_device() -> dtensor_device.DTensorDevice:
433 with _dtensor_singleton_lock:
434 if _dtensor_singleton is None:
435 _set_dtensor_device(
436 dtensor_device.DTensorDevice(meshes=[], is_async=True))
437 return _dtensor_singleton
440def _reset() -> None:
441 global _dtensor_singleton
442 if _dtensor_singleton is not None:
443 _dtensor_singleton.clear_tpu_core_ids()
444 with _dtensor_singleton_lock:
445 _dtensor_singleton = None
448# ----------------------------------------------------------------------------
449# Gradients
452@ops.RegisterGradient("Relayout")
453def _relayout_gradient(op, grad):
454 grad = gen_dtensor_ops.relayout_grad(grad, forward_input=op.inputs[0])
455 return grad
458@ops.RegisterGradient("RelayoutGrad")
459def _relayout_grad_gradient(op, grad):
460 # Gradient of RelayoutGrad is relayout to the original Relayout's output.
461 grad = gen_dtensor_ops.relayout_grad(grad, forward_input=op.inputs[0])
462 # Return None for forward_input's partial gradient since it is not connected
463 # to the target's gradient.
464 return grad, None
467@ops.RegisterGradient("CopyToMesh")
468def _copy_to_mesh_gradient(op, grad):
469 grad = gen_dtensor_ops.copy_to_mesh_grad(
470 grad,
471 forward_input=op.inputs[0],
472 reference_layout=op.get_attr("layout"),
473 )
474 return grad
477@ops.RegisterGradient("CopyToMeshGrad")
478def _copy_to_mesh_grad_gradient(op, grad):
479 grad = gen_dtensor_ops.copy_to_mesh_grad(
480 grad,
481 forward_input=op.inputs[0],
482 reference_layout=op.get_attr("reference_layout"),
483 )
484 return grad, None