Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/cross_device_ops.py: 21%

504 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Classes for different algorithms of reduction and broadcasting.""" 

16 

17import collections 

18import copy 

19import multiprocessing.dummy 

20import multiprocessing.pool 

21import threading 

22 

23import numpy as np 

24import six 

25 

26from tensorflow.python.client import device_lib 

27from tensorflow.python.distribute import collective_util 

28from tensorflow.python.distribute import cross_device_utils 

29from tensorflow.python.distribute import device_util 

30from tensorflow.python.distribute import distribute_utils 

31from tensorflow.python.distribute import ps_values 

32from tensorflow.python.distribute import reduce_util 

33from tensorflow.python.distribute import tpu_values 

34from tensorflow.python.distribute import values as value_lib 

35from tensorflow.python.distribute import values_util 

36from tensorflow.python.eager import context 

37from tensorflow.python.eager import def_function 

38from tensorflow.python.framework import indexed_slices 

39from tensorflow.python.framework import kernels 

40from tensorflow.python.framework import ops 

41from tensorflow.python.framework import tensor_util 

42from tensorflow.python.ops import array_ops 

43from tensorflow.python.ops import math_ops 

44from tensorflow.python.ops import resource_variable_ops 

45from tensorflow.python.platform import tf_logging as logging 

46from tensorflow.python.util import nest 

47from tensorflow.python.util.tf_export import tf_export 

48from tensorflow.tools.docs import doc_controls 

49 

50 

51def check_destinations(destinations): 

52 """Checks whether `destinations` is not empty. 

53 

54 Args: 

55 destinations: a `DistributedValues`, variable, or string object. 

56 

57 Returns: 

58 Boolean which is True if `destinations` is not empty. 

59 """ 

60 # Calling bool() on a ResourceVariable is not allowed. 

61 if isinstance(destinations, 

62 (resource_variable_ops.BaseResourceVariable, ops.Tensor)): 

63 return bool(destinations.device) 

64 return bool(destinations) 

65 

66 

67def validate_destinations(destinations): 

68 """Validates the `destination` is one of expected types.""" 

69 if not isinstance( 

70 destinations, 

71 (value_lib.DistributedValues, ops.Tensor, indexed_slices.IndexedSlices, 

72 ps_values.AggregatingVariable, six.string_types, 

73 tpu_values.TPUMirroredVariable 

74 )) and not resource_variable_ops.is_resource_variable(destinations): 

75 raise ValueError("destinations must be one of a `DistributedValues` object," 

76 " a tf.Variable object, or a device string.") 

77 

78 if not check_destinations(destinations): 

79 raise ValueError("destinations can not be empty") 

80 

81 

82def reduce_non_distributed_value(reduce_op, 

83 value, 

84 destinations, 

85 num_replicas_in_graph, 

86 canonicalize_devices=True): 

87 """Reduce a non-DistributedValue `value` to `destinations`.""" 

88 if isinstance(value, value_lib.DistributedValues): 

89 raise ValueError("You are passing a `DistributedValues` to " 

90 "`reduce_non_distributed_value`, which is not allowed.") 

91 

92 # If the same value is present on all replicas then the PerReplica value will 

93 # be a single value. We also handle the case when `value` is a single value 

94 # and equal to 0. 

95 # TODO(b/138823479): handle the tensor value properly. 

96 if not tensor_util.is_tf_type(value) and np.all(value == 0): 

97 return np.zeros(value.shape, dtype=value.dtype) 

98 # If there is only a single value and the reduce op is MEAN, 

99 # that value should be on all destinations. 

100 if reduce_op == reduce_util.ReduceOp.MEAN: 

101 return value 

102 elif num_replicas_in_graph != 1: 

103 # We do not support a reduce op of SUM if the value is the same across 

104 # all replicas. We call this as part of assign functions for 

105 # MirroredVariables and summing up identical values across replicas is not 

106 # clearly defined. 

107 raise ValueError("A non-DistributedValues value %s cannot be reduced with " 

108 "the given reduce op %s." % (value, reduce_op)) 

109 else: 

110 validate_destinations(destinations) 

111 return simple_broadcast( 

112 value, destinations, canonicalize_devices=canonicalize_devices) 

113 

114 

115def _make_tensor_into_per_replica(input_tensor): 

116 """Converts a single tensor into a PerReplica object.""" 

117 if isinstance(input_tensor, value_lib.DistributedValues): 

118 return input_tensor 

119 

120 # If input is not a Tensor, convert it to a Tensor first. 

121 if not tensor_util.is_tensor(input_tensor): 

122 input_tensor = ops.convert_to_tensor(input_tensor) 

123 

124 if hasattr(input_tensor, "device"): 

125 return value_lib.PerReplica((input_tensor,)) 

126 

127 raise ValueError("Cannot convert `input_tensor` to a `PerReplica` object " 

128 "because it doesn't have device set.") 

129 

130 

131def _normalize_value_destination_pairs(value_destination_pairs): 

132 """Converts each tensor into a PerReplica object in the input list.""" 

133 result = [] 

134 

135 value_destination_pairs = list(value_destination_pairs) 

136 

137 if not isinstance(value_destination_pairs, (list, tuple)): 

138 raise ValueError("`value_destination_pairs` should be a list or tuple") 

139 for pair in value_destination_pairs: 

140 if not isinstance(pair, tuple): 

141 raise ValueError( 

142 "Each element of `value_destination_pairs` should be a tuple.") 

143 if len(pair) != 2: 

144 raise ValueError("Each element of `value_destination_pairs` should be a " 

145 "tuple of size 2.") 

146 

147 per_replica = _make_tensor_into_per_replica(pair[0]) 

148 result.append((per_replica, pair[1])) 

149 return result 

150 

151 

152def _validate_value_destination_pairs(value_destination_pairs): 

153 """Validates value_destination_pairs are valid.""" 

154 # TODO(yuefengz): raise exceptions instead of returning False. 

155 if not value_destination_pairs: return False 

156 if not isinstance(value_destination_pairs, (list, tuple)): return False 

157 if not all(isinstance(pair, tuple) for pair in value_destination_pairs): 

158 return False 

159 if not all(isinstance(v[0], value_lib.PerReplica) 

160 for v in value_destination_pairs): 

161 return False 

162 return True 

163 

164 

165# TODO(yuefengz): consider calling this function in the caller of 

166# CrossDeviceOps. 

167def get_devices_from(destinations, canonicalize_devices=True): 

168 if isinstance(destinations, value_lib.DistributedValues): 

169 return destinations._devices # pylint: disable=protected-access 

170 if canonicalize_devices: 

171 if isinstance(destinations, six.string_types): 

172 return (device_util.resolve(destinations),) 

173 return (device_util.resolve(destinations.device),) 

174 

175 # Let placer canonicalize and resolve destination devices. 

176 if isinstance(destinations, six.string_types): 

177 return (device_util.canonicalize_without_job_and_task(destinations),) 

178 return (device_util.canonicalize_without_job_and_task(destinations.device),) 

179 

180 

181def _devices_match(left, right, canonicalize_devices=True): 

182 return left is right or set(get_devices_from( 

183 left, canonicalize_devices)) == set( 

184 get_devices_from(right, canonicalize_devices)) 

185 

186 

187def _all_devices_match(value_destination_pairs, canonicalize_devices=True): 

188 if not all( 

189 _devices_match(v, d, canonicalize_devices) 

190 for v, d in value_destination_pairs): 

191 return False 

192 if not all( 

193 _devices_match(v, value_destination_pairs[0][0], canonicalize_devices) 

194 for v, _ in value_destination_pairs[1:]): 

195 return False 

196 return True 

197 

198 

199def simple_broadcast(value, 

200 destinations, 

201 always_mirrored=False, 

202 canonicalize_devices=True): 

203 """Broadcast `value` to `destinations` using simple copies.""" 

204 devices = get_devices_from(destinations, canonicalize_devices) 

205 if len(devices) == 1 and not always_mirrored: 

206 return cross_device_utils.copy_tensor_or_indexed_slices_to_device( 

207 value, devices[0]) 

208 else: 

209 value_updates = [] 

210 for d in devices: 

211 value_updates.append( 

212 cross_device_utils.copy_tensor_or_indexed_slices_to_device(value, d)) 

213 return distribute_utils.regroup(value_updates, 

214 wrap_class=value_lib.Mirrored) 

215 

216 

217def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn, 

218 reduce_op): 

219 """Reduces the value by accumulation_fn and reduce_op.""" 

220 all_values = per_replica_value.values 

221 if not all_values: 

222 raise ValueError("`per_replica_value` must be non-empty") 

223 count = len(all_values) 

224 

225 with ops.device(reduce_to_device): 

226 with context.device_policy(context.DEVICE_PLACEMENT_SILENT): 

227 reduced = cross_device_utils.aggregate_tensors_or_indexed_slices( 

228 all_values, accumulation_fn) 

229 if reduce_op == reduce_util.ReduceOp.MEAN: 

230 reduced = cross_device_utils.divide_by_n_tensors_or_indexed_slices( 

231 reduced, count) 

232 elif reduce_op != reduce_util.ReduceOp.SUM: 

233 raise ValueError("`reduce_op` must be Reduce.SUM or Reduce.MEAN.") 

234 return reduced 

235 

236 

237def _simple_gather(per_replica_value, reduce_to_device, axis): 

238 """Concatenate all values in the DistributedValues input and return.""" 

239 all_values = per_replica_value.values 

240 if not all_values: 

241 raise ValueError("`per_replica_value` must be non-empty") 

242 

243 with ops.device(reduce_to_device): 

244 with context.device_policy(context.DEVICE_PLACEMENT_SILENT): 

245 gathered = array_ops.concat(all_values, axis) 

246 return gathered 

247 

248 

249@tf_export("distribute.CrossDeviceOps") 

250class CrossDeviceOps(object): 

251 """Base class for cross-device reduction and broadcasting algorithms. 

