Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/dtensor/python/dtensor_device.py: 24%

185 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Propagates information about tensor layouts across operations.""" 

16 

17import contextlib 

18import logging 

19import threading 

20from typing import Any, List, Sequence, Set 

21 

22import numpy as np 

23 

24from tensorflow.core.framework import attr_value_pb2 

25from tensorflow.dtensor.python import config 

26from tensorflow.dtensor.python import gen_dtensor_ops 

27from tensorflow.dtensor.python import layout as layout_lib 

28from tensorflow.python import _pywrap_dtensor_device 

29from tensorflow.python.eager import context 

30from tensorflow.python.eager import core 

31from tensorflow.python.framework import device as tf_device 

32from tensorflow.python.framework import dtypes 

33from tensorflow.python.framework import ops 

34from tensorflow.python.framework import sparse_tensor 

35from tensorflow.python.framework import tensor_util 

36from tensorflow.python.ops import resource_variable_ops 

37from tensorflow.python.ops import variables 

38 

39 

40# TODO(allenl): Allow something other than "CUSTOM" so we don't need device 

41# numbering hacks to avoid collisions between parallel devices and dtensor 

42# devices. 

43_next_device_number = 0 

44_next_device_number_lock = threading.Lock() 

45 

46 

47class DTensorDevice(object): 

48 """Wraps a custom device which attempts to propagate tensor layouts.""" 

49 

50 def __init__(self, 

51 meshes: List[layout_lib.Mesh], 

52 is_async=True, 

53 in_flight_nodes_limit=8): 

54 """Create a new DTensorDevice which executes ops on `underlying_device`. 

55 

56 Args: 

57 meshes: A list of `Mesh` objects indicating groups of devices to execute 

58 on. These may also be registered lazily. 

59 is_async: Indicates whether DTensor operations on this client will return 

60 immediately (with "non-ready" handles) or block until executed. This is 

61 on by default and is exposed as an option for ease of debugging. 

62 in_flight_nodes_limit: Indicates the limit of in-flight nodes before 

63 enqueueing of async operations to DTensorDevice is blocked. This limit 

64 is per mesh. 0 for no limits from DTensor. Default is 8. 

