Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py: 36%

983 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15# pylint: disable=line-too-long 

16"""Library for running a computation across multiple devices. 

17 

18The intent of this library is that you can write an algorithm in a stylized way 

19and it will be usable with a variety of different `tf.distribute.Strategy` 

20implementations. Each descendant will implement a different strategy for 

21distributing the algorithm across multiple devices/machines. Furthermore, these 

22changes can be hidden inside the specific layers and other library classes that 

23need special treatment to run in a distributed setting, so that most users' 

24model definition code can run unchanged. The `tf.distribute.Strategy` API works 

25the same way with eager and graph execution. 

26 

27*Guides* 

28 

29* [TensorFlow v2.x](https://www.tensorflow.org/guide/distributed_training) 

30* [TensorFlow v1.x](https://github.com/tensorflow/docs/blob/master/site/en/r1/guide/distribute_strategy.ipynb) 

31 

32*Tutorials* 

33 

34* [Distributed Training Tutorials](https://www.tensorflow.org/tutorials/distribute/) 

35 

36 The tutorials cover how to use `tf.distribute.Strategy` to do distributed 

37 training with native Keras APIs, custom training loops, 

38 and Estimator APIs. They also cover how to save/load model when using 

39 `tf.distribute.Strategy`. 

40 

41*Glossary* 

42 

43* _Data parallelism_ is where we run multiple copies of the model 

44 on different slices of the input data. This is in contrast to 

45 _model parallelism_ where we divide up a single copy of a model 

46 across multiple devices. 

47 Note: we only support data parallelism for now, but 

48 hope to add support for model parallelism in the future. 

49* A _device_ is a CPU or accelerator (e.g. GPUs, TPUs) on some machine that 

50 TensorFlow can run operations on (see e.g. `tf.device`). You may have multiple 

51 devices on a single machine, or be connected to devices on multiple 

52 machines. Devices used to run computations are called _worker devices_. 

53 Devices used to store variables are _parameter devices_. For some strategies, 

54 such as `tf.distribute.MirroredStrategy`, the worker and parameter devices 

55 will be the same (see mirrored variables below). For others they will be 

56 different. For example, `tf.distribute.experimental.CentralStorageStrategy` 

57 puts the variables on a single device (which may be a worker device or may be 

58 the CPU), and `tf.distribute.experimental.ParameterServerStrategy` puts the 

59 variables on separate machines called _parameter servers_ (see below). 

60* A _replica_ is one copy of the model, running on one slice of the 

61 input data. Right now each replica is executed on its own 

62 worker device, but once we add support for model parallelism 

63 a replica may span multiple worker devices. 

64* A _host_ is the CPU device on a machine with worker devices, typically 

65 used for running input pipelines. 

66* A _worker_ is defined to be the physical machine(s) containing the physical 

67 devices (e.g. GPUs, TPUs) on which the replicated computation is executed. A 

68 worker may contain one or more replicas, but contains at least one 

69 replica. Typically one worker will correspond to one machine, but in the case 

70 of very large models with model parallelism, one worker may span multiple 

71 machines. We typically run one input pipeline per worker, feeding all the 

72 replicas on that worker. 

73* _Synchronous_, or more commonly _sync_, training is where the updates from 

74 each replica are aggregated together before updating the model variables. This 

75 is in contrast to _asynchronous_, or _async_ training, where each replica 

76 updates the model variables independently. You may also have replicas 

77 partitioned into groups which are in sync within each group but async between 

78 groups. 

79* _Parameter servers_: These are machines that hold a single copy of 

80 parameters/variables, used by some strategies (right now just 

81 `tf.distribute.experimental.ParameterServerStrategy`). All replicas that want 

82 to operate on a variable retrieve it at the beginning of a step and send an 

83 update to be applied at the end of the step. These can in principle support 

84 either sync or async training, but right now we only have support for async 

85 training with parameter servers. Compare to 

86 `tf.distribute.experimental.CentralStorageStrategy`, which puts all variables 

87 on a single device on the same machine (and does sync training), and 

88 `tf.distribute.MirroredStrategy`, which mirrors variables to multiple devices 

89 (see below). 

90 

91* _Replica context_ vs. _Cross-replica context_ vs _Update context_ 

92 

93 A _replica context_ applies 

94 when you execute the computation function that was called with `strategy.run`. 

95 Conceptually, you're in replica context when executing the computation 

96 function that is being replicated. 

97 

98 An _update context_ is entered in a `tf.distribute.StrategyExtended.update` 

99 call. 

100 

101 An _cross-replica context_ is entered when you enter a `strategy.scope`. This 

102 is useful for calling `tf.distribute.Strategy` methods which operate across 

103 the replicas (like `reduce_to()`). By default you start in a _replica context_ 

104 (the "default single _replica context_") and then some methods can switch you 

105 back and forth. 

106 

107* _Distributed value_: Distributed value is represented by the base class 

108 `tf.distribute.DistributedValues`. `tf.distribute.DistributedValues` is useful 

109 to represent values on multiple devices, and it contains a map from replica id 

110 to values. Two representative types of `tf.distribute.DistributedValues` 

111 are `tf.types.experimental.PerReplica` and `tf.types.experimental.Mirrored` 

112 values. 

113 

114 `PerReplica` values exist on the worker devices, with a different value for 

115 each replica. They are produced by iterating through a distributed dataset 

116 returned by `tf.distribute.Strategy.experimental_distribute_dataset` and 

117 `tf.distribute.Strategy.distribute_datasets_from_function`. They are also the 

118 typical result returned by `tf.distribute.Strategy.run`. 

119 

120 `Mirrored` values are like `PerReplica` values, except we know that the value 

121 on all replicas are the same. `Mirrored` values are kept synchronized by the 

122 distribution strategy in use, while `PerReplica` values are left 

123 unsynchronized. `Mirrored` values typically represent model weights. We can 

124 safely read a `Mirrored` value in a cross-replica context by using the value 

125 on any replica, while PerReplica values can only be read within a replica 

126 context. 

127 

128* _Unwrapping_ and _merging_: Consider calling a function `fn` on multiple 

129 replicas, like `strategy.run(fn, args=[w])` with an 

130 argument `w` that is a `tf.distribute.DistributedValues`. This means `w` will 

131 have a map taking replica id `0` to `w0`, replica id `1` to `w1`, etc. 

132 `strategy.run()` unwraps `w` before calling `fn`, so it calls `fn(w0)` on 

133 device `d0`, `fn(w1)` on device `d1`, etc. It then merges the return 

134 values from `fn()`, which leads to one common object if the returned values 

135 are the same object from every replica, or a `DistributedValues` object 

136 otherwise. 

137 

138* _Reductions_ and _all-reduce_: A _reduction_ is a method of aggregating 

139 multiple values into one value, like "sum" or "mean". If a strategy is doing 

140 sync training, we will perform a reduction on the gradients to a parameter 

141 from all replicas before applying the update. _All-reduce_ is an algorithm for 

142 performing a reduction on values from multiple devices and making the result 

143 available on all of those devices. 

144 

145* _Mirrored variables_: These are variables that are created on multiple 

146 devices, where we keep the variables in sync by applying the same 

147 updates to every copy. Mirrored variables are created with 

148 `tf.Variable(...synchronization=tf.VariableSynchronization.ON_WRITE...)`. 

149 Normally they are only used in synchronous training. 

150 

151* _SyncOnRead variables_ 

152 

153 _SyncOnRead variables_ are created by 

154 `tf.Variable(...synchronization=tf.VariableSynchronization.ON_READ...)`, and 

155 they are created on multiple devices. In replica context, each 

156 component variable on the local replica can perform reads and writes without 

157 synchronization with each other. When the 

158 _SyncOnRead variable_ is read in cross-replica context, the values from 

159 component variables are aggregated and returned. 

160 

161 _SyncOnRead variables_ bring a lot of custom configuration difficulty to the 

162 underlying logic, so we do not encourage users to instantiate and use 

163 _SyncOnRead variable_ on their own. We have mainly used _SyncOnRead 

164 variables_ for use cases such as batch norm and metrics. For performance 

165 reasons, we often don't need to keep these statistics in sync every step and 

166 they can be accumulated on each replica independently. The only time we want 

167 to sync them is reporting or checkpointing, which typically happens in 

168 cross-replica context. _SyncOnRead variables_ are also often used by advanced 

169 users who want to control when variable values are aggregated. For example, 

170 users sometimes want to maintain gradients independently on each replica for a 

171 couple of steps without aggregation. 

172 

173* _Distribute-aware layers_ 

174 

175 Layers are generally called in a replica context, except when defining a 

176 Keras functional model. `tf.distribute.in_cross_replica_context` will let you 

177 determine which case you are in. If in a replica context, 

178 the `tf.distribute.get_replica_context` function will return the default 

179 replica context outside a strategy scope, `None` within a strategy scope, and 

180 a `tf.distribute.ReplicaContext` object inside a strategy scope and within a 

181 `tf.distribute.Strategy.run` function. The `ReplicaContext` object has an 

182 `all_reduce` method for aggregating across all replicas. 

183 

184 

185Note that we provide a default version of `tf.distribute.Strategy` that is 

186used when no other strategy is in scope, that provides the same API with 

187reasonable default behavior. 

188""" 

189# pylint: enable=line-too-long 

190 

191import collections 

192import contextlib 

193import copy 

194import enum # pylint: disable=g-bad-import-order 

195import functools 

196import threading 

197import weakref 

198 

199import six 

200 

201from tensorflow.python import tf2 

202from tensorflow.python.autograph.core import ag_ctx as autograph_ctx 

203from tensorflow.python.autograph.impl import api as autograph 

204from tensorflow.python.data.ops import dataset_ops 

205from tensorflow.python.distribute import collective_util 

206from tensorflow.python.distribute import device_util 

207from tensorflow.python.distribute import numpy_dataset 

208from tensorflow.python.distribute import reduce_util 

209from tensorflow.python.eager import context as eager_context 

210from tensorflow.python.eager import def_function 

211from tensorflow.python.eager import monitoring 

212from tensorflow.python.eager import tape 

213from tensorflow.python.framework import constant_op 

214from tensorflow.python.framework import dtypes 

215from tensorflow.python.framework import indexed_slices 

216from tensorflow.python.framework import ops 

217from tensorflow.python.framework import tensor_shape 

218from tensorflow.python.framework import tensor_util 

219from tensorflow.python.ops import array_ops 

220from tensorflow.python.ops import control_flow_ops 

221from tensorflow.python.ops import custom_gradient 

222from tensorflow.python.ops import math_ops 

223from tensorflow.python.ops import ref_variable 

224from tensorflow.python.ops import summary_ops_v2 

225from tensorflow.python.ops import variable_scope 

226from tensorflow.python.ops import variable_v1 

227from tensorflow.python.platform import tf_logging 

228from tensorflow.python.trackable import base as trackable 

229from tensorflow.python.types import distribute as ds_types 

230from tensorflow.python.util import deprecation 

231from tensorflow.python.util import nest 

232from tensorflow.python.util import tf_contextlib 

233from tensorflow.python.util.deprecation import deprecated 

234from tensorflow.python.util.tf_export import tf_export 

235from tensorflow.tools.docs import doc_controls 

236 

237# ------------------------------------------------------------------------------ 

238# Context tracking whether in a strategy.update() or .update_non_slot() call. 

239 

240 

241_update_replica_id = threading.local() 

242 

243 

244def get_update_replica_id(): 

245 """Get the current device if in a `tf.distribute.Strategy.update()` call.""" 

246 try: 

247 return _update_replica_id.current 

248 except AttributeError: 

249 return None 

250 

251 

252class UpdateContext(object): 

253 """Context manager when you are in `update()` or `update_non_slot()`.""" 

254 

255 __slots__ = ["_replica_id", "_old_replica_id"] 

256 

257 def __init__(self, replica_id): 

258 self._replica_id = replica_id 

259 self._old_replica_id = None 

260 

261 def __enter__(self): 

262 self._old_replica_id = get_update_replica_id() 

263 _update_replica_id.current = self._replica_id 

264 

265 def __exit__(self, exception_type, exception_value, traceback): 

266 del exception_type, exception_value, traceback 

267 _update_replica_id.current = self._old_replica_id 

268 

269 

270# ------------------------------------------------------------------------------ 

271# Internal API for validating the current thread mode 

272 

273 

274def _require_cross_replica_or_default_context_extended(extended, 

275 error_message=None): 

276 """Verify in cross-replica context.""" 

277 context = _get_per_thread_mode() 

278 cross_replica = context.cross_replica_context 

279 if cross_replica is not None and cross_replica.extended is extended: 

280 return 

281 if context is _get_default_replica_mode(): 

282 return 

283 strategy = extended._container_strategy() # pylint: disable=protected-access 

284 # We have an error to report, figure out the right message. 

285 if context.strategy is not strategy: 

286 _wrong_strategy_scope(strategy, context) 

287 assert cross_replica is None 

288 if not error_message: 

289 error_message = ("Method requires being in cross-replica context, use " 

290 "get_replica_context().merge_call()") 

291 raise RuntimeError(error_message) 

292 

293 

294def _wrong_strategy_scope(strategy, context): 

295 # Figure out the right error message. 

296 if not has_strategy(): 

297 raise RuntimeError( 

298 'Need to be inside "with strategy.scope()" for %s' % 

299 (strategy,)) 

300 else: 

301 raise RuntimeError( 

302 "Mixing different tf.distribute.Strategy objects: %s is not %s" % 

303 (context.strategy, strategy)) 

304 

305 

306def require_replica_context(replica_ctx): 

307 """Verify in `replica_ctx` replica context.""" 

308 context = _get_per_thread_mode() 

309 if context.replica_context is replica_ctx: return 

310 # We have an error to report, figure out the right message. 

311 if context.replica_context is None: 

312 raise RuntimeError("Need to be inside `call_for_each_replica()`") 

313 if context.strategy is replica_ctx.strategy: 

314 # Two different ReplicaContexts with the same tf.distribute.Strategy. 

315 raise RuntimeError("Mismatching ReplicaContext.") 

316 raise RuntimeError( 

317 "Mismatching tf.distribute.Strategy objects: %s is not %s." % 

318 (context.strategy, replica_ctx.strategy)) 

319 

320 

321def _require_strategy_scope_strategy(strategy): 

322 """Verify in a `strategy.scope()` in this thread.""" 

323 context = _get_per_thread_mode() 

324 if context.strategy is strategy: return 

325 _wrong_strategy_scope(strategy, context) 

326 

327 

328def _require_strategy_scope_extended(extended): 

329 """Verify in a `distribution_strategy.scope()` in this thread.""" 

330 context = _get_per_thread_mode() 

331 if context.strategy.extended is extended: return 

332 # Report error. 

333 strategy = extended._container_strategy() # pylint: disable=protected-access 

334 _wrong_strategy_scope(strategy, context) 

335 

336 

337_creating_default_strategy_singleton = False 

338 

339# ------------------------------------------------------------------------------ 

340# Internal API for setting the current thread mode as being either in a 

341# replica or cross-replica context for a particular tf.distribute.Strategy. 

342 

343 

344class _ThreadMode(object): 

345 

346 def __init__(self, dist, cross, replica): 

347 self.strategy = dist 

348 self.cross_replica_context = cross 

349 self.replica_context = replica 

350 

351 

352class _CrossReplicaThreadMode(_ThreadMode): 

353 

354 def __init__(self, strategy): 

355 _ThreadMode.__init__(self, strategy, strategy, None) 

356 

357 

358class _InReplicaThreadMode(_ThreadMode): 

359 

360 def __init__(self, replica_ctx): 

361 _ThreadMode.__init__(self, replica_ctx.strategy, None, replica_ctx) 

362 

363 

364def _push_per_thread_mode(context): 

365 ops.get_default_graph()._distribution_strategy_stack.append(context) # pylint: disable=protected-access 

366 

367 

368def _pop_per_thread_mode(): 

369 ops.get_default_graph()._distribution_strategy_stack.pop(-1) # pylint: disable=protected-access 

370 

371 

372class _DefaultReplicaThreadMode(_ThreadMode): 

373 """Type of default value returned by `_get_per_thread_mode()`. 

374 

375 Used when the thread-local stack is empty. 

376 """ 

377 

378 def __init__(self): 

379 _ThreadMode.__init__(self, _get_default_strategy(), None, 

380 _get_default_replica_context()) 

381 

382 

383def _get_per_thread_mode(): 

384 try: 

385 return ops.get_default_graph()._distribution_strategy_stack[-1] # pylint: disable=protected-access 

386 except (AttributeError, IndexError): 

387 return _get_default_replica_mode() 

388 

389 

390_variable_sync_on_read_context = threading.local() 

391 

392 

393@tf_export("__internal__.distribute.variable_sync_on_read_context", v1=[]) 

394@contextlib.contextmanager 

395def variable_sync_on_read_context(): 

396 """A context that forces SyncOnReadVariable to aggregate upon reading. 

397 

398 This context is useful if one wants to read the aggregated value out of a 

399 SyncOnReadVariable in replica context. By default the aggregation is turned 

400 off per the definition of SyncOnReadVariable. 

401 

402 When reading a SyncOnReadVariable in cross-replica context, aggregation is 

403 always turned on so there is no need for such context. 

404 

405 By reading a SyncOnReadVariable, we mean: 

406 1. Convert the variable to a tensor using `convert_to_tensor`. 

407 2. Calling `variable.value()` or `variable.read_value()`. 

408 

409 Example usage: 

410 

411 ``` 

412 strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"]) 

413 with strategy.scope(): 

414 v = tf.Variable(1.0, synchronization=tf.VariableSynchronization.ON_READ, 

415 aggregation=tf.VariableAggregation.SUM) 

416 

417 def replica_fn(): 

418 return v + 10.0 

419 

420 non_aggregated = strategy.run(replica_fn) 

421 print(non_aggregated) # PerReplica: {0: 11.0, 1: 11.0} 

422 

423 def replica_fn(): 

424 with variable_sync_on_read_context(): 

425 return v + 10.0 

426 

427 aggregated = strategy.run(replica_fn) 

428 print(aggregated) # PerReplica: {0: 12.0, 1: 12.0} 

429 ``` 

430 

431 Yields: 

432 Context manager for aggregating SyncOnReadVariable upon reading. 

433 """ 

434 try: 

435 _variable_sync_on_read_context.entered = True 

436 yield 

437 finally: 

438 _variable_sync_on_read_context.entered = False 

439 

440 

441def in_variable_sync_on_read_context(): 

442 try: 

443 return _variable_sync_on_read_context.entered 

444 except AttributeError: 

445 return False 

446 

447# ------------------------------------------------------------------------------ 

448# Public API for accessing the current thread mode 

449 

450 

451@tf_export("distribute.get_replica_context") 

452def get_replica_context(): 

453 """Returns the current `tf.distribute.ReplicaContext` or `None`. 

454 

455 Returns `None` if in a cross-replica context. 

456 

457 Note that execution: 

458 

459 1. starts in the default (single-replica) replica context (this function 

460 will return the default `ReplicaContext` object); 

461 2. switches to cross-replica context (in which case this will return 

462 `None`) when entering a `with tf.distribute.Strategy.scope():` block; 

463 3. switches to a (non-default) replica context inside `strategy.run(fn, ...)`; 

464 4. if `fn` calls `get_replica_context().merge_call(merge_fn, ...)`, then 

465 inside `merge_fn` you are back in the cross-replica context (and again 

466 this function will return `None`). 

467 

468 Most `tf.distribute.Strategy` methods may only be executed in 

469 a cross-replica context, in a replica context you should use the 

470 API of the `tf.distribute.ReplicaContext` object returned by this 

471 method instead. 

472 

473 ``` 

474 assert tf.distribute.get_replica_context() is not None # default 

475 with strategy.scope(): 

476 assert tf.distribute.get_replica_context() is None 

477 

478 def f(): 

479 replica_context = tf.distribute.get_replica_context() # for strategy 

480 assert replica_context is not None 

481 tf.print("Replica id: ", replica_context.replica_id_in_sync_group, 

482 " of ", replica_context.num_replicas_in_sync) 

483 

484 strategy.run(f) 

485 ``` 

486 

487 Returns: 

488 The current `tf.distribute.ReplicaContext` object when in a replica context 

489 scope, else `None`. 

490 

491 Within a particular block, exactly one of these two things will be true: 

492 

493 * `get_replica_context()` returns non-`None`, or 

494 * `tf.distribute.is_cross_replica_context()` returns True. 

495 """ 

496 return _get_per_thread_mode().replica_context 

497 

498 

499def get_cross_replica_context(): 

500 """Returns the current tf.distribute.Strategy if in a cross-replica context. 

501 

502 DEPRECATED: Please use `in_cross_replica_context()` and 

503 `get_strategy()` instead. 

504 

505 Returns: 

506 Returns the current `tf.distribute.Strategy` object in a cross-replica 

507 context, or `None`. 

508 

509 Exactly one of `get_replica_context()` and `get_cross_replica_context()` 

510 will return `None` in a particular block. 

511 """ 

512 return _get_per_thread_mode().cross_replica_context 

513 

514 

515@tf_export("distribute.in_cross_replica_context") 

516def in_cross_replica_context(): 

517 """Returns `True` if in a cross-replica context. 

518 

519 See `tf.distribute.get_replica_context` for details. 

520 

521 ``` 

522 assert not tf.distribute.in_cross_replica_context() 

523 with strategy.scope(): 

524 assert tf.distribute.in_cross_replica_context() 

525 

526 def f(): 

527 assert not tf.distribute.in_cross_replica_context() 

528 

529 strategy.run(f) 

530 ``` 

531 

532 Returns: 

533 `True` if in a cross-replica context (`get_replica_context()` returns 

534 `None`), or `False` if in a replica context (`get_replica_context()` returns 

535 non-`None`). 

536 """ 

537 return _get_per_thread_mode().cross_replica_context is not None 

538 

539 

540@tf_export("distribute.get_strategy") 

541def get_strategy(): 

542 """Returns the current `tf.distribute.Strategy` object. 

543 

544 Typically only used in a cross-replica context: 

545 

546 ``` 

547 if tf.distribute.in_cross_replica_context(): 

548 strategy = tf.distribute.get_strategy() 

549 ... 

550 ``` 

551 

552 Returns: 

553 A `tf.distribute.Strategy` object. Inside a `with strategy.scope()` block, 

554 it returns `strategy`, otherwise it returns the default (single-replica) 

555 `tf.distribute.Strategy` object. 

556 """ 

557 return _get_per_thread_mode().strategy 

558 

559 

560@tf_export("distribute.has_strategy") 

561def has_strategy(): 

562 """Return if there is a current non-default `tf.distribute.Strategy`. 

563 

564 ``` 

565 assert not tf.distribute.has_strategy() 

566 with strategy.scope(): 

567 assert tf.distribute.has_strategy() 

568 ``` 

569 

570 Returns: 

571 True if inside a `with strategy.scope():`. 

572 """ 

573 return get_strategy() is not _get_default_strategy() 

574 

575 

576def get_strategy_and_replica_context(): 

577 per_thread_mode = _get_per_thread_mode() 

578 return (per_thread_mode.strategy, per_thread_mode.replica_context) 