252 

253 The main purpose of this class is to be passed to 

254 `tf.distribute.MirroredStrategy` in order to choose among different cross 

255 device communication implementations. Prefer using the methods of 

256 `tf.distribute.Strategy` instead of the ones of this class. 

257 

258 Implementations: 

259 * `tf.distribute.ReductionToOneDevice` 

260 * `tf.distribute.NcclAllReduce` 

261 * `tf.distribute.HierarchicalCopyAllReduce` 

262 """ 

263 

264 def __init__(self): 

265 self._canonicalize_devices = True 

266 pass 

267 

268 @property 

269 def _num_between_graph_workers(self): 

270 # Returns 1 by default, the value may be overridden by sub classes. 

271 return 1 

272 

273 def reduce(self, reduce_op, per_replica_value, destinations, options=None): 

274 """Reduce `per_replica_value` to `destinations`. 

275 

276 See `tf.distribute.StrategyExtended.reduce_to`. This can only be called in 

277 the cross-replica context. 

278 

279 Args: 

280 reduce_op: a `tf.distribute.ReduceOp` specifying how values should be 

281 combined. 

282 per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` 

283 like object. 

284 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a 

285 `tf.Tensor` alike object, or a device string. It specifies the devices 

286 to reduce to. To perform an all-reduce, pass the same to `value` and 

287 `destinations`. Note that if it's a `tf.Variable`, the value is reduced 

288 to the devices of that variable, and this method doesn't update the 

289 variable. 

290 options: a `tf.distribute.experimental.CommunicationOptions`. See 

291 `tf.distribute.experimental.CommunicationOptions` for details. 

292 

293 Returns: 

294 A `tf.Tensor` or `tf.distribute.DistributedValues`. 

295 

296 Raises: 

297 ValueError: if per_replica_value can't be converted to a 

298 `tf.distribute.DistributedValues` or if destinations is not a string, 

299 `tf.Variable` or `tf.distribute.DistributedValues`. 

300 """ 

301 if options is None: 

302 options = collective_util.Options() 

303 

304 per_replica_value = _make_tensor_into_per_replica(per_replica_value) 

305 

306 validate_destinations(destinations) 

307 

308 # Shortcut if `per_replica_value` only contains one value. 

309 if self._num_between_graph_workers == 1 and len( 

310 per_replica_value.values) == 1 and _devices_match( 

311 per_replica_value, destinations, self._canonicalize_devices): 

312 with ops.device(per_replica_value.values[0].device): 

313 v = array_ops.identity(per_replica_value.values[0]) 

314 return distribute_utils.regroup((v,), wrap_class=value_lib.Mirrored) 

315 

316 if options is None: 

317 options = collective_util.Options() 

318 return self.reduce_implementation(reduce_op, per_replica_value, 

319 destinations, options) 

320 

321 def _gather(self, per_replica_value, destinations, axis, options=None): 

322 """Gather `per_replica_value` to `destinations`. 

323 

324 Args: 

325 per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` 

326 like object. 

327 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a 

328 `tf.Tensor` alike object, or a device string. It specifies the devices 

329 to gather to. To perform an all-gather, pass the same to `value` and 

330 `destinations`. Note that if it's a `tf.Variable`, the value is gathered 

331 to the devices of that variable, and this method doesn't update the 

332 variable. 

333 axis: specifies the dimension to gather along within each replica's 

334 tensor. 

335 options: a `tf.distribute.experimental.CommunicationOptions`. See 

336 `tf.distribute.experimental.CommunicationOptions` for details. 

337 

338 Returns: 

339 A `tf.Tensor` or `tf.distribute.DistributedValues` 

340 

341 Raises: 

342 ValueError: if per_replica_value can't be converted to a 

343 `tf.distribute.DistributedValues` or if destinations is not a string, 

344 `tf.Variable` or `tf.distribute.DistributedValues`. 

345 """ 

346 if isinstance(per_replica_value, indexed_slices.IndexedSlices): 

347 raise NotImplementedError("gather/all_gather does not support " 

348 "IndexedSlices") 

349 if options is None: 

350 options = collective_util.Options() 

351 

352 per_replica_value = _make_tensor_into_per_replica(per_replica_value) 

353 

354 validate_destinations(destinations) 

355 

356 # Shortcut if `per_replica_value` only contains one value. 

357 if self._num_between_graph_workers == 1 and len( 

358 per_replica_value.values) == 1 and _devices_match( 

359 per_replica_value, destinations, self._canonicalize_devices): 

360 with ops.device(per_replica_value.values[0].device): 

361 v = array_ops.identity(per_replica_value.values[0]) 

362 return distribute_utils.regroup((v,), wrap_class=value_lib.Mirrored) 

363 

364 return self._gather_implementation(per_replica_value, destinations, axis, 

365 options) 

366 

367 def _gather_implementation(self, per_replica_value, destinations, axis, 

368 options): 

369 """Implementation of `gather` method of `tf.distribute.CrossDeviceOps`. 

370 

371 Overriding this method is useful for subclass implementers. 

372 

373 Args: 

