Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/dtensor/python/api.py: 53%

90 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Core DTensor Python API.""" 

16 

17import contextlib 

18import threading 

19from typing import Any, Callable, Optional, Sequence 

20 

21from tensorflow.dtensor.python import dtensor_device 

22from tensorflow.dtensor.python import gen_dtensor_ops 

23from tensorflow.dtensor.python import layout as layout_lib 

24from tensorflow.python.framework import ops 

25from tensorflow.python.util import deprecation 

26from tensorflow.python.util.tf_export import tf_export 

27 

28_dtensor_singleton = None 

29_dtensor_singleton_lock = threading.Lock() 

30 

31# ----------------------------------------------------------------------------- 

32# Main methods to launch DTensor computations. 

33 

34 

35@tf_export("experimental.dtensor.call_with_layout", v1=[]) 

36def call_with_layout(fn: Callable[..., 

37 Any], layout: Optional[layout_lib.Layout], 

38 *args, **kwargs) -> Any: 

39 """Calls a function in the DTensor device scope if `layout` is not None. 

40 

41 If `layout` is not None, `fn` consumes DTensor(s) as input and produces a 

42 DTensor as output; a DTensor is a tf.Tensor with layout-related attributes. 

43 

44 If `layout` is None, `fn` consumes and produces regular tf.Tensors. 

45 

46 Args: 

47 fn: A supported TF API function such as tf.zeros. 

48 layout: Optional, the layout of the output DTensor. 

49 *args: Arguments given to `fn`. 

50 **kwargs: Keyword arguments given to `fn`. 

51 

52 Returns: 

53 The return value of `fn` transformed to a DTensor if requested. 

54 """ 

55 if layout is not None: 

56 with default_mesh(layout.mesh): 

57 with _dtensor_device()._default_layout(layout): # pylint: disable=protected-access 

58 return fn(*args, **kwargs) 

59 return fn(*args, **kwargs) 

60 

61 

62@tf_export("experimental.dtensor.run_on", v1=[]) 

63@deprecation.deprecated(None, "Use `dtensor.default_mesh` scope instead.") 

64@contextlib.contextmanager 

65def run_on(mesh: layout_lib.Mesh): 

66 """Runs enclosed functions in the DTensor device scope. 

67 

68 This function returns a scope. All the ops and tf.functions in this scope will 

69 run on the DTensor device using the mesh provided. 

70 This is useful for wrapping any tf.function that doesn't take a DTensor as 

71 input but would like to produce DTensor as result. The scope will also make 

72 sure all small constants be replicated as DTensor. 

73 

74 Args: 

75 mesh: A Mesh instance to extract a default mesh from. 

76 

77 Yields: 

78 A context in which all ops and tf.functions will run on the DTensor device. 

79 """ 

80 with default_mesh(mesh): 

81 yield 

82 

83 

84@tf_export("experimental.dtensor.default_mesh", v1=[]) 

85@contextlib.contextmanager 

86def default_mesh(mesh: layout_lib.Mesh): 

87 """Sets the default DTensor device mesh to use for enclosed functions. 

88 

89 This function returns a scope. All the ops and tf.functions in this scope will 

90 default to this DTensor mesh if a mesh cannot be inferred from any of the 

91 inputs 

92 This is useful for wrapping any tf.function that doesn't take a DTensor as 

93 input but would like to produce DTensor as result. The scope will also make 

94 sure all small constants are replicated as DTensors. 

95 

96 Args: 

97 mesh: A Mesh instance to extract a default mesh from. 

98 

99 Yields: 

100 A context in which all ops and tf.functions will run on the given mesh. 

101 """ 

102 if not isinstance(mesh, layout_lib.Mesh): 

103 raise ValueError(f"Expect `mesh` to be `Mesh`, got {type(mesh)}") 

104 

105 with _dtensor_device()._experimental_default_mesh(mesh): # pylint: disable=protected-access 

106 with ops.device(device_name()): 

107 yield 

108 

109 

110@tf_export("experimental.dtensor.device_name", v1=[]) 

111def device_name() -> str: 

112 """Returns the singleton DTensor device's name. 

113 

114 This function can be used in the following way: 

115 

116 ```python 

117 import tensorflow as tf 

118 

119 with tf.device(dtensor.device_name()): 

120 # ... 

121 ``` 

122 """ 

123 return _dtensor_device().name 

124 

125 

126@tf_export("experimental.dtensor.is_dtensor", v1=[]) 

127def is_dtensor(tensor) -> bool: 

128 """Check whether the input tensor is a DTensor. 