579 

580 

581@tf_export("distribute.experimental_set_strategy") 

582def experimental_set_strategy(strategy): 

583 """Set a `tf.distribute.Strategy` as current without `with strategy.scope()`. 

584 

585 ``` 

586 tf.distribute.experimental_set_strategy(strategy1) 

587 f() 

588 tf.distribute.experimental_set_strategy(strategy2) 

589 g() 

590 tf.distribute.experimental_set_strategy(None) 

591 h() 

592 ``` 

593 

594 is equivalent to: 

595 

596 ``` 

597 with strategy1.scope(): 

598 f() 

599 with strategy2.scope(): 

600 g() 

601 h() 

602 ``` 

603 

604 In general, you should use the `with strategy.scope():` API, but this 

605 alternative may be convenient in notebooks where you would have to put 

606 each cell in a `with strategy.scope():` block. 

607 

608 Note: This should only be called outside of any TensorFlow scope to 

609 avoid improper nesting. 

610 

611 Args: 

612 strategy: A `tf.distribute.Strategy` object or None. 

613 

614 Raises: 

615 RuntimeError: If called inside a `with strategy.scope():`. 

616 """ 

617 old_scope = ops.get_default_graph()._global_distribute_strategy_scope # pylint: disable=protected-access 

618 if old_scope is not None: 

619 old_scope.__exit__(None, None, None) 

620 ops.get_default_graph()._global_distribute_strategy_scope = None # pylint: disable=protected-access 

621 if has_strategy(): 

622 raise RuntimeError( 

623 "Must not be called inside a `tf.distribute.Strategy` scope.") 

624 if strategy is not None: 

625 new_scope = strategy.scope() 

626 new_scope.__enter__() 

627 ops.get_default_graph()._global_distribute_strategy_scope = new_scope # pylint: disable=protected-access 

628 

629 

630# ------------------------------------------------------------------------------ 

631# Internal helpers. 

632 

633 

634@contextlib.contextmanager 

635def enter_or_assert_strategy(strategy): 

636 if has_strategy(): 

637 _assert_strategy(strategy) 

638 yield 

639 else: 

640 with strategy.scope(): 

641 yield 

642 

643 

644# ------------------------------------------------------------------------------ 

645# Defaults that are used when no tf.distribute.Strategy is explicitly created. 

646# We create them lazily in a function so that we can workaround the circular 

647# dependency on distribute_lib. See lazy loader at the top of this file. 

648 

649_defaults = { 

650 "strategy": None, 

651 "replica_context": None, 

652 "replica_mode": None 

653} 

654# Note: These need to be different locks since _get_default_replica_context 

655# calls _get_default_strategy inside its lock, and them using the same lock 

656# can lead to deadlock. 

657_default_strategy_lock = threading.Lock() 

658_default_replica_context_lock = threading.Lock() 

659_default_replica_mode_lock = threading.Lock() 

660 

661 

662def _assert_strategy(strategy): 

663 if not has_strategy(): 

664 raise RuntimeError('Need to be inside "with strategy.scope()" for %s' % 

665 (strategy,)) 

666 current_strategy = get_strategy() 

667 if current_strategy is not strategy: 

668 raise RuntimeError( 

669 "Mixing different tf.distribute.Strategy objects: %s is not %s" % 

670 (current_strategy, strategy)) 

671 

672 

673def _get_default_strategy(): 

674 if _defaults["strategy"] is None: 

675 # Avoid race condition causing two defaults to be created 

676 with _default_strategy_lock: 

677 if _defaults["strategy"] is None: 

678 # pylint: disable=protected-access 

679 # Make sure distribute_lib module is loaded by accessing some member. 

680 global _creating_default_strategy_singleton 

681 _creating_default_strategy_singleton = True 

682 if tf2.enabled(): 

683 _defaults["strategy"] = _DefaultDistributionStrategy() 

684 else: 

685 _defaults["strategy"] = ( 

686 _DefaultDistributionStrategyV1()) 

687 _creating_default_strategy_singleton = False 

688 # pylint: enable=protected-access 

689 return _defaults["strategy"] 

690 

691 

692def _get_default_replica_context(): 

693 if _defaults["replica_context"] is None: 

694 # Avoid race condition causing two defaults to be created 

695 with _default_replica_context_lock: 

696 if _defaults["replica_context"] is None: 

697 # pylint: disable=protected-access 

698 _defaults["replica_context"] = _DefaultReplicaContext( 

699 _get_default_strategy(), replica_id_in_sync_group=0) 

700 # pylint: enable=protected-access 

701 return _defaults["replica_context"] 

702 

703 

704def _get_default_replica_mode(): 

705 if _defaults["replica_mode"] is None: 

706 # Avoid race condition causing two defaults to be created 

707 with _default_replica_mode_lock: 

708 if _defaults["replica_mode"] is None: 

709 _defaults["replica_mode"] = _DefaultReplicaThreadMode() 

710 return _defaults["replica_mode"] 

711 

712 

713# Aliases for compatibility with old names. 

714get_distribution_strategy = get_strategy 

715has_distribution_strategy = has_strategy 

716 

717 

718# ------------------------------------------------------------------------------ 

719# Internal context managers used to implement the DistributionStrategy 

720# base class 

721 

722 

723class _CurrentDistributionContext(object): 

724 """Context manager setting the current `tf.distribute.Strategy`. 

725 

726 Also: overrides the variable creator and optionally the current device. 

727 """ 

728 

729 def __init__(self, 

730 strategy, 

731 var_creator_scope, 

732 var_scope=None, 

733 resource_creator_scope=None, 

734 default_device=None): 

735 self._context = _CrossReplicaThreadMode( # pylint: disable=protected-access 

736 strategy) 

737 self._var_creator_scope = var_creator_scope 

738 self._var_scope = var_scope 

739 self._resource_creator_scope = resource_creator_scope 

740 if default_device: 

741 self._device_scope = ops.device(default_device) 

742 else: 

743 self._device_scope = None 

744 self._same_scope_again_count = 0 

745 

746 def __enter__(self): 

747 # Allow this scope to be entered if this strategy is already in scope. 

748 if has_strategy(): 

749 _require_cross_replica_or_default_context_extended( 

750 self._context.strategy.extended) 

751 self._same_scope_again_count += 1 

752 else: 

753 _push_per_thread_mode(self._context) 

754 if self._var_scope: 

755 self._var_scope.__enter__() 

756 self._var_creator_scope.__enter__() 

757 if self._resource_creator_scope: 

758 nest.map_structure(lambda scope: scope.__enter__(), 

759 self._resource_creator_scope) 

760 if self._device_scope: 

761 self._device_scope.__enter__() 

762 return self._context.strategy 

763 

764 def __exit__(self, exception_type, exception_value, traceback): 

765 if self._same_scope_again_count > 0: 

766 self._same_scope_again_count -= 1 

767 return 

768 if self._device_scope: 

769 try: 

770 self._device_scope.__exit__(exception_type, exception_value, traceback) 

771 except RuntimeError as e: 

772 six.raise_from( 

773 RuntimeError("Device scope nesting error: move call to " 

774 "tf.distribute.set_strategy() out of `with` scope."), 

775 e) 

776 

777 try: 

778 self._var_creator_scope.__exit__( 

779 exception_type, exception_value, traceback) 

780 except RuntimeError as e: 

781 six.raise_from( 

782 RuntimeError("Variable creator scope nesting error: move call to " 

783 "tf.distribute.set_strategy() out of `with` scope."), 

784 e) 

785 

786 if self._resource_creator_scope: 

787 try: 

788 if isinstance(self._resource_creator_scope, list): 

789 reversed_resource_creator_scope = self._resource_creator_scope[::-1] 

790 nest.map_structure( 

791 lambda scope: scope.__exit__(exception_type, exception_value, # pylint:disable=g-long-lambda 

792 traceback), 

793 reversed_resource_creator_scope) 

794 

795 else: 

796 self._resource_creator_scope.__exit__(exception_type, exception_value, 

797 traceback) 

798 except RuntimeError as e: 

799 six.raise_from( 

800 RuntimeError("Resource creator scope nesting error: move call " 

801 "to tf.distribute.set_strategy() out of `with` " 

802 "scope."), e) 

803 

804 if self._var_scope: 

805 try: 

806 self._var_scope.__exit__(exception_type, exception_value, traceback) 

807 except RuntimeError as e: 

808 six.raise_from( 

809 RuntimeError("Variable scope nesting error: move call to " 

810 "tf.distribute.set_strategy() out of `with` scope."), 

811 e) 

812 _pop_per_thread_mode() 

813 

814 

815# TODO(yuefengz): add more replication modes. 

816@tf_export("distribute.InputReplicationMode") 

817class InputReplicationMode(enum.Enum): 

818 """Replication mode for input function. 

819 

820 * `PER_WORKER`: The input function will be called on each worker 

821 independently, creating as many input pipelines as number of workers. 

822 Replicas will dequeue from the local Dataset on their worker. 

823 `tf.distribute.Strategy` doesn't manage any state sharing between such 

824 separate input pipelines. 

825 * `PER_REPLICA`: The input function will be called on each replica separately. 

826 `tf.distribute.Strategy` doesn't manage any state sharing between such 

827 separate input pipelines. 

828 """ 

829 PER_WORKER = "PER_WORKER" 

830 PER_REPLICA = "PER_REPLICA" 

831 

832 

833@tf_export("distribute.InputContext") 

834class InputContext(object): 

835 """A class wrapping information needed by an input function. 

836 

837 This is a context class that is passed to the user's input function and 

838 contains information about the compute replicas and input pipelines. The 

839 number of compute replicas (in sync training) helps compute the local batch 

840 size from the desired global batch size for each replica. The input pipeline 

841 information can be used to return a different subset of the input in each 

842 replica (for e.g. shard the input pipeline, use a different input 

843 source etc). 

844 """ 

845 

846 __slots__ = [ 

847 "_num_input_pipelines", "_input_pipeline_id", "_num_replicas_in_sync" 

848 ] 

849 

850 def __init__(self, 

851 num_input_pipelines=1, 

852 input_pipeline_id=0, 

853 num_replicas_in_sync=1): 

854 """Initializes an InputContext object. 

855 

856 Args: 

857 num_input_pipelines: the number of input pipelines in a cluster. 

858 input_pipeline_id: the current input pipeline id, should be an int in 

859 [0,`num_input_pipelines`). 

860 num_replicas_in_sync: the number of replicas that are in sync. 

861 """ 

862 self._num_input_pipelines = num_input_pipelines 

863 self._input_pipeline_id = input_pipeline_id 

864 self._num_replicas_in_sync = num_replicas_in_sync 

865 

866 @property 

867 def num_replicas_in_sync(self): 

868 """Returns the number of compute replicas in sync.""" 

869 return self._num_replicas_in_sync 

870 

871 @property 

872 def input_pipeline_id(self): 

873 """Returns the input pipeline ID.""" 

874 return self._input_pipeline_id 

875 

876 @property 

877 def num_input_pipelines(self): 

878 """Returns the number of input pipelines.""" 

879 return self._num_input_pipelines 

880 

881 def get_per_replica_batch_size(self, global_batch_size): 

882 """Returns the per-replica batch size. 

883 

884 Args: 

885 global_batch_size: the global batch size which should be divisible by 

886 `num_replicas_in_sync`. 

887 

888 Returns: 

889 the per-replica batch size. 

890 

891 Raises: 

892 ValueError: if `global_batch_size` not divisible by 

893 `num_replicas_in_sync`. 

894 """ 

895 if global_batch_size % self._num_replicas_in_sync != 0: 

896 raise ValueError("The `global_batch_size` %r is not divisible by " 

897 "`num_replicas_in_sync` %r " % 

898 (global_batch_size, self._num_replicas_in_sync)) 

899 return global_batch_size // self._num_replicas_in_sync 

900 

901 def __str__(self): 

902 return "tf.distribute.InputContext(input pipeline id {}, total: {})".format( 

903 self.input_pipeline_id, self.num_input_pipelines) 

904 

905 

906@tf_export("distribute.experimental.ValueContext", v1=[]) 

907class ValueContext(object): 

908 """A class wrapping information needed by a distribute function. 

909 

910 This is a context class that is passed to the `value_fn` in 

911 `strategy.experimental_distribute_values_from_function` and contains 

912 information about the compute replicas. The `num_replicas_in_sync` and 

913 `replica_id` can be used to customize the value on each replica. 

914 

915 Example usage: 

916 

917 1. Directly constructed. 

918 

919 >>> def value_fn(context): 

920 ... return context.replica_id_in_sync_group/context.num_replicas_in_sync 

921 >>> context = tf.distribute.experimental.ValueContext( 

922 ... replica_id_in_sync_group=2, num_replicas_in_sync=4) 

923 >>> per_replica_value = value_fn(context) 

924 >>> per_replica_value 

925 0.5 

926 

927 2. Passed in by `experimental_distribute_values_from_function`. {: value=2} 

928 

929 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 

930 >>> def value_fn(value_context): 

931 ... return value_context.num_replicas_in_sync 

932 >>> distributed_values = ( 

933 ... strategy.experimental_distribute_values_from_function( 

934 ... value_fn)) 

935 >>> local_result = strategy.experimental_local_results(distributed_values) 

936 >>> local_result 

937 (2, 2) 

938 

939 """ 

940 

941 __slots__ = ["_replica_id_in_sync_group", "_num_replicas_in_sync"] 

942 

943 def __init__(self, 

944 replica_id_in_sync_group=0, 

945 num_replicas_in_sync=1): 

946 """Initializes an ValueContext object. 

947 

948 Args: 

949 replica_id_in_sync_group: the current replica_id, should be an int in 

950 [0,`num_replicas_in_sync`). 

951 num_replicas_in_sync: the number of replicas that are in sync. 

952 """ 

953 self._replica_id_in_sync_group = replica_id_in_sync_group 

954 self._num_replicas_in_sync = num_replicas_in_sync 

955 

956 @property 

957 def num_replicas_in_sync(self): 

958 """Returns the number of compute replicas in sync.""" 

959 return self._num_replicas_in_sync 

960 

961 @property 

962 def replica_id_in_sync_group(self): 

963 """Returns the replica ID.""" 

964 return self._replica_id_in_sync_group 

965 

966 def __str__(self): 

967 return (("tf.distribute.ValueContext(replica id {}, " 

968 " total replicas in sync: ""{})") 

969 .format(self.replica_id_in_sync_group, self.num_replicas_in_sync)) 

970 

971 

972@tf_export("distribute.RunOptions") 

973class RunOptions( 

974 collections.namedtuple("RunOptions", [ 

975 "experimental_enable_dynamic_batch_size", 

976 "experimental_bucketizing_dynamic_shape", 

977 "experimental_xla_options", 

978 ])): 

979 """Run options for `strategy.run`. 

980 

981 This can be used to hold some strategy specific configs. 

982 

983 Attributes: 

984 experimental_enable_dynamic_batch_size: Boolean. Only applies to 

985 TPUStrategy. Default to True. If True, TPUStrategy will enable dynamic 

986 padder to support dynamic batch size for the inputs. Otherwise only static 

987 shape inputs are allowed. 

988 experimental_bucketizing_dynamic_shape: Boolean. Only applies to 

989 TPUStrategy. Default to False. If True, TPUStrategy will automatic 

990 bucketize inputs passed into `run` if the input shape is 

991 dynamic. This is a performance optimization to reduce XLA recompilation, 

992 which should not have impact on correctness. 

993 experimental_xla_options: A `tf.tpu.XLAOptions` instance. Only applies to 

994 TPUStrategy. Controls the XLA compiling options on TPUs. Default to None. 

995 """ 

996 

997 def __new__(cls, 

998 experimental_enable_dynamic_batch_size=True, 

999 experimental_bucketizing_dynamic_shape=False, 

1000 experimental_xla_options=None): 

1001 return super(RunOptions, 

1002 cls).__new__(cls, experimental_enable_dynamic_batch_size, 

1003 experimental_bucketizing_dynamic_shape, 

1004 experimental_xla_options) 

1005 

1006 

1007@tf_export("distribute.InputOptions", v1=[]) 

1008class InputOptions( 

1009 collections.namedtuple("InputOptions", [ 

1010 "experimental_fetch_to_device", 

1011 "experimental_replication_mode", 

1012 "experimental_place_dataset_on_device", 

1013 "experimental_per_replica_buffer_size", 

1014 ])): 

1015 """Run options for `experimental_distribute_dataset(s_from_function)`. 

1016 

1017 This can be used to hold some strategy specific configs. 

1018 

1019 ```python 

1020 # Setup TPUStrategy 

1021 resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 

1022 tf.config.experimental_connect_to_cluster(resolver) 

1023 tf.tpu.experimental.initialize_tpu_system(resolver) 

1024 strategy = tf.distribute.TPUStrategy(resolver) 

1025 

1026 dataset = tf.data.Dataset.range(16) 

1027 distributed_dataset_on_host = ( 

1028 strategy.experimental_distribute_dataset( 

1029 dataset, 

1030 tf.distribute.InputOptions( 

1031 experimental_replication_mode= 

1032 experimental_replication_mode.PER_WORKER, 

1033 experimental_place_dataset_on_device=False, 

1034 experimental_per_replica_buffer_size=1))) 

1035 ``` 

1036 

1037 Attributes: 

1038 experimental_fetch_to_device: Boolean. If True, dataset 

1039 elements will be prefetched to accelerator device memory. When False, 

1040 dataset elements are prefetched to host device memory. Must be False when 

1041 using TPUEmbedding API. experimental_fetch_to_device can only be used 

1042 with experimental_replication_mode=PER_WORKER. Default behavior is same as 

1043 setting it to True. 

1044 experimental_replication_mode: Replication mode for the input function. 

1045 Currently, the InputReplicationMode.PER_REPLICA is only supported with 

1046 tf.distribute.MirroredStrategy. 

1047 experimental_distribute_datasets_from_function. 

1048 The default value is InputReplicationMode.PER_WORKER. 

1049 experimental_place_dataset_on_device: Boolean. Default to False. When True, 

1050 dataset will be placed on the device, otherwise it will remain on the 

1051 host. experimental_place_dataset_on_device=True can only be used with 

1052 experimental_replication_mode=PER_REPLICA 

1053 experimental_per_replica_buffer_size: Integer. Default to 1. Indicates the 

1054 prefetch buffer size in the replica device memory. Users can set it 

1055 to 0 to completely disable prefetching behavior, or a number greater than 

1056 1 to enable larger buffer size. Note that this option is still 

1057 valid with `experimental_fetch_to_device=False`. 

1058 """ 

1059 

1060 def __new__(cls, 

1061 experimental_fetch_to_device=None, 

1062 experimental_replication_mode=InputReplicationMode.PER_WORKER, 

1063 experimental_place_dataset_on_device=False, 

1064 experimental_per_replica_buffer_size=1): 

1065 if experimental_fetch_to_device is None: 

1066 experimental_fetch_to_device = True 

1067 

1068 return super(InputOptions, 

1069 cls).__new__(cls, experimental_fetch_to_device, 

1070 experimental_replication_mode, 

1071 experimental_place_dataset_on_device, 

1072 experimental_per_replica_buffer_size) 

1073 

1074# ------------------------------------------------------------------------------ 

1075# Base classes for all distribution strategies. 

1076 

1077 

1078# Base class for v1 Strategy and v2 Strategy classes. For API's specific to 

1079# v1/v2 Strategy, add to implementing classes of StrategyBase. 

1080# pylint: disable=line-too-long 

1081class StrategyBase(object): 

1082 """A state & compute distribution policy on a list of devices. 

1083 

1084 See [the guide](https://www.tensorflow.org/guide/distributed_training) 

1085 for overview and examples. See `tf.distribute.StrategyExtended` and 

1086 [`tf.distribute`](https://www.tensorflow.org/api_docs/python/tf/distribute) 

1087 for a glossary of concepts mentioned on this page such as "per-replica", 

1088 _replica_, and _reduce_. 

1089 

1090 In short: 

1091 

1092 * To use it with Keras `compile`/`fit`, 

1093 [please 

1094 read](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_keras). 

1095 * You may pass descendant of `tf.distribute.Strategy` to 

1096 `tf.estimator.RunConfig` to specify how a `tf.estimator.Estimator` 

1097 should distribute its computation. See 

1098 [guide](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_estimator_limited_support). 

1099 * Otherwise, use `tf.distribute.Strategy.scope` to specify that a 

1100 strategy should be used when building an executing your model. 

1101 (This puts you in the "cross-replica context" for this strategy, which 

1102 means the strategy is put in control of things like variable placement.) 

1103 * If you are writing a custom training loop, you will need to call a few more 

1104 methods, 

1105 [see the 

1106 guide](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_custom_training_loops): 

1107 

1108 * Start by creating a `tf.data.Dataset` normally. 

1109 * Use `tf.distribute.Strategy.experimental_distribute_dataset` to convert 

1110 a `tf.data.Dataset` to something that produces "per-replica" values. 

1111 If you want to manually specify how the dataset should be partitioned 

1112 across replicas, use 

1113 `tf.distribute.Strategy.distribute_datasets_from_function` 

1114 instead. 

1115 * Use `tf.distribute.Strategy.run` to run a function 

1116 once per replica, taking values that may be "per-replica" (e.g. 

1117 from a `tf.distribute.DistributedDataset` object) and returning 

1118 "per-replica" values. 

1119 This function is executed in "replica context", which means each 

1120 operation is performed separately on each replica. 

1121 * Finally use a method (such as `tf.distribute.Strategy.reduce`) to 

1122 convert the resulting "per-replica" values into ordinary `Tensor`s. 

1123 

1124 A custom training loop can be as simple as: 

1125 

1126 ``` 

1127 with my_strategy.scope(): 

1128 @tf.function 

1129 def distribute_train_epoch(dataset): 

1130 def replica_fn(input): 

1131 # process input and return result 

1132 return result 

1133 

1134 total_result = 0 

1135 for x in dataset: 

1136 per_replica_result = my_strategy.run(replica_fn, args=(x,)) 

1137 total_result += my_strategy.reduce(tf.distribute.ReduceOp.SUM, 

1138 per_replica_result, axis=None) 

1139 return total_result 

1140 

1141 dist_dataset = my_strategy.experimental_distribute_dataset(dataset) 

1142 for _ in range(EPOCHS): 

1143 train_result = distribute_train_epoch(dist_dataset) 

1144 ``` 

1145 

1146 This takes an ordinary `dataset` and `replica_fn` and runs it 

1147 distributed using a particular `tf.distribute.Strategy` named 

1148 `my_strategy` above. Any variables created in `replica_fn` are created 

1149 using `my_strategy`'s policy, and library functions called by 

1150 `replica_fn` can use the `get_replica_context()` API to implement 

1151 distributed-specific behavior. 

1152 

1153 You can use the `reduce` API to aggregate results across replicas and use 

1154 this as a return value from one iteration over a 

1155 `tf.distribute.DistributedDataset`. Or 

1156 you can use `tf.keras.metrics` (such as loss, accuracy, etc.) to 

1157 accumulate metrics across steps in a given epoch. 

1158 

1159 See the 

1160 [custom training loop 

1161 tutorial](https://www.tensorflow.org/tutorials/distribute/custom_training) 

1162 for a more detailed example. 

1163 

1164 Note: `tf.distribute.Strategy` currently does not support TensorFlow's 

1165 partitioned variables (where a single variable is split across multiple 

1166 devices) at this time. 

1167 """ 