65 """ 

66 if any(not isinstance(mesh, layout_lib.Mesh) for mesh in meshes): 

67 raise TypeError( 

68 "Expected a flat list of Mesh objects, got {}".format(meshes)) 

69 global _next_device_number 

70 ctx = context.context() 

71 with _next_device_number_lock: 

72 self.name = "{}/device:CUSTOM:{}".format(ctx.host_address_space(), 

73 _next_device_number) 

74 _next_device_number += 1 

75 device, device_info = _pywrap_dtensor_device.Allocate( 

76 self.name, is_async, in_flight_nodes_limit 

77 ) 

78 context.register_custom_device(device, self.name, device_info) 

79 

80 self._device_info = device_info 

81 self._current_output_layout = None 

82 self._current_default_mesh = None 

83 self._meshes = set() 

84 self._mesh_lock = threading.Lock() 

85 for mesh in meshes: 

86 self._register_mesh(mesh) 

87 

88 def _create_host_array(self, shape, host_id): 

89 """Returns ID and device lists that can be used to create a host mesh.""" 

90 num_global_devices = np.prod(shape) 

91 global_device_ids = np.arange(num_global_devices).reshape(shape) 

92 local_device_list = [ 

93 tf_device.DeviceSpec( 

94 job=config.full_job_name(), device_type="CPU", device_index=0) 

95 ] 

96 num_local_devices = len(local_device_list) 

97 local_device_ids = [ 

98 x + host_id * num_local_devices for x in range(num_local_devices) 

99 ] 

100 return global_device_ids, local_device_ids, local_device_list 

101 

102 def _create_embedding_host_mesh(self, tpu_mesh: layout_lib.Mesh): 

103 """Returns Embedding host mesh for each client.""" 

104 if tpu_mesh.device_type().upper() != "TPU": 

105 raise ValueError("Must pass input of a tpu mesh.") 

106 

107 # Global device ids are global host ids, while local device ids contains 

108 # local host id. 

109 

110 ts_local_device_ids = [] 

111 ts_local_devices = [] 

112 for local_device_str in tpu_mesh.local_devices(): 

113 # We only need to keep TPU:0 for each client. 

114 if not local_device_str.endswith("TPU:0"): 

115 continue 

116 

117 device_spec = tf_device.DeviceSpec.from_string(local_device_str) 

118 ts_local_device_ids.append(device_spec.task) 

119 ts_local_devices.append(device_spec.replace(device_type="CPU")) 

120 

121 if not ts_local_device_ids or not ts_local_device_ids: 

122 logging.info( 

123 "Cannot create tpu system mesh as %s has no `TPU:0` local device " 

124 "found", tpu_mesh.to_string()) 

125 return None 

126 

127 ts_global_device_ids = np.arange(config.num_clients()) 

128 # TODO(zhonglinhan): parse global device specs as input when not None. 

129 return layout_lib.Mesh( 

130 dim_names=[tpu_mesh.dim_names[0]], # 1D mesh. 

131 global_device_ids=ts_global_device_ids, 

132 local_device_ids=ts_local_device_ids, 

133 local_devices=ts_local_devices) 

134 

135 def _register_mesh(self, mesh: layout_lib.Mesh): 

136 """Idempotently register `mesh` with the dtensor device.""" 

137 with self._mesh_lock: 

138 if mesh not in self._meshes: 

139 _pywrap_dtensor_device.AddMesh( 

140 self._device_info, mesh.to_string(), False 

141 ) 

142 self._meshes.add(mesh) 

143 if mesh.device_type().upper() == "TPU": 

144 logging.info( 

145 "Registering virtual 1:1 mapped host mesh %s for mesh %s", 

146 mesh.host_mesh().to_string(), mesh.to_string()) 

147 _pywrap_dtensor_device.AddMesh( 

148 self._device_info, mesh.host_mesh().to_string(), True 

149 ) 

150 self._meshes.add(mesh.host_mesh()) 

151 embedding_host_mesh = self._create_embedding_host_mesh(mesh) 

152 if embedding_host_mesh: 

153 logging.info( 

154 "Registering embedding host mesh %s on each client for mesh %s", 

155 embedding_host_mesh.to_string(), mesh.to_string()) 

156 _pywrap_dtensor_device.AddMesh( 

157 self._device_info, embedding_host_mesh.to_string(), False 

158 ) 

159 self._meshes.add(embedding_host_mesh) 

160 

161 @property 

162 def meshes(self) -> Set[layout_lib.Mesh]: 

163 return self._meshes 

164 

165 def copy_to_mesh(self, tensor, new_layout) -> ops.Tensor: 

166 """Copy `tensor` to `device` with the given layout.""" 

167 self._register_mesh(new_layout.mesh) 

168 with ops.device(self.name): 

169 return gen_dtensor_ops.copy_to_mesh(tensor, layout=new_layout.to_string()) 

170 

171 def pack(self, tensors: Sequence[Any], layout: layout_lib.Layout) -> Any: 

172 """Packs tensors into a DTensor handle on this DTensor device. 

173 

174 Packing and unpacking are inverse operations: 

175 

176 ``` 

177 * unpack(pack(tensors)) == tensors 

178 * pack(unpack(dtensor)) == dtensor 

179 ``` 

180 

181 Refer to `dtensor.pack` for more information. 

182 

183 Args: 

184 tensors: The list of tensors to pack into a DTensor. 

185 layout: The layout of the DTensor to be created. 

186 

187 Returns: 

188 A DTensor created from the individual component tensors. 

189 

190 Raises: 

191 RuntimeError: When not called eagerly. 