129 

130 In Python, a DTensor has the same type as a `tf.Tensor`. This method will 

131 let you check and handle the tensor differently if a tf.Tensor is a DTensor. 

132 

133 Args: 

134 tensor: an object to be checked. 

135 

136 Returns: 

137 bool, True if the given tensor is a DTensor. 

138 """ 

139 return _dtensor_device().is_dtensor(tensor) 

140 

141 

142# ----------------------------------------------------------------------------- 

143# Data transfer methods. 

144 

145 

146@tf_export("experimental.dtensor.copy_to_mesh", v1=[]) 

147def copy_to_mesh( 

148 tensor: Any, 

149 layout: layout_lib.Layout, 

150 source_layout: Optional[layout_lib.Layout] = None) -> ops.Tensor: 

151 """Copies a tf.Tensor onto the DTensor device with the given layout. 

152 

153 Copies a regular tf.Tensor onto the DTensor device. Use the mesh attached to 

154 `layout` as target mesh. This method currently only supports replicated 

155 layouts, or one-to-one copies for sharded layouts. 

156 

157 Args: 

158 tensor: A regular tf.Tensor to be copied as a DTensor. 

159 layout: Target layout (and mesh) for the result DTensor. 

160 source_layout: Source layout of the tensor before copy. This argument 

161 is deprecated. 

162 

163 Returns: 

164 A DTensor on the DTensor device with the given layout. 

165 """ 

166 del source_layout 

167 with default_mesh(layout.mesh): 

168 return gen_dtensor_ops.copy_to_mesh(tensor, layout.to_string()) 

169 

170 

171@tf_export("experimental.dtensor.pack", v1=[]) 

172def pack(tensors: Sequence[Any], layout: layout_lib.Layout) -> Any: 

173 """Packs `tf.Tensor` components into a DTensor. 

174 

175 Packing and unpacking are inverse operations: 

176 

177 ``` 

178 * unpack(pack(tensors)) == tensors 

179 * pack(unpack(dtensor)) == dtensor 

180 ``` 

181 

182 1. For any DTensor on the mesh, `unpack` returns the raw components placed on 

183 each underlying device. 

184 2. Packing these raw components in the same order using `pack` returns a 

185 DTensor which should be identical to the original DTensor--both the content 

186 value and the layout. 

187 

188 **Shape, Rank, and Scalars**: The rank of the DTensor is the same as the 

189 rank of its raw components, i.e., rank is preserved. This leads to a 

190 consistent interpretation for packing scalar values into a DTensor. The only 

191 valid layout for a scalar value is fully replicated, and the individual 

192 components must be identical scalars. 

193 

194 Each input `tensors[i]` will be copied to `layout.mesh.local_device[i]` 

195 if not already on the local device. Non-local components should not be passed 

196 to `pack`; use `copy_to_mesh` and `relayout` to place tensors on all global 

197 devices on a mesh. 

198 

199 It is the caller's responsibility to ensure that the underlying values 

200 for `pack` adhere to the specified layout, and that only as many values are 

201 specified as there are local devices. Pack does not move data between clients. 

202 See examples below for more detail about layouts. 

203 

204 For example, assume we have a mesh `[X(2), Y(3)]`, which has in total 6 

205 underlying devices. Futuremore, assume that the device location mapping is 

206 the following: 

207 

208 ``` 

209 device_ID | location X, Y 

210 0 0, 0 

211 1 0, 1 

212 2 0, 2 

213 3 1, 0 

214 4 1, 1 

215 5 1, 2 

216 ``` 

217 

218 1. For 1-D vector DTensor with shape `[128]` with layout `[mesh.X]` and value 

219 as `range(128)`, the raw components will have shape `[64]` each, and the 

220 raw components will be: 

221 

222 ``` 

223 device_ID | raw component 

224 0 range(0, 64) 

225 1 range(0, 64) 

226 2 range(0, 64) 

227 3 range(64, 128) 

228 4 range(64, 128) 

229 5 range(64, 128) 

230 ``` 

231 

232 This also means for a 1-D DTensor with shape `[2]` and layout `[mesh.X]`, 

233 the raw components have shape `[1]` rather than the shape for scalar values 

234 `[]`. 

235 

236 2. For 2-D vector DTensor with shape `[2, 3]` with layout `[mesh.X, mesh.Y]` 

237 and value as `range(6)`, this is basically a fully-sharded DTensor. 