1168 # pylint: enable=line-too-long 

1169 

1170 # TODO(josh11b): Partitioned computations, state; sharding 

1171 # TODO(josh11b): Model parallelism: "replicas" with multiple devices; shuffling 

1172 

1173 def __init__(self, extended): 

1174 self._extended = extended 

1175 

1176 # Flag that is used to indicate whether distribution strategy is used with 

1177 # Estimator. This is required for backward compatibility of loss scaling 

1178 # when using v1 optimizer with estimator. 

1179 self._scale_loss_for_estimator = False 

1180 

1181 if not hasattr(extended, "_retrace_functions_for_each_device"): 

1182 # pylint: disable=protected-access 

1183 # `extended._retrace_functions_for_each_device` dictates 

1184 # whether the same function will be retraced when it is called on 

1185 # different devices. 

1186 try: 

1187 extended._retrace_functions_for_each_device = ( 

1188 len(extended.worker_devices) > 1) 

1189 distribution_strategy_replica_gauge.get_cell("num_replicas").set( 

1190 self.num_replicas_in_sync) 

1191 except: # pylint: disable=bare-except 

1192 # Default for the case where extended.worker_devices can't return 

1193 # a sensible value. 

1194 extended._retrace_functions_for_each_device = True 

1195 

1196 # Below are the dicts of axis(int) -> `tf.function`. 

1197 self._mean_reduce_helper_fns = {} 

1198 self._reduce_sum_fns = {} 

1199 

1200 # Whether this strategy is designed to work with `ClusterCoordinator`. 

1201 self._should_use_with_coordinator = False 

1202 

1203 @property 

1204 def extended(self): 

1205 """`tf.distribute.StrategyExtended` with additional methods.""" 

1206 return self._extended 

1207 

1208 @tf_contextlib.contextmanager 

1209 def _scale_loss_for_estimator_enabled(self): 

1210 """Scope which sets a flag used for scaling losses in optimizer. 

1211 

1212 Yields: 

1213 `_scale_loss_for_estimator_enabled` is a context manager with a 

1214 side effect, but doesn't return a value. 

1215 """ 

1216 self._scale_loss_for_estimator = True 

1217 try: 

1218 yield 

1219 finally: 

1220 self._scale_loss_for_estimator = False 

1221 

1222 # pylint: disable=line-too-long 

1223 def scope(self): 

1224 """Context manager to make the strategy current and distribute variables. 

1225 

1226 This method returns a context manager, and is used as follows: 

1227 

1228 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 

1229 >>> # Variable created inside scope: 

1230 >>> with strategy.scope(): 

1231 ... mirrored_variable = tf.Variable(1.) 

1232 >>> mirrored_variable 

1233 MirroredVariable:{ 

1234 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>, 

1235 1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0> 

1236 } 

1237 >>> # Variable created outside scope: 

1238 >>> regular_variable = tf.Variable(1.) 

1239 >>> regular_variable 

1240 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0> 

1241 

1242 _What happens when Strategy.scope is entered?_ 

1243 

1244 * `strategy` is installed in the global context as the "current" strategy. 

1245 Inside this scope, `tf.distribute.get_strategy()` will now return this 

1246 strategy. Outside this scope, it returns the default no-op strategy. 

1247 * Entering the scope also enters the "cross-replica context". See 

1248 `tf.distribute.StrategyExtended` for an explanation on cross-replica and 

1249 replica contexts. 

1250 * Variable creation inside `scope` is intercepted by the strategy. Each 

1251 strategy defines how it wants to affect the variable creation. Sync 

1252 strategies like `MirroredStrategy`, `TPUStrategy` and 

1253 `MultiWorkerMiroredStrategy` create variables replicated on each replica, 

1254 whereas `ParameterServerStrategy` creates variables on the parameter 

1255 servers. This is done using a custom `tf.variable_creator_scope`. 

1256 * In some strategies, a default device scope may also be entered: in 

1257 `MultiWorkerMiroredStrategy`, a default device scope of "/CPU:0" is 

1258 entered on each worker. 

1259 

1260 Note: Entering a scope does not automatically distribute a computation, except 

1261 in the case of high level training framework like keras `model.fit`. If 

1262 you're not using `model.fit`, you 

1263 need to use `strategy.run` API to explicitly distribute that computation. 

1264 See an example in the [custom training loop tutorial](https://www.tensorflow.org/tutorials/distribute/custom_training). 

1265 

1266 

1267 _What should be in scope and what should be outside?_ 

1268 

1269 There are a number of requirements on what needs to happen inside the scope. 

1270 However, in places where we have information about which strategy is in use, 

1271 we often enter the scope for the user, so they don't have to do it 

1272 explicitly (i.e. calling those either inside or outside the scope is OK). 

1273 

1274 * Anything that creates variables that should be distributed variables 

1275 must be called in a `strategy.scope`. This can be accomplished either by 

1276 directly calling the variable creating function within the scope context, 

1277 or by relying on another API like `strategy.run` or `keras.Model.fit` to 

1278 automatically enter it for you. Any variable that is created outside scope 

1279 will not be distributed and may have performance implications. Some common 

1280 objects that create variables in TF are Models, Optimizers, Metrics. Such 

1281 objects should always be initialized in the scope, and any functions 

1282 that may lazily create variables (e.g., `Model.__call__()`, tracing a 

1283 `tf.function`, etc.) should similarly be called within scope. Another 

1284 source of variable creation can be a checkpoint restore - when variables 

1285 are created lazily. Note that any variable created inside a strategy 

1286 captures the strategy information. So reading and writing to these 

1287 variables outside the `strategy.scope` can also work seamlessly, without 

1288 the user having to enter the scope. 

1289 * Some strategy APIs (such as `strategy.run` and `strategy.reduce`) which 

1290 require to be in a strategy's scope, enter the scope automatically, which 

1291 means when using those APIs you don't need to explicitly enter the scope 

1292 yourself. 

1293 * When a `tf.keras.Model` is created inside a `strategy.scope`, the Model 

1294 object captures the scope information. When high level training framework 

1295 methods such as `model.compile`, `model.fit`, etc. are then called, the 

1296 captured scope will be automatically entered, and the associated strategy 

1297 will be used to distribute the training etc. See a detailed example in 

1298 [distributed keras tutorial](https://www.tensorflow.org/tutorials/distribute/keras). 

1299 WARNING: Simply calling `model(..)` does not automatically enter the 

1300 captured scope -- only high level training framework APIs support this 

1301 behavior: `model.compile`, `model.fit`, `model.evaluate`, `model.predict` 

1302 and `model.save` can all be called inside or outside the scope. 

1303 * The following can be either inside or outside the scope: 

1304 * Creating the input datasets 

1305 * Defining `tf.function`s that represent your training step 

1306 * Saving APIs such as `tf.saved_model.save`. Loading creates variables, 

1307 so that should go inside the scope if you want to train the model in a 

1308 distributed way. 

1309 * Checkpoint saving. As mentioned above - `checkpoint.restore` may 

1310 sometimes need to be inside scope if it creates variables. 

1311 

1312 Returns: 

1313 A context manager. 

1314 """ 

1315 return self._extended._scope(self) # pylint: disable=protected-access 

1316 # pylint: enable=line-too-long 

1317 

1318 @doc_controls.do_not_doc_inheritable # DEPRECATED, moving to `extended` 

1319 @deprecated(None, "use extended.colocate_vars_with() instead.") 

1320 def colocate_vars_with(self, colocate_with_variable): 

1321 """DEPRECATED: use extended.colocate_vars_with() instead.""" 

1322 return self._extended.colocate_vars_with(colocate_with_variable) 

1323 

1324 @doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only 

1325 def make_dataset_iterator(self, dataset): 

1326 """DEPRECATED TF 1.x ONLY.""" 

1327 return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access 

1328 

1329 @doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only 

1330 def make_input_fn_iterator(self, 

1331 input_fn, 

1332 replication_mode=InputReplicationMode.PER_WORKER): 

1333 """DEPRECATED TF 1.x ONLY.""" 

1334 if replication_mode != InputReplicationMode.PER_WORKER: 

1335 raise ValueError( 

1336 "Input replication mode not supported: %r" % replication_mode) 

1337 with self.scope(): 

1338 return self.extended._make_input_fn_iterator( # pylint: disable=protected-access 

1339 input_fn, replication_mode=replication_mode) 

1340 

1341 @doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only 

1342 @deprecated(None, "use run() instead") 

1343 def experimental_run(self, fn, input_iterator=None): 

1344 """DEPRECATED TF 1.x ONLY.""" 

1345 with self.scope(): 

1346 args = (input_iterator.get_next(),) if input_iterator is not None else () 

1347 return self.run(fn, args=args) 

1348 

1349 def experimental_distribute_dataset(self, dataset, options=None): 

1350 # pylint: disable=line-too-long 

1351 """Creates `tf.distribute.DistributedDataset` from `tf.data.Dataset`. 

1352 

1353 The returned `tf.distribute.DistributedDataset` can be iterated over 

1354 similar to regular datasets. 

1355 NOTE: The user cannot add any more transformations to a 

1356 `tf.distribute.DistributedDataset`. You can only create an iterator or 

1357 examine the `tf.TypeSpec` of the data generated by it. See API docs of 

1358 `tf.distribute.DistributedDataset` to learn more. 

1359 

1360 The following is an example: 

1361 

1362 >>> global_batch_size = 2 

1363 >>> # Passing the devices is optional. 

1364 ... strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"]) 

1365 >>> # Create a dataset 

1366 ... dataset = tf.data.Dataset.range(4).batch(global_batch_size) 

1367 >>> # Distribute that dataset 

1368 ... dist_dataset = strategy.experimental_distribute_dataset(dataset) 

1369 >>> @tf.function 

1370 ... def replica_fn(input): 

1371 ... return input*2 

1372 >>> result = [] 

1373 >>> # Iterate over the `tf.distribute.DistributedDataset` 

1374 ... for x in dist_dataset: 

1375 ... # process dataset elements 

1376 ... result.append(strategy.run(replica_fn, args=(x,))) 

1377 >>> print(result) 

1378 [PerReplica:{ 

1379 0: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([0])>, 

1380 1: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([2])> 

1381 }, PerReplica:{ 

1382 0: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([4])>, 

1383 1: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([6])> 

1384 }] 

1385 

1386 

1387 Three key actions happening under the hood of this method are batching, 

1388 sharding, and prefetching. 

1389 

1390 In the code snippet above, `dataset` is batched by `global_batch_size`, and 

1391 calling `experimental_distribute_dataset` on it rebatches `dataset` to a 

1392 new batch size that is equal to the global batch size divided by the number 

1393 of replicas in sync. We iterate through it using a Pythonic for loop. 

1394 `x` is a `tf.distribute.DistributedValues` containing data for all replicas, 

1395 and each replica gets data of the new batch size. 

1396 `tf.distribute.Strategy.run` will take care of feeding the right per-replica 

1397 data in `x` to the right `replica_fn` executed on each replica. 

1398 

1399 Sharding contains autosharding across multiple workers and within every 

1400 worker. First, in multi-worker distributed training (i.e. when you use 

1401 `tf.distribute.experimental.MultiWorkerMirroredStrategy` 

1402 or `tf.distribute.TPUStrategy`), autosharding a dataset over a set of 

1403 workers means that each worker is assigned a subset of the entire dataset 

1404 (if the right `tf.data.experimental.AutoShardPolicy` is set). This is to 

1405 ensure that at each step, a global batch size of non-overlapping dataset 

1406 elements will be processed by each worker. Autosharding has a couple of 

1407 different options that can be specified using 

1408 `tf.data.experimental.DistributeOptions`. Then, sharding within each worker 

1409 means the method will split the data among all the worker devices (if more 

1410 than one a present). This will happen regardless of multi-worker 

1411 autosharding. 

1412 

1413 Note: for autosharding across multiple workers, the default mode is 

1414 `tf.data.experimental.AutoShardPolicy.AUTO`. This mode 

1415 will attempt to shard the input dataset by files if the dataset is 

1416 being created out of reader datasets (e.g. `tf.data.TFRecordDataset`, 

1417 `tf.data.TextLineDataset`, etc.) or otherwise shard the dataset by data, 

1418 where each of the workers will read the entire dataset and only process the 

1419 shard assigned to it. However, if you have less than one input file per 

1420 worker, we suggest that you disable dataset autosharding across workers by 

1421 setting the `tf.data.experimental.DistributeOptions.auto_shard_policy` to be 

1422 `tf.data.experimental.AutoShardPolicy.OFF`. 

1423 

1424 By default, this method adds a prefetch transformation at the end of the 

1425 user provided `tf.data.Dataset` instance. The argument to the prefetch 

1426 transformation which is `buffer_size` is equal to the number of replicas in 

1427 sync. 

1428 

1429 If the above batch splitting and dataset sharding logic is undesirable, 

1430 please use 

1431 `tf.distribute.Strategy.distribute_datasets_from_function` 

1432 instead, which does not do any automatic batching or sharding for you. 

1433 

1434 Note: If you are using TPUStrategy, the order in which the data is processed 

1435 by the workers when using 

1436 `tf.distribute.Strategy.experimental_distribute_dataset` or 

1437 `tf.distribute.Strategy.distribute_datasets_from_function` is 

1438 not guaranteed. This is typically required if you are using 

1439 `tf.distribute` to scale prediction. You can however insert an index for 

1440 each element in the batch and order outputs accordingly. Refer to [this 

1441 snippet](https://www.tensorflow.org/tutorials/distribute/input#caveats) 

1442 for an example of how to order outputs. 

1443 

1444 Note: Stateful dataset transformations are currently not supported with 

1445 `tf.distribute.experimental_distribute_dataset` or 

1446 `tf.distribute.distribute_datasets_from_function`. Any stateful 

1447 ops that the dataset may have are currently ignored. For example, if your 

1448 dataset has a `map_fn` that uses `tf.random.uniform` to rotate an image, 

1449 then you have a dataset graph that depends on state (i.e the random seed) on 

1450 the local machine where the python process is being executed. 

1451 

1452 For a tutorial on more usage and properties of this method, refer to the 

1453 [tutorial on distributed input](https://www.tensorflow.org/tutorials/distribute/input#tfdistributestrategyexperimental_distribute_dataset). 

1454 If you are interested in last partial batch handling, read [this section](https://www.tensorflow.org/tutorials/distribute/input#partial_batches). 

1455 

1456 Args: 

1457 dataset: `tf.data.Dataset` that will be sharded across all replicas using 

1458 the rules stated above. 

1459 options: `tf.distribute.InputOptions` used to control options on how this 

1460 dataset is distributed. 

1461 

1462 Returns: 

1463 A `tf.distribute.DistributedDataset`. 

1464 """ 

1465 distribution_strategy_input_api_counter.get_cell( 

1466 self.__class__.__name__, "distribute_dataset").increase_by(1) 

1467 # pylint: enable=line-too-long 

1468 return self._extended._experimental_distribute_dataset(dataset, options) # pylint: disable=protected-access 

1469 

1470 def distribute_datasets_from_function(self, dataset_fn, options=None): 

1471 # pylint: disable=line-too-long 

1472 """Distributes `tf.data.Dataset` instances created by calls to `dataset_fn`. 

1473 

1474 The argument `dataset_fn` that users pass in is an input function that has a 

1475 `tf.distribute.InputContext` argument and returns a `tf.data.Dataset` 

1476 instance. It is expected that the returned dataset from `dataset_fn` is 

1477 already batched by per-replica batch size (i.e. global batch size divided by 

1478 the number of replicas in sync) and sharded. 

1479 `tf.distribute.Strategy.distribute_datasets_from_function` does 

1480 not batch or shard the `tf.data.Dataset` instance 

1481 returned from the input function. `dataset_fn` will be called on the CPU 

1482 device of each of the workers and each generates a dataset where every 

1483 replica on that worker will dequeue one batch of inputs (i.e. if a worker 

1484 has two replicas, two batches will be dequeued from the `Dataset` every 

1485 step). 

1486 

1487 This method can be used for several purposes. First, it allows you to 

1488 specify your own batching and sharding logic. (In contrast, 

1489 `tf.distribute.experimental_distribute_dataset` does batching and sharding 

1490 for you.) For example, where 

1491 `experimental_distribute_dataset` is unable to shard the input files, this 

1492 method might be used to manually shard the dataset (avoiding the slow 

1493 fallback behavior in `experimental_distribute_dataset`). In cases where the 

1494 dataset is infinite, this sharding can be done by creating dataset replicas 

1495 that differ only in their random seed. 

1496 

1497 The `dataset_fn` should take an `tf.distribute.InputContext` instance where 

1498 information about batching and input replication can be accessed. 

1499 

1500 You can use `element_spec` property of the 

1501 `tf.distribute.DistributedDataset` returned by this API to query the 

1502 `tf.TypeSpec` of the elements returned by the iterator. This can be used to 

1503 set the `input_signature` property of a `tf.function`. Follow 

1504 `tf.distribute.DistributedDataset.element_spec` to see an example. 

1505 

1506 IMPORTANT: The `tf.data.Dataset` returned by `dataset_fn` should have a 

1507 per-replica batch size, unlike `experimental_distribute_dataset`, which uses 

1508 the global batch size. This may be computed using 

1509 `input_context.get_per_replica_batch_size`. 

1510 

1511 Note: If you are using TPUStrategy, the order in which the data is processed 

1512 by the workers when using 

1513 `tf.distribute.Strategy.experimental_distribute_dataset` or 

1514 `tf.distribute.Strategy.distribute_datasets_from_function` is 

1515 not guaranteed. This is typically required if you are using 

1516 `tf.distribute` to scale prediction. You can however insert an index for 

1517 each element in the batch and order outputs accordingly. Refer to [this 

1518 snippet](https://www.tensorflow.org/tutorials/distribute/input#caveats) 

1519 for an example of how to order outputs. 

1520 

1521 Note: Stateful dataset transformations are currently not supported with 

1522 `tf.distribute.experimental_distribute_dataset` or 

1523 `tf.distribute.distribute_datasets_from_function`. Any stateful 

1524 ops that the dataset may have are currently ignored. For example, if your 

1525 dataset has a `map_fn` that uses `tf.random.uniform` to rotate an image, 

1526 then you have a dataset graph that depends on state (i.e the random seed) on 

1527 the local machine where the python process is being executed. 

1528 

1529 For a tutorial on more usage and properties of this method, refer to the 

1530 [tutorial on distributed input](https://www.tensorflow.org/tutorials/distribute/input#tfdistributestrategyexperimental_distribute_datasets_from_function)). 

1531 If you are interested in last partial batch handling, read [this section](https://www.tensorflow.org/tutorials/distribute/input#partial_batches). 

1532 

1533 Args: 

1534 dataset_fn: A function taking a `tf.distribute.InputContext` instance and 

1535 returning a `tf.data.Dataset`. 

1536 options: `tf.distribute.InputOptions` used to control options on how this 

1537 dataset is distributed. 

1538 

1539 Returns: 

1540 A `tf.distribute.DistributedDataset`. 

1541 """ 

1542 distribution_strategy_input_api_counter.get_cell( 

1543 self.__class__.__name__, 

1544 "distribute_datasets_from_function").increase_by(1) 

1545 # pylint: enable=line-too-long 

1546 return self._extended._distribute_datasets_from_function( # pylint: disable=protected-access 

1547 dataset_fn, options) 

1548 

1549 # TODO(b/162776748): Remove deprecated symbol. 

1550 @doc_controls.do_not_doc_inheritable 

1551 @deprecation.deprecated(None, "rename to distribute_datasets_from_function") 

1552 def experimental_distribute_datasets_from_function(self, 

1553 dataset_fn, 

1554 options=None): 

1555 return self.distribute_datasets_from_function(dataset_fn, options) 

1556 

1557 def run(self, fn, args=(), kwargs=None, options=None): 

1558 """Invokes `fn` on each replica, with the given arguments. 

1559 

1560 This method is the primary way to distribute your computation with a 

1561 tf.distribute object. It invokes `fn` on each replica. If `args` or `kwargs` 

1562 have `tf.distribute.DistributedValues`, such as those produced by a 

1563 `tf.distribute.DistributedDataset` from 

1564 `tf.distribute.Strategy.experimental_distribute_dataset` or 

1565 `tf.distribute.Strategy.distribute_datasets_from_function`, 

1566 when `fn` is executed on a particular replica, it will be executed with the 

1567 component of `tf.distribute.DistributedValues` that correspond to that 

1568 replica. 

1569 

1570 `fn` is invoked under a replica context. `fn` may call 

1571 `tf.distribute.get_replica_context()` to access members such as 

1572 `all_reduce`. Please see the module-level docstring of tf.distribute for the 

1573 concept of replica context. 

1574 

1575 All arguments in `args` or `kwargs` can be a nested structure of tensors, 

1576 e.g. a list of tensors, in which case `args` and `kwargs` will be passed to 

1577 the `fn` invoked on each replica. Or `args` or `kwargs` can be 

1578 `tf.distribute.DistributedValues` containing tensors or composite tensors, 

1579 i.e. `tf.compat.v1.TensorInfo.CompositeTensor`, in which case each `fn` call 

1580 will get the component of a `tf.distribute.DistributedValues` corresponding 

1581 to its replica. Note that arbitrary Python values that are not of the types 

1582 above are not supported. 

1583 

1584 IMPORTANT: Depending on the implementation of `tf.distribute.Strategy` and 

1585 whether eager execution is enabled, `fn` may be called one or more times. If 

1586 `fn` is annotated with `tf.function` or `tf.distribute.Strategy.run` is 

1587 called inside a `tf.function` (eager execution is disabled inside a 

1588 `tf.function` by default), `fn` is called once per replica to generate a 

1589 Tensorflow graph, which will then be reused for execution with new inputs. 

1590 Otherwise, if eager execution is enabled, `fn` will be called once per 

1591 replica every step just like regular python code. 

1592 

1593 Example usage: 

1594 

1595 1. Constant tensor input. 

1596 

1597 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 

1598 >>> tensor_input = tf.constant(3.0) 

1599 >>> @tf.function 

1600 ... def replica_fn(input): 

1601 ... return input*2.0 

1602 >>> result = strategy.run(replica_fn, args=(tensor_input,)) 

1603 >>> result 

1604 PerReplica:{ 

1605 0: <tf.Tensor: shape=(), dtype=float32, numpy=6.0>, 

1606 1: <tf.Tensor: shape=(), dtype=float32, numpy=6.0> 

1607 } 

1608 

1609 2. DistributedValues input. {: value=2} 

1610 

1611 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 

1612 >>> @tf.function 

1613 ... def run(): 

1614 ... def value_fn(value_context): 

1615 ... return value_context.num_replicas_in_sync 

1616 ... distributed_values = ( 

1617 ... strategy.experimental_distribute_values_from_function( 

1618 ... value_fn)) 

1619 ... def replica_fn2(input): 

1620 ... return input*2 

1621 ... return strategy.run(replica_fn2, args=(distributed_values,)) 

1622 >>> result = run() 

1623 >>> result 

1624 <tf.Tensor: shape=(), dtype=int32, numpy=4> 

1625 

1626 3. Use `tf.distribute.ReplicaContext` to allreduce values. {: value=3} 

1627 

1628 >>> strategy = tf.distribute.MirroredStrategy(["gpu:0", "gpu:1"]) 

1629 >>> @tf.function 

1630 ... def run(): 

1631 ... def value_fn(value_context): 

1632 ... return tf.constant(value_context.replica_id_in_sync_group) 

1633 ... distributed_values = ( 

1634 ... strategy.experimental_distribute_values_from_function( 

1635 ... value_fn)) 

1636 ... def replica_fn(input): 

1637 ... return tf.distribute.get_replica_context().all_reduce( 

1638 ... "sum", input) 

1639 ... return strategy.run(replica_fn, args=(distributed_values,)) 

1640 >>> result = run() 

1641 >>> result 

1642 PerReplica:{ 

1643 0: <tf.Tensor: shape=(), dtype=int32, numpy=1>, 

1644 1: <tf.Tensor: shape=(), dtype=int32, numpy=1> 

1645 } 

1646 

1647 Args: 

1648 fn: The function to run on each replica. 

1649 args: Optional positional arguments to `fn`. Its element can be a tensor, 

1650 a nested structure of tensors or a `tf.distribute.DistributedValues`. 

1651 kwargs: Optional keyword arguments to `fn`. Its element can be a tensor, 

1652 a nested structure of tensors or a `tf.distribute.DistributedValues`. 

1653 options: An optional instance of `tf.distribute.RunOptions` specifying 

1654 the options to run `fn`. 

1655 

1656 Returns: 

1657 Merged return value of `fn` across replicas. The structure of the return 

1658 value is the same as the return value from `fn`. Each element in the 

1659 structure can either be `tf.distribute.DistributedValues`, `Tensor` 

1660 objects, or `Tensor`s (for example, if running on a single replica). 

1661 """ 