374 per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` 

375 like object. 

376 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a 

377 `tf.Tensor` alike object, or a device string. It specifies the devices 

378 to gather to. To perform an all-gather, pass the same to `value` and 

379 `destinations`. Note that if it's a `tf.Variable`, the value is gathered 

380 to the devices of that variable, this method doesn't update the 

381 variable. 

382 axis: specifies the dimension to gather along within each replica's 

383 tensor. 

384 options: a `tf.distribute.experimental.CommunicationOptions`. See 

385 `tf.distribute.experimental.CommunicationOptions` for details. 

386 

387 Returns: 

388 A `tf.Tensor` or `tf.distribute.DistributedValues`. 

389 

390 Raises: 

391 ValueError: if per_replica_value can't be converted to a 

392 `tf.distribute.DistributedValues` or if destinations is not a string, 

393 `tf.Variable` or `tf.distribute.DistributedValues`. 

394 """ 

395 raise NotImplementedError( 

396 "_gather method must be implemented in descendants.") 

397 

398 def batch_reduce(self, reduce_op, value_destination_pairs, options=None): 

399 """Reduce values to destinations in batches. 

400 

401 See `tf.distribute.StrategyExtended.batch_reduce_to`. This can only be 

402 called in the cross-replica context. 

403 

404 Args: 

405 reduce_op: a `tf.distribute.ReduceOp` specifying how values should be 

406 combined. 

407 value_destination_pairs: a sequence of (value, destinations) pairs. See 

408 `tf.distribute.CrossDeviceOps.reduce` for descriptions. 

409 options: a `tf.distribute.experimental.CommunicationOptions`. See 

410 `tf.distribute.experimental.CommunicationOptions` for details. 

411 

412 Returns: 

413 A list of `tf.Tensor` or `tf.distribute.DistributedValues`, one per pair 

414 in `value_destination_pairs`. 

415 

416 Raises: 

417 ValueError: if `value_destination_pairs` is not an iterable of 

418 tuples of `tf.distribute.DistributedValues` and destinations. 

419 """ 

420 if options is None: 

421 options = collective_util.Options() 

422 # TODO(yuefengz): if destinations are different, split into several 

423 # `_batch_reduce` invocations. 

424 if not _validate_value_destination_pairs(value_destination_pairs): 

425 # If the first element of each pair is a tensor, we try to turn it into a 

426 # PerReplica object. 

427 value_destination_pairs = _normalize_value_destination_pairs( 

428 value_destination_pairs) 

429 

430 for _, d in value_destination_pairs: 

431 validate_destinations(d) 

432 

433 # Shortcut all PerReplica objects only contain one value. 

434 if self._num_between_graph_workers == 1 and _all_devices_match( 

435 value_destination_pairs, self._canonicalize_devices) and len( 

436 value_destination_pairs[0][0].values) == 1: 

437 return [ 

438 distribute_utils.regroup(v.values, wrap_class=value_lib.Mirrored) 

439 for v, _ in value_destination_pairs 

440 ] 

441 

442 if options is None: 

443 options = collective_util.Options() 

444 return self.batch_reduce_implementation(reduce_op, value_destination_pairs, 

445 options) 

446 

447 def broadcast(self, tensor, destinations): 

448 """Broadcast `tensor` to `destinations`. 

449 

450 This can only be called in the cross-replica context. 

451 

452 Args: 

453 tensor: a `tf.Tensor` like object. The value to broadcast. 

454 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a 

455 `tf.Tensor` alike object, or a device string. It specifies the devices 

456 to broadcast to. Note that if it's a `tf.Variable`, the value is 

457 broadcasted to the devices of that variable, this method doesn't update 

458 the variable. 

459 

460 Returns: 

461 A `tf.Tensor` or `tf.distribute.DistributedValues`. 

462 """ 

463 validate_destinations(destinations) 

464 return self.broadcast_implementation(tensor, destinations) 

465 

466 @doc_controls.for_subclass_implementers 

467 def reduce_implementation(self, reduce_op, per_replica_value, destinations, 

468 options): 

469 """Implementation of `reduce`. 

470 

471 Overriding this method is useful for subclass implementers. 

472 

473 Args: 

474 reduce_op: a `tf.distribute.ReduceOp` specifying how values should be 

475 combined. 

476 per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` 

477 like object. 

478 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a 

479 `tf.Tensor` alike object, or a device string. It specifies the devices 

480 to reduce to. To perform an all-reduce, pass the same to `value` and 

481 `destinations`. Note that if it's a `tf.Variable`, the value is reduced 

482 to the devices of that variable, this method doesn't update the 

483 variable. 

484 options: a `tf.distribute.experimental.CommunicationOptions`. See 

485 `tf.distribute.experimental.CommunicationOptions` for details. 

486 

487 Returns: 

488 A `tf.Tensor` or `tf.distribute.DistributedValues`. 

489 

490 Raises: 

491 ValueError: if per_replica_value can't be converted to a 

492 `tf.distribute.DistributedValues` or if destinations is not a string, 

493 `tf.Variable` or `tf.distribute.DistributedValues`. 

494 """ 

495 raise NotImplementedError( 

496 "_reduce method must be implemented in descendants.") 

497 

498 @doc_controls.for_subclass_implementers 

499 def batch_reduce_implementation(self, reduce_op, value_destination_pairs, 

500 options): 

501 """Implementation of `batch_reduce`. 

502 

503 Overriding this method is useful for subclass implementers. 

504 

505 Args: 

506 reduce_op: a `tf.distribute.ReduceOp` specifying how values should be 

507 combined. 

508 value_destination_pairs: a sequence of (value, destinations) pairs. See 

509 `reduce` for descriptions. 

510 options: a `tf.distribute.experimental.CommunicationOptions`. See 

511 `tf.distribute.experimental.CommunicationOptions` for details. 

512 

513 Returns: 

514 A list of `tf.Tensor` or `tf.distribute.DistributedValues`, one per pair 

515 in `value_destination_pairs`. 

516 

517 Raises: 

518 ValueError: if `value_destination_pairs` is not an iterable of 

519 tuples of `tf.distribute.DistributedValues` and destinations. 

520 """ 

521 raise NotImplementedError( 

522 "batch_reduce_implementation method must be implemented in descendants." 

523 ) 

524 

525 @doc_controls.for_subclass_implementers 

526 def broadcast_implementation(self, tensor, destinations): 

527 """Implementation of `broadcast`. 

528 

529 Args: 

530 tensor: a `tf.Tensor` like object. The value to broadcast. 

531 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a 

532 `tf.Tensor` alike object, or a device string. It specifies the devices 

533 to broadcast to. 

534 `destinations`. Note that if it's a `tf.Variable`, the value is 

535 broadcasted to the devices of that variable, this method doesn't update 

536 the variable. 

537 

538 Returns: 

539 A `tf.Tensor` or `tf.distribute.DistributedValues`. 

540 """ 

541 return simple_broadcast( 

542 tensor, 

543 destinations, 

544 always_mirrored=True, 

545 canonicalize_devices=self._canonicalize_devices) 

546 

547 # ========================== Collective APIs ================================ 

548 # 

549 # Different than `reduce`, `batch_reduce` and `broadcast` which must be called 

550 # in cross-replcia context, collective APIs are to be called in replica 

551 # context. 

552 

553 def _all_reduce(self, reduce_op, value, replica_id, options): 

554 """All-reduce the `value` across all replicas so that all get the result. 

555 

556 `value` can be a nested structure of tensors or `IndexedSlices`. The 

557 implementation should generally batch the all-reduces when possible. 