192 """ 

193 if not context.executing_eagerly(): 

194 raise RuntimeError("`pack` must be called eagerly.") 

195 if any( 

196 issubclass(type(t), resource_variable_ops.BaseResourceVariable) 

197 for t in tensors): 

198 raise TypeError( 

199 "Received Variable input to Pack, Variable is not supported.") 

200 self._register_mesh(layout.mesh) 

201 with ops.device(self.name): 

202 if all(isinstance(t, sparse_tensor.SparseTensor) for t in tensors): 

203 if not all(t.shape == tensors[0].shape for t in tensors): 

204 raise TypeError("All input SparseTensors to Pack must be same shape.") 

205 is_sparse = True 

206 tensors = [t.indices for t in tensors] + [t.values for t in tensors] + [ 

207 ops.convert_to_tensor(t.shape, dtype=dtypes.int64) for t in tensors 

208 ] 

209 elif any(isinstance(t, sparse_tensor.SparseTensor) for t in tensors): 

210 raise TypeError("Cannot Pack SparseTensors with Tensors.") 

211 else: 

212 is_sparse = False 

213 try: 

214 return _pywrap_dtensor_device.Pack( 

215 context.context()._handle, # pylint: disable=protected-access 

216 tensors, 

217 layout.to_string(), 

218 self._device_info, 

219 is_sparse) 

220 except core._NotOkStatusException as e: # pylint: disable=protected-access 

221 raise core._status_to_exception(e) from None # pylint: disable=protected-access 

222 

223 def unpack(self, dtensor: Any) -> Sequence[Any]: 

224 """Unpacks a DTensor handle on this DTensor device. 

225 

226 Packing and unpacking are inverse operations: 

227 

228 ``` 

229 * unpack(pack(tensors)) == tensors 

230 * pack(unpack(dtensor)) == dtensor 

231 ``` 

232 

233 Refer to `dtensor.unpack` for more information. 

234 

235 Args: 

236 dtensor: The DTensor to unpack. 

237 

238 Returns: 

239 The raw underlying tensor components of the DTensor. 

240 

241 Raises: 

242 RuntimeError: When not called eagerly. 

243 """ 

244 if not context.executing_eagerly(): 

245 raise RuntimeError("`unpack` must be called eagerly.") 

246 if issubclass(type(dtensor), resource_variable_ops.BaseResourceVariable): 

247 raise TypeError( 

248 "Received Variable input to unpack, Variable is not supported.") 

249 try: 

250 tensors = _pywrap_dtensor_device.Unpack( 

251 context.context()._handle, # pylint: disable=protected-access 

252 dtensor, 

253 self._device_info) 

254 except core._NotOkStatusException as e: # pylint: disable=protected-access 

255 raise core._status_to_exception(e) from None # pylint: disable=protected-access 

256 

257 is_sparse = _pywrap_dtensor_device.IsSparseDTensor( 

258 context.context()._handle, # pylint: disable=protected-access. 

259 dtensor, 

260 self._device_info) 

261 if is_sparse: 

262 result = [] 

263 for i in range(len(tensors) // 3): 

264 result.append( 

265 sparse_tensor.SparseTensor(tensors[i], 

266 tensors[i + len(tensors) // 3], 

267 tensors[i + 2 * len(tensors) // 3])) 

268 return result 

269 else: 

270 return tensors 

271 

272 def fetch_layout(self, dtensor: Any) -> layout_lib.Layout: 

273 """Fetches the layout of the DTensor. 

274 

275 Args: 

276 dtensor: The DTensor whose layout is to be fetched. 

277 

278 Returns: 

279 The `Layout` of this DTensor. 

280 

281 Raises: 

282 RuntimeError: When not called eagerly. 

283 """ 

284 if not context.executing_eagerly(): 

285 raise RuntimeError("`fetch_layout` must be called eagerly.") 

286 if issubclass(type(dtensor), resource_variable_ops.BaseResourceVariable): 

287 dtensor = dtensor.read_value() 

288 try: 

289 layout_string = _pywrap_dtensor_device.FetchLayout( 

290 context.context()._handle, # pylint: disable=protected-access 

291 dtensor, 

292 self._device_info) 

293 except core._NotOkStatusException as e: # pylint: disable=protected-access 

294 raise core._status_to_exception(e) from None # pylint: disable=protected-access 

295 

296 if layout_string is None: 

297 return None 

298 return layout_lib.Layout.from_string(layout_string) 

299 

300 def is_dtensor(self, tensor: Any) -> bool: 

301 """Check whether the input tensor is a DTensor. 