1662 del options 

1663 

1664 if not isinstance(args, (list, tuple)): 

1665 raise ValueError( 

1666 "positional args must be a list or tuple, got {}".format(type(args))) 

1667 

1668 with self.scope(): 

1669 # tf.distribute supports Eager functions, so AutoGraph should not be 

1670 # applied when the caller is also in Eager mode. 

1671 fn = autograph.tf_convert( 

1672 fn, autograph_ctx.control_status_ctx(), convert_by_default=False) 

1673 return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs) 

1674 

1675 def reduce(self, reduce_op, value, axis): 

1676 """Reduce `value` across replicas and return result on current device. 

1677 

1678 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 

1679 >>> def step_fn(): 

1680 ... i = tf.distribute.get_replica_context().replica_id_in_sync_group 

1681 ... return tf.identity(i) 

1682 >>> 

1683 >>> per_replica_result = strategy.run(step_fn) 

1684 >>> total = strategy.reduce("SUM", per_replica_result, axis=None) 

1685 >>> total 

1686 <tf.Tensor: shape=(), dtype=int32, numpy=1> 

1687 

1688 To see how this would look with multiple replicas, consider the same 

1689 example with MirroredStrategy with 2 GPUs: 

1690 

1691 ```python 

1692 strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"]) 

1693 def step_fn(): 

1694 i = tf.distribute.get_replica_context().replica_id_in_sync_group 

1695 return tf.identity(i) 

1696 

1697 per_replica_result = strategy.run(step_fn) 

1698 # Check devices on which per replica result is: 

1699 strategy.experimental_local_results(per_replica_result)[0].device 

1700 # /job:localhost/replica:0/task:0/device:GPU:0 

1701 strategy.experimental_local_results(per_replica_result)[1].device 

1702 # /job:localhost/replica:0/task:0/device:GPU:1 

1703 

1704 total = strategy.reduce("SUM", per_replica_result, axis=None) 

1705 # Check device on which reduced result is: 

1706 total.device 

1707 # /job:localhost/replica:0/task:0/device:CPU:0 

1708 

1709 ``` 

1710 

1711 This API is typically used for aggregating the results returned from 

1712 different replicas, for reporting etc. For example, loss computed from 

1713 different replicas can be averaged using this API before printing. 

1714 

1715 Note: The result is copied to the "current" device - which would typically 

1716 be the CPU of the worker on which the program is running. For `TPUStrategy`, 

1717 it is the first TPU host. For multi client `MultiWorkerMirroredStrategy`, 

1718 this is CPU of each worker. 

1719 

1720 There are a number of different tf.distribute APIs for reducing values 

1721 across replicas: 

1722 * `tf.distribute.ReplicaContext.all_reduce`: This differs from 

1723 `Strategy.reduce` in that it is for replica context and does 

1724 not copy the results to the host device. `all_reduce` should be typically 

1725 used for reductions inside the training step such as gradients. 

1726 * `tf.distribute.StrategyExtended.reduce_to` and 

1727 `tf.distribute.StrategyExtended.batch_reduce_to`: These APIs are more 

1728 advanced versions of `Strategy.reduce` as they allow customizing the 

1729 destination of the result. They are also called in cross replica context. 

1730 

1731 _What should axis be?_ 

1732 

1733 Given a per-replica value returned by `run`, say a 

1734 per-example loss, the batch will be divided across all the replicas. This 

1735 function allows you to aggregate across replicas and optionally also across 

1736 batch elements by specifying the axis parameter accordingly. 

1737 

1738 For example, if you have a global batch size of 8 and 2 

1739 replicas, values for examples `[0, 1, 2, 3]` will be on replica 0 and 

1740 `[4, 5, 6, 7]` will be on replica 1. With `axis=None`, `reduce` will 

1741 aggregate only across replicas, returning `[0+4, 1+5, 2+6, 3+7]`. 

1742 This is useful when each replica is computing a scalar or some other value 

1743 that doesn't have a "batch" dimension (like a gradient or loss). 

1744 ``` 

1745 strategy.reduce("sum", per_replica_result, axis=None) 

1746 ``` 

1747 

1748 Sometimes, you will want to aggregate across both the global batch _and_ 

1749 all replicas. You can get this behavior by specifying the batch 

1750 dimension as the `axis`, typically `axis=0`. In this case it would return a 

1751 scalar `0+1+2+3+4+5+6+7`. 

1752 ``` 

1753 strategy.reduce("sum", per_replica_result, axis=0) 

1754 ``` 

1755 

1756 If there is a last partial batch, you will need to specify an axis so 

1757 that the resulting shape is consistent across replicas. So if the last 

1758 batch has size 6 and it is divided into [0, 1, 2, 3] and [4, 5], you 

1759 would get a shape mismatch unless you specify `axis=0`. If you specify 

1760 `tf.distribute.ReduceOp.MEAN`, using `axis=0` will use the correct 

1761 denominator of 6. Contrast this with computing `reduce_mean` to get a 

1762 scalar value on each replica and this function to average those means, 

1763 which will weigh some values `1/8` and others `1/4`. 

1764 

1765 Args: 

1766 reduce_op: a `tf.distribute.ReduceOp` value specifying how values should 

1767 be combined. Allows using string representation of the enum such as 

1768 "SUM", "MEAN". 

1769 value: a `tf.distribute.DistributedValues` instance, e.g. returned by 

1770 `Strategy.run`, to be combined into a single tensor. It can also be a 

1771 regular tensor when used with `OneDeviceStrategy` or default strategy. 

1772 axis: specifies the dimension to reduce along within each 

1773 replica's tensor. Should typically be set to the batch dimension, or 

1774 `None` to only reduce across replicas (e.g. if the tensor has no batch 

1775 dimension). 

1776 

1777 Returns: 

1778 A `Tensor`. 

1779 """ 

1780 # TODO(josh11b): support `value` being a nest. 

1781 _require_cross_replica_or_default_context_extended(self._extended) 

1782 if isinstance(reduce_op, six.string_types): 

1783 reduce_op = reduce_util.ReduceOp(reduce_op.upper()) 

1784 if axis is None: 

1785 return self._extended._reduce(reduce_op, value) # pylint: disable=protected-access 

1786 if reduce_op == reduce_util.ReduceOp.SUM: 

1787 

1788 def reduce_sum(v): 

1789 return math_ops.reduce_sum(v, axis=axis) 

1790 

1791 if eager_context.executing_eagerly(): 

1792 # As some strategies (e.g. TPUStrategy) doesn't support pure eager 

1793 # execution, wrap the `reduce_sum_fn` with a `tf.function` so it can be 

1794 # run from eager mode. Cache the tf.function by `axis` to avoid the 

1795 # same function to be traced again. 

1796 if axis not in self._reduce_sum_fns: 

1797 self._reduce_sum_fns[axis] = def_function.function(reduce_sum) 

1798 value = self.run(self._reduce_sum_fns[axis], args=(value,)) 

1799 else: 

1800 value = self.run(reduce_sum, args=(value,)) 

1801 

1802 return self._extended._reduce(reduce_op, value) # pylint: disable=protected-access 

1803 if reduce_op != reduce_util.ReduceOp.MEAN: 

1804 raise TypeError("Expected `reduce_op` to be a `tf.distribute.ReduceOp`, " 

1805 "not: %r" % reduce_op) 

1806 

1807 def mean_reduce_helper(v, axes=axis): 

1808 """Computes the numerator and denominator on each replica.""" 

1809 numer = math_ops.reduce_sum(v, axis=axes) 

1810 def dimension(axis): 

1811 if v.shape.rank is not None: 

1812 # Note(joshl): We support axis < 0 to be consistent with the 

1813 # tf.math.reduce_* operations. 

1814 if axis < 0: 

1815 if axis + v.shape.rank < 0: 

1816 raise ValueError( 

1817 "`axis` = %r out of range for `value` with rank %d" % 

1818 (axis, v.shape.rank)) 

1819 axis += v.shape.rank 

1820 elif axis >= v.shape.rank: 

1821 raise ValueError( 

1822 "`axis` = %r out of range for `value` with rank %d" % 

1823 (axis, v.shape.rank)) 

1824 # TF v2 returns `None` for unknown dimensions and an integer for 

1825 # known dimension, whereas TF v1 returns tensor_shape.Dimension(None) 

1826 # or tensor_shape.Dimension(integer). `dimension_value` hides this 

1827 # difference, always returning `None` or an integer. 

1828 dim = tensor_shape.dimension_value(v.shape[axis]) 

1829 if dim is not None: 

1830 # By returning a python value in the static shape case, we can 

1831 # maybe get a fast path for reducing the denominator. 

1832 # TODO(b/151871486): Remove array_ops.identity after we fallback to 

1833 # simple reduction if inputs are all on CPU. 

1834 return array_ops.identity( 

1835 constant_op.constant(dim, dtype=dtypes.int64)) 

1836 elif axis < 0: 

1837 axis = axis + array_ops.rank(v) 

1838 # TODO(b/151871486): Remove array_ops.identity after we fallback to 

1839 # simple reduction if inputs are all on CPU. 

1840 return array_ops.identity( 

1841 array_ops.shape_v2(v, out_type=dtypes.int64)[axis]) 

1842 if isinstance(axis, six.integer_types): 

1843 denom = dimension(axis) 

1844 elif isinstance(axis, (tuple, list)): 

1845 denom = math_ops.reduce_prod([dimension(a) for a in axes]) 

1846 else: 

1847 raise TypeError( 

1848 "Expected `axis` to be an integer, tuple or list not: %r" % axis) 

1849 # TODO(josh11b): Should we cast denom to v.dtype here instead of after the 

1850 # reduce is complete? 

1851 return numer, denom 

1852 

1853 if eager_context.executing_eagerly(): 

1854 # As some strategies (e.g. TPUStrategy) doesn't support pure eager 

1855 # execution, wrap the `mean_reduce_helper` with a `tf.function` so it can 

1856 # be run from eager mode. Cache the tf.function by `axis` to avoid the 

1857 # same function to be traced again. 

1858 if axis not in self._mean_reduce_helper_fns: 

1859 self._mean_reduce_helper_fns[axis] = def_function.function( 

1860 mean_reduce_helper) 

1861 numer, denom = self.run(self._mean_reduce_helper_fns[axis], args=(value,)) 

1862 else: 

1863 numer, denom = self.run(mean_reduce_helper, args=(value,)) 

1864 

1865 # TODO(josh11b): Should batch reduce here instead of doing two. 

1866 numer = self._extended._reduce(reduce_util.ReduceOp.SUM, numer) # pylint: disable=protected-access 

1867 denom = self._extended._reduce(reduce_util.ReduceOp.SUM, denom) # pylint: disable=protected-access 

1868 denom = math_ops.cast(denom, numer.dtype) 

1869 return math_ops.truediv(numer, denom) 

1870 

1871 @doc_controls.do_not_doc_inheritable # DEPRECATED 

1872 @deprecated(None, "use `experimental_local_results` instead.") 

1873 def unwrap(self, value): 

1874 """Returns the list of all local per-replica values contained in `value`. 

1875 

1876 DEPRECATED: Please use `experimental_local_results` instead. 

1877 

1878 Note: This only returns values on the workers initiated by this client. 

1879 When using a `tf.distribute.Strategy` like 

1880 `tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker 

1881 will be its own client, and this function will only return values 

1882 computed on that worker. 

1883 

1884 Args: 

1885 value: A value returned by `experimental_run()`, 

1886 `extended.call_for_each_replica()`, or a variable created in `scope`. 

1887 

1888 Returns: 

1889 A tuple of values contained in `value`. If `value` represents a single 

1890 value, this returns `(value,).` 

1891 """ 

1892 return self._extended._local_results(value) # pylint: disable=protected-access 

1893 

1894 def experimental_local_results(self, value): 

1895 """Returns the list of all local per-replica values contained in `value`. 

1896 

1897 Note: This only returns values on the worker initiated by this client. 

1898 When using a `tf.distribute.Strategy` like 

1899 `tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker 

1900 will be its own client, and this function will only return values 

1901 computed on that worker. 

1902 

1903 Args: 

1904 value: A value returned by `experimental_run()`, `run(), or a variable 

1905 created in `scope`. 

1906 

1907 Returns: 

1908 A tuple of values contained in `value` where ith element corresponds to 

1909 ith replica. If `value` represents a single value, this returns 

1910 `(value,).` 

1911 """ 

1912 return self._extended._local_results(value) # pylint: disable=protected-access 

1913 

1914 @doc_controls.do_not_doc_inheritable # DEPRECATED: TF v1.x only 

1915 def group(self, value, name=None): 

1916 """Shortcut for `tf.group(self.experimental_local_results(value))`.""" 

1917 return self._extended._group(value, name) # pylint: disable=protected-access 

1918 

1919 @property 

1920 def num_replicas_in_sync(self): 

1921 """Returns number of replicas over which gradients are aggregated.""" 

1922 return self._extended._num_replicas_in_sync # pylint: disable=protected-access 

1923 

1924 @doc_controls.do_not_doc_inheritable # DEPRECATED: see doc string 

1925 @deprecated(None, "use `update_config_proto` instead.") 

1926 def configure(self, 

1927 session_config=None, 

1928 cluster_spec=None, 

1929 task_type=None, 

1930 task_id=None): 

1931 # pylint: disable=g-doc-return-or-yield,g-doc-args 

1932 """DEPRECATED: use `update_config_proto` instead. 

1933 

1934 Configures the strategy class. 

1935 

1936 DEPRECATED: This method's functionality has been split into the strategy 

1937 constructor and `update_config_proto`. In the future, we will allow passing 

1938 cluster and config_proto to the constructor to configure the strategy. And 

1939 `update_config_proto` can be used to update the config_proto based on the 

1940 specific strategy. 

1941 """ 

1942 return self._extended._configure( # pylint: disable=protected-access 

1943 session_config, cluster_spec, task_type, task_id) 

1944 

1945 @doc_controls.do_not_generate_docs # DEPRECATED 

1946 def update_config_proto(self, config_proto): 

1947 """DEPRECATED TF 1.x ONLY.""" 

1948 return self._extended._update_config_proto(config_proto) # pylint: disable=protected-access 

1949 

1950 def __deepcopy__(self, memo): 

1951 # First do a regular deepcopy of `self`. 

1952 cls = self.__class__ 

1953 result = cls.__new__(cls) 

1954 memo[id(self)] = result 

1955 for k, v in self.__dict__.items(): 

1956 setattr(result, k, copy.deepcopy(v, memo)) 

1957 # One little fix-up: we want `result._extended` to reference `result` 

1958 # instead of `self`. 

1959 result._extended._container_strategy_weakref = weakref.ref(result) # pylint: disable=protected-access 

1960 return result 

1961 

1962 def __copy__(self): 

1963 raise RuntimeError("Must only deepcopy DistributionStrategy.") 

1964 

1965 @property 

1966 def cluster_resolver(self): 

1967 """Returns the cluster resolver associated with this strategy. 

1968 

1969 In general, when using a multi-worker `tf.distribute` strategy such as 

1970 `tf.distribute.experimental.MultiWorkerMirroredStrategy` or 

1971 `tf.distribute.TPUStrategy()`, there is a 

1972 `tf.distribute.cluster_resolver.ClusterResolver` associated with the 

1973 strategy used, and such an instance is returned by this property. 

1974 

1975 Strategies that intend to have an associated 

1976 `tf.distribute.cluster_resolver.ClusterResolver` must set the 

1977 relevant attribute, or override this property; otherwise, `None` is returned 

1978 by default. Those strategies should also provide information regarding what 

1979 is returned by this property. 

1980 

1981 Single-worker strategies usually do not have a 

1982 `tf.distribute.cluster_resolver.ClusterResolver`, and in those cases this 

1983 property will return `None`. 

1984 

1985 The `tf.distribute.cluster_resolver.ClusterResolver` may be useful when the 

1986 user needs to access information such as the cluster spec, task type or task 

1987 id. For example, 

1988 

1989 ```python 

1990 

1991 os.environ['TF_CONFIG'] = json.dumps({ 

1992 'cluster': { 

1993 'worker': ["localhost:12345", "localhost:23456"], 

1994 'ps': ["localhost:34567"] 

1995 }, 

1996 'task': {'type': 'worker', 'index': 0} 

1997 }) 

1998 

1999 # This implicitly uses TF_CONFIG for the cluster and current task info. 

2000 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() 

2001 

2002 ... 

2003 

2004 if strategy.cluster_resolver.task_type == 'worker': 

2005 # Perform something that's only applicable on workers. Since we set this 

2006 # as a worker above, this block will run on this particular instance. 

2007 elif strategy.cluster_resolver.task_type == 'ps': 

2008 # Perform something that's only applicable on parameter servers. Since we 

2009 # set this as a worker above, this block will not run on this particular 

2010 # instance. 

2011 ``` 

2012 

2013 For more information, please see 

2014 `tf.distribute.cluster_resolver.ClusterResolver`'s API docstring. 

2015 

2016 Returns: 

2017 The cluster resolver associated with this strategy. Returns `None` if a 

2018 cluster resolver is not applicable or available in this strategy. 

2019 """ 

2020 if hasattr(self.extended, "_cluster_resolver"): 

2021 return self.extended._cluster_resolver # pylint: disable=protected-access 

2022 return None 

2023 

2024 

2025@tf_export("distribute.Strategy", v1=[]) # pylint: disable=g-missing-docstring 

2026class Strategy(StrategyBase): 

2027 

2028 __doc__ = StrategyBase.__doc__ 

2029 

2030 def experimental_distribute_values_from_function(self, value_fn): 

2031 """Generates `tf.distribute.DistributedValues` from `value_fn`. 

2032 

2033 This function is to generate `tf.distribute.DistributedValues` to pass 

2034 into `run`, `reduce`, or other methods that take 

2035 distributed values when not using datasets. 

2036 

2037 Args: 

2038 value_fn: The function to run to generate values. It is called for 

2039 each replica with `tf.distribute.ValueContext` as the sole argument. It 

2040 must return a Tensor or a type that can be converted to a Tensor. 

2041 Returns: 

2042 A `tf.distribute.DistributedValues` containing a value for each replica. 

2043 

2044 Example usage: 

2045 

2046 1. Return constant value per replica: 

2047 

2048 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 

2049 >>> def value_fn(ctx): 

2050 ... return tf.constant(1.) 

2051 >>> distributed_values = ( 

2052 ... strategy.experimental_distribute_values_from_function( 

2053 ... value_fn)) 

2054 >>> local_result = strategy.experimental_local_results( 

2055 ... distributed_values) 

2056 >>> local_result 

2057 (<tf.Tensor: shape=(), dtype=float32, numpy=1.0>, 

2058 <tf.Tensor: shape=(), dtype=float32, numpy=1.0>) 

2059 

2060 2. Distribute values in array based on replica_id: {: value=2} 

2061 

2062 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 

2063 >>> array_value = np.array([3., 2., 1.]) 

2064 >>> def value_fn(ctx): 

2065 ... return array_value[ctx.replica_id_in_sync_group] 

2066 >>> distributed_values = ( 

2067 ... strategy.experimental_distribute_values_from_function( 

2068 ... value_fn)) 

2069 >>> local_result = strategy.experimental_local_results( 

2070 ... distributed_values) 

2071 >>> local_result 

2072 (3.0, 2.0) 

2073 

2074 3. Specify values using num_replicas_in_sync: {: value=3} 

2075 

2076 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 

2077 >>> def value_fn(ctx): 

2078 ... return ctx.num_replicas_in_sync 

2079 >>> distributed_values = ( 

2080 ... strategy.experimental_distribute_values_from_function( 

2081 ... value_fn)) 

2082 >>> local_result = strategy.experimental_local_results( 

2083 ... distributed_values) 

2084 >>> local_result 

2085 (2, 2) 

2086 

2087 4. Place values on devices and distribute: {: value=4} 

2088 

2089 ``` 

2090 strategy = tf.distribute.TPUStrategy() 

2091 worker_devices = strategy.extended.worker_devices 

2092 multiple_values = [] 

2093 for i in range(strategy.num_replicas_in_sync): 

2094 with tf.device(worker_devices[i]): 

2095 multiple_values.append(tf.constant(1.0)) 

2096 

2097 def value_fn(ctx): 

2098 return multiple_values[ctx.replica_id_in_sync_group] 

2099 

2100 distributed_values = strategy. 

2101 experimental_distribute_values_from_function( 

2102 value_fn) 

2103 ``` 

2104 

2105 """ 

2106 return self._extended._experimental_distribute_values_from_function( # pylint: disable=protected-access 

2107 value_fn) 

2108 

2109 def gather(self, value, axis): 

2110 # pylint: disable=line-too-long, protected-access 