558 `options` can be set to hint the batching behavior. 

559 

560 This API must be called in a replica context. 

561 

562 Args: 

563 reduce_op: A `tf.distribute.ReduceOp` value specifying how values should 

564 be combined. 

565 value: Value to be reduced. A tensor or a nested structure of tensors or 

566 `IndexedSlices`. 

567 replica_id: An interger indicating the id of the replica where this 

568 all_reduce is called under. This is the local replica id that ranges 

569 from 0 to len(local_devices) - 1. 

570 options: A `tf.distribute.experimental.CommunicationOptions`. 

571 

572 Returns: 

573 A tensor/IndexedSlices or a nested strucutre of tensors/IndexedSlices with 

574 the reduced values. The structure is the same as `value`. 

575 """ 

576 raise NotImplementedError("_all_reduce must be implemented in descendants.") 

577 

578 

579@tf_export("distribute.ReductionToOneDevice") 

580class ReductionToOneDevice(CrossDeviceOps): 

581 """A CrossDeviceOps implementation that copies values to one device to reduce. 

582 

583 This implementation always copies values to one device to reduce them, then 

584 broadcast reduced values to the destinations. It doesn't support efficient 

585 batching. 

586 

587 Here is how you can use `ReductionToOneDevice` in 

588 `tf.distribute.MirroredStrategy`: 

589 

590 ``` 

591 strategy = tf.distribute.MirroredStrategy( 

592 cross_device_ops=tf.distribute.ReductionToOneDevice()) 

593 ``` 

594 """ 

595 

596 def __init__(self, reduce_to_device=None, accumulation_fn=None): 

597 """Initializes with a device to reduce to and a way to accumulate. 

598 

599 Args: 

600 reduce_to_device: the intermediate device to reduce to. If None, reduce 

601 to the first device in `destinations` of the `reduce` method. 

602 accumulation_fn: a function that does accumulation. If None, 

603 `tf.math.add_n` is used. 

604 """ 

605 self.reduce_to_device = reduce_to_device 

606 self.accumulation_fn = accumulation_fn or math_ops.add_n 

607 super(ReductionToOneDevice, self).__init__() 

608 

609 def reduce_implementation(self, reduce_op, per_replica_value, destinations, 

610 options): 

611 del options # Unused. 

612 if check_destinations(destinations): 

613 devices = get_devices_from(destinations, self._canonicalize_devices) 

614 else: 

615 devices = get_devices_from(per_replica_value, self._canonicalize_devices) 

616 reduce_to_device = self.reduce_to_device or devices[0] 

617 logging.log_first_n( 

618 logging.INFO, 

619 "Reduce to %s then broadcast to %r." % (reduce_to_device, devices), 10) 

620 reduced = _simple_reduce(per_replica_value, reduce_to_device, 

621 self.accumulation_fn, reduce_op) 

622 return self.broadcast(reduced, destinations) 

623 

624 def _gather_implementation(self, per_replica_value, destinations, axis, 

625 options): 

626 del options # Unused. 

627 if check_destinations(destinations): 

628 devices = get_devices_from(destinations, self._canonicalize_devices) 

629 else: 

630 devices = get_devices_from(per_replica_value, self._canonicalize_devices) 

631 reduce_to_device = self.reduce_to_device or devices[0] 

632 logging.log_first_n( 

633 logging.INFO, 

634 "Gather to %s then broadcast to %r." % (reduce_to_device, devices), 10) 

635 gathered = _simple_gather(per_replica_value, reduce_to_device, axis) 

636 return self.broadcast(gathered, destinations) 

637 

638 def batch_reduce_implementation(self, reduce_op, value_destination_pairs, 

639 options): 

640 return [ 

641 self.reduce_implementation( 

642 reduce_op, t, destinations=v, options=options) 

643 for t, v in value_destination_pairs 

644 ] 

645 

646 

647def _group_value_by_device(per_replica_values): 

648 """Group values into sublists by their devices. 

649 

650 This grouping is needed to call the all-reduce library because it expects a 

651 list of the following form: 

652 [[(grad0_gpu0, v0_gpu0), (grad1_gpu0, v1_gpu0), (grad2_gpu0, v2_gpu0) ...], 

653 [(grad0_gpu1, v0_gpu1), (grad1_gpu1, v1_gpu1), (grad2_gpu1, v2_gpu1) ...], 

654 [(grad0_gpu2, v0_gpu2), (grad1_gpu0, v1_gpu2), (grad2_gpu0, v2_gpu2) ...], 

655 ... 

656 ] 

657 

658 Args: 

659 per_replica_values: a list of PerReplica objects. 

660 

661 Returns: 

662 a list of lists, each sublist has components for its corresponding device of 

663 PerReplica objects, paired with a None. 

664 """ 

665 destinations = per_replica_values[0]._devices # pylint: disable=protected-access 

666 grouped = [[] for _ in range(len(destinations))] 

667 for per_replica_value in per_replica_values: 

668 # pylint: disable=protected-access 

669 for i, v in enumerate(per_replica_value.values): 

670 assert per_replica_value._devices == destinations 

671 grouped[i].append((v, None)) 

672 return grouped 

673 

674 

675def _ungroup_and_make_mirrored(grouped_reduced, 

676 destinations, 

677 reduce_op, 

678 num_between_graph_workers=1): 

679 """Ungroup results from all-reduce and make Mirrored objects. 

680 

681 Each all-reduce result will be divided by the number of destinations before 

682 Mirrored objects are created if reduce_op is "mean". 

683 

684 Args: 

685 grouped_reduced: a list of lists, each sublist has components for each 

686 device, paired with a None. It is the result from 

687 cross_device_utils.aggregate_gradients_using*. 

688 destinations: a value to colocate the result with. 

689 reduce_op: Indicates how values will be aggregated. Accepted values 

690 are `tf.distribute.ReduceOp.SUM`, `tf.distribute.ReduceOp.MEAN`. 

691 num_between_graph_workers: number of workers in the between-graph 

692 replication. 

693 

694 Returns: 

695 a list of Mirrored objects. 

696 """ 

697 num_replicas = len(get_devices_from(destinations)) * num_between_graph_workers 

698 index = [[] for _ in range(len(grouped_reduced[0]))] 

699 for per_replica_reduced in grouped_reduced: 

700 for i, (v, _) in enumerate(per_replica_reduced): 

701 if reduce_op == reduce_util.ReduceOp.MEAN: 

702 with ops.device(v.device): 

703 index[i].append(v / num_replicas) 

704 else: 

705 index[i].append(v) 

706 return [distribute_utils.regroup( 

707 v, wrap_class=value_lib.Mirrored) for v in index] 

708 

709 

710class _ConcatAndSplitPacker(object): 

711 """Concatenate and split tensors for reduction.""" 

712 

713 def __init__(self, num_packs=1): 

714 """Initialize the _ConcatAndSplitPacker object. 

715 

716 Args: 

717 num_packs: specifies the number of split packs that will be 

718 formed. 

719 

720 Raises: 

721 ValueError: if num_packs is not greater than 0. 