238 

239 From global view, the content looks like 

240 ``` 

241 [ 

242 [0.0, 1.0, 2.0], 

243 [3.0, 4.0, 5.0], 

244 ] 

245 ``` 

246 

247 The raw components will have shape `[1, 1]` each, and have the following 

248 content: 

249 

250 ``` 

251 device_ID | raw component 

252 0 [[0.0]] 

253 1 [[1.0]] 

254 2 [[2.0]] 

255 3 [[3.0]] 

256 4 [[4.0]] 

257 5 [[5.0]] 

258 ``` 

259 

260 3. For a scalar value `123.0` DTensor, it can only have one legitimate layout 

261 `[]` (no dimension, but fully replicated). 

262 

263 The raw components will have shape `[]` each, and have the following 

264 content: 

265 

266 ``` 

267 device_ID | raw component 

268 0 123.0 

269 1 123.0 

270 2 123.0 

271 3 123.0 

272 4 123.0 

273 5 123.0 

274 ``` 

275 

276 Again, caller of `pack` is expected to provide 6 identical value raw 

277 components with scalar shapes. 

278 

279 4. For 3-D vector DTensor with shape `[2, 2, 3]` with layout 

280 `[X, unsharded, unsharded]` and value as `range(12)`, 

281 

282 From global view, the content looks like: 

283 ``` 

284 [ 

285 [ 

286 [0.0, 1.0, 2.0], 

287 [3.0, 4.0, 5.0], 

288 ], 

289 [ 

290 [6.0, 7.0, 8.0], 

291 [9.0, 10., 11.], 

292 ], 

293 ] 

294 ``` 

295 

296 The raw components will have shape `[1, 2, 3]` each, and have the following 

297 content: 

298 

299 ``` 

300 device_ID | raw component 

301 0 range(6).reshape([1, 2, 3]) 

302 1 range(6).reshape([1, 2, 3]) 

303 2 range(6).reshape([1, 2, 3]) 

304 3 range(6, 12).reshape([1, 2, 3]) 

305 4 range(6, 12).reshape([1, 2, 3]) 

306 5 range(6, 12).reshape([1, 2, 3]) 

307 ``` 

308 

309 Args: 

310 tensors: The list of local tensor components to pack into a DTensor. 

311 layout: The layout of the DTensor to be created. 

312 

313 Returns: 

314 A DTensor created from the individual component tensors. 

315 

316 Raises: 

317 RuntimeError: When `pack` is not called eagerly. 

318 """ 

319 return _dtensor_device().pack(tensors, layout) 

320 

321 

322@tf_export("experimental.dtensor.unpack", v1=[]) 

323def unpack(tensor: Any) -> Sequence[Any]: 

324 """Unpacks a DTensor into `tf.Tensor` components. 

325 

326 Packing and unpacking are inverse operations: 

327 

328 ``` 

329 * unpack(pack(tensors)) == tensors 

330 * pack(unpack(dtensor)) == dtensor 

331 ``` 

332 

333 1. For any DTensor on the mesh, `unpack` returns the raw components placed on 

334 each underlying device. 

335 2. Packing these raw components in the same order using `pack` returns a 

336 DTensor which should be identical to the original DTensor--both the content 

337 value and the layout. 

338 

339 See the documentation for `pack` for more information about how packing and 

340 unpacking works. 

341 

342 Args: 

343 tensor: The DTensor to unpack. 

344 

345 Returns: 

346 The individual component tensors of the DTensor. This will include only the 

347 client-local components, i.e. the components placed on the local devices. 

348 

349 Raises: 

350 RuntimeError: When `unpack` is not called eagerly. 

351 """ 

352 return _dtensor_device().unpack(tensor) 

353 

354 

355# ----------------------------------------------------------------------------- 

356# Layout-related methods. 

357 

358 

359@tf_export("experimental.dtensor.fetch_layout", v1=[]) 

360def fetch_layout(tensor: ops.Tensor) -> layout_lib.Layout: 

361 """Fetches the layout of a DTensor. 

362 

363 Args: 

364 tensor: The DTensor whose layout is to be fetched. 

365 

366 Returns: 

367 The `Layout` of this DTensor. 

368 

369 Raises: 

370 RuntimeError: When not called eagerly. 

371 """ 

372 return _dtensor_device().fetch_layout(tensor) 

373 

374 

375@tf_export("experimental.dtensor.check_layout", v1=[]) 

376def check_layout(tensor: ops.Tensor, layout: layout_lib.Layout) -> None: 

377 """Asserts that the layout of the DTensor is `layout`. 