2111 """Gather `value` across replicas along `axis` to the current device. 

2112 

2113 Given a `tf.distribute.DistributedValues` or `tf.Tensor`-like 

2114 object `value`, this API gathers and concatenates `value` across replicas 

2115 along the `axis`-th dimension. The result is copied to the "current" device, 

2116 which would typically be the CPU of the worker on which the program is 

2117 running. For `tf.distribute.TPUStrategy`, it is the first TPU host. For 

2118 multi-client `tf.distribute.MultiWorkerMirroredStrategy`, this is the CPU of 

2119 each worker. 

2120 

2121 This API can only be called in the cross-replica context. For a counterpart 

2122 in the replica context, see `tf.distribute.ReplicaContext.all_gather`. 

2123 

2124 Note: For all strategies except `tf.distribute.TPUStrategy`, the input 

2125 `value` on different replicas must have the same rank, and their shapes must 

2126 be the same in all dimensions except the `axis`-th dimension. In other 

2127 words, their shapes cannot be different in a dimension `d` where `d` does 

2128 not equal to the `axis` argument. For example, given a 

2129 `tf.distribute.DistributedValues` with component tensors of shape 

2130 `(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can call 

2131 `gather(..., axis=1, ...)` on it, but not `gather(..., axis=0, ...)` or 

2132 `gather(..., axis=2, ...)`. However, for `tf.distribute.TPUStrategy.gather`, 

2133 all tensors must have exactly the same rank and same shape. 

2134 

2135 Note: Given a `tf.distribute.DistributedValues` `value`, its component 

2136 tensors must have a non-zero rank. Otherwise, consider using 

2137 `tf.expand_dims` before gathering them. 

2138 

2139 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 

2140 >>> # A DistributedValues with component tensor of shape (2, 1) on each replica 

2141 ... distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(tf.constant([[1], [2]]))) 

2142 >>> @tf.function 

2143 ... def run(): 

2144 ... return strategy.gather(distributed_values, axis=0) 

2145 >>> run() 

2146 <tf.Tensor: shape=(4, 1), dtype=int32, numpy= 

2147 array([[1], 

2148 [2], 

2149 [1], 

2150 [2]], dtype=int32)> 

2151 

2152 

2153 Consider the following example for more combinations: 

2154 

2155 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2", "GPU:3"]) 

2156 >>> single_tensor = tf.reshape(tf.range(6), shape=(1,2,3)) 

2157 >>> distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(single_tensor)) 

2158 >>> @tf.function 

2159 ... def run(axis): 

2160 ... return strategy.gather(distributed_values, axis=axis) 

2161 >>> axis=0 

2162 >>> run(axis) 

2163 <tf.Tensor: shape=(4, 2, 3), dtype=int32, numpy= 

2164 array([[[0, 1, 2], 

2165 [3, 4, 5]], 

2166 [[0, 1, 2], 

2167 [3, 4, 5]], 

2168 [[0, 1, 2], 

2169 [3, 4, 5]], 

2170 [[0, 1, 2], 

2171 [3, 4, 5]]], dtype=int32)> 

2172 >>> axis=1 

2173 >>> run(axis) 

2174 <tf.Tensor: shape=(1, 8, 3), dtype=int32, numpy= 

2175 array([[[0, 1, 2], 

2176 [3, 4, 5], 

2177 [0, 1, 2], 

2178 [3, 4, 5], 

2179 [0, 1, 2], 

2180 [3, 4, 5], 

2181 [0, 1, 2], 

2182 [3, 4, 5]]], dtype=int32)> 

2183 >>> axis=2 

2184 >>> run(axis) 

2185 <tf.Tensor: shape=(1, 2, 12), dtype=int32, numpy= 

2186 array([[[0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2], 

2187 [3, 4, 5, 3, 4, 5, 3, 4, 5, 3, 4, 5]]], dtype=int32)> 

2188 

2189 

2190 Args: 

2191 value: a `tf.distribute.DistributedValues` instance, e.g. returned by 

2192 `Strategy.run`, to be combined into a single tensor. It can also be a 

2193 regular tensor when used with `tf.distribute.OneDeviceStrategy` or the 

2194 default strategy. The tensors that constitute the DistributedValues 

2195 can only be dense tensors with non-zero rank, NOT a `tf.IndexedSlices`. 

2196 axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the 

2197 range [0, rank(value)). 

2198 

2199 Returns: 

2200 A `Tensor` that's the concatenation of `value` across replicas along 

2201 `axis` dimension. 

2202 """ 

2203 # pylint: enable=line-too-long 

2204 error_message = ("tf.distribute.Strategy.gather method requires " 

2205 "cross-replica context, use " 

2206 "get_replica_context().all_gather() instead") 

2207 _require_cross_replica_or_default_context_extended(self._extended, 

2208 error_message) 

2209 dst = device_util.current( 

2210 ) or self._extended._default_device or "/device:CPU:0" 

2211 if isinstance(value, indexed_slices.IndexedSlices): 

2212 raise NotImplementedError("gather does not support IndexedSlices") 

2213 return self._extended._local_results( 

2214 self._extended._gather_to(value, dst, axis))[0] 

2215 

2216 

2217# TF v1.x version has additional deprecated APIs 

2218@tf_export(v1=["distribute.Strategy"]) 

2219class StrategyV1(StrategyBase): 

2220 """A list of devices with a state & compute distribution policy. 

2221 

2222 See [the guide](https://www.tensorflow.org/guide/distribute_strategy) 

2223 for overview and examples. 

2224 

2225 Note: Not all `tf.distribute.Strategy` implementations currently support 

2226 TensorFlow's partitioned variables (where a single variable is split across 

2227 multiple devices) at this time. 

2228 """ 

2229 

2230 def make_dataset_iterator(self, dataset): 

2231 """Makes an iterator for input provided via `dataset`. 

2232 

2233 DEPRECATED: This method is not available in TF 2.x. 

2234 

2235 Data from the given dataset will be distributed evenly across all the 

2236 compute replicas. We will assume that the input dataset is batched by the 

2237 global batch size. With this assumption, we will make a best effort to 

2238 divide each batch across all the replicas (one or more workers). 

2239 If this effort fails, an error will be thrown, and the user should instead 

2240 use `make_input_fn_iterator` which provides more control to the user, and 

2241 does not try to divide a batch across replicas. 

2242 

2243 The user could also use `make_input_fn_iterator` if they want to 

2244 customize which input is fed to which replica/worker etc. 

2245 

2246 Args: 

2247 dataset: `tf.data.Dataset` that will be distributed evenly across all 

2248 replicas. 

2249 

2250 Returns: 

2251 An `tf.distribute.InputIterator` which returns inputs for each step of the 

2252 computation. User should call `initialize` on the returned iterator. 

2253 """ 

2254 return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access 

2255 

2256 def make_input_fn_iterator(self, # pylint: disable=useless-super-delegation 

2257 input_fn, 

2258 replication_mode=InputReplicationMode.PER_WORKER): 

2259 """Returns an iterator split across replicas created from an input function. 

2260 

2261 DEPRECATED: This method is not available in TF 2.x. 

2262 

2263 The `input_fn` should take an `tf.distribute.InputContext` object where 

2264 information about batching and input sharding can be accessed: 

2265 

2266 ``` 

2267 def input_fn(input_context): 

2268 batch_size = input_context.get_per_replica_batch_size(global_batch_size) 

2269 d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size) 

2270 return d.shard(input_context.num_input_pipelines, 

2271 input_context.input_pipeline_id) 

2272 with strategy.scope(): 

2273 iterator = strategy.make_input_fn_iterator(input_fn) 

2274 replica_results = strategy.experimental_run(replica_fn, iterator) 

2275 ``` 

2276 

2277 The `tf.data.Dataset` returned by `input_fn` should have a per-replica 

2278 batch size, which may be computed using 

2279 `input_context.get_per_replica_batch_size`. 

2280 

2281 Args: 

2282 input_fn: A function taking a `tf.distribute.InputContext` object and 

2283 returning a `tf.data.Dataset`. 

2284 replication_mode: an enum value of `tf.distribute.InputReplicationMode`. 

2285 Only `PER_WORKER` is supported currently, which means there will be 

2286 a single call to `input_fn` per worker. Replicas will dequeue from the 

2287 local `tf.data.Dataset` on their worker. 

2288 

2289 Returns: 

2290 An iterator object that should first be `.initialize()`-ed. It may then 

2291 either be passed to `strategy.experimental_run()` or you can 

2292 `iterator.get_next()` to get the next value to pass to 

2293 `strategy.extended.call_for_each_replica()`. 

2294 """ 

2295 return super(StrategyV1, self).make_input_fn_iterator( 

2296 input_fn, replication_mode) 

2297 

2298 def experimental_make_numpy_dataset(self, numpy_input, session=None): 

2299 """Makes a tf.data.Dataset for input provided via a numpy array. 

2300 

2301 This avoids adding `numpy_input` as a large constant in the graph, 

2302 and copies the data to the machine or machines that will be processing 

2303 the input. 

2304 

2305 Note that you will likely need to use 

2306 tf.distribute.Strategy.experimental_distribute_dataset 

2307 with the returned dataset to further distribute it with the strategy. 

2308 

2309 Example: 

2310 ``` 

2311 numpy_input = np.ones([10], dtype=np.float32) 

2312 dataset = strategy.experimental_make_numpy_dataset(numpy_input) 

2313 dist_dataset = strategy.experimental_distribute_dataset(dataset) 

2314 ``` 

2315 

2316 Args: 

2317 numpy_input: A nest of NumPy input arrays that will be converted into a 

2318 dataset. Note that lists of Numpy arrays are stacked, as that is normal 

2319 `tf.data.Dataset` behavior. 

2320 session: (TensorFlow v1.x graph execution only) A session used for 

2321 initialization. 

2322 

2323 Returns: 

2324 A `tf.data.Dataset` representing `numpy_input`. 

2325 """ 

2326 return self.extended.experimental_make_numpy_dataset( 

2327 numpy_input, session=session) 

2328 

2329 @deprecated( 

2330 None, 

2331 "This method is not available in TF 2.x. Please switch to using `run` instead." 

2332 ) 

2333 def experimental_run(self, fn, input_iterator=None): # pylint: disable=useless-super-delegation 

2334 """Runs ops in `fn` on each replica, with inputs from `input_iterator`. 

2335 

2336 DEPRECATED: This method is not available in TF 2.x. Please switch 

2337 to using `run` instead. 

2338 

2339 When eager execution is enabled, executes ops specified by `fn` on each 

2340 replica. Otherwise, builds a graph to execute the ops on each replica. 

2341 

2342 Each replica will take a single, different input from the inputs provided by 

2343 one `get_next` call on the input iterator. 

2344 

2345 `fn` may call `tf.distribute.get_replica_context()` to access members such 

2346 as `replica_id_in_sync_group`. 

2347 

2348 IMPORTANT: Depending on the `tf.distribute.Strategy` implementation being 

2349 used, and whether eager execution is enabled, `fn` may be called one or more 

2350 times (once for each replica). 

2351 

2352 Args: 

2353 fn: The function to run. The inputs to the function must match the outputs 

2354 of `input_iterator.get_next()`. The output must be a `tf.nest` of 

2355 `Tensor`s. 

2356 input_iterator: (Optional) input iterator from which the inputs are taken. 

2357 

2358 Returns: 

2359 Merged return value of `fn` across replicas. The structure of the return 

2360 value is the same as the return value from `fn`. Each element in the 

2361 structure can either be `PerReplica` (if the values are unsynchronized), 

2362 `Mirrored` (if the values are kept in sync), or `Tensor` (if running on a 

2363 single replica). 

2364 """ 

2365 return super(StrategyV1, self).experimental_run( 

2366 fn, input_iterator) 

2367 

2368 def reduce(self, reduce_op, value, axis=None): 

2369 return super(StrategyV1, self).reduce(reduce_op, value, axis) 

2370 

2371 reduce.__doc__ = StrategyBase.reduce.__doc__ 

2372 

2373 def update_config_proto(self, config_proto): 

2374 """Returns a copy of `config_proto` modified for use with this strategy. 

2375 

2376 DEPRECATED: This method is not available in TF 2.x. 

2377 

2378 The updated config has something needed to run a strategy, e.g. 

2379 configuration to run collective ops, or device filters to improve 

2380 distributed training performance. 

2381 

2382 Args: 

2383 config_proto: a `tf.ConfigProto` object. 

2384 

2385 Returns: 

2386 The updated copy of the `config_proto`. 

2387 """ 

2388 return self._extended._update_config_proto(config_proto) # pylint: disable=protected-access 

2389 

2390 

2391# NOTE(josh11b): For any strategy that needs to support tf.compat.v1, 

2392# instead descend from StrategyExtendedV1. 

2393@tf_export("distribute.StrategyExtended", v1=[]) 

2394class StrategyExtendedV2(object): 

2395 """Additional APIs for algorithms that need to be distribution-aware. 

2396 

2397 Note: For most usage of `tf.distribute.Strategy`, there should be no need to 

2398 call these methods, since TensorFlow libraries (such as optimizers) already 

2399 call these methods when needed on your behalf. 

2400 

2401 

2402 Some common use cases of functions on this page: 

2403 

2404 * _Locality_ 

2405 

2406 `tf.distribute.DistributedValues` can have the same _locality_ as a 

2407 _distributed variable_, which leads to a mirrored value residing on the same 

2408 devices as the variable (as opposed to the compute devices). Such values may 

2409 be passed to a call to `tf.distribute.StrategyExtended.update` to update the 

2410 value of a variable. You may use 

2411 `tf.distribute.StrategyExtended.colocate_vars_with` to give a variable the 

2412 same locality as another variable. You may convert a "PerReplica" value to a 

2413 variable's locality by using `tf.distribute.StrategyExtended.reduce_to` or 

2414 `tf.distribute.StrategyExtended.batch_reduce_to`. 

2415 

2416 * _How to update a distributed variable_ 

2417 

2418 A distributed variable is variables created on multiple devices. As discussed 

2419 in the [glossary](https://www.tensorflow.org/api_docs/python/tf/distribute), 

2420 mirrored variable and SyncOnRead variable are two examples. The standard 

2421 pattern for updating distributed variables is to: 

2422 

2423 1. In your function passed to `tf.distribute.Strategy.run`, 

2424 compute a list of (update, variable) pairs. For example, the update might 

2425 be a gradient of the loss with respect to the variable. 

2426 2. Switch to cross-replica mode by calling 

2427 `tf.distribute.get_replica_context().merge_call()` with the updates and 

2428 variables as arguments. 

2429 3. Call 

2430 `tf.distribute.StrategyExtended.reduce_to(VariableAggregation.SUM, t, v)` 

2431 (for one variable) or `tf.distribute.StrategyExtended.batch_reduce_to` 

2432 (for a list of variables) to sum the updates. 

2433 4. Call `tf.distribute.StrategyExtended.update(v)` for each variable to update 

2434 its value. 

2435 

2436 Steps 2 through 4 are done automatically by class 

2437 `tf.keras.optimizers.Optimizer` if you call its 

2438 `tf.keras.optimizers.Optimizer.apply_gradients` method in a replica context. 

2439 

2440 In fact, a higher-level solution to update a distributed variable is by 

2441 calling `assign` on the variable as you would do to a regular `tf.Variable`. 

2442 You can call the method in both _replica context_ and _cross-replica context_. 

2443 For a _mirrored variable_, calling `assign` in _replica context_ requires you 

2444 to specify the `aggregation` type in the variable constructor. In that case, 

2445 the context switching and sync described in steps 2 through 4 are handled for 

2446 you. If you call `assign` on _mirrored variable_ in _cross-replica context_, 

2447 you can only assign a single value or assign values from another mirrored 

2448 variable or a mirrored `tf.distribute.DistributedValues`. For a _SyncOnRead 

2449 variable_, in _replica context_, you can simply call `assign` on it and no 

2450 aggregation happens under the hood. In _cross-replica context_, you can only 

2451 assign a single value to a SyncOnRead variable. One example case is restoring 

2452 from a checkpoint: if the `aggregation` type of the variable is 

2453 `tf.VariableAggregation.SUM`, it is assumed that replica values were added 

2454 before checkpointing, so at the time of restoring, the value is divided by 

2455 the number of replicas and then assigned to each replica; if the `aggregation` 

2456 type is `tf.VariableAggregation.MEAN`, the value is assigned to each replica 

2457 directly. 

2458 

2459 """ 

2460 

2461 def __init__(self, container_strategy): 

2462 self._container_strategy_weakref = weakref.ref(container_strategy) 

2463 self._default_device = None 

2464 # This property is used to determine if we should set drop_remainder=True 

2465 # when creating Datasets from numpy array inputs. 

2466 self._require_static_shapes = False 

2467 

2468 def _resource_creator_scope(self): 

2469 """Returns one or a list of ops.resource_creator_scope for some Strategy.""" 

2470 return None 

2471 

2472 def _container_strategy(self): 

2473 """Get the containing `tf.distribute.Strategy`. 

2474 

2475 This should not generally be needed except when creating a new 

2476 `ReplicaContext` and to validate that the caller is in the correct 

2477 `scope()`. 

2478 

2479 Returns: 

2480 The `tf.distribute.Strategy` such that `strategy.extended` is `self`. 

2481 """ 

2482 container_strategy = self._container_strategy_weakref() 

2483 assert container_strategy is not None 

2484 return container_strategy 

2485 

2486 def _scope(self, strategy): 

2487 """Implementation of tf.distribute.Strategy.scope().""" 

2488 

2489 def creator_with_resource_vars(next_creator, **kwargs): 

2490 """Variable creator to use in `_CurrentDistributionContext`.""" 

2491 if ops.inside_function(): 

2492 if_graph_building = "graph_building" 

2493 else: 

2494 if_graph_building = "not_graph_building" 

2495 

2496 with monitoring.MonitoredTimer(distributed_variable_creation_time_counter.get_cell(strategy.__class__.__name__, if_graph_building)): 

2497 _require_strategy_scope_extended(self) 

2498 kwargs["use_resource"] = True 

2499 kwargs["distribute_strategy"] = strategy 

2500 

2501 # Unwrap `initial_value` if it is a `CheckpointInitialValue` to avoid 

2502 # dereferencing a `Tensor` that is without a `name`. We still need to 

2503 # propagate the metadata it's holding. 

2504 if isinstance(kwargs["initial_value"], trackable.CheckpointInitialValue): 

2505 checkpoint_restore_uid = kwargs[ 

2506 "initial_value"].checkpoint_position.restore_uid 

2507 kwargs["initial_value"] = kwargs["initial_value"].wrapped_value 

2508 elif isinstance(kwargs["initial_value"], 

2509 trackable.CheckpointInitialValueCallable): 

2510 checkpoint_restore_uid = kwargs[ 

2511 "initial_value"].checkpoint_position.restore_uid 

2512 elif (isinstance(kwargs["initial_value"], functools.partial) and 

2513 isinstance(kwargs["initial_value"].func, 

2514 trackable.CheckpointInitialValueCallable)): 

2515 # Some libraries (e.g, Keras) create partial function out of initializer 

2516 # to bind shape/dtype, for example: 

2517 # initial_val = functools.partial(initializer, shape, dtype=dtype) 

2518 # Therefore to get the restore_uid we need to examine the "func" of 

2519 # the partial function. 

2520 checkpoint_restore_uid = kwargs[ 

2521 "initial_value"].func.checkpoint_position.restore_uid 

2522 else: 

2523 checkpoint_restore_uid = None 

2524 

2525 created = self._create_variable(next_creator, **kwargs) 

2526 

2527 if checkpoint_restore_uid is not None: 

2528 # pylint: disable=protected-access 

2529 # Let the checkpointing infrastructure know that the variable was 

2530 # already restored so it doesn't waste memory loading the value again. 

2531 # In this case of CheckpointInitialValueCallable this may already be 

2532 # done by the final variable creator, but it doesn't hurt to do it 

2533 # again. 

2534 created._maybe_initialize_trackable() 

2535 created._update_uid = checkpoint_restore_uid 

2536 # pylint: enable=protected-access 

2537 return created 

2538 

2539 def distributed_getter(getter, *args, **kwargs): 

2540 if not self._allow_variable_partition(): 

2541 if kwargs.pop("partitioner", None) is not None: 

2542 tf_logging.log_first_n( 

2543 tf_logging.WARN, "Partitioned variables are disabled when using " 

2544 "current tf.distribute.Strategy.", 1) 

2545 return getter(*args, **kwargs) 

2546 

2547 return _CurrentDistributionContext( 

2548 strategy, 

2549 variable_scope.variable_creator_scope(creator_with_resource_vars), 

2550 variable_scope.variable_scope( 

2551 variable_scope.get_variable_scope(), 

2552 custom_getter=distributed_getter), 

2553 strategy.extended._resource_creator_scope(), # pylint: disable=protected-access 

2554 self._default_device) 

2555 

2556 def _allow_variable_partition(self): 

2557 return False 

2558 

2559 def _create_variable(self, next_creator, **kwargs): 

2560 # Note: should support "colocate_with" argument. 

2561 raise NotImplementedError("must be implemented in descendants") 

2562 

2563 def variable_created_in_scope(self, v): 

2564 """Tests whether `v` was created while this strategy scope was active. 

2565 

2566 Variables created inside the strategy scope are "owned" by it: 

2567 

2568 >>> strategy = tf.distribute.MirroredStrategy() 

2569 >>> with strategy.scope(): 

2570 ... v = tf.Variable(1.) 

2571 >>> strategy.extended.variable_created_in_scope(v) 

2572 True 

2573 

2574 Variables created outside the strategy are not owned by it: 

2575 

2576 >>> strategy = tf.distribute.MirroredStrategy() 

2577 >>> v = tf.Variable(1.) 

2578 >>> strategy.extended.variable_created_in_scope(v) 

2579 False 

2580 

2581 Args: 

2582 v: A `tf.Variable` instance. 

2583 

2584 Returns: 

2585 True if `v` was created inside the scope, False if not. 

2586 """ 

2587 return v._distribute_strategy == self._container_strategy_weakref() # pylint: disable=protected-access 

2588 

2589 def colocate_vars_with(self, colocate_with_variable): 

2590 """Scope that controls which devices variables will be created on. 

2591 

2592 No operations should be added to the graph inside this scope, it 

2593 should only be used when creating variables (some implementations 

2594 work by changing variable creation, others work by using a 

2595 tf.compat.v1.colocate_with() scope). 

2596 

2597 This may only be used inside `self.scope()`. 

2598 

2599 Example usage: 

2600 

2601 ``` 

2602 with strategy.scope(): 

2603 var1 = tf.Variable(...) 

2604 with strategy.extended.colocate_vars_with(var1): 

2605 # var2 and var3 will be created on the same device(s) as var1 

2606 var2 = tf.Variable(...) 

2607 var3 = tf.Variable(...) 

2608 

2609 def fn(v1, v2, v3): 

2610 # operates on v1 from var1, v2 from var2, and v3 from var3 

2611 

2612 # `fn` runs on every device `var1` is on, `var2` and `var3` will be there 

2613 # too. 

2614 strategy.extended.update(var1, fn, args=(var2, var3)) 

2615 ``` 

2616 

2617 Args: 

2618 colocate_with_variable: A variable created in this strategy's `scope()`. 

2619 Variables created while in the returned context manager will be on the 

2620 same set of devices as `colocate_with_variable`. 

2621 

2622 Returns: 

2623 A context manager. 

2624 """ 

2625 

2626 def create_colocated_variable(next_creator, **kwargs): 

2627 _require_strategy_scope_extended(self) 

2628 kwargs["use_resource"] = True 