722 """ 

723 if num_packs <= 0: 

724 raise ValueError("num_packs must be greater than zero.") 

725 self.num_packs = num_packs 

726 

727 def pack(self, grouped_grads_and_vars): 

728 """Pack tensors.""" 

729 self.grouped_grads_and_vars = grouped_grads_and_vars 

730 self.all_device_shapes = [] 

731 self.all_device_sizes = [] 

732 

733 device_grad_packs = [] 

734 for device_grads_and_vars in grouped_grads_and_vars: 

735 with ops.colocate_with(device_grads_and_vars[0][0]): 

736 # Flatten all the grads. 

737 flat_grads = [ 

738 array_ops.reshape(g, [-1]) for g, _ in device_grads_and_vars 

739 ] 

740 # Remember the original shape of all the grads. 

741 device_shapes = [array_ops.shape(g) for g, _ in device_grads_and_vars] 

742 # Remember the original sizes of all the grads. 

743 device_sizes = [array_ops.size(g) for g, _ in device_grads_and_vars] 

744 # Concat all the flat grads into a big flat tensor. 

745 concat_grads = array_ops.concat(flat_grads, 0) 

746 

747 # Split the big tensor into num_splits packs. In cases where the 

748 # total size is not divisible num_splits, the last pack gets 

749 # more elements. 

750 # TODO(zhengxq): it is also possible to optimize away all the concat 

751 # as well. 

752 num_splits = self.num_packs 

753 

754 # The array_ops.size function will sometimes remove static shapes. So if 

755 # all gradient shapes are defined, we use another method to get the 

756 # total size. 

757 # TODO(yuefengz): move this logic to array_ops.size. 

758 if all(g.shape.is_fully_defined() for g, _ in device_grads_and_vars): 

759 total_grad_size = sum( 

760 [g.shape.num_elements() for g, _ in device_grads_and_vars]) 

761 else: 

762 total_grad_size = array_ops.size(concat_grads) 

763 

764 split_size = total_grad_size // num_splits 

765 split_size_last = total_grad_size - split_size * (num_splits - 1) 

766 split_sizes = [split_size] * (num_splits - 1) + [split_size_last] 

767 grad_packs = array_ops.split(concat_grads, split_sizes) 

768 

769 # Ready to aggregate the repacked gradients, with fake variables. 

770 # TODO(zhengxq): It is hacky to have to use fake variables. 

771 # We should remove the need for variables in 

772 # aggregate_gradients_using*. 

773 device_grad_packs.append(zip(grad_packs, [None] * num_splits)) 

774 self.all_device_shapes.append(device_shapes) 

775 self.all_device_sizes.append(device_sizes) 

776 

777 return device_grad_packs 

778 

779 def unpack(self, summed_device_grad_packs): 

780 """Reverse the pack.""" 

781 aggregated_device_grads = [] 

782 for (summed_device_grad_packs, 

783 device_grads_and_vars, device_shapes, device_sizes) in zip( 

784 summed_device_grad_packs, self.grouped_grads_and_vars, 

785 self.all_device_shapes, self.all_device_sizes): 

786 # pylint: enable=line-too-long 

787 # Reverse the packing operations in the previous steps. Form the 

788 # summed gradients back into their original shapes. 

789 with ops.colocate_with(summed_device_grad_packs[0][0]): 

790 # Form a list of the summed grad packs. 

791 device_grad_packs = [g for g, _ in summed_device_grad_packs] 

792 

793 # Concat them back into a big flat tensor. 

794 device_grads_concat = array_ops.concat(device_grad_packs, 0) 

795 

796 # Split the tensors back into their original sizes. 

797 grads_with_sizes = array_ops.split(device_grads_concat, device_sizes) 

798 

799 # Reshape the tensors back into their original shapes. 

800 grads_with_shapes = [ 

801 array_ops.reshape(grad, shape) 

802 for shape, grad in zip(device_shapes, grads_with_sizes) 

803 ] 

804 

805 # Form the list with the original list of variables. 

806 summed_device_grads = [ 

807 (g, v) for g, (_, v) in zip(grads_with_shapes, 

808 device_grads_and_vars) 

809 ] 

810 aggregated_device_grads.append(summed_device_grads) 

811 return aggregated_device_grads 

812 

813 

814def _pack_tensors(device_grads, num_packs=0): 

815 """Pack tensors if specified.""" 

816 if num_packs > 0: 

817 tensor_packer = _ConcatAndSplitPacker(num_packs) 

818 device_grad_packs = tensor_packer.pack(device_grads) 

819 else: 

820 tensor_packer = None 

821 device_grad_packs = device_grads 

822 return device_grad_packs, tensor_packer 

823 

824 

825def _unpack_tensors(reduced, tensor_packer=None): 

826 """Unpack tensors if they are packed before all-reduce.""" 

827 if tensor_packer: 

828 return tensor_packer.unpack(reduced) 

829 return reduced 

830 

831 

832class AllReduceCrossDeviceOps(CrossDeviceOps): 

833 """All-reduce implementation of CrossDeviceOps. 

834 

835 It performs all-reduce when applicable using NCCL or hierarchical copy. For 

836 the batch API, tensors will be repacked or aggregated for more efficient 

837 cross-device transportation. 

838 

839 For reduces that are not all-reduce, it falls back to 

840 `tf.distribute.ReductionToOneDevice`. 

841 """ 

842 

843 def __init__(self, all_reduce_alg="nccl", num_packs=1): 

844 """Initializes the object. 

845 

846 Args: 

847 all_reduce_alg: the all-reduce algorithm to use, currently only "nccl" or 

848 "hierarchical_copy" are supported. 

849 num_packs: a non-negative integer. The number of packs to split values 

850 into. If zero, no packing will be done. 