378 

379 Args: 

380 tensor: A DTensor whose layout is to be checked. 

381 layout: The `Layout` to compare against. 

382 

383 Raises: 

384 ValueError: If the layout of `tensor` does not match the supplied `layout`. 

385 """ 

386 if fetch_layout(tensor) != layout: 

387 raise ValueError("Layout of tensor: " + str(fetch_layout(tensor)) + 

388 ", did not match expected layout: " + str(layout)) 

389 

390 

391@tf_export("experimental.dtensor.relayout", v1=[]) 

392def relayout(tensor: ops.Tensor, layout: layout_lib.Layout) -> ops.Tensor: 

393 """Changes the layout of `tensor`. 

394 

395 Changes the layout of `tensor` to `layout`. This is used to fine-tune the 

396 behavior of ops following/connected to `tensor`, such as choosing one SPMD 

397 expansion pattern over another. This works by forward propagating `layout` 

398 to connected TensorFlow computation graphs during layout propagation. 

399 

400 Currently, only converting layouts from replicated to sharded or sharded to 

401 replicated per mesh dimension is supported. That is, "x, y" -> "unsharded, y" 

402 is supported, while "x, y" -> "z, y" is not supported. 

403 

404 We also support a special "match" sharding spec, which instructs the relayout 

405 to act as an identity operation with respect to any sharding on these 

406 mesh dimensions. 

407 

408 Relayout is internally lowered to a set of Split and/or AllToAll ops. When 

409 tensor layouts are converted from replicated to sharded, the cost is 

410 comparatively low because we only insert Split ops and no cross-device 

411 communication is needed. However, when tensor layouts are converted from 

412 sharded to replicated, cross-device communication may occur, causing potential 

413 performance impact. 

414 

415 Args: 

416 tensor: A DTensor to specify a new layout for. 

417 layout: A Layout object specifying a new sharding spec. 

418 

419 Returns: 

420 A DTensor output from the Relayout op. 

421 """ 

422 layout_str = layout.to_string() 

423 with default_mesh(layout.mesh): 

424 return gen_dtensor_ops.relayout(tensor, layout_str) 

425 

426 

427def _set_dtensor_device(device: dtensor_device.DTensorDevice) -> None: 

428 global _dtensor_singleton 

429 _dtensor_singleton = device 

430 

431 

432def _dtensor_device() -> dtensor_device.DTensorDevice: 

433 with _dtensor_singleton_lock: 

434 if _dtensor_singleton is None: 

435 _set_dtensor_device( 

436 dtensor_device.DTensorDevice(meshes=[], is_async=True)) 

437 return _dtensor_singleton 

438 

439 

440def _reset() -> None: 

441 global _dtensor_singleton 

442 if _dtensor_singleton is not None: 

443 _dtensor_singleton.clear_tpu_core_ids() 

444 with _dtensor_singleton_lock: 

445 _dtensor_singleton = None 

446 

447 

448# ---------------------------------------------------------------------------- 

449# Gradients 

450 

451 

452@ops.RegisterGradient("Relayout") 

453def _relayout_gradient(op, grad): 

454 grad = gen_dtensor_ops.relayout_grad(grad, forward_input=op.inputs[0]) 

455 return grad 

456 

457 

458@ops.RegisterGradient("RelayoutGrad") 

459def _relayout_grad_gradient(op, grad): 

460 # Gradient of RelayoutGrad is relayout to the original Relayout's output. 

461 grad = gen_dtensor_ops.relayout_grad(grad, forward_input=op.inputs[0]) 

462 # Return None for forward_input's partial gradient since it is not connected 

463 # to the target's gradient. 

464 return grad, None 

465 

466 

467@ops.RegisterGradient("CopyToMesh") 

468def _copy_to_mesh_gradient(op, grad): 

469 grad = gen_dtensor_ops.copy_to_mesh_grad( 

470 grad, 

471 forward_input=op.inputs[0], 

472 reference_layout=op.get_attr("layout"), 

473 ) 

474 return grad 

475 

476 

477@ops.RegisterGradient("CopyToMeshGrad") 

478def _copy_to_mesh_grad_gradient(op, grad): 

479 grad = gen_dtensor_ops.copy_to_mesh_grad( 

480 grad, 

481 forward_input=op.inputs[0], 

482 reference_layout=op.get_attr("reference_layout"), 

483 ) 

484 return grad, None