2629 kwargs["colocate_with"] = colocate_with_variable 

2630 return next_creator(**kwargs) 

2631 

2632 _require_strategy_scope_extended(self) 

2633 self._validate_colocate_with_variable(colocate_with_variable) 

2634 return variable_scope.variable_creator_scope(create_colocated_variable) 

2635 

2636 def _validate_colocate_with_variable(self, colocate_with_variable): 

2637 """Validate `colocate_with_variable` argument to `colocate_vars_with`.""" 

2638 pass 

2639 

2640 def _make_dataset_iterator(self, dataset): 

2641 raise NotImplementedError("must be implemented in descendants") 

2642 

2643 def _make_input_fn_iterator(self, input_fn, replication_mode): 

2644 raise NotImplementedError("must be implemented in descendants") 

2645 

2646 def _experimental_distribute_dataset(self, dataset, options): 

2647 raise NotImplementedError("must be implemented in descendants") 

2648 

2649 def _distribute_datasets_from_function(self, dataset_fn, options): 

2650 raise NotImplementedError("must be implemented in descendants") 

2651 

2652 def _experimental_distribute_values_from_function(self, value_fn): 

2653 raise NotImplementedError("must be implemented in descendants") 

2654 

2655 def _reduce(self, reduce_op, value): 

2656 # Default implementation until we have an implementation for each strategy. 

2657 dst = device_util.current() or self._default_device or "/device:CPU:0" 

2658 return self._local_results(self.reduce_to(reduce_op, value, dst))[0] 

2659 

2660 def reduce_to(self, reduce_op, value, destinations, options=None): 

2661 """Combine (via e.g. sum or mean) values across replicas. 

2662 

2663 `reduce_to` aggregates `tf.distribute.DistributedValues` and distributed 

2664 variables. It supports both dense values and `tf.IndexedSlices`. 

2665 

2666 This API currently can only be called in cross-replica context. Other 

2667 variants to reduce values across replicas are: 

2668 * `tf.distribute.StrategyExtended.batch_reduce_to`: the batch version of 

2669 this API. 

2670 * `tf.distribute.ReplicaContext.all_reduce`: the counterpart of this API 

2671 in replica context. It supports both batched and non-batched all-reduce. 

2672 * `tf.distribute.Strategy.reduce`: a more convenient method to reduce 

2673 to the host in cross-replica context. 

2674 

2675 `destinations` specifies where to reduce the value to, e.g. "GPU:0". You can 

2676 also pass in a `Tensor`, and the destinations will be the device of that 

2677 tensor. For all-reduce, pass the same to `value` and `destinations`. 

2678 

2679 It can be used in `tf.distribute.ReplicaContext.merge_call` to write code 

2680 that works for all `tf.distribute.Strategy`. 

2681 

2682 @tf.function 

2683 def step_fn(var): 

2684 

2685 def merge_fn(strategy, value, var): 

2686 # All-reduce the value. Note that `value` here is a 

2687 # `tf.distribute.DistributedValues`. 

2688 reduced = strategy.extended.reduce_to(tf.distribute.ReduceOp.SUM, 

2689 value, destinations=var) 

2690 strategy.extended.update(var, lambda var, value: var.assign(value), 

2691 args=(reduced,)) 

2692 

2693 value = tf.identity(1.) 

2694 tf.distribute.get_replica_context().merge_call(merge_fn, 

2695 args=(value, var)) 

2696 

2697 def run(strategy): 

2698 with strategy.scope(): 

2699 v = tf.Variable(0.) 

2700 strategy.run(step_fn, args=(v,)) 

2701 return v 

2702 

2703 run(tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])) 

2704 MirroredVariable:{ 

2705 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>, 

2706 1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=2.0> 

2707 } 

2708 run(tf.distribute.experimental.CentralStorageStrategy( 

2709 compute_devices=["GPU:0", "GPU:1"], parameter_device="CPU:0")) 

2710 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0> 

2711 run(tf.distribute.OneDeviceStrategy("GPU:0")) 

2712 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0> 

2713 

2714 Args: 

2715 reduce_op: a `tf.distribute.ReduceOp` value specifying how values should 

2716 be combined. Allows using string representation of the enum such as 

2717 "SUM", "MEAN". 

2718 value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` like object. 

2719 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a 

2720 `tf.Tensor` alike object, or a device string. It specifies the devices 

2721 to reduce to. To perform an all-reduce, pass the same to `value` and 

2722 `destinations`. Note that if it's a `tf.Variable`, the value is reduced 

2723 to the devices of that variable, and this method doesn't update the 

2724 variable. 

2725 options: a `tf.distribute.experimental.CommunicationOptions`. Options to 

2726 perform collective operations. This overrides the default options if the 

2727 `tf.distribute.Strategy` takes one in the constructor. See 

2728 `tf.distribute.experimental.CommunicationOptions` for details of the 

2729 options. 

2730 

2731 Returns: 

2732 A tensor or value reduced to `destinations`. 

2733 """ 

2734 if options is None: 

2735 options = collective_util.Options() 

2736 _require_cross_replica_or_default_context_extended(self) 

2737 assert not isinstance(destinations, (list, tuple)) 

2738 assert not isinstance(reduce_op, variable_scope.VariableAggregation) 

2739 if isinstance(reduce_op, six.string_types): 

2740 reduce_op = reduce_util.ReduceOp(reduce_op.upper()) 

2741 assert (reduce_op == reduce_util.ReduceOp.SUM or 

2742 reduce_op == reduce_util.ReduceOp.MEAN) 

2743 return self._reduce_to(reduce_op, value, destinations, options) 

2744 

2745 def _reduce_to(self, reduce_op, value, destinations, options): 

2746 raise NotImplementedError("must be implemented in descendants") 

2747 

2748 def batch_reduce_to(self, reduce_op, value_destination_pairs, options=None): 

2749 """Combine multiple `reduce_to` calls into one for faster execution. 

2750 

2751 Similar to `reduce_to`, but accepts a list of (value, destinations) pairs. 

2752 It's more efficient than reduce each value separately. 

2753 

2754 This API currently can only be called in cross-replica context. Other 

2755 variants to reduce values across replicas are: 

2756 * `tf.distribute.StrategyExtended.reduce_to`: the non-batch version of 

2757 this API. 

2758 * `tf.distribute.ReplicaContext.all_reduce`: the counterpart of this API 

2759 in replica context. It supports both batched and non-batched all-reduce. 

2760 * `tf.distribute.Strategy.reduce`: a more convenient method to reduce 

2761 to the host in cross-replica context. 

2762 

2763 See `reduce_to` for more information. 

2764 

2765 @tf.function 

2766 def step_fn(var): 

2767 

2768 def merge_fn(strategy, value, var): 

2769 # All-reduce the value. Note that `value` here is a 

2770 # `tf.distribute.DistributedValues`. 

2771 reduced = strategy.extended.batch_reduce_to( 

2772 tf.distribute.ReduceOp.SUM, [(value, var)])[0] 

2773 strategy.extended.update(var, lambda var, value: var.assign(value), 

2774 args=(reduced,)) 

2775 

2776 value = tf.identity(1.) 

2777 tf.distribute.get_replica_context().merge_call(merge_fn, 

2778 args=(value, var)) 

2779 

2780 def run(strategy): 

2781 with strategy.scope(): 

2782 v = tf.Variable(0.) 

2783 strategy.run(step_fn, args=(v,)) 

2784 return v 

2785 

2786 run(tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])) 

2787 MirroredVariable:{ 

2788 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>, 

2789 1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=2.0> 

2790 } 

2791 run(tf.distribute.experimental.CentralStorageStrategy( 

2792 compute_devices=["GPU:0", "GPU:1"], parameter_device="CPU:0")) 

2793 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0> 

2794 run(tf.distribute.OneDeviceStrategy("GPU:0")) 

2795 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0> 

2796 

2797 Args: 

2798 reduce_op: a `tf.distribute.ReduceOp` value specifying how values should 

2799 be combined. Allows using string representation of the enum such as 

2800 "SUM", "MEAN". 

2801 value_destination_pairs: a sequence of (value, destinations) pairs. See 

2802 `tf.distribute.Strategy.reduce_to` for descriptions. 

2803 options: a `tf.distribute.experimental.CommunicationOptions`. Options to 

2804 perform collective operations. This overrides the default options if the 

2805 `tf.distribute.Strategy` takes one in the constructor. See 

2806 `tf.distribute.experimental.CommunicationOptions` for details of the 

2807 options. 

2808 

2809 Returns: 

2810 A list of reduced values, one per pair in `value_destination_pairs`. 

2811 """ 

2812 if options is None: 

2813 options = collective_util.Options() 

2814 _require_cross_replica_or_default_context_extended(self) 

2815 assert not isinstance(reduce_op, variable_scope.VariableAggregation) 

2816 if isinstance(reduce_op, six.string_types): 

2817 reduce_op = reduce_util.ReduceOp(reduce_op.upper()) 

2818 return self._batch_reduce_to(reduce_op, value_destination_pairs, options) 

2819 

2820 def _batch_reduce_to(self, reduce_op, value_destination_pairs, options): 

2821 return [ 

2822 self.reduce_to(reduce_op, t, destinations=v, options=options) 

2823 for t, v in value_destination_pairs 

2824 ] 

2825 

2826 def _replica_ctx_all_reduce(self, reduce_op, value, options=None): 

2827 """All-reduce `value` across all replicas so that all get the final result. 

2828 

2829 If `value` is a nested structure of tensors, all-reduces of these tensors 

2830 will be batched when possible. `options` can be set to hint the batching 

2831 behavior. 

2832 

2833 This API must be called in a replica context. 

2834 

2835 Args: 

2836 reduce_op: A `tf.distribute.ReduceOp` value specifying how values should 

2837 be combined. 

2838 value: Value to be reduced. A tensor or a nested structure of tensors. 

2839 options: A `tf.distribute.experimental.CommunicationOptions`. Options to 

2840 perform collective operations. This overrides the default options if the 

2841 `tf.distribute.Strategy` takes one in the constructor. 

2842 

2843 Returns: 

2844 A tensor or a nested strucutre of tensors with the reduced values. The 

2845 structure is the same as `value`. 

2846 """ 

2847 if options is None: 

2848 options = collective_util.Options() 

2849 replica_context = get_replica_context() 

2850 assert replica_context, ( 

2851 "`StrategyExtended._replica_ctx_all_reduce` must be called in" 

2852 " a replica context") 

2853 

2854 def merge_fn(_, flat_value): 

2855 return self.batch_reduce_to(reduce_op, [(v, v) for v in flat_value], 

2856 options) 

2857 

2858 reduced = replica_context.merge_call(merge_fn, args=(nest.flatten(value),)) 

2859 return nest.pack_sequence_as(value, reduced) 

2860 

2861 def _replica_ctx_update(self, var, fn, args=(), kwargs=None, group=True): 

2862 """Run `fn` with `args` and `kwargs` to update `var`.""" 

2863 # This method is called by ReplicaContext.update. Strategies who'd like to 

2864 # remove merge_call in this path should override this method. 

2865 replica_context = get_replica_context() 

2866 if not replica_context: 

2867 raise ValueError("`StrategyExtended._replica_ctx_update` must be called " 

2868 "in a replica context.") 

2869 

2870 def merge_fn(_, *merged_args, **merged_kwargs): 

2871 return self.update(var, fn, merged_args, merged_kwargs, group=group) 

2872 

2873 return replica_context.merge_call(merge_fn, args=args, kwargs=kwargs) 

2874 

2875 def _gather_to(self, value, destinations, axis, options=None): 

2876 """Gather `value` across replicas along axis-th dimension to `destinations`. 

2877 

2878 `gather_to` gathers `tf.distribute.DistributedValues` or `tf.Tensor`-like 

2879 object, along `axis`-th dimension. It supports only dense tensors but NOT 

2880 sparse tensor. This API can only be called in cross-replica context. 

2881 

2882 Args: 

2883 value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` like object. 

2884 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a 

2885 `tf.Tensor` alike object, or a device string. It specifies the devices 

2886 to reduce to. To perform an all-gather, pass the same to `value` and 

2887 `destinations`. Note that if it's a `tf.Variable`, the value is reduced 

2888 to the devices of that variable, and this method doesn't update the 

2889 variable. 

2890 axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the 

2891 range [0, rank(value)). 

2892 options: a `tf.distribute.experimental.CommunicationOptions`. Options to 

2893 perform collective operations. This overrides the default options if the 

2894 `tf.distribute.Strategy` takes one in the constructor. See 

2895 `tf.distribute.experimental.CommunicationOptions` for details of the 

2896 options. 

2897 

2898 Returns: 

2899 A tensor or value gathered to `destinations`. 

2900 """ 

2901 _require_cross_replica_or_default_context_extended(self) 

2902 assert not isinstance(destinations, (list, tuple)) 

2903 if options is None: 

2904 options = collective_util.Options() 

2905 return self._gather_to_implementation(value, destinations, axis, options) 

2906 

2907 def _gather_to_implementation(self, value, destinations, axis, options): 

2908 raise NotImplementedError("_gather_to must be implemented in descendants") 

2909 

2910 def _batch_gather_to(self, value_destination_pairs, axis, options=None): 

2911 _require_cross_replica_or_default_context_extended(self) 

2912 if options is None: 

2913 options = collective_util.Options() 

2914 return [ 

2915 self._gather_to(t, destinations=v, axis=axis, options=options) 

2916 for t, v in value_destination_pairs 

2917 ] 

2918 

2919 def update(self, var, fn, args=(), kwargs=None, group=True): 

2920 """Run `fn` to update `var` using inputs mirrored to the same devices. 

2921 

2922 `tf.distribute.StrategyExtended.update` takes a distributed variable `var` 

2923 to be updated, an update function `fn`, and `args` and `kwargs` for `fn`. It 

2924 applies `fn` to each component variable of `var` and passes corresponding 

2925 values from `args` and `kwargs`. Neither `args` nor `kwargs` may contain 

2926 per-replica values. If they contain mirrored values, they will be unwrapped 

2927 before calling `fn`. For example, `fn` can be `assign_add` and `args` can be 

2928 a mirrored DistributedValues where each component contains the value to be 

2929 added to this mirrored variable `var`. Calling `update` will call 

2930 `assign_add` on each component variable of `var` with the corresponding 

2931 tensor value on that device. 

2932 

2933 Example usage: 

2934 

2935 ```python 

2936 strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # With 2 

2937 devices 

2938 with strategy.scope(): 

2939 v = tf.Variable(5.0, aggregation=tf.VariableAggregation.SUM) 

2940 def update_fn(v): 

2941 return v.assign(1.0) 

2942 result = strategy.extended.update(v, update_fn) 

2943 # result is 

2944 # Mirrored:{ 

2945 # 0: tf.Tensor(1.0, shape=(), dtype=float32), 

2946 # 1: tf.Tensor(1.0, shape=(), dtype=float32) 

2947 # } 

2948 ``` 

2949 

2950 If `var` is mirrored across multiple devices, then this method implements 

2951 logic as following: 

2952 

2953 ```python 

2954 results = {} 

2955 for device, v in var: 

2956 with tf.device(device): 

2957 # args and kwargs will be unwrapped if they are mirrored. 

2958 results[device] = fn(v, *args, **kwargs) 

2959 return merged(results) 

2960 ``` 

2961 

2962 Otherwise, this method returns `fn(var, *args, **kwargs)` colocated with 

2963 `var`. 

2964 

2965 Args: 

2966 var: Variable, possibly mirrored to multiple devices, to operate on. 

2967 fn: Function to call. Should take the variable as the first argument. 

2968 args: Tuple or list. Additional positional arguments to pass to `fn()`. 

2969 kwargs: Dict with keyword arguments to pass to `fn()`. 

2970 group: Boolean. Defaults to True. If False, the return value will be 

2971 unwrapped. 

2972 

2973 Returns: 

2974 By default, the merged return value of `fn` across all replicas. The 

2975 merged result has dependencies to make sure that if it is evaluated at 

2976 all, the side effects (updates) will happen on every replica. If instead 

2977 "group=False" is specified, this function will return a nest of lists 

2978 where each list has an element per replica, and the caller is responsible 

2979 for ensuring all elements are executed. 

2980 """ 

2981 # TODO(b/178944108): Update the documentation to relfect the fact that 

2982 # `update` can be called in a replica context. 

2983 if kwargs is None: 

2984 kwargs = {} 

2985 replica_context = get_replica_context() 

2986 # pylint: disable=protected-access 

2987 if (replica_context is None or replica_context is 

2988 _get_default_replica_context()): 

2989 fn = autograph.tf_convert( 

2990 fn, autograph_ctx.control_status_ctx(), convert_by_default=False) 

2991 with self._container_strategy().scope(): 

2992 return self._update(var, fn, args, kwargs, group) 

2993 else: 

2994 return self._replica_ctx_update( 

2995 var, fn, args=args, kwargs=kwargs, group=group) 

2996 

2997 def _update(self, var, fn, args, kwargs, group): 

2998 raise NotImplementedError("must be implemented in descendants") 

2999 

3000 def _local_results(self, val): 

3001 """Returns local results per replica as a tuple.""" 

3002 if isinstance(val, ds_types.DistributedValues): 

3003 return val._values # pylint: disable=protected-access 

3004 

3005 if nest.is_nested(val): 

3006 replica_values = [] 

3007 

3008 def get_values(x, index): 

3009 if isinstance(x, ds_types.DistributedValues): 

3010 return x._values[index] # pylint: disable=protected-access 

3011 return x 

3012 

3013 for i in range(len(self.worker_devices)): 

3014 replica_values.append( 

3015 nest.map_structure( 

3016 lambda x: get_values(x, i), # pylint: disable=cell-var-from-loop 

3017 val)) 

3018 return tuple(replica_values) 

3019 return (val,) 

3020 

3021 def value_container(self, value): 

3022 """Returns the container that this per-replica `value` belongs to. 

3023 

3024 Args: 

3025 value: A value returned by `run()` or a variable created in `scope()`. 

3026 

3027 Returns: 

3028 A container that `value` belongs to. 

3029 If value does not belong to any container (including the case of 

3030 container having been destroyed), returns the value itself. 

3031 `value in experimental_local_results(value_container(value))` will 

3032 always be true. 

3033 """ 

3034 raise NotImplementedError("must be implemented in descendants") 

3035 

3036 def _group(self, value, name=None): 

3037 """Implementation of `group`.""" 

3038 value = nest.flatten(self._local_results(value)) 

3039 

3040 if len(value) != 1 or name is not None: 

3041 return control_flow_ops.group(value, name=name) 

3042 # Special handling for the common case of one op. 

3043 v, = value 

3044 if hasattr(v, "op"): 

3045 v = v.op 

3046 return v 

3047 

3048 @property 

3049 def experimental_require_static_shapes(self): 

3050 """Returns `True` if static shape is required; `False` otherwise.""" 

3051 return self._require_static_shapes 

3052 

3053 @property 

3054 def _num_replicas_in_sync(self): 

3055 """Returns number of replicas over which gradients are aggregated.""" 

3056 raise NotImplementedError("must be implemented in descendants") 

3057 

3058 @property 

3059 def worker_devices(self): 

3060 """Returns the tuple of all devices used to for compute replica execution. 

3061 """ 

3062 # TODO(josh11b): More docstring 

3063 raise NotImplementedError("must be implemented in descendants") 

3064 

3065 @property 

3066 def parameter_devices(self): 

3067 """Returns the tuple of all devices used to place variables.""" 

3068 # TODO(josh11b): More docstring 

3069 raise NotImplementedError("must be implemented in descendants") 

3070 

3071 def _configure(self, 

3072 session_config=None, 

3073 cluster_spec=None, 

3074 task_type=None, 

3075 task_id=None): 

3076 """Configures the strategy class.""" 

3077 del session_config, cluster_spec, task_type, task_id 

3078 

3079 def _update_config_proto(self, config_proto): 

3080 return copy.deepcopy(config_proto) 

3081 

3082 def _in_multi_worker_mode(self): 

3083 """Whether this strategy indicates working in multi-worker settings. 

3084 

3085 Multi-worker training refers to the setup where the training is 

3086 distributed across multiple workers, as opposed to the case where 

3087 only a local process performs the training. This function is 

3088 used by higher-level APIs such as Keras' `model.fit()` to infer 

3089 for example whether or not a distribute coordinator should be run, 

3090 and thus TensorFlow servers should be started for communication 

3091 with other servers in the cluster, or whether or not saving/restoring 

3092 checkpoints is relevant for preemption fault tolerance. 

3093 

3094 Subclasses should override this to provide whether the strategy is 

3095 currently in multi-worker setup. 

3096 

3097 Experimental. Signature and implementation are subject to change. 

3098 """ 

3099 raise NotImplementedError("must be implemented in descendants") 

3100 

3101 

3102@tf_export(v1=["distribute.StrategyExtended"]) # pylint: disable=missing-docstring 

3103class StrategyExtendedV1(StrategyExtendedV2): 

3104 

3105 __doc__ = StrategyExtendedV2.__doc__ 

3106 

3107 def experimental_make_numpy_dataset(self, numpy_input, session=None): 

3108 """Makes a dataset for input provided via a numpy array. 

3109 

3110 This avoids adding `numpy_input` as a large constant in the graph, 

3111 and copies the data to the machine or machines that will be processing 

3112 the input. 

3113 

3114 Args: 

3115 numpy_input: A nest of NumPy input arrays that will be distributed evenly 

3116 across all replicas. Note that lists of Numpy arrays are stacked, as 

3117 that is normal `tf.data.Dataset` behavior. 

3118 session: (TensorFlow v1.x graph execution only) A session used for 

3119 initialization. 

3120 

3121 Returns: 

3122 A `tf.data.Dataset` representing `numpy_input`. 

3123 """ 

3124 _require_cross_replica_or_default_context_extended(self) 

3125 return self._experimental_make_numpy_dataset(numpy_input, session=session) 

3126 

3127 def _experimental_make_numpy_dataset(self, numpy_input, session): 

3128 raise NotImplementedError("must be implemented in descendants") 

3129 

3130 def broadcast_to(self, tensor, destinations): 

3131 """Mirror a tensor on one device to all worker devices. 

3132 

3133 Args: 

3134 tensor: A Tensor value to broadcast. 

3135 destinations: A mirrored variable or device string specifying the 

3136 destination devices to copy `tensor` to. 

3137 

3138 Returns: 

3139 A value mirrored to `destinations` devices. 

3140 """ 

3141 assert destinations is not None # from old strategy.broadcast() 

3142 # TODO(josh11b): More docstring 

3143 _require_cross_replica_or_default_context_extended(self) 

3144 assert not isinstance(destinations, (list, tuple)) 

3145 return self._broadcast_to(tensor, destinations) 

3146 

3147 def _broadcast_to(self, tensor, destinations): 

3148 raise NotImplementedError("must be implemented in descendants") 

3149 

3150 @deprecated(None, "please use `run` instead.") 

3151 def experimental_run_steps_on_iterator(self, 

3152 fn, 

3153 iterator, 

3154 iterations=1, 

3155 initial_loop_values=None): 