851 """ 

852 self._all_reduce_alg = all_reduce_alg 

853 self._num_packs = num_packs 

854 self._simple_cross_replica_ops = ReductionToOneDevice() 

855 super(AllReduceCrossDeviceOps, self).__init__() 

856 

857 def reduce_implementation(self, reduce_op, per_replica_value, destinations, 

858 options): 

859 del options # Unused. 

860 # To use NCCL or all-reduce, source and destination devices should match, 

861 # and none of the devices should be CPU. 

862 if (_devices_match(per_replica_value, destinations) and 

863 not any("cpu" in d.lower() for d in get_devices_from(destinations))): 

864 return self._batch_all_reduce(reduce_op, [per_replica_value])[0] 

865 else: 

866 return self._simple_cross_replica_ops.reduce(reduce_op, per_replica_value, 

867 destinations) 

868 

869 def batch_reduce_implementation(self, reduce_op, value_destination_pairs, 

870 options): 

871 if _all_devices_match(value_destination_pairs): 

872 return self._batch_all_reduce(reduce_op, 

873 [v[0] for v in value_destination_pairs]) 

874 else: 

875 return [ 

876 self.reduce_implementation(reduce_op, value, dest, options) 

877 for value, dest in value_destination_pairs 

878 ] 

879 

880 def _batch_all_reduce(self, reduce_op, per_replica_values): 

881 """All-reduce algorithm in a batch.""" 

882 dense_values, dense_indices, sparse_values, sparse_indices = ( 

883 cross_device_utils.split_by_sparsity(per_replica_values)) 

884 if dense_values: 

885 dense_results = self._do_batch_all_reduce(reduce_op, dense_values) 

886 else: 

887 dense_results = [] 

888 if sparse_values: 

889 sparse_results = self._do_batch_all_reduce_sparse(reduce_op, 

890 sparse_values) 

891 else: 

892 sparse_results = [] 

893 return cross_device_utils.stitch_values(((dense_results, dense_indices), 

894 (sparse_results, sparse_indices))) 

895 

896 def _do_batch_all_reduce(self, reduce_op, dense_values): 

897 """Run batch all-reduces.""" 

898 logging.log_first_n( 

899 logging.INFO, 

900 "batch_all_reduce: %d all-reduces with algorithm = %s, num_packs = %d" % 

901 (len(dense_values), self._all_reduce_alg, self._num_packs), 10) 

902 

903 destinations = dense_values[0]._devices # pylint: disable=protected-access 

904 grouped = _group_value_by_device(dense_values) 

905 

906 # device_grad_packs: 

907 # [[(t0_gpu0, None), (t1_gpu0, None)], [(t0_gpu1, None), (t1_gpu1, None)]] 

908 device_grad_packs, tensor_packer = _pack_tensors(grouped, self._num_packs) 

909 

910 # The actual aggregation of the repacked gradients. Note that they are 

911 # sharded among different aggregation trees. So it is important to strike 

912 # the balance on num_splits. 

913 if self._all_reduce_alg == "nccl": 

914 # TODO(yuefengz): merge this into the all-reduce library. 

915 reduced = cross_device_utils.aggregate_gradients_using_nccl( 

916 device_grad_packs) 

917 else: 

918 # TODO(yuefengz): check that gpu ids in `destinations` are in ascending 

919 # order. 

920 reduced = ( 

921 cross_device_utils.aggregate_gradients_using_hierarchical_copy( 

922 destinations, device_grad_packs)) 

923 

924 reduced = _unpack_tensors(reduced, tensor_packer) 

925 return _ungroup_and_make_mirrored(reduced, dense_values[0], reduce_op) 

926 

927 def _do_batch_all_reduce_sparse(self, reduce_op, sparse_values): 

928 """Run batch all-reduce for sparse values.""" 

929 logging.log_first_n( 

930 logging.WARN, 

931 "Efficient allreduce is not supported for %d IndexedSlices" % 

932 len(sparse_values), 10) 

933 # Use `sparse_values` as destinations to do all-reduces. It is effectively 

934 # an allgather under the hood but not an efficient one. 

935 return self._simple_cross_replica_ops.batch_reduce( 

936 reduce_op, zip(sparse_values, sparse_values)) 

937 

938 def _gather_implementation(self, per_replica_value, destinations, axis, 

939 options): 

940 logging.log_first_n( 

941 logging.WARN, 

942 "gather/all_gather with NCCL or HierarchicalCopy is not supported. " 

943 "Falling back to gather on one device and then broadcast. We're working" 

944 " on a more efficient implementation.", 3) 

945 return ReductionToOneDevice()._gather(per_replica_value, destinations, axis, # pylint: disable=protected-access 

946 options) 

947 

948 

949# For compatibility with code using the old name of `AllReduceCrossDeviceOps`. 

950AllReduceCrossTowerOps = AllReduceCrossDeviceOps 

951 

952 

953AllReduceSpecTuple = collections.namedtuple("AllReduceSpecTuple", 

954 "alg shards limit") 

955 

956 

957@tf_export("distribute.NcclAllReduce") 

958class NcclAllReduce(AllReduceCrossDeviceOps): 

959 """NCCL all-reduce implementation of CrossDeviceOps. 

960 

961 It uses Nvidia NCCL for all-reduce. For the batch API, tensors will be 

962 repacked or aggregated for more efficient cross-device transportation. 

963 

964 For reduces that are not all-reduce, it falls back to 

965 `tf.distribute.ReductionToOneDevice`. 

966 

967 Here is how you can use `NcclAllReduce` in `tf.distribute.MirroredStrategy`: 

968 

969 

970 ``` 

971 strategy = tf.distribute.MirroredStrategy( 

972 cross_device_ops=tf.distribute.NcclAllReduce()) 

973 ``` 

974 """ 

975 

976 def __init__(self, num_packs=1): 

977 """Initializes the object. 

978 

979 Args: 

980 num_packs: a non-negative integer. The number of packs to split values 

981 into. If zero, no packing will be done. 

982 

983 Raises: 

984 ValueError: if `num_packs` is negative. 

985 """ 

986 if num_packs < 0: 

987 raise ValueError( 

988 "NCCL all-reduce requires num_packs >= 0, but {} is specified".format( 

989 num_packs)) 

990 super(NcclAllReduce, self).__init__( 

991 all_reduce_alg="nccl", num_packs=num_packs) 

992 

993 

994@tf_export("distribute.HierarchicalCopyAllReduce") 

995class HierarchicalCopyAllReduce(AllReduceCrossDeviceOps): 

996 """Hierarchical copy all-reduce implementation of CrossDeviceOps. 

997 

998 It reduces to one GPU along edges in some hierarchy and broadcasts back to 

999 each GPU along the same path. For the batch API, tensors will be repacked or 

1000 aggregated for more efficient cross-device transportation. 

1001 

1002 This is a reduction created for Nvidia DGX-1 which assumes GPUs connects like 

1003 that on DGX-1 machine. If you have different GPU inter-connections, it is 

1004 likely that it would be slower than `tf.distribute.ReductionToOneDevice`. 

1005 

1006 For reduces that are not all-reduce, it falls back to 

1007 `tf.distribute.ReductionToOneDevice`. 

1008 

1009 Here is how you can use `HierarchicalCopyAllReduce` in 

1010 `tf.distribute.MirroredStrategy`: 

1011 

1012 ``` 

1013 strategy = tf.distribute.MirroredStrategy( 

1014 cross_device_ops=tf.distribute.HierarchicalCopyAllReduce()) 

1015 ``` 

1016 """ 

1017 

1018 def __init__(self, num_packs=1): 

1019 """Initializes the object. 

1020 

1021 Args: 

1022 num_packs: a non-negative integer. The number of packs to split values 

1023 into. If zero, no packing will be done. 

1024 

1025 Raises: 

1026 ValueError if `num_packs` is negative. 

1027 """ 

1028 if num_packs < 0: 

1029 raise ValueError( 

1030 "HierarchicalCopy requires num_packs >= 0, but {} is specified" 

1031 .format(num_packs)) 

1032 super(HierarchicalCopyAllReduce, self).__init__( 

1033 all_reduce_alg="hierarchical_copy", 

1034 num_packs=num_packs) 

1035 

1036 

1037# TODO(crccw): remove after migrating all callers. 

1038CollectiveCommunication = collective_util.CommunicationImplementation 

1039CommunicationImplementation = collective_util.CommunicationImplementation 

1040 

1041 

1042# TODO(yuefengz): support in-graph collective all-reduce. 

1043class CollectiveAllReduce(CrossDeviceOps): 

1044 """All-reduce cross device ops using collective ops. 

1045 

1046 In the between-graph replicated training, it will still do all-reduces across 

1047 all workers and then put results on the right destinations. 

1048 """ 

1049 

1050 def __init__(self, 

1051 devices, 

1052 group_size, 

1053 options, 

1054 collective_keys=None, 

1055 canonicalize_devices=True): 

1056 """Initializes the object. 

1057 

1058 Args: 

1059 devices: a list of device strings to run collectives on. 

1060 group_size: the global group size. For between-graph replicated training 

1061 it's the total number of devices across all workers. 

1062 options: a `tf.distribute.experimental.CommunicationOptions`. 

1063 collective_keys: an optional CollectiveKey object. 

1064 canonicalize_devices: Whether to canonicalize devices for workers or not. 