302 

303 In Python, a DTensor has the same type as a `tf.Tensor`. This method will 

304 let you check and handle the tensor differently if a tf.Tensor is a DTensor. 

305 

306 Args: 

307 tensor: an object to be checked. 

308 

309 Returns: 

310 bool, True if the given tensor is a DTensor. 

311 

312 Raises: 

313 RuntimeError: When not called eagerly. 

314 """ 

315 if not context.executing_eagerly(): 

316 raise RuntimeError("`is_dtensor` must be called eagerly.") 

317 if not tensor_util.is_tensor(tensor): 

318 return False 

319 if isinstance(tensor, variables.Variable): 

320 # Get the resource handle for tf.Variable 

321 tensor = tensor._handle # pylint: disable=protected-access 

322 return _pywrap_dtensor_device.IsDTensor( 

323 context.context()._handle, # pylint: disable=protected-access 

324 tensor, 

325 self._device_info, 

326 ) 

327 

328 def set_tpu_core_ids(self, mesh_name, tpu_core_ids): 

329 """Sets the singleton global device ID-to-physical core ID map. 

330 

331 Args: 

332 mesh_name: The name of a mesh. If empty, set the default mapping. 

333 tpu_core_ids: TPU core IDs sorted by TF task/device ordinal. 

334 """ 

335 _pywrap_dtensor_device.SetTPUCoreIDs(self._device_info, mesh_name, 

336 tpu_core_ids) 

337 

338 def clear_tpu_core_ids(self): 

339 _pywrap_dtensor_device.ClearTPUCoreIDs(self._device_info) 

340 

341 def tpu_core_ids_to_locations(self, tpu_core_ids): 

342 """Translates TPU core IDs to TPU core locations. 

343 

344 Args: 

345 tpu_core_ids: A list of TPU core IDs. Each one is an unsigned integer. 

346 

347 Returns: 

348 A list of corresponding TPU core locations. 

349 """ 

350 return _pywrap_dtensor_device.TPUCoreIDsToLocations( 

351 context.context()._handle, # pylint: disable=protected-access 

352 self._device_info, 

353 tpu_core_ids) 

354 

355 def tpu_core_locations_to_ids(self, tpu_core_locations): 

356 """Translates TPU core locations to TPU core IDs. 

357 

358 Args: 

359 tpu_core_locations: A list of TPU core locations. Each one is a list of 

360 four unsigned integers, [x, y, z, core]. 

361 

362 Returns: 

363 A list of corresponding TPU core IDs. 

364 """ 

365 return _pywrap_dtensor_device.TPUCoreLocationsToIDs( 

366 context.context()._handle, # pylint: disable=protected-access 

367 self._device_info, 

368 tpu_core_locations) 

369 

370 def _get_function_cache_stats(self): 

371 """Returns the number of cache hit and miss for function compilation. 

372 

373 Returns: 

374 A dictionary. 

375 'miss': number of cache misses; 

376 'hit': number of cache hits; and 

377 'size': size of cache; 

378 miss count. 

379 """ 

380 return _pywrap_dtensor_device.GetFunctionCacheStats( 

381 context.context()._handle, # pylint: disable=protected-access, 

382 self._device_info, 

383 ) 

384 

385 def set_iterator_element_layouts(self, iterator_resource_dtensor, 

386 layouts: List[layout_lib.Layout]): 

387 """Sets the element layouts on an iterator resource tensor. 

388 

389 Args: 

390 iterator_resource_dtensor: a DTensor created by packing the individiual 

391 iterator resource tensors. 

392 layouts: the flattened list of layouts to be applied to the elements 

393 emitted by the iterator resource DTensor. 