3156 """DEPRECATED: please use `run` instead. 

3157 

3158 Run `fn` with input from `iterator` for `iterations` times. 

3159 

3160 This method can be used to run a step function for training a number of 

3161 times using input from a dataset. 

3162 

3163 Args: 

3164 fn: function to run using this distribution strategy. The function must 

3165 have the following signature: `def fn(context, inputs)`. `context` is an 

3166 instance of `MultiStepContext` that will be passed when `fn` is run. 

3167 `context` can be used to specify the outputs to be returned from `fn` 

3168 by calling `context.set_last_step_output`. It can also be used to 

3169 capture non tensor outputs by `context.set_non_tensor_output`. See 

3170 `MultiStepContext` documentation for more information. `inputs` will 

3171 have same type/structure as `iterator.get_next()`. Typically, `fn` 

3172 will use `call_for_each_replica` method of the strategy to distribute 

3173 the computation over multiple replicas. 

3174 iterator: Iterator of a dataset that represents the input for `fn`. The 

3175 caller is responsible for initializing the iterator as needed. 

3176 iterations: (Optional) Number of iterations that `fn` should be run. 

3177 Defaults to 1. 

3178 initial_loop_values: (Optional) Initial values to be passed into the 

3179 loop that runs `fn`. Defaults to `None`. # TODO(priyag): Remove 

3180 initial_loop_values argument when we have a mechanism to infer the 

3181 outputs of `fn`. 

3182 

3183 Returns: 

3184 Returns the `MultiStepContext` object which has the following properties, 

3185 among other things: 

3186 - run_op: An op that runs `fn` `iterations` times. 

3187 - last_step_outputs: A dictionary containing tensors set using 

3188 `context.set_last_step_output`. Evaluating this returns the value of 

3189 the tensors after the last iteration. 

3190 - non_tensor_outputs: A dictionary containing anything that was set by 

3191 `fn` by calling `context.set_non_tensor_output`. 

3192 """ 

3193 _require_cross_replica_or_default_context_extended(self) 

3194 with self._container_strategy().scope(): 

3195 return self._experimental_run_steps_on_iterator(fn, iterator, iterations, 

3196 initial_loop_values) 

3197 

3198 def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, 

3199 initial_loop_values): 

3200 raise NotImplementedError("must be implemented in descendants") 

3201 

3202 def call_for_each_replica(self, fn, args=(), kwargs=None): 

3203 """Run `fn` once per replica. 

3204 

3205 `fn` may call `tf.get_replica_context()` to access methods such as 

3206 `replica_id_in_sync_group` and `merge_call()`. 

3207 

3208 `merge_call()` is used to communicate between the replicas and 

3209 re-enter the cross-replica context. All replicas pause their execution 

3210 having encountered a `merge_call()` call. After that the 

3211 `merge_fn`-function is executed. Its results are then unwrapped and 

3212 given back to each replica call. After that execution resumes until 

3213 `fn` is complete or encounters another `merge_call()`. Example: 

3214 

3215 ```python 

3216 # Called once in "cross-replica" context. 

3217 def merge_fn(distribution, three_plus_replica_id): 

3218 # sum the values across replicas 

3219 return sum(distribution.experimental_local_results(three_plus_replica_id)) 

3220 

3221 # Called once per replica in `distribution`, in a "replica" context. 

3222 def fn(three): 

3223 replica_ctx = tf.get_replica_context() 

3224 v = three + replica_ctx.replica_id_in_sync_group 

3225 # Computes the sum of the `v` values across all replicas. 

3226 s = replica_ctx.merge_call(merge_fn, args=(v,)) 

3227 return s + v 

3228 

3229 with distribution.scope(): 

3230 # in "cross-replica" context 

3231 ... 

3232 merged_results = distribution.run(fn, args=[3]) 

3233 # merged_results has the values from every replica execution of `fn`. 

3234 # This statement prints a list: 

3235 print(distribution.experimental_local_results(merged_results)) 

3236 ``` 

3237 

3238 Args: 

3239 fn: function to run (will be run once per replica). 

3240 args: Tuple or list with positional arguments for `fn`. 

3241 kwargs: Dict with keyword arguments for `fn`. 

3242 

3243 Returns: 

3244 Merged return value of `fn` across all replicas. 

3245 """ 

3246 _require_cross_replica_or_default_context_extended(self) 

3247 if kwargs is None: 

3248 kwargs = {} 

3249 with self._container_strategy().scope(): 

3250 return self._call_for_each_replica(fn, args, kwargs) 

3251 

3252 def _call_for_each_replica(self, fn, args, kwargs): 

3253 raise NotImplementedError("must be implemented in descendants") 

3254 

3255 def read_var(self, v): 

3256 """Reads the value of a variable. 

3257 

3258 Returns the aggregate value of a replica-local variable, or the 

3259 (read-only) value of any other variable. 

3260 

3261 Args: 

3262 v: A variable allocated within the scope of this `tf.distribute.Strategy`. 

3263 

3264 Returns: 

3265 A tensor representing the value of `v`, aggregated across replicas if 

3266 necessary. 

3267 """ 

3268 raise NotImplementedError("must be implemented in descendants") 

3269 

3270 def update_non_slot( 

3271 self, colocate_with, fn, args=(), kwargs=None, group=True): 

3272 """Runs `fn(*args, **kwargs)` on `colocate_with` devices. 

3273 

3274 Used to update non-slot variables. 

3275 

3276 DEPRECATED: TF 1.x ONLY. 

3277 

3278 Args: 

3279 colocate_with: Devices returned by `non_slot_devices()`. 

3280 fn: Function to execute. 

3281 args: Tuple or list. Positional arguments to pass to `fn()`. 

3282 kwargs: Dict with keyword arguments to pass to `fn()`. 

3283 group: Boolean. Defaults to True. If False, the return value will be 

3284 unwrapped. 

3285 

3286 Returns: 

3287 Return value of `fn`, possibly merged across devices. 

3288 """ 

3289 _require_cross_replica_or_default_context_extended(self) 

3290 if kwargs is None: 

3291 kwargs = {} 

3292 fn = autograph.tf_convert( 

3293 fn, autograph_ctx.control_status_ctx(), convert_by_default=False) 

3294 with self._container_strategy().scope(): 

3295 return self._update_non_slot(colocate_with, fn, args, kwargs, group) 

3296 

3297 def _update_non_slot(self, colocate_with, fn, args, kwargs, group): 

3298 raise NotImplementedError("must be implemented in descendants") 

3299 

3300 def non_slot_devices(self, var_list): 

3301 """Device(s) for non-slot variables. 

3302 

3303 DEPRECATED: TF 1.x ONLY. 

3304 

3305 This method returns non-slot devices where non-slot variables are placed. 

3306 Users can create non-slot variables on these devices by using a block: 

3307 

3308 ```python 

3309 with tf.distribute.StrategyExtended.colocate_vars_with(tf.distribute.StrategyExtended.non_slot_devices(...)): 

3310 ... 

3311 ``` 

3312 

3313 Args: 

3314 var_list: The list of variables being optimized, needed with the 

3315 default `tf.distribute.Strategy`. 

3316 Returns: 

3317 A sequence of devices for non-slot variables. 

3318 """ 

3319 raise NotImplementedError("must be implemented in descendants") 

3320 

3321 def _use_merge_call(self): 

3322 """Whether to use merge-calls inside the distributed strategy.""" 

3323 return True 

3324 

3325 @property 

3326 def experimental_between_graph(self): 

3327 """Whether the strategy uses between-graph replication or not. 

3328 

3329 This is expected to return a constant value that will not be changed 

3330 throughout its life cycle. 

3331 """ 

3332 raise NotImplementedError("must be implemented in descendants") 

3333 

3334 @property 

3335 def experimental_should_init(self): 

3336 """Whether initialization is needed.""" 

3337 raise NotImplementedError("must be implemented in descendants") 

3338 

3339 @property 

3340 def should_checkpoint(self): 

3341 """Whether checkpointing is needed.""" 

3342 raise NotImplementedError("must be implemented in descendants") 

3343 

3344 @property 

3345 def should_save_summary(self): 

3346 """Whether saving summaries is needed.""" 

3347 raise NotImplementedError("must be implemented in descendants") 

3348 

3349 

3350# A note about the difference between the context managers 

3351# `ReplicaContext` (defined here) and `_CurrentDistributionContext` 

3352# (defined above) used by `tf.distribute.Strategy.scope()`: 

3353# 

3354# * a ReplicaContext is only present during a `run()` 

3355# call (except during a `merge_run` call) and in such a scope it 

3356# will be returned by calls to `get_replica_context()`. Implementers of new 

3357# Strategy descendants will frequently also need to 

3358# define a descendant of ReplicaContext, and are responsible for 

3359# entering and exiting this context. 

3360# 

3361# * Strategy.scope() sets up a variable_creator scope that 

3362# changes variable creation calls (e.g. to make mirrored 

3363# variables). This is intended as an outer scope that users enter once 

3364# around their model creation and graph definition. There is no 

3365# anticipated need to define descendants of _CurrentDistributionContext. 

3366# It sets the current Strategy for purposes of 

3367# `get_strategy()` and `has_strategy()` 

3368# and switches the thread mode to a "cross-replica context". 

3369class ReplicaContextBase(object): 

3370 """A class with a collection of APIs that can be called in a replica context. 

3371 

3372 You can use `tf.distribute.get_replica_context` to get an instance of 

3373 `ReplicaContext`, which can only be called inside the function passed to 

3374 `tf.distribute.Strategy.run`. 

3375 

3376 >>> strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) 

3377 >>> def func(): 

3378 ... replica_context = tf.distribute.get_replica_context() 

3379 ... return replica_context.replica_id_in_sync_group 

3380 >>> strategy.run(func) 

3381 PerReplica:{ 

3382 0: <tf.Tensor: shape=(), dtype=int32, numpy=0>, 

3383 1: <tf.Tensor: shape=(), dtype=int32, numpy=1> 

3384 } 

3385 """ 

3386 

3387 def __init__(self, strategy, replica_id_in_sync_group): 

3388 """Creates a ReplicaContext. 

3389 

3390 Args: 

3391 strategy: A `tf.distribute.Strategy`. 

3392 replica_id_in_sync_group: An integer, a `Tensor` or None. Prefer an 

3393 integer whenever possible to avoid issues with nested `tf.function`. It 

3394 accepts a `Tensor` only to be compatible with `tpu.replicate`. 

3395 """ 

3396 self._strategy = strategy 

3397 self._thread_context = _InReplicaThreadMode( # pylint: disable=protected-access 

3398 self) 

3399 if not (replica_id_in_sync_group is None or 

3400 tensor_util.is_tf_type(replica_id_in_sync_group) or 

3401 isinstance(replica_id_in_sync_group, int)): 

3402 raise ValueError( 

3403 "replica_id_in_sync_group can only be an integer, a Tensor or None.") 

3404 self._replica_id_in_sync_group = replica_id_in_sync_group 

3405 # We need this check because TPUContext extends from ReplicaContext and 

3406 # does not pass a strategy object since it is used by TPUEstimator. 

3407 if strategy: 

3408 self._local_replica_id = strategy.extended._get_local_replica_id( 

3409 replica_id_in_sync_group) 

3410 self._summary_recording_distribution_strategy = None 

3411 

3412 @doc_controls.do_not_generate_docs 

3413 def __enter__(self): 

3414 _push_per_thread_mode(self._thread_context) 

3415 

3416 def replica_id_is_zero(): 

3417 return math_ops.equal(self.replica_id_in_sync_group, 

3418 constant_op.constant(0)) 

3419 

3420 summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access 

3421 self._summary_recording_distribution_strategy = ( 

3422 summary_state.is_recording_distribution_strategy) 

3423 summary_state.is_recording_distribution_strategy = replica_id_is_zero 

3424 

3425 @doc_controls.do_not_generate_docs 

3426 def __exit__(self, exception_type, exception_value, traceback): 

3427 summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access 

3428 summary_state.is_recording_distribution_strategy = ( 

3429 self._summary_recording_distribution_strategy) 

3430 _pop_per_thread_mode() 

3431 

3432 def merge_call(self, merge_fn, args=(), kwargs=None): 

3433 """Merge args across replicas and run `merge_fn` in a cross-replica context. 

3434 

3435 This allows communication and coordination when there are multiple calls 

3436 to the step_fn triggered by a call to `strategy.run(step_fn, ...)`. 

3437 

3438 See `tf.distribute.Strategy.run` for an explanation. 

3439 

3440 If not inside a distributed scope, this is equivalent to: 

3441 

3442 ``` 

3443 strategy = tf.distribute.get_strategy() 

3444 with cross-replica-context(strategy): 

3445 return merge_fn(strategy, *args, **kwargs) 

3446 ``` 

3447 

3448 Args: 

3449 merge_fn: Function that joins arguments from threads that are given as 

3450 PerReplica. It accepts `tf.distribute.Strategy` object as 

3451 the first argument. 

3452 args: List or tuple with positional per-thread arguments for `merge_fn`. 

3453 kwargs: Dict with keyword per-thread arguments for `merge_fn`. 

3454 

3455 Returns: 

3456 The return value of `merge_fn`, except for `PerReplica` values which are 

3457 unpacked. 

3458 """ 

3459 require_replica_context(self) 

3460 if kwargs is None: 

3461 kwargs = {} 

3462 

3463 merge_fn = autograph.tf_convert( 

3464 merge_fn, autograph_ctx.control_status_ctx(), convert_by_default=False) 

3465 return self._merge_call(merge_fn, args, kwargs) 

3466 

3467 def _merge_call(self, merge_fn, args, kwargs): 

3468 """Default implementation for single replica.""" 

3469 _push_per_thread_mode( # thread-local, so not needed with multiple threads 

3470 _CrossReplicaThreadMode(self._strategy)) # pylint: disable=protected-access 

3471 try: 

3472 return merge_fn(self._strategy, *args, **kwargs) 

3473 finally: 

3474 _pop_per_thread_mode() 

3475 

3476 @property 

3477 def num_replicas_in_sync(self): 

3478 """Returns number of replicas that are kept in sync.""" 

3479 return self._strategy.num_replicas_in_sync 

3480 

3481 @property 

3482 def replica_id_in_sync_group(self): 

3483 """Returns the id of the replica. 

3484 

3485 This identifies the replica among all replicas that are kept in sync. The 

3486 value of the replica id can range from 0 to 

3487 `tf.distribute.ReplicaContext.num_replicas_in_sync` - 1. 

3488 

3489 NOTE: This is not guaranteed to be the same ID as the XLA replica ID use 

3490 for low-level operations such as collective_permute. 

3491 

3492 Returns: 

3493 a `Tensor`. 

3494 """ 

3495 # It's important to prefer making the Tensor at call time whenever possible. 

3496 # Keeping Tensors in global states doesn't work well with nested 

3497 # tf.function, since it's possible that the tensor is generated in one func 

3498 # graph, and gets captured by another, which will result in a subtle "An op 

3499 # outside of the function building code is being passed a Graph tensor" 

3500 # error. Making the tensor at call time to ensure it is the same graph where 

3501 # it's used. However to be compatible with tpu.replicate(), 

3502 # self._replica_id_in_sync_group can also be a Tensor. 

3503 if tensor_util.is_tf_type(self._replica_id_in_sync_group): 

3504 return self._replica_id_in_sync_group 

3505 return constant_op.constant( 

3506 self._replica_id_in_sync_group, 

3507 dtypes.int32, 

3508 name="replica_id_in_sync_group") 

3509 

3510 @property 

3511 def _replica_id(self): 

3512 """This is the local replica id in a given sync group.""" 

3513 return self._local_replica_id 

3514 

3515 @property 

3516 def strategy(self): 

3517 """The current `tf.distribute.Strategy` object.""" 

3518 return self._strategy 

3519 

3520 @property 

3521 @deprecation.deprecated(None, "Please avoid relying on devices property.") 

3522 def devices(self): 

3523 """Returns the devices this replica is to be executed on, as a tuple of strings. 

3524 

3525 NOTE: For `tf.distribute.MirroredStrategy` and 

3526 `tf.distribute.experimental.MultiWorkerMirroredStrategy`, this returns a 

3527 nested 

3528 list of device strings, e.g, [["GPU:0"]]. 

3529 """ 

3530 require_replica_context(self) 

3531 return (device_util.current(),) 

3532 

3533 def all_reduce(self, reduce_op, value, options=None): 

3534 """All-reduces `value` across all replicas. 

3535 

3536 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 

3537 >>> def step_fn(): 

3538 ... ctx = tf.distribute.get_replica_context() 

3539 ... value = tf.identity(1.) 

3540 ... return ctx.all_reduce(tf.distribute.ReduceOp.SUM, value) 

3541 >>> strategy.experimental_local_results(strategy.run(step_fn)) 

3542 (<tf.Tensor: shape=(), dtype=float32, numpy=2.0>, 

3543 <tf.Tensor: shape=(), dtype=float32, numpy=2.0>) 

3544 

3545 It supports batched operations. You can pass a list of values and it 

3546 attempts to batch them when possible. You can also specify `options` 

3547 to indicate the desired batching behavior, e.g. batch the values into 

3548 multiple packs so that they can better overlap with computations. 

3549 

3550 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 

3551 >>> def step_fn(): 

3552 ... ctx = tf.distribute.get_replica_context() 

3553 ... value1 = tf.identity(1.) 

3554 ... value2 = tf.identity(2.) 

3555 ... return ctx.all_reduce(tf.distribute.ReduceOp.SUM, [value1, value2]) 

3556 >>> strategy.experimental_local_results(strategy.run(step_fn)) 

3557 ([<tf.Tensor: shape=(), dtype=float32, numpy=2.0>, 

3558 <tf.Tensor: shape=(), dtype=float32, numpy=4.0>], 

3559 [<tf.Tensor: shape=(), dtype=float32, numpy=2.0>, 

3560 <tf.Tensor: shape=(), dtype=float32, numpy=4.0>]) 

3561 

3562 Note that all replicas need to participate in the all-reduce, otherwise this 

3563 operation hangs. Note that if there're multiple all-reduces, they need to 

3564 execute in the same order on all replicas. Dispatching all-reduce based on 

3565 conditions is usually error-prone. 

3566 

3567 Known limitation: if `value` contains `tf.IndexedSlices`, attempting to 

3568 compute gradient w.r.t `value` would result in an error. 

3569 

3570 This API currently can only be called in the replica context. Other 

3571 variants to reduce values across replicas are: 

3572 * `tf.distribute.StrategyExtended.reduce_to`: the reduce and all-reduce API 

3573 in the cross-replica context. 

3574 * `tf.distribute.StrategyExtended.batch_reduce_to`: the batched reduce and 

3575 all-reduce API in the cross-replica context. 

3576 * `tf.distribute.Strategy.reduce`: a more convenient method to reduce 

3577 to the host in cross-replica context. 

3578 

3579 Args: 

3580 reduce_op: a `tf.distribute.ReduceOp` value specifying how values should 

3581 be combined. Allows using string representation of the enum such as 

3582 "SUM", "MEAN". 

3583 value: a potentially nested structure of `tf.Tensor` or `tf.IndexedSlices` which 

3584 `tf.nest.flatten` accepts. The structure and the shapes of `value` need to be 

3585 same on all replicas. 

3586 options: a `tf.distribute.experimental.CommunicationOptions`. Options to 

3587 perform collective operations. This overrides the default options if the 

3588 `tf.distribute.Strategy` takes one in the constructor. See 

3589 `tf.distribute.experimental.CommunicationOptions` for details of the 

3590 options. 

3591 

3592 Returns: 

3593 A nested structure of `tf.Tensor` with the reduced values. The structure 

3594 is the same as `value`. 

3595 """ 

3596 flattened_value = nest.flatten(value) 

3597 has_indexed_slices = False 

3598 

3599 for v in flattened_value: 

3600 if isinstance(v, indexed_slices.IndexedSlices): 

3601 has_indexed_slices = True 

3602 

3603 if isinstance(reduce_op, six.string_types): 

3604 reduce_op = reduce_util.ReduceOp(reduce_op.upper()) 

3605 if options is None: 

3606 options = collective_util.Options() 

3607 

3608 def batch_all_reduce(strategy, *value_flat): 

3609 return strategy.extended.batch_reduce_to( 

3610 reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat], 

3611 options) 

3612 

3613 # Due to the use of `capture_call_time_value` in collective ops, we have 

3614 # to maintain two branches: one w/ merge_call and one w/o. Details can be 

3615 # found in b/184009754. 

3616 if self._strategy.extended._use_merge_call(): # pylint: disable=protected-access 

3617 # TODO(cjfj): Work out why `batch_reduce` doesn't return the correct grad. 

3618 if has_indexed_slices: 

3619 return nest.pack_sequence_as( 

3620 value, 

3621 self.merge_call(batch_all_reduce, args=flattened_value)) 

3622 

3623 @custom_gradient.custom_gradient 

3624 def grad_wrapper(*xs): 

3625 ys = self.merge_call(batch_all_reduce, args=xs) 

3626 # The gradient of an all-sum is itself an all-sum (all-mean, likewise). 

3627 return ys, lambda *dy_s: self.all_reduce(reduce_op, dy_s) 

3628 return nest.pack_sequence_as(value, grad_wrapper(*flattened_value)) 

3629 else: 

3630 if has_indexed_slices: 

3631 return nest.pack_sequence_as( 

3632 value, 

3633 self._strategy.extended._replica_ctx_all_reduce( # pylint: disable=protected-access 

3634 reduce_op, flattened_value, options)) 

3635 

3636 @custom_gradient.custom_gradient 

3637 def grad_wrapper(*xs): 

3638 ys = self._strategy.extended._replica_ctx_all_reduce( # pylint: disable=protected-access 

3639 reduce_op, xs, options) 

3640 # The gradient of an all-sum is itself an all-sum (all-mean, likewise). 

3641 return ys, lambda *dy_s: self.all_reduce(reduce_op, dy_s) 

3642 

3643 return nest.pack_sequence_as(value, grad_wrapper(*flattened_value)) 

3644 

3645 # TODO(josh11b): Implement `start_all_reduce(method, t)` for efficient 

3646 # all-reduce. It would return a function returning the result of reducing `t` 

3647 # across all replicas. The caller would wait to call this function until they 

3648 # needed the reduce result, allowing an efficient implementation: 

3649 # * With eager execution, the reduction could be performed asynchronously 

3650 # in the background, not blocking until the result was needed. 

3651 # * When constructing a graph, it could batch up all reduction requests up 

3652 # to that point that the first result is needed. Most likely this can be 

3653 # implemented in terms of `merge_call()` and `batch_reduce_to()`. 

3654 

3655 

3656@tf_export("distribute.ReplicaContext", v1=[]) 

3657class ReplicaContext(ReplicaContextBase): 

3658 

3659 __doc__ = ReplicaContextBase.__doc__ 

3660 

3661 def all_gather(self, value, axis, options=None): 