1065 """ 

1066 if group_size % len(devices) > 0: 

1067 raise ValueError("group_size must be divisible by the number of devices.") 

1068 

1069 self._group_size = group_size 

1070 self._options = options 

1071 self._collective_keys = (collective_keys or 

1072 cross_device_utils.CollectiveKeys()) 

1073 # This lock guards all collective launches, i.e. calls to 

1074 # cross_device_utils.build_collectve_*. 

1075 # 

1076 # In a multi threaded eager program we need to ensure different groups of 

1077 # collectives don't interleave each other, otherwise there could be 

1078 # deadlocks. E.g. if two user threads both are launching collectives: 

1079 # user-thread-0 device0 device1 

1080 # user-thread-1 device0 device1 

1081 # In eager mode, we use one thread per device to launch collective ops, so 

1082 # the above launch sequences end up with the following queues: 

1083 # device-0 collective-0 collective-1 

1084 # device-1 collective-1 collective-0 

1085 # This deadlocks since neither collective is able to finish. 

1086 self._lock = threading.Lock() 

1087 

1088 if canonicalize_devices: 

1089 self._devices = tuple(device_util.canonicalize(d) for d in devices) 

1090 else: 

1091 self._devices = tuple( 

1092 device_util.canonicalize_without_job_and_task(d) for d in devices) 

1093 group_key = self._collective_keys.get_group_key(self._devices) 

1094 self._launchers = [] 

1095 # Whether to only use NCCL for batched all-reduce when NCCL is requested. 

1096 # This is because of the lack of mechanism to order NCCL operations 

1097 # deterministically. 

1098 self._limited_nccl = False 

1099 for device in self._devices: 

1100 launcher = cross_device_utils.CollectiveReplicaLauncher( 

1101 group_key, group_size, self._collective_keys, device, options) 

1102 self._launchers.append(launcher) 

1103 if not launcher.can_order_nccl(): 

1104 self._limited_nccl = True 

1105 

1106 super(CollectiveAllReduce, self).__init__() 

1107 self._canonicalize_devices = canonicalize_devices 

1108 

1109 @property 

1110 def _num_between_graph_workers(self): 

1111 # Currently we only support equal number of devices on each worker. 

1112 return self._group_size / len(self._devices) 

1113 

1114 def _all_reduce(self, reduce_op, value, replica_id, options): 

1115 """Implements CrossDeviceOps.all_reduce.""" 

1116 # TODO(b/122840926): reuse this method in _batch_all_reduce. 

1117 flat_values = nest.flatten(value) 

1118 

1119 # If NCCL launches can't be ordered (self._limited_nccl == True), we only 

1120 # use NCCL when batch_size > 1, hoping that there's only one batched 

1121 # all-reduce, which is the gradient aggregation in optimizer. For TF 2.x, 

1122 # NCCL launches are always ordered. 

1123 if (self._limited_nccl and options.implementation 

1124 == collective_util.CommunicationImplementation.NCCL and 

1125 len(flat_values) == 1): 

1126 options = options.merge( 

1127 collective_util.Options( 

1128 implementation=collective_util.CommunicationImplementation.RING)) 

1129 

1130 launcher = self._launchers[replica_id] 

1131 dense_values, dense_indices, sparse_values, sparse_indices = ( 

1132 cross_device_utils.split_by_sparsity(flat_values)) 

1133 dense_results = [] 

1134 sparse_results = [] 

1135 

1136 if dense_values: 

1137 # Reverse the lists so that there's better chance that values follows 

1138 # the order in which they are calculated (e.g. when they're gradients), so 

1139 # as to overlap calculation with communication. However, this may not be 

1140 # optimal for cases like gradients of complicated non-sequential models. 

1141 # 

1142 # Note that we reverse the list before packing so that the first pack 

1143 # won't be too small, since it's more likely for first few packs to have 

1144 # long queuing time due to concurrent intense computation. 

1145 # 

1146 # TODO(b/147393503): explore solutions for optimal ordering. 

1147 dense_values.reverse() 

1148 packs = cross_device_utils.group_by_size(dense_values, 

1149 options.bytes_per_pack) 

1150 

1151 if not context.executing_eagerly() and replica_id == 0: 

1152 logging.info( 

1153 "Collective all_reduce tensors: %d all_reduces, num_devices = %d, " 

1154 "group_size = %d, implementation = %s, num_packs = %d", 

1155 len(dense_values), len(self._launchers), self._group_size, 

1156 options.implementation, len(packs)) 

1157 

1158 dense_results = launcher.batch_all_reduce(packs, options) 

1159 if reduce_op == reduce_util.ReduceOp.MEAN: 

1160 for i, v in enumerate(dense_results): 

1161 with ops.device(self._devices[replica_id]): 

1162 dense_results[i] = v / self._group_size 

1163 dense_results.reverse() 

1164 

1165 if sparse_values: 

1166 if not context.executing_eagerly() and replica_id == 0: 

1167 logging.info( 

1168 "Collective all_reduce IndexedSlices: %d all_reduces, num_devices =" 

1169 "%d, group_size = %d, implementation = %s", len(sparse_values), 

1170 len(self._launchers), self._group_size, options.implementation) 

1171 

1172 for indexed_slice in sparse_values: 

1173 sparse_results.append( 

1174 launcher.all_reduce_indexed_slices(indexed_slice, options)) 

1175 

1176 if reduce_op == reduce_util.ReduceOp.MEAN: 

1177 for i, v in enumerate(sparse_results): 

1178 with ops.device(self._devices[replica_id]): 

1179 sparse_results[i] = indexed_slices.IndexedSlices( 

1180 values=sparse_results[i].values / self._group_size, 

1181 indices=sparse_results[i].indices, 

1182 dense_shape=sparse_results[i].dense_shape) 

1183 

1184 flat_results = cross_device_utils.stitch_values( 

1185 ((dense_results, dense_indices), (sparse_results, sparse_indices))) 

1186 return nest.pack_sequence_as(value, flat_results) 

1187 

1188 def _all_reduce_per_replica_values(self, reduce_op, per_replica_values, 

1189 options): 

1190 """All reduce a list of per_replica_value.""" 

1191 values_by_device = [[] for _ in self._devices] 

1192 num_devices = len(self._devices) 

1193 for per_replica in per_replica_values: 

1194 for i in range(num_devices): 

1195 values_by_device[i].append(per_replica.values[i]) 

1196 

1197 if context.executing_eagerly(): 

1198 

1199 def thread_fn(device_id): 

1200 with context.eager_mode(): 

1201 return self._all_reduce(reduce_op, values_by_device[device_id], 

1202 device_id, options) 

1203 

1204 with self._lock: 

1205 pool = multiprocessing.pool.ThreadPool(len(self._devices)) 

1206 outputs_by_device = pool.map(thread_fn, list(range(num_devices))) 

1207 pool.close() 

1208 else: 

1209 outputs_by_device = [] 

1210 with self._lock: 

1211 for i in range(num_devices): 

1212 outputs_by_device.append( 

1213 self._all_reduce(reduce_op, values_by_device[i], i, options)) 

1214 

1215 result = [] 

1216 for values in zip(*outputs_by_device): 

1217 result.append( 

1218 distribute_utils.regroup(values, wrap_class=value_lib.Mirrored)) 

1219 return result 

1220 

1221 def reduce_implementation(self, reduce_op, per_replica_value, destinations, 

1222 options): 

1223 values_util.mark_as_unsaveable() 

1224 all_reduced = self._all_reduce_per_replica_values(reduce_op, 

1225 [per_replica_value], 

1226 options)[0] 

1227 devices = get_devices_from(destinations, self._canonicalize_devices) 

1228 

1229 if _devices_match(per_replica_value, destinations, 

1230 self._canonicalize_devices): 

1231 return all_reduced 

1232 

1233 # Convert `all_reduced` to a `Mirrored` object, as a simple and uniform 

1234 # utility to access component for a particular device. 

1235 if not isinstance(all_reduced, value_lib.Mirrored): 

1236 all_reduced = value_lib.Mirrored([all_reduced]) 

1237 

1238 # If we got this far, the destination devices do not match the all-reduce 

1239 # devices, so we must map from one to the other. 

1240 index = [] 

1241 # We must add these control dependencies, otherwise we can get deadlock. 

1242 with ops.control_dependencies(all_reduced.values): 

1243 for d in devices: 

1244 with ops.device(d): 

1245 for v in all_reduced.values: 

1246 if v.device == d: 

1247 index.append(array_ops.identity(v)) 

1248 break 

1249 else: 

1250 # TODO(josh11b): Once we add support for model parallelism, get the 

1251 # copy from the corresponding replica instead of the primary. 

1252 index.append(array_ops.identity(all_reduced._primary)) # pylint: disable=protected-access 

1253 return distribute_utils.regroup(index, wrap_class=value_lib.Mirrored) 

1254 

1255 def batch_reduce_implementation(self, reduce_op, value_destination_pairs, 

1256 options): 

1257 values_util.mark_as_unsaveable() 

1258 all_devices_match = _all_devices_match(value_destination_pairs, 

1259 self._canonicalize_devices) 

1260 if all_devices_match: 

1261 return self._all_reduce_per_replica_values( 

1262 reduce_op, [v[0] for v in value_destination_pairs], options) 

1263 else: 

1264 if not all_devices_match: 

1265 logging.log_first_n( 

1266 logging.WARN, "Efficient batch_reduce is not supported if " 

1267 "destinations are different.", 10) 

1268 

1269 return [ 

1270 self.reduce_implementation(reduce_op, value, dest, options) 

1271 for value, dest in value_destination_pairs 

1272 ] 

1273 

1274 def _gather_implementation(self, per_replica_value, destinations, axis, 

1275 options): 

1276 all_gathered = self._batch_all_gather([per_replica_value], axis, options)[0] 

1277 values_util.mark_as_unsaveable() 

1278 devices = get_devices_from(destinations, self._canonicalize_devices) 

1279 

1280 if _devices_match(per_replica_value, destinations, 

1281 self._canonicalize_devices): 

1282 return all_gathered 

1283 

1284 # Convert `all_gathered` to a `Mirrored` object, as a simple and uniform 

1285 # utility to access component for a particular device. 

1286 if not isinstance(all_gathered, value_lib.Mirrored): 

1287 all_gathered = value_lib.Mirrored([all_gathered]) 

1288 

1289 # If we got this far, the destination devices do not match the all-gather 

1290 # devices, so we must map from one to the other. 

1291 index = [] 

1292 # We must add these control dependencies, otherwise we can get deadlock. 

1293 with ops.control_dependencies(all_gathered.values): 

1294 for d in devices: 

1295 with ops.device(d): 

1296 for v in all_gathered.values: 

1297 if v.device == d: 

1298 index.append(array_ops.identity(v)) 

1299 break 

1300 else: 

1301 index.append(array_ops.identity(all_gathered._primary)) # pylint: disable=protected-access 

1302 return distribute_utils.regroup(index, wrap_class=value_lib.Mirrored) 

1303 

1304 def _batch_all_gather(self, per_replica_values, axis, options): 

1305 """all gather multiple per-replica-values.""" 

1306 batch_size = len(per_replica_values) 

1307 # For now, we use NCCL only when batch_size > 1. 

1308 # TODO(b/132575814): switch to NCCL for all collectives when implementation 

1309 # is NCCL. 

1310 if (self._limited_nccl and options.implementation 

1311 == collective_util.CommunicationImplementation.NCCL and 

1312 batch_size == 1): 

1313 options = options.merge( 

1314 collective_util.Options( 

1315 implementation=collective_util.CommunicationImplementation.RING)) 

1316 

1317 logging.log_first_n( 

1318 logging.INFO, "Collective batch_all_gather: %d all-gathers, " 

1319 "num_devices = %d, group_size = %d, implementation = %s, " % 

1320 (batch_size, len( 

1321 self._devices), self._group_size, options.implementation), 10) 

1322 

1323 def compute_gathered_values(): 

1324 gathered_values = [] 

1325 with self._lock, ops.name_scope("allgather"): 

1326 for per_replica in per_replica_values: 

1327 outputs = [] 

1328 for i in range(len(self._devices)): 

1329 outputs.append(self._launchers[i].all_gather( 

1330 per_replica.values[i], axis, options)) 

1331 gathered_values.append(outputs) 

1332 return gathered_values 

1333 

1334 if context.executing_eagerly(): 

1335 gathered_values = def_function.function(compute_gathered_values)() 

1336 else: 

1337 gathered_values = compute_gathered_values() 

1338 

1339 mirrored = [] 

1340 for value in gathered_values: 

1341 mirrored.append( 

1342 distribute_utils.regroup(value, wrap_class=value_lib.Mirrored)) 

1343 return mirrored 

1344 

1345 def __deepcopy__(self, memo): 

1346 # distribute_coordinator deep-copies the strategy object, so 

1347 # CollectiveAllReduce needs to support deep copy as well. 

1348 collective_keys = copy.deepcopy(self._collective_keys, memo) 

1349 return CollectiveAllReduce(self._devices, self._group_size, self._options, 

1350 collective_keys, self._canonicalize_devices) 

1351 

1352 

1353def select_cross_device_ops(devices, session_config=None): 

1354 """Find the best `CrossDeviceOps` locally given a `tf.compat.v1.ConfigProto`. 