394 """ 

395 _pywrap_dtensor_device.SetIteratorElementLayouts( 

396 context.context()._handle, # pylint: disable=protected-access 

397 iterator_resource_dtensor, 

398 [layout.to_string() for layout in layouts], 

399 self._device_info) 

400 

401 @contextlib.contextmanager 

402 def _experimental_default_mesh(self, mesh: layout_lib.Mesh): 

403 """Sets a default mesh for all ops in the scope. 

404 

405 Note: This is an internal helper method, which is not user facing api. 

406 

407 Useful for requesting a specific mesh for ops which would have no inferred 

408 layout, e.g. tf.zeros. 

409 

410 Args: 

411 mesh: A Mesh to be used for ops without Mesh. 

412 

413 Yields: 

414 Nothing. 

415 """ 

416 previous_default = self._current_default_mesh 

417 self._register_mesh(mesh) 

418 _pywrap_dtensor_device.ExperimentalSetDefaultMesh( 

419 self._device_info, 

420 mesh.to_string().encode("utf-8")) 

421 self._current_default_mesh = mesh 

422 yield 

423 _pywrap_dtensor_device.ExperimentalClearDefaultMesh(self._device_info) 

424 if previous_default: 

425 _pywrap_dtensor_device.ExperimentalSetDefaultMesh( 

426 self._device_info, 

427 previous_default.to_string().encode("utf-8")) 

428 self._current_default_mesh = previous_default 

429 

430 @contextlib.contextmanager 

431 def _default_layout(self, layout: layout_lib.Layout): 

432 """Sets a default output layout for all ops in the scope. 

433 

434 Note: This is an internal helper method, which is not user facing api. 

435 

436 Useful for requesting a specific layout for ops which would have no inferred 

437 layout, e.g. tf.zeros. 

438 

439 Caveats: 

440 

441 - Currently only affects the first output of an op. For Op with multiple 

442 outputs, this does not support yet. 

443 

444 - All Ops in the scope will be attached with the same layout. This might not 

445 be valid as the rank is different. The current suggestion is: Try to wrap 

446 the raw op wheneven possible. 

447 

448 Args: 

449 layout: A Layout for the outputs of all operations in this scope. 

450 

451 Yields: 

452 Nothing. 

453 """ 

454 previous_default = None 

455 previous_graph_size = None 

456 graph = None 

457 

458 self._register_mesh(layout.mesh) 

459 try: 

460 previous_default = self._current_output_layout 

461 self._current_output_layout = layout.to_string().encode("utf-8") 

462 _pywrap_dtensor_device.ExperimentalSetDefaultLayout( 

463 self._device_info, self._current_output_layout) 

464 if context.executing_eagerly(): 

465 with ops.device(self.name): 

466 yield 

467 else: 

468 # Custom devices currently don't affect graph building, so we need a 

469 # separate way to indicate layouts. 

470 # 

471 # TODO(allenl): Remove this case once the DTensor device is active 

472 # during tracing. 

473 graph = ops.get_default_graph() 

474 previous_graph_size = len(graph.get_operations()) 

475 yield 

476 finally: 

477 if graph is not None: 

478 # Tag operations added under this scope 

479 for operation in graph.get_operations()[previous_graph_size:]: 

480 # Set layout directly on the Op itself. 

481 operation._set_attr( # pylint: disable=protected-access 

482 "_layout", 

483 attr_value_pb2.AttrValue( 

484 list=attr_value_pb2.AttrValue.ListValue( 

485 s=[self._current_output_layout]))) 

486 operation._set_attr( # pylint: disable=protected-access 

487 "_mesh", 

488 attr_value_pb2.AttrValue( 

489 s=layout.mesh.to_string().encode("utf-8"))) 

490 

491 self._current_output_layout = previous_default 

492 if self._current_output_layout is None: 

493 _pywrap_dtensor_device.ExperimentalClearDefaultLayout(self._device_info) 

494 else: 

495 _pywrap_dtensor_device.ExperimentalSetDefaultLayout( 

496 self._device_info, self._current_output_layout.decode("utf-8"))