3662 """All-gathers `value` across all replicas along `axis`. 

3663 

3664 Note: An `all_gather` method can only be called in replica context. For 

3665 a cross-replica context counterpart, see `tf.distribute.Strategy.gather`. 

3666 All replicas need to participate in the all-gather, otherwise this 

3667 operation hangs. So if `all_gather` is called in any replica, it must be 

3668 called in all replicas. 

3669 

3670 Note: If there are multiple `all_gather` calls, they need to be executed in 

3671 the same order on all replicas. Dispatching `all_gather` based on conditions 

3672 is usually error-prone. 

3673 

3674 For all strategies except `tf.distribute.TPUStrategy`, the input 

3675 `value` on different replicas must have the same rank, and their shapes must 

3676 be the same in all dimensions except the `axis`-th dimension. In other 

3677 words, their shapes cannot be different in a dimension `d` where `d` does 

3678 not equal to the `axis` argument. For example, given a 

3679 `tf.distribute.DistributedValues` with component tensors of shape 

3680 `(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can call 

3681 `all_gather(..., axis=1, ...)` on it, but not `all_gather(..., axis=0, ...)` 

3682 or `all_gather(..., axis=2, ...)`. However, with 

3683 `tf.distribute.TPUStrategy`, all tensors must have exactly the same rank and 

3684 same shape. 

3685 

3686 Note: The input `value` must have a non-zero rank. Otherwise, consider using 

3687 `tf.expand_dims` before gathering them. 

3688 

3689 You can pass in a single tensor to all-gather: 

3690 

3691 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 

3692 >>> @tf.function 

3693 ... def gather_value(): 

3694 ... ctx = tf.distribute.get_replica_context() 

3695 ... local_value = tf.constant([1, 2, 3]) 

3696 ... return ctx.all_gather(local_value, axis=0) 

3697 >>> result = strategy.run(gather_value) 

3698 >>> result 

3699 PerReplica:{ 

3700 0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>, 

3701 1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)> 

3702 } 

3703 >>> strategy.experimental_local_results(result) 

3704 (<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], 

3705 dtype=int32)>, 

3706 <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], 

3707 dtype=int32)>) 

3708 

3709 

3710 You can also pass in a nested structure of tensors to all-gather, say, a 

3711 list: 

3712 

3713 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 

3714 >>> @tf.function 

3715 ... def gather_nest(): 

3716 ... ctx = tf.distribute.get_replica_context() 

3717 ... value_1 = tf.constant([1, 2, 3]) 

3718 ... value_2 = tf.constant([[1, 2], [3, 4]]) 

3719 ... # all_gather a nest of `tf.distribute.DistributedValues` 

3720 ... return ctx.all_gather([value_1, value_2], axis=0) 

3721 >>> result = strategy.run(gather_nest) 

3722 >>> result 

3723 [PerReplica:{ 

3724 0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>, 

3725 1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)> 

3726 }, PerReplica:{ 

3727 0: <tf.Tensor: shape=(4, 2), dtype=int32, numpy= 

3728 array([[1, 2], 

3729 [3, 4], 

3730 [1, 2], 

3731 [3, 4]], dtype=int32)>, 

3732 1: <tf.Tensor: shape=(4, 2), dtype=int32, numpy= 

3733 array([[1, 2], 

3734 [3, 4], 

3735 [1, 2], 

3736 [3, 4]], dtype=int32)> 

3737 }] 

3738 >>> strategy.experimental_local_results(result) 

3739 ([<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>, 

3740 <tf.Tensor: shape=(4, 2), dtype=int32, numpy= 

3741 array([[1, 2], 

3742 [3, 4], 

3743 [1, 2], 

3744 [3, 4]], dtype=int32)>], 

3745 [<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>, 

3746 <tf.Tensor: shape=(4, 2), dtype=int32, numpy= 

3747 array([[1, 2], 

3748 [3, 4], 

3749 [1, 2], 

3750 [3, 4]], dtype=int32)>]) 

3751 

3752 

3753 What if you are all-gathering tensors with different shapes on different 

3754 replicas? Consider the following example with two replicas, where you have 

3755 `value` as a nested structure consisting of two items to all-gather, `a` and 

3756 `b`. 

3757 

3758 * On Replica 0, `value` is `{'a': [0], 'b': [[0, 1]]}`. 

3759 * On Replica 1, `value` is `{'a': [1], 'b': [[2, 3], [4, 5]]}`. 

3760 * Result for `all_gather` with `axis=0` (on each of the replicas) is: 

3761 

3762 ``` 

3763 {'a': [1, 2], 'b': [[0, 1], [2, 3], [4, 5]]} 

3764 ``` 

3765 

3766 Args: 

3767 value: a nested structure of `tf.Tensor` which `tf.nest.flatten` accepts, 

3768 or a `tf.distribute.DistributedValues` instance. The structure of the 

3769 `tf.Tensor` need to be same on all replicas. The underlying tensor 

3770 constructs can only be dense tensors with non-zero rank, NOT 

3771 `tf.IndexedSlices`. 

3772 axis: 0-D int32 Tensor. Dimension along which to gather. 

3773 options: a `tf.distribute.experimental.CommunicationOptions`. Options to 

3774 perform collective operations. This overrides the default options if the 

3775 `tf.distribute.Strategy` takes one in the constructor. See 

3776 `tf.distribute.experimental.CommunicationOptions` for details of the 

3777 options. 

3778 

3779 Returns: 

3780 A nested structure of `tf.Tensor` with the gathered values. The structure 

3781 is the same as `value`. 

3782 """ 

3783 for v in nest.flatten(value): 

3784 if isinstance(v, indexed_slices.IndexedSlices): 

3785 raise NotImplementedError("all_gather does not support IndexedSlices") 

3786 

3787 if options is None: 

3788 options = collective_util.Options() 

3789 

3790 def batch_all_gather(strategy, *value_flat): 

3791 return strategy.extended._batch_gather_to( # pylint: disable=protected-access 

3792 [(v, _batch_reduce_destination(v)) for v in value_flat], axis, 

3793 options) 

3794 

3795 @custom_gradient.custom_gradient 

3796 def grad_wrapper(*xs): 

3797 ys = self.merge_call(batch_all_gather, args=xs) 

3798 

3799 def grad(*dy_s): 

3800 grads = self.all_reduce(reduce_util.ReduceOp.SUM, dy_s) 

3801 new_grads = [] 

3802 for i, grad in enumerate(grads): 

3803 input_shape = array_ops.shape(xs[i]) 

3804 axis_dim = array_ops.reshape(input_shape[axis], [1]) 

3805 with ops.control_dependencies([array_ops.identity(grads)]): 

3806 d = self.all_gather(axis_dim, axis=0) 

3807 begin_dim = math_ops.reduce_sum(d[:self.replica_id_in_sync_group]) 

3808 end_dim = begin_dim + array_ops.shape(xs[i])[axis] 

3809 new_grad = array_ops.gather( 

3810 grad, axis=axis, indices=math_ops.range(begin_dim, end_dim)) 

3811 new_grads.append(new_grad) 

3812 return new_grads 

3813 

3814 return ys, grad 

3815 

3816 return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value))) 

3817 

3818 def _update(self, var, fn, args=(), kwargs=None, group=True): 

3819 """Run `fn` to update `var` with `args` and `kwargs` in replica context. 

3820 

3821 `tf.distribute.ReplicaContext.update` takes a (distributed) variable `var` 

3822 to be updated, an update function `fn`, and `args` and `kwargs` for `fn`. 

3823 `fn` applies to each component variable of `var` with corresponding input 

3824 values from `args` and `kwargs`. 

3825 

3826 Example usage: 

3827 

3828 >>> strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # 2 replicas 

3829 >>> with strategy.scope(): 

3830 ... distributed_variable = tf.Variable(5.0) 

3831 >>> distributed_variable 

3832 MirroredVariable:{ 

3833 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>, 

3834 1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=5.0> 

3835 } 

3836 >>> def replica_fn(v): 

3837 ... value = tf.identity(1.0) 

3838 ... replica_context = tf.distribute.get_replica_context() 

3839 ... update_fn = lambda var, value: var.assign(value) 

3840 ... replica_context._update(v, update_fn, args=(value,)) 

3841 >>> strategy.run(replica_fn, args=(distributed_variable,)) 

3842 >>> distributed_variable 

3843 MirroredVariable:{ 

3844 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>, 

3845 1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0> 

3846 } 

3847 

3848 This API must be called in a replica context. 

3849 

3850 Note that if `var` is a MirroredVariable (i.e., the type of variable created 

3851 under the scope of a synchronous strategy, and is synchronized on-write, see 

3852 `tf.VariableSynchronization` for more information) and `args`/`kwargs` 

3853 contains different values for different replicas, `var` will be dangerously 

3854 out of synchronization. Thus we recommend using `variable.assign(value)` as 

3855 long as you can, which under the hood aggregates the updates and guarantees 

3856 the synchronization. The case where you actually want this API instead of 

3857 `variable.assign(value)` is that before assigning `value` to the `variable`, 

3858 you'd like to conduct some pre-`assign` computation colocated with the 

3859 variable devices (i.e. where variables reside, for MirroredStrategy they are 

3860 the same as the compute device, for ParameterServerStrategy they refer to 

3861 parameter servers). E.g., 

3862 

3863 ```python 

3864 strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # 2 replicas 

3865 with strategy.scope(): 

3866 v = tf.Variable(5.0, aggregation=tf.VariableAggregation.SUM) 

3867 def replica_fn(inputs): 

3868 value = computation(inputs) 

3869 replica_context = tf.distribute.get_replica_context() 

3870 reduced_value = replica_context.all_reduce(value) 

3871 

3872 def update_fn(var, value): 

3873 # this computation will colocate with `var`'s device 

3874 updated_value = post_reduce_pre_update_computation(value) 

3875 var.assign(value) 

3876 

3877 replica_context._update(v, update_fn, args=(reduced_value,)) 

3878 

3879 strategy.run(replica_fn, args=(inputs,)) 

3880 ``` 

3881 

3882 This code snippet is consistent across all strategies. If you directly 

3883 compute and use `assign` in the replica context instead of wrapping it with 

3884 `update`, for strategies with fewer variable devices than compute devices 

3885 (e.g., parameter server strategy, usually), the 

3886 `post_reduce_pre_update_computation` will happen 

3887 N==number_of_compute_devices times which is less performant. 

3888 

3889 

3890 Args: 

3891 var: Variable, possibly distributed to multiple devices, to operate on. 

3892 fn: Function to call. Should take the variable as the first argument. 

3893 args: Tuple or list. Additional positional arguments to pass to `fn()`. 

3894 kwargs: Dict with keyword arguments to pass to `fn()`. 

3895 group: Boolean. Defaults to True. Most strategies enter a merge_call to 

3896 conduct update in cross-replica context, and group=True guarantees updates 

3897 on all replicas is executed. 

3898 

3899 Returns: 

3900 The return value of `fn` for the local replica. 

3901 """ 

3902 if kwargs is None: 

3903 kwargs = {} 

3904 return self._strategy.extended._replica_ctx_update(var, fn, args=args, kwargs=kwargs, group=group) # pylint: disable=protected-access 

3905 

3906 

3907@tf_export(v1=["distribute.ReplicaContext"]) 

3908class ReplicaContextV1(ReplicaContextBase): 

3909 __doc__ = ReplicaContextBase.__doc__ 

3910 

3911 

3912def _batch_reduce_destination(x): 

3913 """Returns the destinations for batch all-reduce.""" 

3914 if isinstance(x, ops.Tensor): 

3915 # If this is a one device strategy. 

3916 return x.device 

3917 else: 

3918 return x 

3919# ------------------------------------------------------------------------------ 

3920 

3921 

3922class _DefaultDistributionStrategyV1(StrategyV1): 

3923 """Default `tf.distribute.Strategy` if none is explicitly selected.""" 

3924 

3925 def __init__(self): 

3926 if not _creating_default_strategy_singleton: 

3927 raise RuntimeError("Should only create a single instance of " 

3928 "_DefaultDistributionStrategy") 

3929 super(_DefaultDistributionStrategyV1, 

3930 self).__init__(_DefaultDistributionExtended(self)) 

3931 

3932 def __deepcopy__(self, memo): 

3933 del memo 

3934 raise RuntimeError("Should only create a single instance of " 

3935 "_DefaultDistributionStrategy") 

3936 

3937 

3938class _DefaultDistributionStrategy(Strategy): 

3939 """Default `tf.distribute.Strategy` if none is explicitly selected.""" 

3940 

3941 def __init__(self): 

3942 if not _creating_default_strategy_singleton: 

3943 raise RuntimeError("Should only create a single instance of " 

3944 "_DefaultDistributionStrategy") 

3945 super(_DefaultDistributionStrategy, self).__init__( 

3946 _DefaultDistributionExtended(self)) 

3947 

3948 def __deepcopy__(self, memo): 

3949 del memo 

3950 raise RuntimeError("Should only create a single instance of " 

3951 "_DefaultDistributionStrategy") 

3952 

3953 

3954class _DefaultDistributionContext(object): 

3955 """Context manager setting the default `tf.distribute.Strategy`.""" 

3956 

3957 __slots__ = ["_var_creator_scope", "_strategy", "_nested_count"] 

3958 

3959 def __init__(self, strategy): 

3960 

3961 def creator(next_creator, **kwargs): 

3962 _require_strategy_scope_strategy(strategy) 

3963 return next_creator(**kwargs) 

3964 

3965 self._var_creator_scope = variable_scope.variable_creator_scope(creator) 

3966 self._strategy = strategy 

3967 self._nested_count = 0 

3968 

3969 def __enter__(self): 

3970 # Allow this scope to be entered if this strategy is already in scope. 

3971 if has_strategy(): 

3972 raise RuntimeError("Must not nest tf.distribute.Strategy scopes.") 

3973 if self._nested_count == 0: 

3974 self._var_creator_scope.__enter__() 

3975 self._nested_count += 1 

3976 return self._strategy 

3977 

3978 def __exit__(self, exception_type, exception_value, traceback): 

3979 self._nested_count -= 1 

3980 if self._nested_count == 0: 

3981 try: 

3982 self._var_creator_scope.__exit__( 

3983 exception_type, exception_value, traceback) 

3984 except RuntimeError as e: 

3985 six.raise_from( 

3986 RuntimeError("Variable creator scope nesting error: move call to " 

3987 "tf.distribute.set_strategy() out of `with` scope."), 

3988 e) 

3989 

3990 

3991class _DefaultDistributionExtended(StrategyExtendedV1): 

3992 """Implementation of _DefaultDistributionStrategy.""" 

3993 

3994 def __init__(self, container_strategy): 

3995 super(_DefaultDistributionExtended, self).__init__(container_strategy) 

3996 self._retrace_functions_for_each_device = False 

3997 

3998 def _scope(self, strategy): 

3999 """Context manager setting a variable creator and `self` as current.""" 

4000 return _DefaultDistributionContext(strategy) 

4001 

4002 def colocate_vars_with(self, colocate_with_variable): 

4003 """Does not require `self.scope`.""" 

4004 _require_strategy_scope_extended(self) 

4005 return ops.colocate_with(colocate_with_variable) 

4006 

4007 def variable_created_in_scope(self, v): 

4008 return v._distribute_strategy is None # pylint: disable=protected-access 

4009 

4010 def _experimental_distribute_dataset(self, dataset, options): 

4011 return dataset 

4012 

4013 def _distribute_datasets_from_function(self, dataset_fn, options): 

4014 return dataset_fn(InputContext()) 

4015 

4016 def _experimental_distribute_values_from_function(self, value_fn): 

4017 return value_fn(ValueContext()) 

4018 

4019 def _make_dataset_iterator(self, dataset): 

4020 return _DefaultDistributionExtended.DefaultInputIterator(dataset) 

4021 

4022 def _make_input_fn_iterator(self, 

4023 input_fn, 

4024 replication_mode=InputReplicationMode.PER_WORKER): 

4025 dataset = input_fn(InputContext()) 

4026 return _DefaultDistributionExtended.DefaultInputIterator(dataset) 

4027 

4028 def _experimental_make_numpy_dataset(self, numpy_input, session): 

4029 numpy_flat = nest.flatten(numpy_input) 

4030 vars_flat = tuple( 

4031 variable_v1.VariableV1(array_ops.zeros(i.shape, i.dtype), 

4032 trainable=False, use_resource=True) 

4033 for i in numpy_flat 

4034 ) 

4035 for v, i in zip(vars_flat, numpy_flat): 

4036 numpy_dataset.init_var_from_numpy(v, i, session) 

4037 vars_nested = nest.pack_sequence_as(numpy_input, vars_flat) 

4038 return dataset_ops.Dataset.from_tensor_slices(vars_nested) 

4039 

4040 def _broadcast_to(self, tensor, destinations): 

4041 if destinations is None: 

4042 return tensor 

4043 else: 

4044 raise NotImplementedError("TODO") 

4045 

4046 def _call_for_each_replica(self, fn, args, kwargs): 

4047 with ReplicaContext(self._container_strategy(), replica_id_in_sync_group=0): 

4048 return fn(*args, **kwargs) 

4049 

4050 def _reduce_to(self, reduce_op, value, destinations, options): 

4051 # TODO(josh11b): Use destinations? 

4052 del reduce_op, destinations, options 

4053 return value 

4054 

4055 def _gather_to_implementation(self, value, destinations, axis, options): 

4056 del destinations, axis, options 

4057 return value 

4058 

4059 def _update(self, var, fn, args, kwargs, group): 

4060 # The implementations of _update() and _update_non_slot() are identical 

4061 # except _update() passes `var` as the first argument to `fn()`. 

4062 return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group) 

4063 

4064 def _update_non_slot(self, colocate_with, fn, args, kwargs, should_group): 

4065 # TODO(josh11b): Figure out what we should be passing to UpdateContext() 

4066 # once that value is used for something. 

4067 with UpdateContext(colocate_with): 

4068 result = fn(*args, **kwargs) 

4069 if should_group: 

4070 return result 

4071 else: 

4072 return nest.map_structure(self._local_results, result) 

4073 

4074 def read_var(self, replica_local_var): 

4075 return array_ops.identity(replica_local_var) 

4076 

4077 def _local_results(self, distributed_value): 

4078 return (distributed_value,) 

4079 

4080 def value_container(self, value): 

4081 return value 

4082 

4083 @property 

4084 def _num_replicas_in_sync(self): 

4085 return 1 

4086 

4087 @property 

4088 def worker_devices(self): 

4089 raise RuntimeError("worker_devices() method unsupported by default " 

4090 "tf.distribute.Strategy.") 

4091 

4092 @property 

4093 def parameter_devices(self): 

4094 raise RuntimeError("parameter_devices() method unsupported by default " 

4095 "tf.distribute.Strategy.") 

4096 

4097 def non_slot_devices(self, var_list): 

4098 return min(var_list, key=lambda x: x.name) 

4099 

4100 def _in_multi_worker_mode(self): 

4101 """Whether this strategy indicates working in multi-worker settings.""" 

4102 # Default strategy doesn't indicate multi-worker training. 

4103 return False 

4104 

4105 @property 

4106 def should_checkpoint(self): 

4107 return True 

4108 

4109 @property 

4110 def should_save_summary(self): 

4111 return True 

4112 

4113 def _get_local_replica_id(self, replica_id_in_sync_group): 

4114 return replica_id_in_sync_group 

4115 

4116 def _get_replica_id_in_sync_group(self, replica_id): 

4117 return replica_id 

4118 

4119 # TODO(priyag): This should inherit from `InputIterator`, once dependency 

4120 # issues have been resolved. 

4121 class DefaultInputIterator(object): 

4122 """Default implementation of `InputIterator` for default strategy.""" 

4123 

4124 def __init__(self, dataset): 

4125 self._dataset = dataset 

4126 if eager_context.executing_eagerly(): 

4127 self._iterator = dataset_ops.make_one_shot_iterator(dataset) 

4128 else: 

4129 self._iterator = dataset_ops.make_initializable_iterator(dataset) 

4130 

4131 def get_next(self): 

4132 return self._iterator.get_next() 

4133 

4134 def get_next_as_optional(self): 

4135 return self._iterator.get_next_as_optional() 

4136 

4137 @deprecated(None, "Use the iterator's `initializer` property instead.") 

4138 def initialize(self): 

4139 """Initialize underlying iterators. 

4140 

4141 Returns: 

4142 A list of any initializer ops that should be run. 

4143 """ 

4144 if eager_context.executing_eagerly(): 

4145 self._iterator = self._dataset.make_one_shot_iterator() 

4146 return [] 

4147 else: 

4148 return [self._iterator.initializer] 

4149 

4150 @property 

4151 def initializer(self): 

4152 """Returns a list of ops that initialize the iterator.""" 

4153 return self.initialize() 

4154 

4155 # TODO(priyag): Delete this once all strategies use global batch size. 

4156 @property 

4157 def _global_batch_size(self): 

4158 """Global and per-replica batching are equivalent for this strategy.""" 

4159 return True 

4160 

4161 

4162class _DefaultReplicaContext(ReplicaContext): 

4163 """ReplicaContext for _DefaultDistributionStrategy.""" 

4164 

4165 @property 

4166 def replica_id_in_sync_group(self): 

4167 # Return 0 instead of a constant tensor to avoid creating a new node for 

4168 # users who don't use distribution strategy. 

4169 return 0 

4170 

4171 

4172# ------------------------------------------------------------------------------ 

4173# We haven't yet implemented deserialization for DistributedVariables. 

4174# So here we catch any attempts to deserialize variables 

4175# when using distribution strategies. 

4176# pylint: disable=protected-access 

4177_original_from_proto = ref_variable._from_proto_fn 

4178 

4179 

4180def _from_proto_fn(v, import_scope=None): 

4181 if has_strategy(): 

4182 raise NotImplementedError( 

4183 "Deserialization of variables is not yet supported when using a " 

4184 "tf.distribute.Strategy.") 

4185 else: 

4186 return _original_from_proto(v, import_scope=import_scope) 

4187 

4188ref_variable._from_proto_fn = _from_proto_fn 

4189# pylint: enable=protected-access 

4190 

4191 

4192def get_local_results_or_value_container(variable): 

4193 strategy, context = get_strategy_and_replica_context() 

4194 if context: 

4195 return [strategy.extended.value_container(variable)] 

4196 else: 

4197 return strategy.experimental_local_results(variable) 

4198 

4199 

4200tape.register_watched_variable_resolver(get_local_results_or_value_container) 

4201 

4202 

4203# ------------------------------------------------------------------------------ 

4204# Metrics to track which distribution strategy is being called 

4205distribution_strategy_gauge = monitoring.StringGauge( 

4206 "/tensorflow/api/distribution_strategy", 

4207 "Gauge to track the type of distribution strategy used.", "TFVersion") 

4208distribution_strategy_replica_gauge = monitoring.IntGauge( 

4209 "/tensorflow/api/distribution_strategy/replica", 

4210 "Gauge to track the number of replica each distribution strategy used.", 

4211 "CountType") 

4212distribution_strategy_input_api_counter = monitoring.Counter( 

4213 "/tensorflow/api/distribution_strategy/input_api", 

4214 "Counter to track the usage of the input APIs", "strategy", "api") 

4215distributed_variable_creation_time_counter = monitoring.Counter( 

4216 "/tensorflow/api/distribution_strategy/distributed_variable_creation_time_usecs", 

4217 "Time to create distributed variables (us).", "strategy", "if_graph_building")