1355 

1356 Args: 

1357 devices: a list of devices passed to `tf.distribute.Strategy`. 

1358 session_config: a `tf.compat.v1.ConfigProto` or `None`. If `None`, it will 

1359 make decision based on all logical devices. 

1360 

1361 Returns: 

1362 A subclass of `CrossDeviceOps`. 

1363 """ 

1364 requested_devices = set(device_util.canonicalize(d) for d in devices) 

1365 if ops.executing_eagerly_outside_functions(): 

1366 logical_gpus = context.context().list_logical_devices(device_type="GPU") 

1367 physical_gpus = context.context().list_physical_devices(device_type="GPU") 

1368 if len(logical_gpus) != len(physical_gpus): 

1369 logging.warning("NCCL is not supported when using virtual GPUs, falling" 

1370 "back to reduction to one device") 

1371 return ReductionToOneDevice() 

1372 

1373 machine_devices = context.context().list_logical_devices() 

1374 else: 

1375 machine_devices = device_lib.list_local_devices( 

1376 session_config=session_config) 

1377 using_devices = set() 

1378 for d in machine_devices: 

1379 if device_util.canonicalize(d.name) in requested_devices: 

1380 using_devices.add(d.name) 

1381 

1382 if len(using_devices) != len(requested_devices): 

1383 logging.warning( 

1384 "Some requested devices in `tf.distribute.Strategy` are not visible " 

1385 "to TensorFlow: %s", ",".join(list(requested_devices - using_devices))) 

1386 

1387 if any("gpu" not in d.lower() for d in requested_devices): 

1388 logging.warning("There are non-GPU devices in `tf.distribute.Strategy`, " 

1389 "not using nccl allreduce.") 

1390 return ReductionToOneDevice() 

1391 

1392 if kernels.get_registered_kernels_for_op("NcclAllReduce"): 

1393 return NcclAllReduce(num_packs=1) 

1394 else: 

1395 logging.warning("Nccl kernel is not found, not using nccl allreduce.") 

1396 return ReductionToOneDevice()