Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py: 36%
983 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=line-too-long
16"""Library for running a computation across multiple devices.
18The intent of this library is that you can write an algorithm in a stylized way
19and it will be usable with a variety of different `tf.distribute.Strategy`
20implementations. Each descendant will implement a different strategy for
21distributing the algorithm across multiple devices/machines. Furthermore, these
22changes can be hidden inside the specific layers and other library classes that
23need special treatment to run in a distributed setting, so that most users'
24model definition code can run unchanged. The `tf.distribute.Strategy` API works
25the same way with eager and graph execution.
27*Guides*
29* [TensorFlow v2.x](https://www.tensorflow.org/guide/distributed_training)
30* [TensorFlow v1.x](https://github.com/tensorflow/docs/blob/master/site/en/r1/guide/distribute_strategy.ipynb)
32*Tutorials*
34* [Distributed Training Tutorials](https://www.tensorflow.org/tutorials/distribute/)
36 The tutorials cover how to use `tf.distribute.Strategy` to do distributed
37 training with native Keras APIs, custom training loops,
38 and Estimator APIs. They also cover how to save/load model when using
39 `tf.distribute.Strategy`.
41*Glossary*
43* _Data parallelism_ is where we run multiple copies of the model
44 on different slices of the input data. This is in contrast to
45 _model parallelism_ where we divide up a single copy of a model
46 across multiple devices.
47 Note: we only support data parallelism for now, but
48 hope to add support for model parallelism in the future.
49* A _device_ is a CPU or accelerator (e.g. GPUs, TPUs) on some machine that
50 TensorFlow can run operations on (see e.g. `tf.device`). You may have multiple
51 devices on a single machine, or be connected to devices on multiple
52 machines. Devices used to run computations are called _worker devices_.
53 Devices used to store variables are _parameter devices_. For some strategies,
54 such as `tf.distribute.MirroredStrategy`, the worker and parameter devices
55 will be the same (see mirrored variables below). For others they will be
56 different. For example, `tf.distribute.experimental.CentralStorageStrategy`
57 puts the variables on a single device (which may be a worker device or may be
58 the CPU), and `tf.distribute.experimental.ParameterServerStrategy` puts the
59 variables on separate machines called _parameter servers_ (see below).
60* A _replica_ is one copy of the model, running on one slice of the
61 input data. Right now each replica is executed on its own
62 worker device, but once we add support for model parallelism
63 a replica may span multiple worker devices.
64* A _host_ is the CPU device on a machine with worker devices, typically
65 used for running input pipelines.
66* A _worker_ is defined to be the physical machine(s) containing the physical
67 devices (e.g. GPUs, TPUs) on which the replicated computation is executed. A
68 worker may contain one or more replicas, but contains at least one
69 replica. Typically one worker will correspond to one machine, but in the case
70 of very large models with model parallelism, one worker may span multiple
71 machines. We typically run one input pipeline per worker, feeding all the
72 replicas on that worker.
73* _Synchronous_, or more commonly _sync_, training is where the updates from
74 each replica are aggregated together before updating the model variables. This
75 is in contrast to _asynchronous_, or _async_ training, where each replica
76 updates the model variables independently. You may also have replicas
77 partitioned into groups which are in sync within each group but async between
78 groups.
79* _Parameter servers_: These are machines that hold a single copy of
80 parameters/variables, used by some strategies (right now just
81 `tf.distribute.experimental.ParameterServerStrategy`). All replicas that want
82 to operate on a variable retrieve it at the beginning of a step and send an
83 update to be applied at the end of the step. These can in principle support
84 either sync or async training, but right now we only have support for async
85 training with parameter servers. Compare to
86 `tf.distribute.experimental.CentralStorageStrategy`, which puts all variables
87 on a single device on the same machine (and does sync training), and
88 `tf.distribute.MirroredStrategy`, which mirrors variables to multiple devices
89 (see below).
91* _Replica context_ vs. _Cross-replica context_ vs _Update context_
93 A _replica context_ applies
94 when you execute the computation function that was called with `strategy.run`.
95 Conceptually, you're in replica context when executing the computation
96 function that is being replicated.
98 An _update context_ is entered in a `tf.distribute.StrategyExtended.update`
99 call.
101 An _cross-replica context_ is entered when you enter a `strategy.scope`. This
102 is useful for calling `tf.distribute.Strategy` methods which operate across
103 the replicas (like `reduce_to()`). By default you start in a _replica context_
104 (the "default single _replica context_") and then some methods can switch you
105 back and forth.
107* _Distributed value_: Distributed value is represented by the base class
108 `tf.distribute.DistributedValues`. `tf.distribute.DistributedValues` is useful
109 to represent values on multiple devices, and it contains a map from replica id
110 to values. Two representative types of `tf.distribute.DistributedValues`
111 are `tf.types.experimental.PerReplica` and `tf.types.experimental.Mirrored`
112 values.
114 `PerReplica` values exist on the worker devices, with a different value for
115 each replica. They are produced by iterating through a distributed dataset
116 returned by `tf.distribute.Strategy.experimental_distribute_dataset` and
117 `tf.distribute.Strategy.distribute_datasets_from_function`. They are also the
118 typical result returned by `tf.distribute.Strategy.run`.
120 `Mirrored` values are like `PerReplica` values, except we know that the value
121 on all replicas are the same. `Mirrored` values are kept synchronized by the
122 distribution strategy in use, while `PerReplica` values are left
123 unsynchronized. `Mirrored` values typically represent model weights. We can
124 safely read a `Mirrored` value in a cross-replica context by using the value
125 on any replica, while PerReplica values can only be read within a replica
126 context.
128* _Unwrapping_ and _merging_: Consider calling a function `fn` on multiple
129 replicas, like `strategy.run(fn, args=[w])` with an
130 argument `w` that is a `tf.distribute.DistributedValues`. This means `w` will
131 have a map taking replica id `0` to `w0`, replica id `1` to `w1`, etc.
132 `strategy.run()` unwraps `w` before calling `fn`, so it calls `fn(w0)` on
133 device `d0`, `fn(w1)` on device `d1`, etc. It then merges the return
134 values from `fn()`, which leads to one common object if the returned values
135 are the same object from every replica, or a `DistributedValues` object
136 otherwise.
138* _Reductions_ and _all-reduce_: A _reduction_ is a method of aggregating
139 multiple values into one value, like "sum" or "mean". If a strategy is doing
140 sync training, we will perform a reduction on the gradients to a parameter
141 from all replicas before applying the update. _All-reduce_ is an algorithm for
142 performing a reduction on values from multiple devices and making the result
143 available on all of those devices.
145* _Mirrored variables_: These are variables that are created on multiple
146 devices, where we keep the variables in sync by applying the same
147 updates to every copy. Mirrored variables are created with
148 `tf.Variable(...synchronization=tf.VariableSynchronization.ON_WRITE...)`.
149 Normally they are only used in synchronous training.
151* _SyncOnRead variables_
153 _SyncOnRead variables_ are created by
154 `tf.Variable(...synchronization=tf.VariableSynchronization.ON_READ...)`, and
155 they are created on multiple devices. In replica context, each
156 component variable on the local replica can perform reads and writes without
157 synchronization with each other. When the
158 _SyncOnRead variable_ is read in cross-replica context, the values from
159 component variables are aggregated and returned.
161 _SyncOnRead variables_ bring a lot of custom configuration difficulty to the
162 underlying logic, so we do not encourage users to instantiate and use
163 _SyncOnRead variable_ on their own. We have mainly used _SyncOnRead
164 variables_ for use cases such as batch norm and metrics. For performance
165 reasons, we often don't need to keep these statistics in sync every step and
166 they can be accumulated on each replica independently. The only time we want
167 to sync them is reporting or checkpointing, which typically happens in
168 cross-replica context. _SyncOnRead variables_ are also often used by advanced
169 users who want to control when variable values are aggregated. For example,
170 users sometimes want to maintain gradients independently on each replica for a
171 couple of steps without aggregation.
173* _Distribute-aware layers_
175 Layers are generally called in a replica context, except when defining a
176 Keras functional model. `tf.distribute.in_cross_replica_context` will let you
177 determine which case you are in. If in a replica context,
178 the `tf.distribute.get_replica_context` function will return the default
179 replica context outside a strategy scope, `None` within a strategy scope, and
180 a `tf.distribute.ReplicaContext` object inside a strategy scope and within a
181 `tf.distribute.Strategy.run` function. The `ReplicaContext` object has an
182 `all_reduce` method for aggregating across all replicas.
185Note that we provide a default version of `tf.distribute.Strategy` that is
186used when no other strategy is in scope, that provides the same API with
187reasonable default behavior.
188"""
189# pylint: enable=line-too-long
191import collections
192import contextlib
193import copy
194import enum # pylint: disable=g-bad-import-order
195import functools
196import threading
197import weakref
199import six
201from tensorflow.python import tf2
202from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
203from tensorflow.python.autograph.impl import api as autograph
204from tensorflow.python.data.ops import dataset_ops
205from tensorflow.python.distribute import collective_util
206from tensorflow.python.distribute import device_util
207from tensorflow.python.distribute import numpy_dataset
208from tensorflow.python.distribute import reduce_util
209from tensorflow.python.eager import context as eager_context
210from tensorflow.python.eager import def_function
211from tensorflow.python.eager import monitoring
212from tensorflow.python.eager import tape
213from tensorflow.python.framework import constant_op
214from tensorflow.python.framework import dtypes
215from tensorflow.python.framework import indexed_slices
216from tensorflow.python.framework import ops
217from tensorflow.python.framework import tensor_shape
218from tensorflow.python.framework import tensor_util
219from tensorflow.python.ops import array_ops
220from tensorflow.python.ops import control_flow_ops
221from tensorflow.python.ops import custom_gradient
222from tensorflow.python.ops import math_ops
223from tensorflow.python.ops import ref_variable
224from tensorflow.python.ops import summary_ops_v2
225from tensorflow.python.ops import variable_scope
226from tensorflow.python.ops import variable_v1
227from tensorflow.python.platform import tf_logging
228from tensorflow.python.trackable import base as trackable
229from tensorflow.python.types import distribute as ds_types
230from tensorflow.python.util import deprecation
231from tensorflow.python.util import nest
232from tensorflow.python.util import tf_contextlib
233from tensorflow.python.util.deprecation import deprecated
234from tensorflow.python.util.tf_export import tf_export
235from tensorflow.tools.docs import doc_controls
237# ------------------------------------------------------------------------------
238# Context tracking whether in a strategy.update() or .update_non_slot() call.
241_update_replica_id = threading.local()
244def get_update_replica_id():
245 """Get the current device if in a `tf.distribute.Strategy.update()` call."""
246 try:
247 return _update_replica_id.current
248 except AttributeError:
249 return None
252class UpdateContext(object):
253 """Context manager when you are in `update()` or `update_non_slot()`."""
255 __slots__ = ["_replica_id", "_old_replica_id"]
257 def __init__(self, replica_id):
258 self._replica_id = replica_id
259 self._old_replica_id = None
261 def __enter__(self):
262 self._old_replica_id = get_update_replica_id()
263 _update_replica_id.current = self._replica_id
265 def __exit__(self, exception_type, exception_value, traceback):
266 del exception_type, exception_value, traceback
267 _update_replica_id.current = self._old_replica_id
270# ------------------------------------------------------------------------------
271# Internal API for validating the current thread mode
274def _require_cross_replica_or_default_context_extended(extended,
275 error_message=None):
276 """Verify in cross-replica context."""
277 context = _get_per_thread_mode()
278 cross_replica = context.cross_replica_context
279 if cross_replica is not None and cross_replica.extended is extended:
280 return
281 if context is _get_default_replica_mode():
282 return
283 strategy = extended._container_strategy() # pylint: disable=protected-access
284 # We have an error to report, figure out the right message.
285 if context.strategy is not strategy:
286 _wrong_strategy_scope(strategy, context)
287 assert cross_replica is None
288 if not error_message:
289 error_message = ("Method requires being in cross-replica context, use "
290 "get_replica_context().merge_call()")
291 raise RuntimeError(error_message)
294def _wrong_strategy_scope(strategy, context):
295 # Figure out the right error message.
296 if not has_strategy():
297 raise RuntimeError(
298 'Need to be inside "with strategy.scope()" for %s' %
299 (strategy,))
300 else:
301 raise RuntimeError(
302 "Mixing different tf.distribute.Strategy objects: %s is not %s" %
303 (context.strategy, strategy))
306def require_replica_context(replica_ctx):
307 """Verify in `replica_ctx` replica context."""
308 context = _get_per_thread_mode()
309 if context.replica_context is replica_ctx: return
310 # We have an error to report, figure out the right message.
311 if context.replica_context is None:
312 raise RuntimeError("Need to be inside `call_for_each_replica()`")
313 if context.strategy is replica_ctx.strategy:
314 # Two different ReplicaContexts with the same tf.distribute.Strategy.
315 raise RuntimeError("Mismatching ReplicaContext.")
316 raise RuntimeError(
317 "Mismatching tf.distribute.Strategy objects: %s is not %s." %
318 (context.strategy, replica_ctx.strategy))
321def _require_strategy_scope_strategy(strategy):
322 """Verify in a `strategy.scope()` in this thread."""
323 context = _get_per_thread_mode()
324 if context.strategy is strategy: return
325 _wrong_strategy_scope(strategy, context)
328def _require_strategy_scope_extended(extended):
329 """Verify in a `distribution_strategy.scope()` in this thread."""
330 context = _get_per_thread_mode()
331 if context.strategy.extended is extended: return
332 # Report error.
333 strategy = extended._container_strategy() # pylint: disable=protected-access
334 _wrong_strategy_scope(strategy, context)
337_creating_default_strategy_singleton = False
339# ------------------------------------------------------------------------------
340# Internal API for setting the current thread mode as being either in a
341# replica or cross-replica context for a particular tf.distribute.Strategy.
344class _ThreadMode(object):
346 def __init__(self, dist, cross, replica):
347 self.strategy = dist
348 self.cross_replica_context = cross
349 self.replica_context = replica
352class _CrossReplicaThreadMode(_ThreadMode):
354 def __init__(self, strategy):
355 _ThreadMode.__init__(self, strategy, strategy, None)
358class _InReplicaThreadMode(_ThreadMode):
360 def __init__(self, replica_ctx):
361 _ThreadMode.__init__(self, replica_ctx.strategy, None, replica_ctx)
364def _push_per_thread_mode(context):
365 ops.get_default_graph()._distribution_strategy_stack.append(context) # pylint: disable=protected-access
368def _pop_per_thread_mode():
369 ops.get_default_graph()._distribution_strategy_stack.pop(-1) # pylint: disable=protected-access
372class _DefaultReplicaThreadMode(_ThreadMode):
373 """Type of default value returned by `_get_per_thread_mode()`.
375 Used when the thread-local stack is empty.
376 """
378 def __init__(self):
379 _ThreadMode.__init__(self, _get_default_strategy(), None,
380 _get_default_replica_context())
383def _get_per_thread_mode():
384 try:
385 return ops.get_default_graph()._distribution_strategy_stack[-1] # pylint: disable=protected-access
386 except (AttributeError, IndexError):
387 return _get_default_replica_mode()
390_variable_sync_on_read_context = threading.local()
393@tf_export("__internal__.distribute.variable_sync_on_read_context", v1=[])
394@contextlib.contextmanager
395def variable_sync_on_read_context():
396 """A context that forces SyncOnReadVariable to aggregate upon reading.
398 This context is useful if one wants to read the aggregated value out of a
399 SyncOnReadVariable in replica context. By default the aggregation is turned
400 off per the definition of SyncOnReadVariable.
402 When reading a SyncOnReadVariable in cross-replica context, aggregation is
403 always turned on so there is no need for such context.
405 By reading a SyncOnReadVariable, we mean:
406 1. Convert the variable to a tensor using `convert_to_tensor`.
407 2. Calling `variable.value()` or `variable.read_value()`.
409 Example usage:
411 ```
412 strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"])
413 with strategy.scope():
414 v = tf.Variable(1.0, synchronization=tf.VariableSynchronization.ON_READ,
415 aggregation=tf.VariableAggregation.SUM)
417 def replica_fn():
418 return v + 10.0
420 non_aggregated = strategy.run(replica_fn)
421 print(non_aggregated) # PerReplica: {0: 11.0, 1: 11.0}
423 def replica_fn():
424 with variable_sync_on_read_context():
425 return v + 10.0
427 aggregated = strategy.run(replica_fn)
428 print(aggregated) # PerReplica: {0: 12.0, 1: 12.0}
429 ```
431 Yields:
432 Context manager for aggregating SyncOnReadVariable upon reading.
433 """
434 try:
435 _variable_sync_on_read_context.entered = True
436 yield
437 finally:
438 _variable_sync_on_read_context.entered = False
441def in_variable_sync_on_read_context():
442 try:
443 return _variable_sync_on_read_context.entered
444 except AttributeError:
445 return False
447# ------------------------------------------------------------------------------
448# Public API for accessing the current thread mode
451@tf_export("distribute.get_replica_context")
452def get_replica_context():
453 """Returns the current `tf.distribute.ReplicaContext` or `None`.
455 Returns `None` if in a cross-replica context.
457 Note that execution:
459 1. starts in the default (single-replica) replica context (this function
460 will return the default `ReplicaContext` object);
461 2. switches to cross-replica context (in which case this will return
462 `None`) when entering a `with tf.distribute.Strategy.scope():` block;
463 3. switches to a (non-default) replica context inside `strategy.run(fn, ...)`;
464 4. if `fn` calls `get_replica_context().merge_call(merge_fn, ...)`, then
465 inside `merge_fn` you are back in the cross-replica context (and again
466 this function will return `None`).
468 Most `tf.distribute.Strategy` methods may only be executed in
469 a cross-replica context, in a replica context you should use the
470 API of the `tf.distribute.ReplicaContext` object returned by this
471 method instead.
473 ```
474 assert tf.distribute.get_replica_context() is not None # default
475 with strategy.scope():
476 assert tf.distribute.get_replica_context() is None
478 def f():
479 replica_context = tf.distribute.get_replica_context() # for strategy
480 assert replica_context is not None
481 tf.print("Replica id: ", replica_context.replica_id_in_sync_group,
482 " of ", replica_context.num_replicas_in_sync)
484 strategy.run(f)
485 ```
487 Returns:
488 The current `tf.distribute.ReplicaContext` object when in a replica context
489 scope, else `None`.
491 Within a particular block, exactly one of these two things will be true:
493 * `get_replica_context()` returns non-`None`, or
494 * `tf.distribute.is_cross_replica_context()` returns True.
495 """
496 return _get_per_thread_mode().replica_context
499def get_cross_replica_context():
500 """Returns the current tf.distribute.Strategy if in a cross-replica context.
502 DEPRECATED: Please use `in_cross_replica_context()` and
503 `get_strategy()` instead.
505 Returns:
506 Returns the current `tf.distribute.Strategy` object in a cross-replica
507 context, or `None`.
509 Exactly one of `get_replica_context()` and `get_cross_replica_context()`
510 will return `None` in a particular block.
511 """
512 return _get_per_thread_mode().cross_replica_context
515@tf_export("distribute.in_cross_replica_context")
516def in_cross_replica_context():
517 """Returns `True` if in a cross-replica context.
519 See `tf.distribute.get_replica_context` for details.
521 ```
522 assert not tf.distribute.in_cross_replica_context()
523 with strategy.scope():
524 assert tf.distribute.in_cross_replica_context()
526 def f():
527 assert not tf.distribute.in_cross_replica_context()
529 strategy.run(f)
530 ```
532 Returns:
533 `True` if in a cross-replica context (`get_replica_context()` returns
534 `None`), or `False` if in a replica context (`get_replica_context()` returns
535 non-`None`).
536 """
537 return _get_per_thread_mode().cross_replica_context is not None
540@tf_export("distribute.get_strategy")
541def get_strategy():
542 """Returns the current `tf.distribute.Strategy` object.
544 Typically only used in a cross-replica context:
546 ```
547 if tf.distribute.in_cross_replica_context():
548 strategy = tf.distribute.get_strategy()
549 ...
550 ```
552 Returns:
553 A `tf.distribute.Strategy` object. Inside a `with strategy.scope()` block,
554 it returns `strategy`, otherwise it returns the default (single-replica)
555 `tf.distribute.Strategy` object.
556 """
557 return _get_per_thread_mode().strategy
560@tf_export("distribute.has_strategy")
561def has_strategy():
562 """Return if there is a current non-default `tf.distribute.Strategy`.
564 ```
565 assert not tf.distribute.has_strategy()
566 with strategy.scope():
567 assert tf.distribute.has_strategy()
568 ```
570 Returns:
571 True if inside a `with strategy.scope():`.
572 """
573 return get_strategy() is not _get_default_strategy()
576def get_strategy_and_replica_context():
577 per_thread_mode = _get_per_thread_mode()
578 return (per_thread_mode.strategy, per_thread_mode.replica_context)
581@tf_export("distribute.experimental_set_strategy")
582def experimental_set_strategy(strategy):
583 """Set a `tf.distribute.Strategy` as current without `with strategy.scope()`.
585 ```
586 tf.distribute.experimental_set_strategy(strategy1)
587 f()
588 tf.distribute.experimental_set_strategy(strategy2)
589 g()
590 tf.distribute.experimental_set_strategy(None)
591 h()
592 ```
594 is equivalent to:
596 ```
597 with strategy1.scope():
598 f()
599 with strategy2.scope():
600 g()
601 h()
602 ```
604 In general, you should use the `with strategy.scope():` API, but this
605 alternative may be convenient in notebooks where you would have to put
606 each cell in a `with strategy.scope():` block.
608 Note: This should only be called outside of any TensorFlow scope to
609 avoid improper nesting.
611 Args:
612 strategy: A `tf.distribute.Strategy` object or None.
614 Raises:
615 RuntimeError: If called inside a `with strategy.scope():`.
616 """
617 old_scope = ops.get_default_graph()._global_distribute_strategy_scope # pylint: disable=protected-access
618 if old_scope is not None:
619 old_scope.__exit__(None, None, None)
620 ops.get_default_graph()._global_distribute_strategy_scope = None # pylint: disable=protected-access
621 if has_strategy():
622 raise RuntimeError(
623 "Must not be called inside a `tf.distribute.Strategy` scope.")
624 if strategy is not None:
625 new_scope = strategy.scope()
626 new_scope.__enter__()
627 ops.get_default_graph()._global_distribute_strategy_scope = new_scope # pylint: disable=protected-access
630# ------------------------------------------------------------------------------
631# Internal helpers.
634@contextlib.contextmanager
635def enter_or_assert_strategy(strategy):
636 if has_strategy():
637 _assert_strategy(strategy)
638 yield
639 else:
640 with strategy.scope():
641 yield
644# ------------------------------------------------------------------------------
645# Defaults that are used when no tf.distribute.Strategy is explicitly created.
646# We create them lazily in a function so that we can workaround the circular
647# dependency on distribute_lib. See lazy loader at the top of this file.
649_defaults = {
650 "strategy": None,
651 "replica_context": None,
652 "replica_mode": None
653}
654# Note: These need to be different locks since _get_default_replica_context
655# calls _get_default_strategy inside its lock, and them using the same lock
656# can lead to deadlock.
657_default_strategy_lock = threading.Lock()
658_default_replica_context_lock = threading.Lock()
659_default_replica_mode_lock = threading.Lock()
662def _assert_strategy(strategy):
663 if not has_strategy():
664 raise RuntimeError('Need to be inside "with strategy.scope()" for %s' %
665 (strategy,))
666 current_strategy = get_strategy()
667 if current_strategy is not strategy:
668 raise RuntimeError(
669 "Mixing different tf.distribute.Strategy objects: %s is not %s" %
670 (current_strategy, strategy))
673def _get_default_strategy():
674 if _defaults["strategy"] is None:
675 # Avoid race condition causing two defaults to be created
676 with _default_strategy_lock:
677 if _defaults["strategy"] is None:
678 # pylint: disable=protected-access
679 # Make sure distribute_lib module is loaded by accessing some member.
680 global _creating_default_strategy_singleton
681 _creating_default_strategy_singleton = True
682 if tf2.enabled():
683 _defaults["strategy"] = _DefaultDistributionStrategy()
684 else:
685 _defaults["strategy"] = (
686 _DefaultDistributionStrategyV1())
687 _creating_default_strategy_singleton = False
688 # pylint: enable=protected-access
689 return _defaults["strategy"]
692def _get_default_replica_context():
693 if _defaults["replica_context"] is None:
694 # Avoid race condition causing two defaults to be created
695 with _default_replica_context_lock:
696 if _defaults["replica_context"] is None:
697 # pylint: disable=protected-access
698 _defaults["replica_context"] = _DefaultReplicaContext(
699 _get_default_strategy(), replica_id_in_sync_group=0)
700 # pylint: enable=protected-access
701 return _defaults["replica_context"]
704def _get_default_replica_mode():
705 if _defaults["replica_mode"] is None:
706 # Avoid race condition causing two defaults to be created
707 with _default_replica_mode_lock:
708 if _defaults["replica_mode"] is None:
709 _defaults["replica_mode"] = _DefaultReplicaThreadMode()
710 return _defaults["replica_mode"]
713# Aliases for compatibility with old names.
714get_distribution_strategy = get_strategy
715has_distribution_strategy = has_strategy
718# ------------------------------------------------------------------------------
719# Internal context managers used to implement the DistributionStrategy
720# base class
723class _CurrentDistributionContext(object):
724 """Context manager setting the current `tf.distribute.Strategy`.
726 Also: overrides the variable creator and optionally the current device.
727 """
729 def __init__(self,
730 strategy,
731 var_creator_scope,
732 var_scope=None,
733 resource_creator_scope=None,
734 default_device=None):
735 self._context = _CrossReplicaThreadMode( # pylint: disable=protected-access
736 strategy)
737 self._var_creator_scope = var_creator_scope
738 self._var_scope = var_scope
739 self._resource_creator_scope = resource_creator_scope
740 if default_device:
741 self._device_scope = ops.device(default_device)
742 else:
743 self._device_scope = None
744 self._same_scope_again_count = 0
746 def __enter__(self):
747 # Allow this scope to be entered if this strategy is already in scope.
748 if has_strategy():
749 _require_cross_replica_or_default_context_extended(
750 self._context.strategy.extended)
751 self._same_scope_again_count += 1
752 else:
753 _push_per_thread_mode(self._context)
754 if self._var_scope:
755 self._var_scope.__enter__()
756 self._var_creator_scope.__enter__()
757 if self._resource_creator_scope:
758 nest.map_structure(lambda scope: scope.__enter__(),
759 self._resource_creator_scope)
760 if self._device_scope:
761 self._device_scope.__enter__()
762 return self._context.strategy
764 def __exit__(self, exception_type, exception_value, traceback):
765 if self._same_scope_again_count > 0:
766 self._same_scope_again_count -= 1
767 return
768 if self._device_scope:
769 try:
770 self._device_scope.__exit__(exception_type, exception_value, traceback)
771 except RuntimeError as e:
772 six.raise_from(
773 RuntimeError("Device scope nesting error: move call to "
774 "tf.distribute.set_strategy() out of `with` scope."),
775 e)
777 try:
778 self._var_creator_scope.__exit__(
779 exception_type, exception_value, traceback)
780 except RuntimeError as e:
781 six.raise_from(
782 RuntimeError("Variable creator scope nesting error: move call to "
783 "tf.distribute.set_strategy() out of `with` scope."),
784 e)
786 if self._resource_creator_scope:
787 try:
788 if isinstance(self._resource_creator_scope, list):
789 reversed_resource_creator_scope = self._resource_creator_scope[::-1]
790 nest.map_structure(
791 lambda scope: scope.__exit__(exception_type, exception_value, # pylint:disable=g-long-lambda
792 traceback),
793 reversed_resource_creator_scope)
795 else:
796 self._resource_creator_scope.__exit__(exception_type, exception_value,
797 traceback)
798 except RuntimeError as e:
799 six.raise_from(
800 RuntimeError("Resource creator scope nesting error: move call "
801 "to tf.distribute.set_strategy() out of `with` "
802 "scope."), e)
804 if self._var_scope:
805 try:
806 self._var_scope.__exit__(exception_type, exception_value, traceback)
807 except RuntimeError as e:
808 six.raise_from(
809 RuntimeError("Variable scope nesting error: move call to "
810 "tf.distribute.set_strategy() out of `with` scope."),
811 e)
812 _pop_per_thread_mode()
815# TODO(yuefengz): add more replication modes.
816@tf_export("distribute.InputReplicationMode")
817class InputReplicationMode(enum.Enum):
818 """Replication mode for input function.
820 * `PER_WORKER`: The input function will be called on each worker
821 independently, creating as many input pipelines as number of workers.
822 Replicas will dequeue from the local Dataset on their worker.
823 `tf.distribute.Strategy` doesn't manage any state sharing between such
824 separate input pipelines.
825 * `PER_REPLICA`: The input function will be called on each replica separately.
826 `tf.distribute.Strategy` doesn't manage any state sharing between such
827 separate input pipelines.
828 """
829 PER_WORKER = "PER_WORKER"
830 PER_REPLICA = "PER_REPLICA"
833@tf_export("distribute.InputContext")
834class InputContext(object):
835 """A class wrapping information needed by an input function.
837 This is a context class that is passed to the user's input function and
838 contains information about the compute replicas and input pipelines. The
839 number of compute replicas (in sync training) helps compute the local batch
840 size from the desired global batch size for each replica. The input pipeline
841 information can be used to return a different subset of the input in each
842 replica (for e.g. shard the input pipeline, use a different input
843 source etc).
844 """
846 __slots__ = [
847 "_num_input_pipelines", "_input_pipeline_id", "_num_replicas_in_sync"
848 ]
850 def __init__(self,
851 num_input_pipelines=1,
852 input_pipeline_id=0,
853 num_replicas_in_sync=1):
854 """Initializes an InputContext object.
856 Args:
857 num_input_pipelines: the number of input pipelines in a cluster.
858 input_pipeline_id: the current input pipeline id, should be an int in
859 [0,`num_input_pipelines`).
860 num_replicas_in_sync: the number of replicas that are in sync.
861 """
862 self._num_input_pipelines = num_input_pipelines
863 self._input_pipeline_id = input_pipeline_id
864 self._num_replicas_in_sync = num_replicas_in_sync
866 @property
867 def num_replicas_in_sync(self):
868 """Returns the number of compute replicas in sync."""
869 return self._num_replicas_in_sync
871 @property
872 def input_pipeline_id(self):
873 """Returns the input pipeline ID."""
874 return self._input_pipeline_id
876 @property
877 def num_input_pipelines(self):
878 """Returns the number of input pipelines."""
879 return self._num_input_pipelines
881 def get_per_replica_batch_size(self, global_batch_size):
882 """Returns the per-replica batch size.
884 Args:
885 global_batch_size: the global batch size which should be divisible by
886 `num_replicas_in_sync`.
888 Returns:
889 the per-replica batch size.
891 Raises:
892 ValueError: if `global_batch_size` not divisible by
893 `num_replicas_in_sync`.
894 """
895 if global_batch_size % self._num_replicas_in_sync != 0:
896 raise ValueError("The `global_batch_size` %r is not divisible by "
897 "`num_replicas_in_sync` %r " %
898 (global_batch_size, self._num_replicas_in_sync))
899 return global_batch_size // self._num_replicas_in_sync
901 def __str__(self):
902 return "tf.distribute.InputContext(input pipeline id {}, total: {})".format(
903 self.input_pipeline_id, self.num_input_pipelines)
906@tf_export("distribute.experimental.ValueContext", v1=[])
907class ValueContext(object):
908 """A class wrapping information needed by a distribute function.
910 This is a context class that is passed to the `value_fn` in
911 `strategy.experimental_distribute_values_from_function` and contains
912 information about the compute replicas. The `num_replicas_in_sync` and
913 `replica_id` can be used to customize the value on each replica.
915 Example usage:
917 1. Directly constructed.
919 >>> def value_fn(context):
920 ... return context.replica_id_in_sync_group/context.num_replicas_in_sync
921 >>> context = tf.distribute.experimental.ValueContext(
922 ... replica_id_in_sync_group=2, num_replicas_in_sync=4)
923 >>> per_replica_value = value_fn(context)
924 >>> per_replica_value
925 0.5
927 2. Passed in by `experimental_distribute_values_from_function`. {: value=2}
929 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
930 >>> def value_fn(value_context):
931 ... return value_context.num_replicas_in_sync
932 >>> distributed_values = (
933 ... strategy.experimental_distribute_values_from_function(
934 ... value_fn))
935 >>> local_result = strategy.experimental_local_results(distributed_values)
936 >>> local_result
937 (2, 2)
939 """
941 __slots__ = ["_replica_id_in_sync_group", "_num_replicas_in_sync"]
943 def __init__(self,
944 replica_id_in_sync_group=0,
945 num_replicas_in_sync=1):
946 """Initializes an ValueContext object.
948 Args:
949 replica_id_in_sync_group: the current replica_id, should be an int in
950 [0,`num_replicas_in_sync`).
951 num_replicas_in_sync: the number of replicas that are in sync.
952 """
953 self._replica_id_in_sync_group = replica_id_in_sync_group
954 self._num_replicas_in_sync = num_replicas_in_sync
956 @property
957 def num_replicas_in_sync(self):
958 """Returns the number of compute replicas in sync."""
959 return self._num_replicas_in_sync
961 @property
962 def replica_id_in_sync_group(self):
963 """Returns the replica ID."""
964 return self._replica_id_in_sync_group
966 def __str__(self):
967 return (("tf.distribute.ValueContext(replica id {}, "
968 " total replicas in sync: ""{})")
969 .format(self.replica_id_in_sync_group, self.num_replicas_in_sync))
972@tf_export("distribute.RunOptions")
973class RunOptions(
974 collections.namedtuple("RunOptions", [
975 "experimental_enable_dynamic_batch_size",
976 "experimental_bucketizing_dynamic_shape",
977 "experimental_xla_options",
978 ])):
979 """Run options for `strategy.run`.
981 This can be used to hold some strategy specific configs.
983 Attributes:
984 experimental_enable_dynamic_batch_size: Boolean. Only applies to
985 TPUStrategy. Default to True. If True, TPUStrategy will enable dynamic
986 padder to support dynamic batch size for the inputs. Otherwise only static
987 shape inputs are allowed.
988 experimental_bucketizing_dynamic_shape: Boolean. Only applies to
989 TPUStrategy. Default to False. If True, TPUStrategy will automatic
990 bucketize inputs passed into `run` if the input shape is
991 dynamic. This is a performance optimization to reduce XLA recompilation,
992 which should not have impact on correctness.
993 experimental_xla_options: A `tf.tpu.XLAOptions` instance. Only applies to
994 TPUStrategy. Controls the XLA compiling options on TPUs. Default to None.
995 """
997 def __new__(cls,
998 experimental_enable_dynamic_batch_size=True,
999 experimental_bucketizing_dynamic_shape=False,
1000 experimental_xla_options=None):
1001 return super(RunOptions,
1002 cls).__new__(cls, experimental_enable_dynamic_batch_size,
1003 experimental_bucketizing_dynamic_shape,
1004 experimental_xla_options)
1007@tf_export("distribute.InputOptions", v1=[])
1008class InputOptions(
1009 collections.namedtuple("InputOptions", [
1010 "experimental_fetch_to_device",
1011 "experimental_replication_mode",
1012 "experimental_place_dataset_on_device",
1013 "experimental_per_replica_buffer_size",
1014 ])):
1015 """Run options for `experimental_distribute_dataset(s_from_function)`.
1017 This can be used to hold some strategy specific configs.
1019 ```python
1020 # Setup TPUStrategy
1021 resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
1022 tf.config.experimental_connect_to_cluster(resolver)
1023 tf.tpu.experimental.initialize_tpu_system(resolver)
1024 strategy = tf.distribute.TPUStrategy(resolver)
1026 dataset = tf.data.Dataset.range(16)
1027 distributed_dataset_on_host = (
1028 strategy.experimental_distribute_dataset(
1029 dataset,
1030 tf.distribute.InputOptions(
1031 experimental_replication_mode=
1032 experimental_replication_mode.PER_WORKER,
1033 experimental_place_dataset_on_device=False,
1034 experimental_per_replica_buffer_size=1)))
1035 ```
1037 Attributes:
1038 experimental_fetch_to_device: Boolean. If True, dataset
1039 elements will be prefetched to accelerator device memory. When False,
1040 dataset elements are prefetched to host device memory. Must be False when
1041 using TPUEmbedding API. experimental_fetch_to_device can only be used
1042 with experimental_replication_mode=PER_WORKER. Default behavior is same as
1043 setting it to True.
1044 experimental_replication_mode: Replication mode for the input function.
1045 Currently, the InputReplicationMode.PER_REPLICA is only supported with
1046 tf.distribute.MirroredStrategy.
1047 experimental_distribute_datasets_from_function.
1048 The default value is InputReplicationMode.PER_WORKER.
1049 experimental_place_dataset_on_device: Boolean. Default to False. When True,
1050 dataset will be placed on the device, otherwise it will remain on the
1051 host. experimental_place_dataset_on_device=True can only be used with
1052 experimental_replication_mode=PER_REPLICA
1053 experimental_per_replica_buffer_size: Integer. Default to 1. Indicates the
1054 prefetch buffer size in the replica device memory. Users can set it
1055 to 0 to completely disable prefetching behavior, or a number greater than
1056 1 to enable larger buffer size. Note that this option is still
1057 valid with `experimental_fetch_to_device=False`.
1058 """
1060 def __new__(cls,
1061 experimental_fetch_to_device=None,
1062 experimental_replication_mode=InputReplicationMode.PER_WORKER,
1063 experimental_place_dataset_on_device=False,
1064 experimental_per_replica_buffer_size=1):
1065 if experimental_fetch_to_device is None:
1066 experimental_fetch_to_device = True
1068 return super(InputOptions,
1069 cls).__new__(cls, experimental_fetch_to_device,
1070 experimental_replication_mode,
1071 experimental_place_dataset_on_device,
1072 experimental_per_replica_buffer_size)
1074# ------------------------------------------------------------------------------
1075# Base classes for all distribution strategies.
1078# Base class for v1 Strategy and v2 Strategy classes. For API's specific to
1079# v1/v2 Strategy, add to implementing classes of StrategyBase.
1080# pylint: disable=line-too-long
1081class StrategyBase(object):
1082 """A state & compute distribution policy on a list of devices.
1084 See [the guide](https://www.tensorflow.org/guide/distributed_training)
1085 for overview and examples. See `tf.distribute.StrategyExtended` and
1086 [`tf.distribute`](https://www.tensorflow.org/api_docs/python/tf/distribute)
1087 for a glossary of concepts mentioned on this page such as "per-replica",
1088 _replica_, and _reduce_.
1090 In short:
1092 * To use it with Keras `compile`/`fit`,
1093 [please
1094 read](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_keras).
1095 * You may pass descendant of `tf.distribute.Strategy` to
1096 `tf.estimator.RunConfig` to specify how a `tf.estimator.Estimator`
1097 should distribute its computation. See
1098 [guide](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_estimator_limited_support).
1099 * Otherwise, use `tf.distribute.Strategy.scope` to specify that a
1100 strategy should be used when building an executing your model.
1101 (This puts you in the "cross-replica context" for this strategy, which
1102 means the strategy is put in control of things like variable placement.)
1103 * If you are writing a custom training loop, you will need to call a few more
1104 methods,
1105 [see the
1106 guide](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_custom_training_loops):
1108 * Start by creating a `tf.data.Dataset` normally.
1109 * Use `tf.distribute.Strategy.experimental_distribute_dataset` to convert
1110 a `tf.data.Dataset` to something that produces "per-replica" values.
1111 If you want to manually specify how the dataset should be partitioned
1112 across replicas, use
1113 `tf.distribute.Strategy.distribute_datasets_from_function`
1114 instead.
1115 * Use `tf.distribute.Strategy.run` to run a function
1116 once per replica, taking values that may be "per-replica" (e.g.
1117 from a `tf.distribute.DistributedDataset` object) and returning
1118 "per-replica" values.
1119 This function is executed in "replica context", which means each
1120 operation is performed separately on each replica.
1121 * Finally use a method (such as `tf.distribute.Strategy.reduce`) to
1122 convert the resulting "per-replica" values into ordinary `Tensor`s.
1124 A custom training loop can be as simple as:
1126 ```
1127 with my_strategy.scope():
1128 @tf.function
1129 def distribute_train_epoch(dataset):
1130 def replica_fn(input):
1131 # process input and return result
1132 return result
1134 total_result = 0
1135 for x in dataset:
1136 per_replica_result = my_strategy.run(replica_fn, args=(x,))
1137 total_result += my_strategy.reduce(tf.distribute.ReduceOp.SUM,
1138 per_replica_result, axis=None)
1139 return total_result
1141 dist_dataset = my_strategy.experimental_distribute_dataset(dataset)
1142 for _ in range(EPOCHS):
1143 train_result = distribute_train_epoch(dist_dataset)
1144 ```
1146 This takes an ordinary `dataset` and `replica_fn` and runs it
1147 distributed using a particular `tf.distribute.Strategy` named
1148 `my_strategy` above. Any variables created in `replica_fn` are created
1149 using `my_strategy`'s policy, and library functions called by
1150 `replica_fn` can use the `get_replica_context()` API to implement
1151 distributed-specific behavior.
1153 You can use the `reduce` API to aggregate results across replicas and use
1154 this as a return value from one iteration over a
1155 `tf.distribute.DistributedDataset`. Or
1156 you can use `tf.keras.metrics` (such as loss, accuracy, etc.) to
1157 accumulate metrics across steps in a given epoch.
1159 See the
1160 [custom training loop
1161 tutorial](https://www.tensorflow.org/tutorials/distribute/custom_training)
1162 for a more detailed example.
1164 Note: `tf.distribute.Strategy` currently does not support TensorFlow's
1165 partitioned variables (where a single variable is split across multiple
1166 devices) at this time.
1167 """
1168 # pylint: enable=line-too-long
1170 # TODO(josh11b): Partitioned computations, state; sharding
1171 # TODO(josh11b): Model parallelism: "replicas" with multiple devices; shuffling
1173 def __init__(self, extended):
1174 self._extended = extended
1176 # Flag that is used to indicate whether distribution strategy is used with
1177 # Estimator. This is required for backward compatibility of loss scaling
1178 # when using v1 optimizer with estimator.
1179 self._scale_loss_for_estimator = False
1181 if not hasattr(extended, "_retrace_functions_for_each_device"):
1182 # pylint: disable=protected-access
1183 # `extended._retrace_functions_for_each_device` dictates
1184 # whether the same function will be retraced when it is called on
1185 # different devices.
1186 try:
1187 extended._retrace_functions_for_each_device = (
1188 len(extended.worker_devices) > 1)
1189 distribution_strategy_replica_gauge.get_cell("num_replicas").set(
1190 self.num_replicas_in_sync)
1191 except: # pylint: disable=bare-except
1192 # Default for the case where extended.worker_devices can't return
1193 # a sensible value.
1194 extended._retrace_functions_for_each_device = True
1196 # Below are the dicts of axis(int) -> `tf.function`.
1197 self._mean_reduce_helper_fns = {}
1198 self._reduce_sum_fns = {}
1200 # Whether this strategy is designed to work with `ClusterCoordinator`.
1201 self._should_use_with_coordinator = False
1203 @property
1204 def extended(self):
1205 """`tf.distribute.StrategyExtended` with additional methods."""
1206 return self._extended
1208 @tf_contextlib.contextmanager
1209 def _scale_loss_for_estimator_enabled(self):
1210 """Scope which sets a flag used for scaling losses in optimizer.
1212 Yields:
1213 `_scale_loss_for_estimator_enabled` is a context manager with a
1214 side effect, but doesn't return a value.
1215 """
1216 self._scale_loss_for_estimator = True
1217 try:
1218 yield
1219 finally:
1220 self._scale_loss_for_estimator = False
1222 # pylint: disable=line-too-long
1223 def scope(self):
1224 """Context manager to make the strategy current and distribute variables.
1226 This method returns a context manager, and is used as follows:
1228 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
1229 >>> # Variable created inside scope:
1230 >>> with strategy.scope():
1231 ... mirrored_variable = tf.Variable(1.)
1232 >>> mirrored_variable
1233 MirroredVariable:{
1234 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>,
1235 1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0>
1236 }
1237 >>> # Variable created outside scope:
1238 >>> regular_variable = tf.Variable(1.)
1239 >>> regular_variable
1240 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>
1242 _What happens when Strategy.scope is entered?_
1244 * `strategy` is installed in the global context as the "current" strategy.
1245 Inside this scope, `tf.distribute.get_strategy()` will now return this
1246 strategy. Outside this scope, it returns the default no-op strategy.
1247 * Entering the scope also enters the "cross-replica context". See
1248 `tf.distribute.StrategyExtended` for an explanation on cross-replica and
1249 replica contexts.
1250 * Variable creation inside `scope` is intercepted by the strategy. Each
1251 strategy defines how it wants to affect the variable creation. Sync
1252 strategies like `MirroredStrategy`, `TPUStrategy` and
1253 `MultiWorkerMiroredStrategy` create variables replicated on each replica,
1254 whereas `ParameterServerStrategy` creates variables on the parameter
1255 servers. This is done using a custom `tf.variable_creator_scope`.
1256 * In some strategies, a default device scope may also be entered: in
1257 `MultiWorkerMiroredStrategy`, a default device scope of "/CPU:0" is
1258 entered on each worker.
1260 Note: Entering a scope does not automatically distribute a computation, except
1261 in the case of high level training framework like keras `model.fit`. If
1262 you're not using `model.fit`, you
1263 need to use `strategy.run` API to explicitly distribute that computation.
1264 See an example in the [custom training loop tutorial](https://www.tensorflow.org/tutorials/distribute/custom_training).
1267 _What should be in scope and what should be outside?_
1269 There are a number of requirements on what needs to happen inside the scope.
1270 However, in places where we have information about which strategy is in use,
1271 we often enter the scope for the user, so they don't have to do it
1272 explicitly (i.e. calling those either inside or outside the scope is OK).
1274 * Anything that creates variables that should be distributed variables
1275 must be called in a `strategy.scope`. This can be accomplished either by
1276 directly calling the variable creating function within the scope context,
1277 or by relying on another API like `strategy.run` or `keras.Model.fit` to
1278 automatically enter it for you. Any variable that is created outside scope
1279 will not be distributed and may have performance implications. Some common
1280 objects that create variables in TF are Models, Optimizers, Metrics. Such
1281 objects should always be initialized in the scope, and any functions
1282 that may lazily create variables (e.g., `Model.__call__()`, tracing a
1283 `tf.function`, etc.) should similarly be called within scope. Another
1284 source of variable creation can be a checkpoint restore - when variables
1285 are created lazily. Note that any variable created inside a strategy
1286 captures the strategy information. So reading and writing to these
1287 variables outside the `strategy.scope` can also work seamlessly, without
1288 the user having to enter the scope.
1289 * Some strategy APIs (such as `strategy.run` and `strategy.reduce`) which
1290 require to be in a strategy's scope, enter the scope automatically, which
1291 means when using those APIs you don't need to explicitly enter the scope
1292 yourself.
1293 * When a `tf.keras.Model` is created inside a `strategy.scope`, the Model
1294 object captures the scope information. When high level training framework
1295 methods such as `model.compile`, `model.fit`, etc. are then called, the
1296 captured scope will be automatically entered, and the associated strategy
1297 will be used to distribute the training etc. See a detailed example in
1298 [distributed keras tutorial](https://www.tensorflow.org/tutorials/distribute/keras).
1299 WARNING: Simply calling `model(..)` does not automatically enter the
1300 captured scope -- only high level training framework APIs support this
1301 behavior: `model.compile`, `model.fit`, `model.evaluate`, `model.predict`
1302 and `model.save` can all be called inside or outside the scope.
1303 * The following can be either inside or outside the scope:
1304 * Creating the input datasets
1305 * Defining `tf.function`s that represent your training step
1306 * Saving APIs such as `tf.saved_model.save`. Loading creates variables,
1307 so that should go inside the scope if you want to train the model in a
1308 distributed way.
1309 * Checkpoint saving. As mentioned above - `checkpoint.restore` may
1310 sometimes need to be inside scope if it creates variables.
1312 Returns:
1313 A context manager.
1314 """
1315 return self._extended._scope(self) # pylint: disable=protected-access
1316 # pylint: enable=line-too-long
1318 @doc_controls.do_not_doc_inheritable # DEPRECATED, moving to `extended`
1319 @deprecated(None, "use extended.colocate_vars_with() instead.")
1320 def colocate_vars_with(self, colocate_with_variable):
1321 """DEPRECATED: use extended.colocate_vars_with() instead."""
1322 return self._extended.colocate_vars_with(colocate_with_variable)
1324 @doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only
1325 def make_dataset_iterator(self, dataset):
1326 """DEPRECATED TF 1.x ONLY."""
1327 return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access
1329 @doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only
1330 def make_input_fn_iterator(self,
1331 input_fn,
1332 replication_mode=InputReplicationMode.PER_WORKER):
1333 """DEPRECATED TF 1.x ONLY."""
1334 if replication_mode != InputReplicationMode.PER_WORKER:
1335 raise ValueError(
1336 "Input replication mode not supported: %r" % replication_mode)
1337 with self.scope():
1338 return self.extended._make_input_fn_iterator( # pylint: disable=protected-access
1339 input_fn, replication_mode=replication_mode)
1341 @doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only
1342 @deprecated(None, "use run() instead")
1343 def experimental_run(self, fn, input_iterator=None):
1344 """DEPRECATED TF 1.x ONLY."""
1345 with self.scope():
1346 args = (input_iterator.get_next(),) if input_iterator is not None else ()
1347 return self.run(fn, args=args)
1349 def experimental_distribute_dataset(self, dataset, options=None):
1350 # pylint: disable=line-too-long
1351 """Creates `tf.distribute.DistributedDataset` from `tf.data.Dataset`.
1353 The returned `tf.distribute.DistributedDataset` can be iterated over
1354 similar to regular datasets.
1355 NOTE: The user cannot add any more transformations to a
1356 `tf.distribute.DistributedDataset`. You can only create an iterator or
1357 examine the `tf.TypeSpec` of the data generated by it. See API docs of
1358 `tf.distribute.DistributedDataset` to learn more.
1360 The following is an example:
1362 >>> global_batch_size = 2
1363 >>> # Passing the devices is optional.
1364 ... strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"])
1365 >>> # Create a dataset
1366 ... dataset = tf.data.Dataset.range(4).batch(global_batch_size)
1367 >>> # Distribute that dataset
1368 ... dist_dataset = strategy.experimental_distribute_dataset(dataset)
1369 >>> @tf.function
1370 ... def replica_fn(input):
1371 ... return input*2
1372 >>> result = []
1373 >>> # Iterate over the `tf.distribute.DistributedDataset`
1374 ... for x in dist_dataset:
1375 ... # process dataset elements
1376 ... result.append(strategy.run(replica_fn, args=(x,)))
1377 >>> print(result)
1378 [PerReplica:{
1379 0: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([0])>,
1380 1: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([2])>
1381 }, PerReplica:{
1382 0: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([4])>,
1383 1: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([6])>
1384 }]
1387 Three key actions happening under the hood of this method are batching,
1388 sharding, and prefetching.
1390 In the code snippet above, `dataset` is batched by `global_batch_size`, and
1391 calling `experimental_distribute_dataset` on it rebatches `dataset` to a
1392 new batch size that is equal to the global batch size divided by the number
1393 of replicas in sync. We iterate through it using a Pythonic for loop.
1394 `x` is a `tf.distribute.DistributedValues` containing data for all replicas,
1395 and each replica gets data of the new batch size.
1396 `tf.distribute.Strategy.run` will take care of feeding the right per-replica
1397 data in `x` to the right `replica_fn` executed on each replica.
1399 Sharding contains autosharding across multiple workers and within every
1400 worker. First, in multi-worker distributed training (i.e. when you use
1401 `tf.distribute.experimental.MultiWorkerMirroredStrategy`
1402 or `tf.distribute.TPUStrategy`), autosharding a dataset over a set of
1403 workers means that each worker is assigned a subset of the entire dataset
1404 (if the right `tf.data.experimental.AutoShardPolicy` is set). This is to
1405 ensure that at each step, a global batch size of non-overlapping dataset
1406 elements will be processed by each worker. Autosharding has a couple of
1407 different options that can be specified using
1408 `tf.data.experimental.DistributeOptions`. Then, sharding within each worker
1409 means the method will split the data among all the worker devices (if more
1410 than one a present). This will happen regardless of multi-worker
1411 autosharding.
1413 Note: for autosharding across multiple workers, the default mode is
1414 `tf.data.experimental.AutoShardPolicy.AUTO`. This mode
1415 will attempt to shard the input dataset by files if the dataset is
1416 being created out of reader datasets (e.g. `tf.data.TFRecordDataset`,
1417 `tf.data.TextLineDataset`, etc.) or otherwise shard the dataset by data,
1418 where each of the workers will read the entire dataset and only process the
1419 shard assigned to it. However, if you have less than one input file per
1420 worker, we suggest that you disable dataset autosharding across workers by
1421 setting the `tf.data.experimental.DistributeOptions.auto_shard_policy` to be
1422 `tf.data.experimental.AutoShardPolicy.OFF`.
1424 By default, this method adds a prefetch transformation at the end of the
1425 user provided `tf.data.Dataset` instance. The argument to the prefetch
1426 transformation which is `buffer_size` is equal to the number of replicas in
1427 sync.
1429 If the above batch splitting and dataset sharding logic is undesirable,
1430 please use
1431 `tf.distribute.Strategy.distribute_datasets_from_function`
1432 instead, which does not do any automatic batching or sharding for you.
1434 Note: If you are using TPUStrategy, the order in which the data is processed
1435 by the workers when using
1436 `tf.distribute.Strategy.experimental_distribute_dataset` or
1437 `tf.distribute.Strategy.distribute_datasets_from_function` is
1438 not guaranteed. This is typically required if you are using
1439 `tf.distribute` to scale prediction. You can however insert an index for
1440 each element in the batch and order outputs accordingly. Refer to [this
1441 snippet](https://www.tensorflow.org/tutorials/distribute/input#caveats)
1442 for an example of how to order outputs.
1444 Note: Stateful dataset transformations are currently not supported with
1445 `tf.distribute.experimental_distribute_dataset` or
1446 `tf.distribute.distribute_datasets_from_function`. Any stateful
1447 ops that the dataset may have are currently ignored. For example, if your
1448 dataset has a `map_fn` that uses `tf.random.uniform` to rotate an image,
1449 then you have a dataset graph that depends on state (i.e the random seed) on
1450 the local machine where the python process is being executed.
1452 For a tutorial on more usage and properties of this method, refer to the
1453 [tutorial on distributed input](https://www.tensorflow.org/tutorials/distribute/input#tfdistributestrategyexperimental_distribute_dataset).
1454 If you are interested in last partial batch handling, read [this section](https://www.tensorflow.org/tutorials/distribute/input#partial_batches).
1456 Args:
1457 dataset: `tf.data.Dataset` that will be sharded across all replicas using
1458 the rules stated above.
1459 options: `tf.distribute.InputOptions` used to control options on how this
1460 dataset is distributed.
1462 Returns:
1463 A `tf.distribute.DistributedDataset`.
1464 """
1465 distribution_strategy_input_api_counter.get_cell(
1466 self.__class__.__name__, "distribute_dataset").increase_by(1)
1467 # pylint: enable=line-too-long
1468 return self._extended._experimental_distribute_dataset(dataset, options) # pylint: disable=protected-access
1470 def distribute_datasets_from_function(self, dataset_fn, options=None):
1471 # pylint: disable=line-too-long
1472 """Distributes `tf.data.Dataset` instances created by calls to `dataset_fn`.
1474 The argument `dataset_fn` that users pass in is an input function that has a
1475 `tf.distribute.InputContext` argument and returns a `tf.data.Dataset`
1476 instance. It is expected that the returned dataset from `dataset_fn` is
1477 already batched by per-replica batch size (i.e. global batch size divided by
1478 the number of replicas in sync) and sharded.
1479 `tf.distribute.Strategy.distribute_datasets_from_function` does
1480 not batch or shard the `tf.data.Dataset` instance
1481 returned from the input function. `dataset_fn` will be called on the CPU
1482 device of each of the workers and each generates a dataset where every
1483 replica on that worker will dequeue one batch of inputs (i.e. if a worker
1484 has two replicas, two batches will be dequeued from the `Dataset` every
1485 step).
1487 This method can be used for several purposes. First, it allows you to
1488 specify your own batching and sharding logic. (In contrast,
1489 `tf.distribute.experimental_distribute_dataset` does batching and sharding
1490 for you.) For example, where
1491 `experimental_distribute_dataset` is unable to shard the input files, this
1492 method might be used to manually shard the dataset (avoiding the slow
1493 fallback behavior in `experimental_distribute_dataset`). In cases where the
1494 dataset is infinite, this sharding can be done by creating dataset replicas
1495 that differ only in their random seed.
1497 The `dataset_fn` should take an `tf.distribute.InputContext` instance where
1498 information about batching and input replication can be accessed.
1500 You can use `element_spec` property of the
1501 `tf.distribute.DistributedDataset` returned by this API to query the
1502 `tf.TypeSpec` of the elements returned by the iterator. This can be used to
1503 set the `input_signature` property of a `tf.function`. Follow
1504 `tf.distribute.DistributedDataset.element_spec` to see an example.
1506 IMPORTANT: The `tf.data.Dataset` returned by `dataset_fn` should have a
1507 per-replica batch size, unlike `experimental_distribute_dataset`, which uses
1508 the global batch size. This may be computed using
1509 `input_context.get_per_replica_batch_size`.
1511 Note: If you are using TPUStrategy, the order in which the data is processed
1512 by the workers when using
1513 `tf.distribute.Strategy.experimental_distribute_dataset` or
1514 `tf.distribute.Strategy.distribute_datasets_from_function` is
1515 not guaranteed. This is typically required if you are using
1516 `tf.distribute` to scale prediction. You can however insert an index for
1517 each element in the batch and order outputs accordingly. Refer to [this
1518 snippet](https://www.tensorflow.org/tutorials/distribute/input#caveats)
1519 for an example of how to order outputs.
1521 Note: Stateful dataset transformations are currently not supported with
1522 `tf.distribute.experimental_distribute_dataset` or
1523 `tf.distribute.distribute_datasets_from_function`. Any stateful
1524 ops that the dataset may have are currently ignored. For example, if your
1525 dataset has a `map_fn` that uses `tf.random.uniform` to rotate an image,
1526 then you have a dataset graph that depends on state (i.e the random seed) on
1527 the local machine where the python process is being executed.
1529 For a tutorial on more usage and properties of this method, refer to the
1530 [tutorial on distributed input](https://www.tensorflow.org/tutorials/distribute/input#tfdistributestrategyexperimental_distribute_datasets_from_function)).
1531 If you are interested in last partial batch handling, read [this section](https://www.tensorflow.org/tutorials/distribute/input#partial_batches).
1533 Args:
1534 dataset_fn: A function taking a `tf.distribute.InputContext` instance and
1535 returning a `tf.data.Dataset`.
1536 options: `tf.distribute.InputOptions` used to control options on how this
1537 dataset is distributed.
1539 Returns:
1540 A `tf.distribute.DistributedDataset`.
1541 """
1542 distribution_strategy_input_api_counter.get_cell(
1543 self.__class__.__name__,
1544 "distribute_datasets_from_function").increase_by(1)
1545 # pylint: enable=line-too-long
1546 return self._extended._distribute_datasets_from_function( # pylint: disable=protected-access
1547 dataset_fn, options)
1549 # TODO(b/162776748): Remove deprecated symbol.
1550 @doc_controls.do_not_doc_inheritable
1551 @deprecation.deprecated(None, "rename to distribute_datasets_from_function")
1552 def experimental_distribute_datasets_from_function(self,
1553 dataset_fn,
1554 options=None):
1555 return self.distribute_datasets_from_function(dataset_fn, options)
1557 def run(self, fn, args=(), kwargs=None, options=None):
1558 """Invokes `fn` on each replica, with the given arguments.
1560 This method is the primary way to distribute your computation with a
1561 tf.distribute object. It invokes `fn` on each replica. If `args` or `kwargs`
1562 have `tf.distribute.DistributedValues`, such as those produced by a
1563 `tf.distribute.DistributedDataset` from
1564 `tf.distribute.Strategy.experimental_distribute_dataset` or
1565 `tf.distribute.Strategy.distribute_datasets_from_function`,
1566 when `fn` is executed on a particular replica, it will be executed with the
1567 component of `tf.distribute.DistributedValues` that correspond to that
1568 replica.
1570 `fn` is invoked under a replica context. `fn` may call
1571 `tf.distribute.get_replica_context()` to access members such as
1572 `all_reduce`. Please see the module-level docstring of tf.distribute for the
1573 concept of replica context.
1575 All arguments in `args` or `kwargs` can be a nested structure of tensors,
1576 e.g. a list of tensors, in which case `args` and `kwargs` will be passed to
1577 the `fn` invoked on each replica. Or `args` or `kwargs` can be
1578 `tf.distribute.DistributedValues` containing tensors or composite tensors,
1579 i.e. `tf.compat.v1.TensorInfo.CompositeTensor`, in which case each `fn` call
1580 will get the component of a `tf.distribute.DistributedValues` corresponding
1581 to its replica. Note that arbitrary Python values that are not of the types
1582 above are not supported.
1584 IMPORTANT: Depending on the implementation of `tf.distribute.Strategy` and
1585 whether eager execution is enabled, `fn` may be called one or more times. If
1586 `fn` is annotated with `tf.function` or `tf.distribute.Strategy.run` is
1587 called inside a `tf.function` (eager execution is disabled inside a
1588 `tf.function` by default), `fn` is called once per replica to generate a
1589 Tensorflow graph, which will then be reused for execution with new inputs.
1590 Otherwise, if eager execution is enabled, `fn` will be called once per
1591 replica every step just like regular python code.
1593 Example usage:
1595 1. Constant tensor input.
1597 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
1598 >>> tensor_input = tf.constant(3.0)
1599 >>> @tf.function
1600 ... def replica_fn(input):
1601 ... return input*2.0
1602 >>> result = strategy.run(replica_fn, args=(tensor_input,))
1603 >>> result
1604 PerReplica:{
1605 0: <tf.Tensor: shape=(), dtype=float32, numpy=6.0>,
1606 1: <tf.Tensor: shape=(), dtype=float32, numpy=6.0>
1607 }
1609 2. DistributedValues input. {: value=2}
1611 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
1612 >>> @tf.function
1613 ... def run():
1614 ... def value_fn(value_context):
1615 ... return value_context.num_replicas_in_sync
1616 ... distributed_values = (
1617 ... strategy.experimental_distribute_values_from_function(
1618 ... value_fn))
1619 ... def replica_fn2(input):
1620 ... return input*2
1621 ... return strategy.run(replica_fn2, args=(distributed_values,))
1622 >>> result = run()
1623 >>> result
1624 <tf.Tensor: shape=(), dtype=int32, numpy=4>
1626 3. Use `tf.distribute.ReplicaContext` to allreduce values. {: value=3}
1628 >>> strategy = tf.distribute.MirroredStrategy(["gpu:0", "gpu:1"])
1629 >>> @tf.function
1630 ... def run():
1631 ... def value_fn(value_context):
1632 ... return tf.constant(value_context.replica_id_in_sync_group)
1633 ... distributed_values = (
1634 ... strategy.experimental_distribute_values_from_function(
1635 ... value_fn))
1636 ... def replica_fn(input):
1637 ... return tf.distribute.get_replica_context().all_reduce(
1638 ... "sum", input)
1639 ... return strategy.run(replica_fn, args=(distributed_values,))
1640 >>> result = run()
1641 >>> result
1642 PerReplica:{
1643 0: <tf.Tensor: shape=(), dtype=int32, numpy=1>,
1644 1: <tf.Tensor: shape=(), dtype=int32, numpy=1>
1645 }
1647 Args:
1648 fn: The function to run on each replica.
1649 args: Optional positional arguments to `fn`. Its element can be a tensor,
1650 a nested structure of tensors or a `tf.distribute.DistributedValues`.
1651 kwargs: Optional keyword arguments to `fn`. Its element can be a tensor,
1652 a nested structure of tensors or a `tf.distribute.DistributedValues`.
1653 options: An optional instance of `tf.distribute.RunOptions` specifying
1654 the options to run `fn`.
1656 Returns:
1657 Merged return value of `fn` across replicas. The structure of the return
1658 value is the same as the return value from `fn`. Each element in the
1659 structure can either be `tf.distribute.DistributedValues`, `Tensor`
1660 objects, or `Tensor`s (for example, if running on a single replica).
1661 """
1662 del options
1664 if not isinstance(args, (list, tuple)):
1665 raise ValueError(
1666 "positional args must be a list or tuple, got {}".format(type(args)))
1668 with self.scope():
1669 # tf.distribute supports Eager functions, so AutoGraph should not be
1670 # applied when the caller is also in Eager mode.
1671 fn = autograph.tf_convert(
1672 fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
1673 return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
1675 def reduce(self, reduce_op, value, axis):
1676 """Reduce `value` across replicas and return result on current device.
1678 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
1679 >>> def step_fn():
1680 ... i = tf.distribute.get_replica_context().replica_id_in_sync_group
1681 ... return tf.identity(i)
1682 >>>
1683 >>> per_replica_result = strategy.run(step_fn)
1684 >>> total = strategy.reduce("SUM", per_replica_result, axis=None)
1685 >>> total
1686 <tf.Tensor: shape=(), dtype=int32, numpy=1>
1688 To see how this would look with multiple replicas, consider the same
1689 example with MirroredStrategy with 2 GPUs:
1691 ```python
1692 strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"])
1693 def step_fn():
1694 i = tf.distribute.get_replica_context().replica_id_in_sync_group
1695 return tf.identity(i)
1697 per_replica_result = strategy.run(step_fn)
1698 # Check devices on which per replica result is:
1699 strategy.experimental_local_results(per_replica_result)[0].device
1700 # /job:localhost/replica:0/task:0/device:GPU:0
1701 strategy.experimental_local_results(per_replica_result)[1].device
1702 # /job:localhost/replica:0/task:0/device:GPU:1
1704 total = strategy.reduce("SUM", per_replica_result, axis=None)
1705 # Check device on which reduced result is:
1706 total.device
1707 # /job:localhost/replica:0/task:0/device:CPU:0
1709 ```
1711 This API is typically used for aggregating the results returned from
1712 different replicas, for reporting etc. For example, loss computed from
1713 different replicas can be averaged using this API before printing.
1715 Note: The result is copied to the "current" device - which would typically
1716 be the CPU of the worker on which the program is running. For `TPUStrategy`,
1717 it is the first TPU host. For multi client `MultiWorkerMirroredStrategy`,
1718 this is CPU of each worker.
1720 There are a number of different tf.distribute APIs for reducing values
1721 across replicas:
1722 * `tf.distribute.ReplicaContext.all_reduce`: This differs from
1723 `Strategy.reduce` in that it is for replica context and does
1724 not copy the results to the host device. `all_reduce` should be typically
1725 used for reductions inside the training step such as gradients.
1726 * `tf.distribute.StrategyExtended.reduce_to` and
1727 `tf.distribute.StrategyExtended.batch_reduce_to`: These APIs are more
1728 advanced versions of `Strategy.reduce` as they allow customizing the
1729 destination of the result. They are also called in cross replica context.
1731 _What should axis be?_
1733 Given a per-replica value returned by `run`, say a
1734 per-example loss, the batch will be divided across all the replicas. This
1735 function allows you to aggregate across replicas and optionally also across
1736 batch elements by specifying the axis parameter accordingly.
1738 For example, if you have a global batch size of 8 and 2
1739 replicas, values for examples `[0, 1, 2, 3]` will be on replica 0 and
1740 `[4, 5, 6, 7]` will be on replica 1. With `axis=None`, `reduce` will
1741 aggregate only across replicas, returning `[0+4, 1+5, 2+6, 3+7]`.
1742 This is useful when each replica is computing a scalar or some other value
1743 that doesn't have a "batch" dimension (like a gradient or loss).
1744 ```
1745 strategy.reduce("sum", per_replica_result, axis=None)
1746 ```
1748 Sometimes, you will want to aggregate across both the global batch _and_
1749 all replicas. You can get this behavior by specifying the batch
1750 dimension as the `axis`, typically `axis=0`. In this case it would return a
1751 scalar `0+1+2+3+4+5+6+7`.
1752 ```
1753 strategy.reduce("sum", per_replica_result, axis=0)
1754 ```
1756 If there is a last partial batch, you will need to specify an axis so
1757 that the resulting shape is consistent across replicas. So if the last
1758 batch has size 6 and it is divided into [0, 1, 2, 3] and [4, 5], you
1759 would get a shape mismatch unless you specify `axis=0`. If you specify
1760 `tf.distribute.ReduceOp.MEAN`, using `axis=0` will use the correct
1761 denominator of 6. Contrast this with computing `reduce_mean` to get a
1762 scalar value on each replica and this function to average those means,
1763 which will weigh some values `1/8` and others `1/4`.
1765 Args:
1766 reduce_op: a `tf.distribute.ReduceOp` value specifying how values should
1767 be combined. Allows using string representation of the enum such as
1768 "SUM", "MEAN".
1769 value: a `tf.distribute.DistributedValues` instance, e.g. returned by
1770 `Strategy.run`, to be combined into a single tensor. It can also be a
1771 regular tensor when used with `OneDeviceStrategy` or default strategy.
1772 axis: specifies the dimension to reduce along within each
1773 replica's tensor. Should typically be set to the batch dimension, or
1774 `None` to only reduce across replicas (e.g. if the tensor has no batch
1775 dimension).
1777 Returns:
1778 A `Tensor`.
1779 """
1780 # TODO(josh11b): support `value` being a nest.
1781 _require_cross_replica_or_default_context_extended(self._extended)
1782 if isinstance(reduce_op, six.string_types):
1783 reduce_op = reduce_util.ReduceOp(reduce_op.upper())
1784 if axis is None:
1785 return self._extended._reduce(reduce_op, value) # pylint: disable=protected-access
1786 if reduce_op == reduce_util.ReduceOp.SUM:
1788 def reduce_sum(v):
1789 return math_ops.reduce_sum(v, axis=axis)
1791 if eager_context.executing_eagerly():
1792 # As some strategies (e.g. TPUStrategy) doesn't support pure eager
1793 # execution, wrap the `reduce_sum_fn` with a `tf.function` so it can be
1794 # run from eager mode. Cache the tf.function by `axis` to avoid the
1795 # same function to be traced again.
1796 if axis not in self._reduce_sum_fns:
1797 self._reduce_sum_fns[axis] = def_function.function(reduce_sum)
1798 value = self.run(self._reduce_sum_fns[axis], args=(value,))
1799 else:
1800 value = self.run(reduce_sum, args=(value,))
1802 return self._extended._reduce(reduce_op, value) # pylint: disable=protected-access
1803 if reduce_op != reduce_util.ReduceOp.MEAN:
1804 raise TypeError("Expected `reduce_op` to be a `tf.distribute.ReduceOp`, "
1805 "not: %r" % reduce_op)
1807 def mean_reduce_helper(v, axes=axis):
1808 """Computes the numerator and denominator on each replica."""
1809 numer = math_ops.reduce_sum(v, axis=axes)
1810 def dimension(axis):
1811 if v.shape.rank is not None:
1812 # Note(joshl): We support axis < 0 to be consistent with the
1813 # tf.math.reduce_* operations.
1814 if axis < 0:
1815 if axis + v.shape.rank < 0:
1816 raise ValueError(
1817 "`axis` = %r out of range for `value` with rank %d" %
1818 (axis, v.shape.rank))
1819 axis += v.shape.rank
1820 elif axis >= v.shape.rank:
1821 raise ValueError(
1822 "`axis` = %r out of range for `value` with rank %d" %
1823 (axis, v.shape.rank))
1824 # TF v2 returns `None` for unknown dimensions and an integer for
1825 # known dimension, whereas TF v1 returns tensor_shape.Dimension(None)
1826 # or tensor_shape.Dimension(integer). `dimension_value` hides this
1827 # difference, always returning `None` or an integer.
1828 dim = tensor_shape.dimension_value(v.shape[axis])
1829 if dim is not None:
1830 # By returning a python value in the static shape case, we can
1831 # maybe get a fast path for reducing the denominator.
1832 # TODO(b/151871486): Remove array_ops.identity after we fallback to
1833 # simple reduction if inputs are all on CPU.
1834 return array_ops.identity(
1835 constant_op.constant(dim, dtype=dtypes.int64))
1836 elif axis < 0:
1837 axis = axis + array_ops.rank(v)
1838 # TODO(b/151871486): Remove array_ops.identity after we fallback to
1839 # simple reduction if inputs are all on CPU.
1840 return array_ops.identity(
1841 array_ops.shape_v2(v, out_type=dtypes.int64)[axis])
1842 if isinstance(axis, six.integer_types):
1843 denom = dimension(axis)
1844 elif isinstance(axis, (tuple, list)):
1845 denom = math_ops.reduce_prod([dimension(a) for a in axes])
1846 else:
1847 raise TypeError(
1848 "Expected `axis` to be an integer, tuple or list not: %r" % axis)
1849 # TODO(josh11b): Should we cast denom to v.dtype here instead of after the
1850 # reduce is complete?
1851 return numer, denom
1853 if eager_context.executing_eagerly():
1854 # As some strategies (e.g. TPUStrategy) doesn't support pure eager
1855 # execution, wrap the `mean_reduce_helper` with a `tf.function` so it can
1856 # be run from eager mode. Cache the tf.function by `axis` to avoid the
1857 # same function to be traced again.
1858 if axis not in self._mean_reduce_helper_fns:
1859 self._mean_reduce_helper_fns[axis] = def_function.function(
1860 mean_reduce_helper)
1861 numer, denom = self.run(self._mean_reduce_helper_fns[axis], args=(value,))
1862 else:
1863 numer, denom = self.run(mean_reduce_helper, args=(value,))
1865 # TODO(josh11b): Should batch reduce here instead of doing two.
1866 numer = self._extended._reduce(reduce_util.ReduceOp.SUM, numer) # pylint: disable=protected-access
1867 denom = self._extended._reduce(reduce_util.ReduceOp.SUM, denom) # pylint: disable=protected-access
1868 denom = math_ops.cast(denom, numer.dtype)
1869 return math_ops.truediv(numer, denom)
1871 @doc_controls.do_not_doc_inheritable # DEPRECATED
1872 @deprecated(None, "use `experimental_local_results` instead.")
1873 def unwrap(self, value):
1874 """Returns the list of all local per-replica values contained in `value`.
1876 DEPRECATED: Please use `experimental_local_results` instead.
1878 Note: This only returns values on the workers initiated by this client.
1879 When using a `tf.distribute.Strategy` like
1880 `tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker
1881 will be its own client, and this function will only return values
1882 computed on that worker.
1884 Args:
1885 value: A value returned by `experimental_run()`,
1886 `extended.call_for_each_replica()`, or a variable created in `scope`.
1888 Returns:
1889 A tuple of values contained in `value`. If `value` represents a single
1890 value, this returns `(value,).`
1891 """
1892 return self._extended._local_results(value) # pylint: disable=protected-access
1894 def experimental_local_results(self, value):
1895 """Returns the list of all local per-replica values contained in `value`.
1897 Note: This only returns values on the worker initiated by this client.
1898 When using a `tf.distribute.Strategy` like
1899 `tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker
1900 will be its own client, and this function will only return values
1901 computed on that worker.
1903 Args:
1904 value: A value returned by `experimental_run()`, `run(), or a variable
1905 created in `scope`.
1907 Returns:
1908 A tuple of values contained in `value` where ith element corresponds to
1909 ith replica. If `value` represents a single value, this returns
1910 `(value,).`
1911 """
1912 return self._extended._local_results(value) # pylint: disable=protected-access
1914 @doc_controls.do_not_doc_inheritable # DEPRECATED: TF v1.x only
1915 def group(self, value, name=None):
1916 """Shortcut for `tf.group(self.experimental_local_results(value))`."""
1917 return self._extended._group(value, name) # pylint: disable=protected-access
1919 @property
1920 def num_replicas_in_sync(self):
1921 """Returns number of replicas over which gradients are aggregated."""
1922 return self._extended._num_replicas_in_sync # pylint: disable=protected-access
1924 @doc_controls.do_not_doc_inheritable # DEPRECATED: see doc string
1925 @deprecated(None, "use `update_config_proto` instead.")
1926 def configure(self,
1927 session_config=None,
1928 cluster_spec=None,
1929 task_type=None,
1930 task_id=None):
1931 # pylint: disable=g-doc-return-or-yield,g-doc-args
1932 """DEPRECATED: use `update_config_proto` instead.
1934 Configures the strategy class.
1936 DEPRECATED: This method's functionality has been split into the strategy
1937 constructor and `update_config_proto`. In the future, we will allow passing
1938 cluster and config_proto to the constructor to configure the strategy. And
1939 `update_config_proto` can be used to update the config_proto based on the
1940 specific strategy.
1941 """
1942 return self._extended._configure( # pylint: disable=protected-access
1943 session_config, cluster_spec, task_type, task_id)
1945 @doc_controls.do_not_generate_docs # DEPRECATED
1946 def update_config_proto(self, config_proto):
1947 """DEPRECATED TF 1.x ONLY."""
1948 return self._extended._update_config_proto(config_proto) # pylint: disable=protected-access
1950 def __deepcopy__(self, memo):
1951 # First do a regular deepcopy of `self`.
1952 cls = self.__class__
1953 result = cls.__new__(cls)
1954 memo[id(self)] = result
1955 for k, v in self.__dict__.items():
1956 setattr(result, k, copy.deepcopy(v, memo))
1957 # One little fix-up: we want `result._extended` to reference `result`
1958 # instead of `self`.
1959 result._extended._container_strategy_weakref = weakref.ref(result) # pylint: disable=protected-access
1960 return result
1962 def __copy__(self):
1963 raise RuntimeError("Must only deepcopy DistributionStrategy.")
1965 @property
1966 def cluster_resolver(self):
1967 """Returns the cluster resolver associated with this strategy.
1969 In general, when using a multi-worker `tf.distribute` strategy such as
1970 `tf.distribute.experimental.MultiWorkerMirroredStrategy` or
1971 `tf.distribute.TPUStrategy()`, there is a
1972 `tf.distribute.cluster_resolver.ClusterResolver` associated with the
1973 strategy used, and such an instance is returned by this property.
1975 Strategies that intend to have an associated
1976 `tf.distribute.cluster_resolver.ClusterResolver` must set the
1977 relevant attribute, or override this property; otherwise, `None` is returned
1978 by default. Those strategies should also provide information regarding what
1979 is returned by this property.
1981 Single-worker strategies usually do not have a
1982 `tf.distribute.cluster_resolver.ClusterResolver`, and in those cases this
1983 property will return `None`.
1985 The `tf.distribute.cluster_resolver.ClusterResolver` may be useful when the
1986 user needs to access information such as the cluster spec, task type or task
1987 id. For example,
1989 ```python
1991 os.environ['TF_CONFIG'] = json.dumps({
1992 'cluster': {
1993 'worker': ["localhost:12345", "localhost:23456"],
1994 'ps': ["localhost:34567"]
1995 },
1996 'task': {'type': 'worker', 'index': 0}
1997 })
1999 # This implicitly uses TF_CONFIG for the cluster and current task info.
2000 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
2002 ...
2004 if strategy.cluster_resolver.task_type == 'worker':
2005 # Perform something that's only applicable on workers. Since we set this
2006 # as a worker above, this block will run on this particular instance.
2007 elif strategy.cluster_resolver.task_type == 'ps':
2008 # Perform something that's only applicable on parameter servers. Since we
2009 # set this as a worker above, this block will not run on this particular
2010 # instance.
2011 ```
2013 For more information, please see
2014 `tf.distribute.cluster_resolver.ClusterResolver`'s API docstring.
2016 Returns:
2017 The cluster resolver associated with this strategy. Returns `None` if a
2018 cluster resolver is not applicable or available in this strategy.
2019 """
2020 if hasattr(self.extended, "_cluster_resolver"):
2021 return self.extended._cluster_resolver # pylint: disable=protected-access
2022 return None
2025@tf_export("distribute.Strategy", v1=[]) # pylint: disable=g-missing-docstring
2026class Strategy(StrategyBase):
2028 __doc__ = StrategyBase.__doc__
2030 def experimental_distribute_values_from_function(self, value_fn):
2031 """Generates `tf.distribute.DistributedValues` from `value_fn`.
2033 This function is to generate `tf.distribute.DistributedValues` to pass
2034 into `run`, `reduce`, or other methods that take
2035 distributed values when not using datasets.
2037 Args:
2038 value_fn: The function to run to generate values. It is called for
2039 each replica with `tf.distribute.ValueContext` as the sole argument. It
2040 must return a Tensor or a type that can be converted to a Tensor.
2041 Returns:
2042 A `tf.distribute.DistributedValues` containing a value for each replica.
2044 Example usage:
2046 1. Return constant value per replica:
2048 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
2049 >>> def value_fn(ctx):
2050 ... return tf.constant(1.)
2051 >>> distributed_values = (
2052 ... strategy.experimental_distribute_values_from_function(
2053 ... value_fn))
2054 >>> local_result = strategy.experimental_local_results(
2055 ... distributed_values)
2056 >>> local_result
2057 (<tf.Tensor: shape=(), dtype=float32, numpy=1.0>,
2058 <tf.Tensor: shape=(), dtype=float32, numpy=1.0>)
2060 2. Distribute values in array based on replica_id: {: value=2}
2062 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
2063 >>> array_value = np.array([3., 2., 1.])
2064 >>> def value_fn(ctx):
2065 ... return array_value[ctx.replica_id_in_sync_group]
2066 >>> distributed_values = (
2067 ... strategy.experimental_distribute_values_from_function(
2068 ... value_fn))
2069 >>> local_result = strategy.experimental_local_results(
2070 ... distributed_values)
2071 >>> local_result
2072 (3.0, 2.0)
2074 3. Specify values using num_replicas_in_sync: {: value=3}
2076 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
2077 >>> def value_fn(ctx):
2078 ... return ctx.num_replicas_in_sync
2079 >>> distributed_values = (
2080 ... strategy.experimental_distribute_values_from_function(
2081 ... value_fn))
2082 >>> local_result = strategy.experimental_local_results(
2083 ... distributed_values)
2084 >>> local_result
2085 (2, 2)
2087 4. Place values on devices and distribute: {: value=4}
2089 ```
2090 strategy = tf.distribute.TPUStrategy()
2091 worker_devices = strategy.extended.worker_devices
2092 multiple_values = []
2093 for i in range(strategy.num_replicas_in_sync):
2094 with tf.device(worker_devices[i]):
2095 multiple_values.append(tf.constant(1.0))
2097 def value_fn(ctx):
2098 return multiple_values[ctx.replica_id_in_sync_group]
2100 distributed_values = strategy.
2101 experimental_distribute_values_from_function(
2102 value_fn)
2103 ```
2105 """
2106 return self._extended._experimental_distribute_values_from_function( # pylint: disable=protected-access
2107 value_fn)
2109 def gather(self, value, axis):
2110 # pylint: disable=line-too-long, protected-access
2111 """Gather `value` across replicas along `axis` to the current device.
2113 Given a `tf.distribute.DistributedValues` or `tf.Tensor`-like
2114 object `value`, this API gathers and concatenates `value` across replicas
2115 along the `axis`-th dimension. The result is copied to the "current" device,
2116 which would typically be the CPU of the worker on which the program is
2117 running. For `tf.distribute.TPUStrategy`, it is the first TPU host. For
2118 multi-client `tf.distribute.MultiWorkerMirroredStrategy`, this is the CPU of
2119 each worker.
2121 This API can only be called in the cross-replica context. For a counterpart
2122 in the replica context, see `tf.distribute.ReplicaContext.all_gather`.
2124 Note: For all strategies except `tf.distribute.TPUStrategy`, the input
2125 `value` on different replicas must have the same rank, and their shapes must
2126 be the same in all dimensions except the `axis`-th dimension. In other
2127 words, their shapes cannot be different in a dimension `d` where `d` does
2128 not equal to the `axis` argument. For example, given a
2129 `tf.distribute.DistributedValues` with component tensors of shape
2130 `(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can call
2131 `gather(..., axis=1, ...)` on it, but not `gather(..., axis=0, ...)` or
2132 `gather(..., axis=2, ...)`. However, for `tf.distribute.TPUStrategy.gather`,
2133 all tensors must have exactly the same rank and same shape.
2135 Note: Given a `tf.distribute.DistributedValues` `value`, its component
2136 tensors must have a non-zero rank. Otherwise, consider using
2137 `tf.expand_dims` before gathering them.
2139 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
2140 >>> # A DistributedValues with component tensor of shape (2, 1) on each replica
2141 ... distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(tf.constant([[1], [2]])))
2142 >>> @tf.function
2143 ... def run():
2144 ... return strategy.gather(distributed_values, axis=0)
2145 >>> run()
2146 <tf.Tensor: shape=(4, 1), dtype=int32, numpy=
2147 array([[1],
2148 [2],
2149 [1],
2150 [2]], dtype=int32)>
2153 Consider the following example for more combinations:
2155 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2", "GPU:3"])
2156 >>> single_tensor = tf.reshape(tf.range(6), shape=(1,2,3))
2157 >>> distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(single_tensor))
2158 >>> @tf.function
2159 ... def run(axis):
2160 ... return strategy.gather(distributed_values, axis=axis)
2161 >>> axis=0
2162 >>> run(axis)
2163 <tf.Tensor: shape=(4, 2, 3), dtype=int32, numpy=
2164 array([[[0, 1, 2],
2165 [3, 4, 5]],
2166 [[0, 1, 2],
2167 [3, 4, 5]],
2168 [[0, 1, 2],
2169 [3, 4, 5]],
2170 [[0, 1, 2],
2171 [3, 4, 5]]], dtype=int32)>
2172 >>> axis=1
2173 >>> run(axis)
2174 <tf.Tensor: shape=(1, 8, 3), dtype=int32, numpy=
2175 array([[[0, 1, 2],
2176 [3, 4, 5],
2177 [0, 1, 2],
2178 [3, 4, 5],
2179 [0, 1, 2],
2180 [3, 4, 5],
2181 [0, 1, 2],
2182 [3, 4, 5]]], dtype=int32)>
2183 >>> axis=2
2184 >>> run(axis)
2185 <tf.Tensor: shape=(1, 2, 12), dtype=int32, numpy=
2186 array([[[0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2],
2187 [3, 4, 5, 3, 4, 5, 3, 4, 5, 3, 4, 5]]], dtype=int32)>
2190 Args:
2191 value: a `tf.distribute.DistributedValues` instance, e.g. returned by
2192 `Strategy.run`, to be combined into a single tensor. It can also be a
2193 regular tensor when used with `tf.distribute.OneDeviceStrategy` or the
2194 default strategy. The tensors that constitute the DistributedValues
2195 can only be dense tensors with non-zero rank, NOT a `tf.IndexedSlices`.
2196 axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the
2197 range [0, rank(value)).
2199 Returns:
2200 A `Tensor` that's the concatenation of `value` across replicas along
2201 `axis` dimension.
2202 """
2203 # pylint: enable=line-too-long
2204 error_message = ("tf.distribute.Strategy.gather method requires "
2205 "cross-replica context, use "
2206 "get_replica_context().all_gather() instead")
2207 _require_cross_replica_or_default_context_extended(self._extended,
2208 error_message)
2209 dst = device_util.current(
2210 ) or self._extended._default_device or "/device:CPU:0"
2211 if isinstance(value, indexed_slices.IndexedSlices):
2212 raise NotImplementedError("gather does not support IndexedSlices")
2213 return self._extended._local_results(
2214 self._extended._gather_to(value, dst, axis))[0]
2217# TF v1.x version has additional deprecated APIs
2218@tf_export(v1=["distribute.Strategy"])
2219class StrategyV1(StrategyBase):
2220 """A list of devices with a state & compute distribution policy.
2222 See [the guide](https://www.tensorflow.org/guide/distribute_strategy)
2223 for overview and examples.
2225 Note: Not all `tf.distribute.Strategy` implementations currently support
2226 TensorFlow's partitioned variables (where a single variable is split across
2227 multiple devices) at this time.
2228 """
2230 def make_dataset_iterator(self, dataset):
2231 """Makes an iterator for input provided via `dataset`.
2233 DEPRECATED: This method is not available in TF 2.x.
2235 Data from the given dataset will be distributed evenly across all the
2236 compute replicas. We will assume that the input dataset is batched by the
2237 global batch size. With this assumption, we will make a best effort to
2238 divide each batch across all the replicas (one or more workers).
2239 If this effort fails, an error will be thrown, and the user should instead
2240 use `make_input_fn_iterator` which provides more control to the user, and
2241 does not try to divide a batch across replicas.
2243 The user could also use `make_input_fn_iterator` if they want to
2244 customize which input is fed to which replica/worker etc.
2246 Args:
2247 dataset: `tf.data.Dataset` that will be distributed evenly across all
2248 replicas.
2250 Returns:
2251 An `tf.distribute.InputIterator` which returns inputs for each step of the
2252 computation. User should call `initialize` on the returned iterator.
2253 """
2254 return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access
2256 def make_input_fn_iterator(self, # pylint: disable=useless-super-delegation
2257 input_fn,
2258 replication_mode=InputReplicationMode.PER_WORKER):
2259 """Returns an iterator split across replicas created from an input function.
2261 DEPRECATED: This method is not available in TF 2.x.
2263 The `input_fn` should take an `tf.distribute.InputContext` object where
2264 information about batching and input sharding can be accessed:
2266 ```
2267 def input_fn(input_context):
2268 batch_size = input_context.get_per_replica_batch_size(global_batch_size)
2269 d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size)
2270 return d.shard(input_context.num_input_pipelines,
2271 input_context.input_pipeline_id)
2272 with strategy.scope():
2273 iterator = strategy.make_input_fn_iterator(input_fn)
2274 replica_results = strategy.experimental_run(replica_fn, iterator)
2275 ```
2277 The `tf.data.Dataset` returned by `input_fn` should have a per-replica
2278 batch size, which may be computed using
2279 `input_context.get_per_replica_batch_size`.
2281 Args:
2282 input_fn: A function taking a `tf.distribute.InputContext` object and
2283 returning a `tf.data.Dataset`.
2284 replication_mode: an enum value of `tf.distribute.InputReplicationMode`.
2285 Only `PER_WORKER` is supported currently, which means there will be
2286 a single call to `input_fn` per worker. Replicas will dequeue from the
2287 local `tf.data.Dataset` on their worker.
2289 Returns:
2290 An iterator object that should first be `.initialize()`-ed. It may then
2291 either be passed to `strategy.experimental_run()` or you can
2292 `iterator.get_next()` to get the next value to pass to
2293 `strategy.extended.call_for_each_replica()`.
2294 """
2295 return super(StrategyV1, self).make_input_fn_iterator(
2296 input_fn, replication_mode)
2298 def experimental_make_numpy_dataset(self, numpy_input, session=None):
2299 """Makes a tf.data.Dataset for input provided via a numpy array.
2301 This avoids adding `numpy_input` as a large constant in the graph,
2302 and copies the data to the machine or machines that will be processing
2303 the input.
2305 Note that you will likely need to use
2306 tf.distribute.Strategy.experimental_distribute_dataset
2307 with the returned dataset to further distribute it with the strategy.
2309 Example:
2310 ```
2311 numpy_input = np.ones([10], dtype=np.float32)
2312 dataset = strategy.experimental_make_numpy_dataset(numpy_input)
2313 dist_dataset = strategy.experimental_distribute_dataset(dataset)
2314 ```
2316 Args:
2317 numpy_input: A nest of NumPy input arrays that will be converted into a
2318 dataset. Note that lists of Numpy arrays are stacked, as that is normal
2319 `tf.data.Dataset` behavior.
2320 session: (TensorFlow v1.x graph execution only) A session used for
2321 initialization.
2323 Returns:
2324 A `tf.data.Dataset` representing `numpy_input`.
2325 """
2326 return self.extended.experimental_make_numpy_dataset(
2327 numpy_input, session=session)
2329 @deprecated(
2330 None,
2331 "This method is not available in TF 2.x. Please switch to using `run` instead."
2332 )
2333 def experimental_run(self, fn, input_iterator=None): # pylint: disable=useless-super-delegation
2334 """Runs ops in `fn` on each replica, with inputs from `input_iterator`.
2336 DEPRECATED: This method is not available in TF 2.x. Please switch
2337 to using `run` instead.
2339 When eager execution is enabled, executes ops specified by `fn` on each
2340 replica. Otherwise, builds a graph to execute the ops on each replica.
2342 Each replica will take a single, different input from the inputs provided by
2343 one `get_next` call on the input iterator.
2345 `fn` may call `tf.distribute.get_replica_context()` to access members such
2346 as `replica_id_in_sync_group`.
2348 IMPORTANT: Depending on the `tf.distribute.Strategy` implementation being
2349 used, and whether eager execution is enabled, `fn` may be called one or more
2350 times (once for each replica).
2352 Args:
2353 fn: The function to run. The inputs to the function must match the outputs
2354 of `input_iterator.get_next()`. The output must be a `tf.nest` of
2355 `Tensor`s.
2356 input_iterator: (Optional) input iterator from which the inputs are taken.
2358 Returns:
2359 Merged return value of `fn` across replicas. The structure of the return
2360 value is the same as the return value from `fn`. Each element in the
2361 structure can either be `PerReplica` (if the values are unsynchronized),
2362 `Mirrored` (if the values are kept in sync), or `Tensor` (if running on a
2363 single replica).
2364 """
2365 return super(StrategyV1, self).experimental_run(
2366 fn, input_iterator)
2368 def reduce(self, reduce_op, value, axis=None):
2369 return super(StrategyV1, self).reduce(reduce_op, value, axis)
2371 reduce.__doc__ = StrategyBase.reduce.__doc__
2373 def update_config_proto(self, config_proto):
2374 """Returns a copy of `config_proto` modified for use with this strategy.
2376 DEPRECATED: This method is not available in TF 2.x.
2378 The updated config has something needed to run a strategy, e.g.
2379 configuration to run collective ops, or device filters to improve
2380 distributed training performance.
2382 Args:
2383 config_proto: a `tf.ConfigProto` object.
2385 Returns:
2386 The updated copy of the `config_proto`.
2387 """
2388 return self._extended._update_config_proto(config_proto) # pylint: disable=protected-access
2391# NOTE(josh11b): For any strategy that needs to support tf.compat.v1,
2392# instead descend from StrategyExtendedV1.
2393@tf_export("distribute.StrategyExtended", v1=[])
2394class StrategyExtendedV2(object):
2395 """Additional APIs for algorithms that need to be distribution-aware.
2397 Note: For most usage of `tf.distribute.Strategy`, there should be no need to
2398 call these methods, since TensorFlow libraries (such as optimizers) already
2399 call these methods when needed on your behalf.
2402 Some common use cases of functions on this page:
2404 * _Locality_
2406 `tf.distribute.DistributedValues` can have the same _locality_ as a
2407 _distributed variable_, which leads to a mirrored value residing on the same
2408 devices as the variable (as opposed to the compute devices). Such values may
2409 be passed to a call to `tf.distribute.StrategyExtended.update` to update the
2410 value of a variable. You may use
2411 `tf.distribute.StrategyExtended.colocate_vars_with` to give a variable the
2412 same locality as another variable. You may convert a "PerReplica" value to a
2413 variable's locality by using `tf.distribute.StrategyExtended.reduce_to` or
2414 `tf.distribute.StrategyExtended.batch_reduce_to`.
2416 * _How to update a distributed variable_
2418 A distributed variable is variables created on multiple devices. As discussed
2419 in the [glossary](https://www.tensorflow.org/api_docs/python/tf/distribute),
2420 mirrored variable and SyncOnRead variable are two examples. The standard
2421 pattern for updating distributed variables is to:
2423 1. In your function passed to `tf.distribute.Strategy.run`,
2424 compute a list of (update, variable) pairs. For example, the update might
2425 be a gradient of the loss with respect to the variable.
2426 2. Switch to cross-replica mode by calling
2427 `tf.distribute.get_replica_context().merge_call()` with the updates and
2428 variables as arguments.
2429 3. Call
2430 `tf.distribute.StrategyExtended.reduce_to(VariableAggregation.SUM, t, v)`
2431 (for one variable) or `tf.distribute.StrategyExtended.batch_reduce_to`
2432 (for a list of variables) to sum the updates.
2433 4. Call `tf.distribute.StrategyExtended.update(v)` for each variable to update
2434 its value.
2436 Steps 2 through 4 are done automatically by class
2437 `tf.keras.optimizers.Optimizer` if you call its
2438 `tf.keras.optimizers.Optimizer.apply_gradients` method in a replica context.
2440 In fact, a higher-level solution to update a distributed variable is by
2441 calling `assign` on the variable as you would do to a regular `tf.Variable`.
2442 You can call the method in both _replica context_ and _cross-replica context_.
2443 For a _mirrored variable_, calling `assign` in _replica context_ requires you
2444 to specify the `aggregation` type in the variable constructor. In that case,
2445 the context switching and sync described in steps 2 through 4 are handled for
2446 you. If you call `assign` on _mirrored variable_ in _cross-replica context_,
2447 you can only assign a single value or assign values from another mirrored
2448 variable or a mirrored `tf.distribute.DistributedValues`. For a _SyncOnRead
2449 variable_, in _replica context_, you can simply call `assign` on it and no
2450 aggregation happens under the hood. In _cross-replica context_, you can only
2451 assign a single value to a SyncOnRead variable. One example case is restoring
2452 from a checkpoint: if the `aggregation` type of the variable is
2453 `tf.VariableAggregation.SUM`, it is assumed that replica values were added
2454 before checkpointing, so at the time of restoring, the value is divided by
2455 the number of replicas and then assigned to each replica; if the `aggregation`
2456 type is `tf.VariableAggregation.MEAN`, the value is assigned to each replica
2457 directly.
2459 """
2461 def __init__(self, container_strategy):
2462 self._container_strategy_weakref = weakref.ref(container_strategy)
2463 self._default_device = None
2464 # This property is used to determine if we should set drop_remainder=True
2465 # when creating Datasets from numpy array inputs.
2466 self._require_static_shapes = False
2468 def _resource_creator_scope(self):
2469 """Returns one or a list of ops.resource_creator_scope for some Strategy."""
2470 return None
2472 def _container_strategy(self):
2473 """Get the containing `tf.distribute.Strategy`.
2475 This should not generally be needed except when creating a new
2476 `ReplicaContext` and to validate that the caller is in the correct
2477 `scope()`.
2479 Returns:
2480 The `tf.distribute.Strategy` such that `strategy.extended` is `self`.
2481 """
2482 container_strategy = self._container_strategy_weakref()
2483 assert container_strategy is not None
2484 return container_strategy
2486 def _scope(self, strategy):
2487 """Implementation of tf.distribute.Strategy.scope()."""
2489 def creator_with_resource_vars(next_creator, **kwargs):
2490 """Variable creator to use in `_CurrentDistributionContext`."""
2491 if ops.inside_function():
2492 if_graph_building = "graph_building"
2493 else:
2494 if_graph_building = "not_graph_building"
2496 with monitoring.MonitoredTimer(distributed_variable_creation_time_counter.get_cell(strategy.__class__.__name__, if_graph_building)):
2497 _require_strategy_scope_extended(self)
2498 kwargs["use_resource"] = True
2499 kwargs["distribute_strategy"] = strategy
2501 # Unwrap `initial_value` if it is a `CheckpointInitialValue` to avoid
2502 # dereferencing a `Tensor` that is without a `name`. We still need to
2503 # propagate the metadata it's holding.
2504 if isinstance(kwargs["initial_value"], trackable.CheckpointInitialValue):
2505 checkpoint_restore_uid = kwargs[
2506 "initial_value"].checkpoint_position.restore_uid
2507 kwargs["initial_value"] = kwargs["initial_value"].wrapped_value
2508 elif isinstance(kwargs["initial_value"],
2509 trackable.CheckpointInitialValueCallable):
2510 checkpoint_restore_uid = kwargs[
2511 "initial_value"].checkpoint_position.restore_uid
2512 elif (isinstance(kwargs["initial_value"], functools.partial) and
2513 isinstance(kwargs["initial_value"].func,
2514 trackable.CheckpointInitialValueCallable)):
2515 # Some libraries (e.g, Keras) create partial function out of initializer
2516 # to bind shape/dtype, for example:
2517 # initial_val = functools.partial(initializer, shape, dtype=dtype)
2518 # Therefore to get the restore_uid we need to examine the "func" of
2519 # the partial function.
2520 checkpoint_restore_uid = kwargs[
2521 "initial_value"].func.checkpoint_position.restore_uid
2522 else:
2523 checkpoint_restore_uid = None
2525 created = self._create_variable(next_creator, **kwargs)
2527 if checkpoint_restore_uid is not None:
2528 # pylint: disable=protected-access
2529 # Let the checkpointing infrastructure know that the variable was
2530 # already restored so it doesn't waste memory loading the value again.
2531 # In this case of CheckpointInitialValueCallable this may already be
2532 # done by the final variable creator, but it doesn't hurt to do it
2533 # again.
2534 created._maybe_initialize_trackable()
2535 created._update_uid = checkpoint_restore_uid
2536 # pylint: enable=protected-access
2537 return created
2539 def distributed_getter(getter, *args, **kwargs):
2540 if not self._allow_variable_partition():
2541 if kwargs.pop("partitioner", None) is not None:
2542 tf_logging.log_first_n(
2543 tf_logging.WARN, "Partitioned variables are disabled when using "
2544 "current tf.distribute.Strategy.", 1)
2545 return getter(*args, **kwargs)
2547 return _CurrentDistributionContext(
2548 strategy,
2549 variable_scope.variable_creator_scope(creator_with_resource_vars),
2550 variable_scope.variable_scope(
2551 variable_scope.get_variable_scope(),
2552 custom_getter=distributed_getter),
2553 strategy.extended._resource_creator_scope(), # pylint: disable=protected-access
2554 self._default_device)
2556 def _allow_variable_partition(self):
2557 return False
2559 def _create_variable(self, next_creator, **kwargs):
2560 # Note: should support "colocate_with" argument.
2561 raise NotImplementedError("must be implemented in descendants")
2563 def variable_created_in_scope(self, v):
2564 """Tests whether `v` was created while this strategy scope was active.
2566 Variables created inside the strategy scope are "owned" by it:
2568 >>> strategy = tf.distribute.MirroredStrategy()
2569 >>> with strategy.scope():
2570 ... v = tf.Variable(1.)
2571 >>> strategy.extended.variable_created_in_scope(v)
2572 True
2574 Variables created outside the strategy are not owned by it:
2576 >>> strategy = tf.distribute.MirroredStrategy()
2577 >>> v = tf.Variable(1.)
2578 >>> strategy.extended.variable_created_in_scope(v)
2579 False
2581 Args:
2582 v: A `tf.Variable` instance.
2584 Returns:
2585 True if `v` was created inside the scope, False if not.
2586 """
2587 return v._distribute_strategy == self._container_strategy_weakref() # pylint: disable=protected-access
2589 def colocate_vars_with(self, colocate_with_variable):
2590 """Scope that controls which devices variables will be created on.
2592 No operations should be added to the graph inside this scope, it
2593 should only be used when creating variables (some implementations
2594 work by changing variable creation, others work by using a
2595 tf.compat.v1.colocate_with() scope).
2597 This may only be used inside `self.scope()`.
2599 Example usage:
2601 ```
2602 with strategy.scope():
2603 var1 = tf.Variable(...)
2604 with strategy.extended.colocate_vars_with(var1):
2605 # var2 and var3 will be created on the same device(s) as var1
2606 var2 = tf.Variable(...)
2607 var3 = tf.Variable(...)
2609 def fn(v1, v2, v3):
2610 # operates on v1 from var1, v2 from var2, and v3 from var3
2612 # `fn` runs on every device `var1` is on, `var2` and `var3` will be there
2613 # too.
2614 strategy.extended.update(var1, fn, args=(var2, var3))
2615 ```
2617 Args:
2618 colocate_with_variable: A variable created in this strategy's `scope()`.
2619 Variables created while in the returned context manager will be on the
2620 same set of devices as `colocate_with_variable`.
2622 Returns:
2623 A context manager.
2624 """
2626 def create_colocated_variable(next_creator, **kwargs):
2627 _require_strategy_scope_extended(self)
2628 kwargs["use_resource"] = True
2629 kwargs["colocate_with"] = colocate_with_variable
2630 return next_creator(**kwargs)
2632 _require_strategy_scope_extended(self)
2633 self._validate_colocate_with_variable(colocate_with_variable)
2634 return variable_scope.variable_creator_scope(create_colocated_variable)
2636 def _validate_colocate_with_variable(self, colocate_with_variable):
2637 """Validate `colocate_with_variable` argument to `colocate_vars_with`."""
2638 pass
2640 def _make_dataset_iterator(self, dataset):
2641 raise NotImplementedError("must be implemented in descendants")
2643 def _make_input_fn_iterator(self, input_fn, replication_mode):
2644 raise NotImplementedError("must be implemented in descendants")
2646 def _experimental_distribute_dataset(self, dataset, options):
2647 raise NotImplementedError("must be implemented in descendants")
2649 def _distribute_datasets_from_function(self, dataset_fn, options):
2650 raise NotImplementedError("must be implemented in descendants")
2652 def _experimental_distribute_values_from_function(self, value_fn):
2653 raise NotImplementedError("must be implemented in descendants")
2655 def _reduce(self, reduce_op, value):
2656 # Default implementation until we have an implementation for each strategy.
2657 dst = device_util.current() or self._default_device or "/device:CPU:0"
2658 return self._local_results(self.reduce_to(reduce_op, value, dst))[0]
2660 def reduce_to(self, reduce_op, value, destinations, options=None):
2661 """Combine (via e.g. sum or mean) values across replicas.
2663 `reduce_to` aggregates `tf.distribute.DistributedValues` and distributed
2664 variables. It supports both dense values and `tf.IndexedSlices`.
2666 This API currently can only be called in cross-replica context. Other
2667 variants to reduce values across replicas are:
2668 * `tf.distribute.StrategyExtended.batch_reduce_to`: the batch version of
2669 this API.
2670 * `tf.distribute.ReplicaContext.all_reduce`: the counterpart of this API
2671 in replica context. It supports both batched and non-batched all-reduce.
2672 * `tf.distribute.Strategy.reduce`: a more convenient method to reduce
2673 to the host in cross-replica context.
2675 `destinations` specifies where to reduce the value to, e.g. "GPU:0". You can
2676 also pass in a `Tensor`, and the destinations will be the device of that
2677 tensor. For all-reduce, pass the same to `value` and `destinations`.
2679 It can be used in `tf.distribute.ReplicaContext.merge_call` to write code
2680 that works for all `tf.distribute.Strategy`.
2682 @tf.function
2683 def step_fn(var):
2685 def merge_fn(strategy, value, var):
2686 # All-reduce the value. Note that `value` here is a
2687 # `tf.distribute.DistributedValues`.
2688 reduced = strategy.extended.reduce_to(tf.distribute.ReduceOp.SUM,
2689 value, destinations=var)
2690 strategy.extended.update(var, lambda var, value: var.assign(value),
2691 args=(reduced,))
2693 value = tf.identity(1.)
2694 tf.distribute.get_replica_context().merge_call(merge_fn,
2695 args=(value, var))
2697 def run(strategy):
2698 with strategy.scope():
2699 v = tf.Variable(0.)
2700 strategy.run(step_fn, args=(v,))
2701 return v
2703 run(tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]))
2704 MirroredVariable:{
2705 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>,
2706 1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=2.0>
2707 }
2708 run(tf.distribute.experimental.CentralStorageStrategy(
2709 compute_devices=["GPU:0", "GPU:1"], parameter_device="CPU:0"))
2710 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>
2711 run(tf.distribute.OneDeviceStrategy("GPU:0"))
2712 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>
2714 Args:
2715 reduce_op: a `tf.distribute.ReduceOp` value specifying how values should
2716 be combined. Allows using string representation of the enum such as
2717 "SUM", "MEAN".
2718 value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` like object.
2719 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
2720 `tf.Tensor` alike object, or a device string. It specifies the devices
2721 to reduce to. To perform an all-reduce, pass the same to `value` and
2722 `destinations`. Note that if it's a `tf.Variable`, the value is reduced
2723 to the devices of that variable, and this method doesn't update the
2724 variable.
2725 options: a `tf.distribute.experimental.CommunicationOptions`. Options to
2726 perform collective operations. This overrides the default options if the
2727 `tf.distribute.Strategy` takes one in the constructor. See
2728 `tf.distribute.experimental.CommunicationOptions` for details of the
2729 options.
2731 Returns:
2732 A tensor or value reduced to `destinations`.
2733 """
2734 if options is None:
2735 options = collective_util.Options()
2736 _require_cross_replica_or_default_context_extended(self)
2737 assert not isinstance(destinations, (list, tuple))
2738 assert not isinstance(reduce_op, variable_scope.VariableAggregation)
2739 if isinstance(reduce_op, six.string_types):
2740 reduce_op = reduce_util.ReduceOp(reduce_op.upper())
2741 assert (reduce_op == reduce_util.ReduceOp.SUM or
2742 reduce_op == reduce_util.ReduceOp.MEAN)
2743 return self._reduce_to(reduce_op, value, destinations, options)
2745 def _reduce_to(self, reduce_op, value, destinations, options):
2746 raise NotImplementedError("must be implemented in descendants")
2748 def batch_reduce_to(self, reduce_op, value_destination_pairs, options=None):
2749 """Combine multiple `reduce_to` calls into one for faster execution.
2751 Similar to `reduce_to`, but accepts a list of (value, destinations) pairs.
2752 It's more efficient than reduce each value separately.
2754 This API currently can only be called in cross-replica context. Other
2755 variants to reduce values across replicas are:
2756 * `tf.distribute.StrategyExtended.reduce_to`: the non-batch version of
2757 this API.
2758 * `tf.distribute.ReplicaContext.all_reduce`: the counterpart of this API
2759 in replica context. It supports both batched and non-batched all-reduce.
2760 * `tf.distribute.Strategy.reduce`: a more convenient method to reduce
2761 to the host in cross-replica context.
2763 See `reduce_to` for more information.
2765 @tf.function
2766 def step_fn(var):
2768 def merge_fn(strategy, value, var):
2769 # All-reduce the value. Note that `value` here is a
2770 # `tf.distribute.DistributedValues`.
2771 reduced = strategy.extended.batch_reduce_to(
2772 tf.distribute.ReduceOp.SUM, [(value, var)])[0]
2773 strategy.extended.update(var, lambda var, value: var.assign(value),
2774 args=(reduced,))
2776 value = tf.identity(1.)
2777 tf.distribute.get_replica_context().merge_call(merge_fn,
2778 args=(value, var))
2780 def run(strategy):
2781 with strategy.scope():
2782 v = tf.Variable(0.)
2783 strategy.run(step_fn, args=(v,))
2784 return v
2786 run(tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]))
2787 MirroredVariable:{
2788 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>,
2789 1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=2.0>
2790 }
2791 run(tf.distribute.experimental.CentralStorageStrategy(
2792 compute_devices=["GPU:0", "GPU:1"], parameter_device="CPU:0"))
2793 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>
2794 run(tf.distribute.OneDeviceStrategy("GPU:0"))
2795 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>
2797 Args:
2798 reduce_op: a `tf.distribute.ReduceOp` value specifying how values should
2799 be combined. Allows using string representation of the enum such as
2800 "SUM", "MEAN".
2801 value_destination_pairs: a sequence of (value, destinations) pairs. See
2802 `tf.distribute.Strategy.reduce_to` for descriptions.
2803 options: a `tf.distribute.experimental.CommunicationOptions`. Options to
2804 perform collective operations. This overrides the default options if the
2805 `tf.distribute.Strategy` takes one in the constructor. See
2806 `tf.distribute.experimental.CommunicationOptions` for details of the
2807 options.
2809 Returns:
2810 A list of reduced values, one per pair in `value_destination_pairs`.
2811 """
2812 if options is None:
2813 options = collective_util.Options()
2814 _require_cross_replica_or_default_context_extended(self)
2815 assert not isinstance(reduce_op, variable_scope.VariableAggregation)
2816 if isinstance(reduce_op, six.string_types):
2817 reduce_op = reduce_util.ReduceOp(reduce_op.upper())
2818 return self._batch_reduce_to(reduce_op, value_destination_pairs, options)
2820 def _batch_reduce_to(self, reduce_op, value_destination_pairs, options):
2821 return [
2822 self.reduce_to(reduce_op, t, destinations=v, options=options)
2823 for t, v in value_destination_pairs
2824 ]
2826 def _replica_ctx_all_reduce(self, reduce_op, value, options=None):
2827 """All-reduce `value` across all replicas so that all get the final result.
2829 If `value` is a nested structure of tensors, all-reduces of these tensors
2830 will be batched when possible. `options` can be set to hint the batching
2831 behavior.
2833 This API must be called in a replica context.
2835 Args:
2836 reduce_op: A `tf.distribute.ReduceOp` value specifying how values should
2837 be combined.
2838 value: Value to be reduced. A tensor or a nested structure of tensors.
2839 options: A `tf.distribute.experimental.CommunicationOptions`. Options to
2840 perform collective operations. This overrides the default options if the
2841 `tf.distribute.Strategy` takes one in the constructor.
2843 Returns:
2844 A tensor or a nested strucutre of tensors with the reduced values. The
2845 structure is the same as `value`.
2846 """
2847 if options is None:
2848 options = collective_util.Options()
2849 replica_context = get_replica_context()
2850 assert replica_context, (
2851 "`StrategyExtended._replica_ctx_all_reduce` must be called in"
2852 " a replica context")
2854 def merge_fn(_, flat_value):
2855 return self.batch_reduce_to(reduce_op, [(v, v) for v in flat_value],
2856 options)
2858 reduced = replica_context.merge_call(merge_fn, args=(nest.flatten(value),))
2859 return nest.pack_sequence_as(value, reduced)
2861 def _replica_ctx_update(self, var, fn, args=(), kwargs=None, group=True):
2862 """Run `fn` with `args` and `kwargs` to update `var`."""
2863 # This method is called by ReplicaContext.update. Strategies who'd like to
2864 # remove merge_call in this path should override this method.
2865 replica_context = get_replica_context()
2866 if not replica_context:
2867 raise ValueError("`StrategyExtended._replica_ctx_update` must be called "
2868 "in a replica context.")
2870 def merge_fn(_, *merged_args, **merged_kwargs):
2871 return self.update(var, fn, merged_args, merged_kwargs, group=group)
2873 return replica_context.merge_call(merge_fn, args=args, kwargs=kwargs)
2875 def _gather_to(self, value, destinations, axis, options=None):
2876 """Gather `value` across replicas along axis-th dimension to `destinations`.
2878 `gather_to` gathers `tf.distribute.DistributedValues` or `tf.Tensor`-like
2879 object, along `axis`-th dimension. It supports only dense tensors but NOT
2880 sparse tensor. This API can only be called in cross-replica context.
2882 Args:
2883 value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` like object.
2884 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
2885 `tf.Tensor` alike object, or a device string. It specifies the devices
2886 to reduce to. To perform an all-gather, pass the same to `value` and
2887 `destinations`. Note that if it's a `tf.Variable`, the value is reduced
2888 to the devices of that variable, and this method doesn't update the
2889 variable.
2890 axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the
2891 range [0, rank(value)).
2892 options: a `tf.distribute.experimental.CommunicationOptions`. Options to
2893 perform collective operations. This overrides the default options if the
2894 `tf.distribute.Strategy` takes one in the constructor. See
2895 `tf.distribute.experimental.CommunicationOptions` for details of the
2896 options.
2898 Returns:
2899 A tensor or value gathered to `destinations`.
2900 """
2901 _require_cross_replica_or_default_context_extended(self)
2902 assert not isinstance(destinations, (list, tuple))
2903 if options is None:
2904 options = collective_util.Options()
2905 return self._gather_to_implementation(value, destinations, axis, options)
2907 def _gather_to_implementation(self, value, destinations, axis, options):
2908 raise NotImplementedError("_gather_to must be implemented in descendants")
2910 def _batch_gather_to(self, value_destination_pairs, axis, options=None):
2911 _require_cross_replica_or_default_context_extended(self)
2912 if options is None:
2913 options = collective_util.Options()
2914 return [
2915 self._gather_to(t, destinations=v, axis=axis, options=options)
2916 for t, v in value_destination_pairs
2917 ]
2919 def update(self, var, fn, args=(), kwargs=None, group=True):
2920 """Run `fn` to update `var` using inputs mirrored to the same devices.
2922 `tf.distribute.StrategyExtended.update` takes a distributed variable `var`
2923 to be updated, an update function `fn`, and `args` and `kwargs` for `fn`. It
2924 applies `fn` to each component variable of `var` and passes corresponding
2925 values from `args` and `kwargs`. Neither `args` nor `kwargs` may contain
2926 per-replica values. If they contain mirrored values, they will be unwrapped
2927 before calling `fn`. For example, `fn` can be `assign_add` and `args` can be
2928 a mirrored DistributedValues where each component contains the value to be
2929 added to this mirrored variable `var`. Calling `update` will call
2930 `assign_add` on each component variable of `var` with the corresponding
2931 tensor value on that device.
2933 Example usage:
2935 ```python
2936 strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # With 2
2937 devices
2938 with strategy.scope():
2939 v = tf.Variable(5.0, aggregation=tf.VariableAggregation.SUM)
2940 def update_fn(v):
2941 return v.assign(1.0)
2942 result = strategy.extended.update(v, update_fn)
2943 # result is
2944 # Mirrored:{
2945 # 0: tf.Tensor(1.0, shape=(), dtype=float32),
2946 # 1: tf.Tensor(1.0, shape=(), dtype=float32)
2947 # }
2948 ```
2950 If `var` is mirrored across multiple devices, then this method implements
2951 logic as following:
2953 ```python
2954 results = {}
2955 for device, v in var:
2956 with tf.device(device):
2957 # args and kwargs will be unwrapped if they are mirrored.
2958 results[device] = fn(v, *args, **kwargs)
2959 return merged(results)
2960 ```
2962 Otherwise, this method returns `fn(var, *args, **kwargs)` colocated with
2963 `var`.
2965 Args:
2966 var: Variable, possibly mirrored to multiple devices, to operate on.
2967 fn: Function to call. Should take the variable as the first argument.
2968 args: Tuple or list. Additional positional arguments to pass to `fn()`.
2969 kwargs: Dict with keyword arguments to pass to `fn()`.
2970 group: Boolean. Defaults to True. If False, the return value will be
2971 unwrapped.
2973 Returns:
2974 By default, the merged return value of `fn` across all replicas. The
2975 merged result has dependencies to make sure that if it is evaluated at
2976 all, the side effects (updates) will happen on every replica. If instead
2977 "group=False" is specified, this function will return a nest of lists
2978 where each list has an element per replica, and the caller is responsible
2979 for ensuring all elements are executed.
2980 """
2981 # TODO(b/178944108): Update the documentation to relfect the fact that
2982 # `update` can be called in a replica context.
2983 if kwargs is None:
2984 kwargs = {}
2985 replica_context = get_replica_context()
2986 # pylint: disable=protected-access
2987 if (replica_context is None or replica_context is
2988 _get_default_replica_context()):
2989 fn = autograph.tf_convert(
2990 fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
2991 with self._container_strategy().scope():
2992 return self._update(var, fn, args, kwargs, group)
2993 else:
2994 return self._replica_ctx_update(
2995 var, fn, args=args, kwargs=kwargs, group=group)
2997 def _update(self, var, fn, args, kwargs, group):
2998 raise NotImplementedError("must be implemented in descendants")
3000 def _local_results(self, val):
3001 """Returns local results per replica as a tuple."""
3002 if isinstance(val, ds_types.DistributedValues):
3003 return val._values # pylint: disable=protected-access
3005 if nest.is_nested(val):
3006 replica_values = []
3008 def get_values(x, index):
3009 if isinstance(x, ds_types.DistributedValues):
3010 return x._values[index] # pylint: disable=protected-access
3011 return x
3013 for i in range(len(self.worker_devices)):
3014 replica_values.append(
3015 nest.map_structure(
3016 lambda x: get_values(x, i), # pylint: disable=cell-var-from-loop
3017 val))
3018 return tuple(replica_values)
3019 return (val,)
3021 def value_container(self, value):
3022 """Returns the container that this per-replica `value` belongs to.
3024 Args:
3025 value: A value returned by `run()` or a variable created in `scope()`.
3027 Returns:
3028 A container that `value` belongs to.
3029 If value does not belong to any container (including the case of
3030 container having been destroyed), returns the value itself.
3031 `value in experimental_local_results(value_container(value))` will
3032 always be true.
3033 """
3034 raise NotImplementedError("must be implemented in descendants")
3036 def _group(self, value, name=None):
3037 """Implementation of `group`."""
3038 value = nest.flatten(self._local_results(value))
3040 if len(value) != 1 or name is not None:
3041 return control_flow_ops.group(value, name=name)
3042 # Special handling for the common case of one op.
3043 v, = value
3044 if hasattr(v, "op"):
3045 v = v.op
3046 return v
3048 @property
3049 def experimental_require_static_shapes(self):
3050 """Returns `True` if static shape is required; `False` otherwise."""
3051 return self._require_static_shapes
3053 @property
3054 def _num_replicas_in_sync(self):
3055 """Returns number of replicas over which gradients are aggregated."""
3056 raise NotImplementedError("must be implemented in descendants")
3058 @property
3059 def worker_devices(self):
3060 """Returns the tuple of all devices used to for compute replica execution.
3061 """
3062 # TODO(josh11b): More docstring
3063 raise NotImplementedError("must be implemented in descendants")
3065 @property
3066 def parameter_devices(self):
3067 """Returns the tuple of all devices used to place variables."""
3068 # TODO(josh11b): More docstring
3069 raise NotImplementedError("must be implemented in descendants")
3071 def _configure(self,
3072 session_config=None,
3073 cluster_spec=None,
3074 task_type=None,
3075 task_id=None):
3076 """Configures the strategy class."""
3077 del session_config, cluster_spec, task_type, task_id
3079 def _update_config_proto(self, config_proto):
3080 return copy.deepcopy(config_proto)
3082 def _in_multi_worker_mode(self):
3083 """Whether this strategy indicates working in multi-worker settings.
3085 Multi-worker training refers to the setup where the training is
3086 distributed across multiple workers, as opposed to the case where
3087 only a local process performs the training. This function is
3088 used by higher-level APIs such as Keras' `model.fit()` to infer
3089 for example whether or not a distribute coordinator should be run,
3090 and thus TensorFlow servers should be started for communication
3091 with other servers in the cluster, or whether or not saving/restoring
3092 checkpoints is relevant for preemption fault tolerance.
3094 Subclasses should override this to provide whether the strategy is
3095 currently in multi-worker setup.
3097 Experimental. Signature and implementation are subject to change.
3098 """
3099 raise NotImplementedError("must be implemented in descendants")
3102@tf_export(v1=["distribute.StrategyExtended"]) # pylint: disable=missing-docstring
3103class StrategyExtendedV1(StrategyExtendedV2):
3105 __doc__ = StrategyExtendedV2.__doc__
3107 def experimental_make_numpy_dataset(self, numpy_input, session=None):
3108 """Makes a dataset for input provided via a numpy array.
3110 This avoids adding `numpy_input` as a large constant in the graph,
3111 and copies the data to the machine or machines that will be processing
3112 the input.
3114 Args:
3115 numpy_input: A nest of NumPy input arrays that will be distributed evenly
3116 across all replicas. Note that lists of Numpy arrays are stacked, as
3117 that is normal `tf.data.Dataset` behavior.
3118 session: (TensorFlow v1.x graph execution only) A session used for
3119 initialization.
3121 Returns:
3122 A `tf.data.Dataset` representing `numpy_input`.
3123 """
3124 _require_cross_replica_or_default_context_extended(self)
3125 return self._experimental_make_numpy_dataset(numpy_input, session=session)
3127 def _experimental_make_numpy_dataset(self, numpy_input, session):
3128 raise NotImplementedError("must be implemented in descendants")
3130 def broadcast_to(self, tensor, destinations):
3131 """Mirror a tensor on one device to all worker devices.
3133 Args:
3134 tensor: A Tensor value to broadcast.
3135 destinations: A mirrored variable or device string specifying the
3136 destination devices to copy `tensor` to.
3138 Returns:
3139 A value mirrored to `destinations` devices.
3140 """
3141 assert destinations is not None # from old strategy.broadcast()
3142 # TODO(josh11b): More docstring
3143 _require_cross_replica_or_default_context_extended(self)
3144 assert not isinstance(destinations, (list, tuple))
3145 return self._broadcast_to(tensor, destinations)
3147 def _broadcast_to(self, tensor, destinations):
3148 raise NotImplementedError("must be implemented in descendants")
3150 @deprecated(None, "please use `run` instead.")
3151 def experimental_run_steps_on_iterator(self,
3152 fn,
3153 iterator,
3154 iterations=1,
3155 initial_loop_values=None):
3156 """DEPRECATED: please use `run` instead.
3158 Run `fn` with input from `iterator` for `iterations` times.
3160 This method can be used to run a step function for training a number of
3161 times using input from a dataset.
3163 Args:
3164 fn: function to run using this distribution strategy. The function must
3165 have the following signature: `def fn(context, inputs)`. `context` is an
3166 instance of `MultiStepContext` that will be passed when `fn` is run.
3167 `context` can be used to specify the outputs to be returned from `fn`
3168 by calling `context.set_last_step_output`. It can also be used to
3169 capture non tensor outputs by `context.set_non_tensor_output`. See
3170 `MultiStepContext` documentation for more information. `inputs` will
3171 have same type/structure as `iterator.get_next()`. Typically, `fn`
3172 will use `call_for_each_replica` method of the strategy to distribute
3173 the computation over multiple replicas.
3174 iterator: Iterator of a dataset that represents the input for `fn`. The
3175 caller is responsible for initializing the iterator as needed.
3176 iterations: (Optional) Number of iterations that `fn` should be run.
3177 Defaults to 1.
3178 initial_loop_values: (Optional) Initial values to be passed into the
3179 loop that runs `fn`. Defaults to `None`. # TODO(priyag): Remove
3180 initial_loop_values argument when we have a mechanism to infer the
3181 outputs of `fn`.
3183 Returns:
3184 Returns the `MultiStepContext` object which has the following properties,
3185 among other things:
3186 - run_op: An op that runs `fn` `iterations` times.
3187 - last_step_outputs: A dictionary containing tensors set using
3188 `context.set_last_step_output`. Evaluating this returns the value of
3189 the tensors after the last iteration.
3190 - non_tensor_outputs: A dictionary containing anything that was set by
3191 `fn` by calling `context.set_non_tensor_output`.
3192 """
3193 _require_cross_replica_or_default_context_extended(self)
3194 with self._container_strategy().scope():
3195 return self._experimental_run_steps_on_iterator(fn, iterator, iterations,
3196 initial_loop_values)
3198 def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
3199 initial_loop_values):
3200 raise NotImplementedError("must be implemented in descendants")
3202 def call_for_each_replica(self, fn, args=(), kwargs=None):
3203 """Run `fn` once per replica.
3205 `fn` may call `tf.get_replica_context()` to access methods such as
3206 `replica_id_in_sync_group` and `merge_call()`.
3208 `merge_call()` is used to communicate between the replicas and
3209 re-enter the cross-replica context. All replicas pause their execution
3210 having encountered a `merge_call()` call. After that the
3211 `merge_fn`-function is executed. Its results are then unwrapped and
3212 given back to each replica call. After that execution resumes until
3213 `fn` is complete or encounters another `merge_call()`. Example:
3215 ```python
3216 # Called once in "cross-replica" context.
3217 def merge_fn(distribution, three_plus_replica_id):
3218 # sum the values across replicas
3219 return sum(distribution.experimental_local_results(three_plus_replica_id))
3221 # Called once per replica in `distribution`, in a "replica" context.
3222 def fn(three):
3223 replica_ctx = tf.get_replica_context()
3224 v = three + replica_ctx.replica_id_in_sync_group
3225 # Computes the sum of the `v` values across all replicas.
3226 s = replica_ctx.merge_call(merge_fn, args=(v,))
3227 return s + v
3229 with distribution.scope():
3230 # in "cross-replica" context
3231 ...
3232 merged_results = distribution.run(fn, args=[3])
3233 # merged_results has the values from every replica execution of `fn`.
3234 # This statement prints a list:
3235 print(distribution.experimental_local_results(merged_results))
3236 ```
3238 Args:
3239 fn: function to run (will be run once per replica).
3240 args: Tuple or list with positional arguments for `fn`.
3241 kwargs: Dict with keyword arguments for `fn`.
3243 Returns:
3244 Merged return value of `fn` across all replicas.
3245 """
3246 _require_cross_replica_or_default_context_extended(self)
3247 if kwargs is None:
3248 kwargs = {}
3249 with self._container_strategy().scope():
3250 return self._call_for_each_replica(fn, args, kwargs)
3252 def _call_for_each_replica(self, fn, args, kwargs):
3253 raise NotImplementedError("must be implemented in descendants")
3255 def read_var(self, v):
3256 """Reads the value of a variable.
3258 Returns the aggregate value of a replica-local variable, or the
3259 (read-only) value of any other variable.
3261 Args:
3262 v: A variable allocated within the scope of this `tf.distribute.Strategy`.
3264 Returns:
3265 A tensor representing the value of `v`, aggregated across replicas if
3266 necessary.
3267 """
3268 raise NotImplementedError("must be implemented in descendants")
3270 def update_non_slot(
3271 self, colocate_with, fn, args=(), kwargs=None, group=True):
3272 """Runs `fn(*args, **kwargs)` on `colocate_with` devices.
3274 Used to update non-slot variables.
3276 DEPRECATED: TF 1.x ONLY.
3278 Args:
3279 colocate_with: Devices returned by `non_slot_devices()`.
3280 fn: Function to execute.
3281 args: Tuple or list. Positional arguments to pass to `fn()`.
3282 kwargs: Dict with keyword arguments to pass to `fn()`.
3283 group: Boolean. Defaults to True. If False, the return value will be
3284 unwrapped.
3286 Returns:
3287 Return value of `fn`, possibly merged across devices.
3288 """
3289 _require_cross_replica_or_default_context_extended(self)
3290 if kwargs is None:
3291 kwargs = {}
3292 fn = autograph.tf_convert(
3293 fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
3294 with self._container_strategy().scope():
3295 return self._update_non_slot(colocate_with, fn, args, kwargs, group)
3297 def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
3298 raise NotImplementedError("must be implemented in descendants")
3300 def non_slot_devices(self, var_list):
3301 """Device(s) for non-slot variables.
3303 DEPRECATED: TF 1.x ONLY.
3305 This method returns non-slot devices where non-slot variables are placed.
3306 Users can create non-slot variables on these devices by using a block:
3308 ```python
3309 with tf.distribute.StrategyExtended.colocate_vars_with(tf.distribute.StrategyExtended.non_slot_devices(...)):
3310 ...
3311 ```
3313 Args:
3314 var_list: The list of variables being optimized, needed with the
3315 default `tf.distribute.Strategy`.
3316 Returns:
3317 A sequence of devices for non-slot variables.
3318 """
3319 raise NotImplementedError("must be implemented in descendants")
3321 def _use_merge_call(self):
3322 """Whether to use merge-calls inside the distributed strategy."""
3323 return True
3325 @property
3326 def experimental_between_graph(self):
3327 """Whether the strategy uses between-graph replication or not.
3329 This is expected to return a constant value that will not be changed
3330 throughout its life cycle.
3331 """
3332 raise NotImplementedError("must be implemented in descendants")
3334 @property
3335 def experimental_should_init(self):
3336 """Whether initialization is needed."""
3337 raise NotImplementedError("must be implemented in descendants")
3339 @property
3340 def should_checkpoint(self):
3341 """Whether checkpointing is needed."""
3342 raise NotImplementedError("must be implemented in descendants")
3344 @property
3345 def should_save_summary(self):
3346 """Whether saving summaries is needed."""
3347 raise NotImplementedError("must be implemented in descendants")
3350# A note about the difference between the context managers
3351# `ReplicaContext` (defined here) and `_CurrentDistributionContext`
3352# (defined above) used by `tf.distribute.Strategy.scope()`:
3353#
3354# * a ReplicaContext is only present during a `run()`
3355# call (except during a `merge_run` call) and in such a scope it
3356# will be returned by calls to `get_replica_context()`. Implementers of new
3357# Strategy descendants will frequently also need to
3358# define a descendant of ReplicaContext, and are responsible for
3359# entering and exiting this context.
3360#
3361# * Strategy.scope() sets up a variable_creator scope that
3362# changes variable creation calls (e.g. to make mirrored
3363# variables). This is intended as an outer scope that users enter once
3364# around their model creation and graph definition. There is no
3365# anticipated need to define descendants of _CurrentDistributionContext.
3366# It sets the current Strategy for purposes of
3367# `get_strategy()` and `has_strategy()`
3368# and switches the thread mode to a "cross-replica context".
3369class ReplicaContextBase(object):
3370 """A class with a collection of APIs that can be called in a replica context.
3372 You can use `tf.distribute.get_replica_context` to get an instance of
3373 `ReplicaContext`, which can only be called inside the function passed to
3374 `tf.distribute.Strategy.run`.
3376 >>> strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1'])
3377 >>> def func():
3378 ... replica_context = tf.distribute.get_replica_context()
3379 ... return replica_context.replica_id_in_sync_group
3380 >>> strategy.run(func)
3381 PerReplica:{
3382 0: <tf.Tensor: shape=(), dtype=int32, numpy=0>,
3383 1: <tf.Tensor: shape=(), dtype=int32, numpy=1>
3384 }
3385 """
3387 def __init__(self, strategy, replica_id_in_sync_group):
3388 """Creates a ReplicaContext.
3390 Args:
3391 strategy: A `tf.distribute.Strategy`.
3392 replica_id_in_sync_group: An integer, a `Tensor` or None. Prefer an
3393 integer whenever possible to avoid issues with nested `tf.function`. It
3394 accepts a `Tensor` only to be compatible with `tpu.replicate`.
3395 """
3396 self._strategy = strategy
3397 self._thread_context = _InReplicaThreadMode( # pylint: disable=protected-access
3398 self)
3399 if not (replica_id_in_sync_group is None or
3400 tensor_util.is_tf_type(replica_id_in_sync_group) or
3401 isinstance(replica_id_in_sync_group, int)):
3402 raise ValueError(
3403 "replica_id_in_sync_group can only be an integer, a Tensor or None.")
3404 self._replica_id_in_sync_group = replica_id_in_sync_group
3405 # We need this check because TPUContext extends from ReplicaContext and
3406 # does not pass a strategy object since it is used by TPUEstimator.
3407 if strategy:
3408 self._local_replica_id = strategy.extended._get_local_replica_id(
3409 replica_id_in_sync_group)
3410 self._summary_recording_distribution_strategy = None
3412 @doc_controls.do_not_generate_docs
3413 def __enter__(self):
3414 _push_per_thread_mode(self._thread_context)
3416 def replica_id_is_zero():
3417 return math_ops.equal(self.replica_id_in_sync_group,
3418 constant_op.constant(0))
3420 summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access
3421 self._summary_recording_distribution_strategy = (
3422 summary_state.is_recording_distribution_strategy)
3423 summary_state.is_recording_distribution_strategy = replica_id_is_zero
3425 @doc_controls.do_not_generate_docs
3426 def __exit__(self, exception_type, exception_value, traceback):
3427 summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access
3428 summary_state.is_recording_distribution_strategy = (
3429 self._summary_recording_distribution_strategy)
3430 _pop_per_thread_mode()
3432 def merge_call(self, merge_fn, args=(), kwargs=None):
3433 """Merge args across replicas and run `merge_fn` in a cross-replica context.
3435 This allows communication and coordination when there are multiple calls
3436 to the step_fn triggered by a call to `strategy.run(step_fn, ...)`.
3438 See `tf.distribute.Strategy.run` for an explanation.
3440 If not inside a distributed scope, this is equivalent to:
3442 ```
3443 strategy = tf.distribute.get_strategy()
3444 with cross-replica-context(strategy):
3445 return merge_fn(strategy, *args, **kwargs)
3446 ```
3448 Args:
3449 merge_fn: Function that joins arguments from threads that are given as
3450 PerReplica. It accepts `tf.distribute.Strategy` object as
3451 the first argument.
3452 args: List or tuple with positional per-thread arguments for `merge_fn`.
3453 kwargs: Dict with keyword per-thread arguments for `merge_fn`.
3455 Returns:
3456 The return value of `merge_fn`, except for `PerReplica` values which are
3457 unpacked.
3458 """
3459 require_replica_context(self)
3460 if kwargs is None:
3461 kwargs = {}
3463 merge_fn = autograph.tf_convert(
3464 merge_fn, autograph_ctx.control_status_ctx(), convert_by_default=False)
3465 return self._merge_call(merge_fn, args, kwargs)
3467 def _merge_call(self, merge_fn, args, kwargs):
3468 """Default implementation for single replica."""
3469 _push_per_thread_mode( # thread-local, so not needed with multiple threads
3470 _CrossReplicaThreadMode(self._strategy)) # pylint: disable=protected-access
3471 try:
3472 return merge_fn(self._strategy, *args, **kwargs)
3473 finally:
3474 _pop_per_thread_mode()
3476 @property
3477 def num_replicas_in_sync(self):
3478 """Returns number of replicas that are kept in sync."""
3479 return self._strategy.num_replicas_in_sync
3481 @property
3482 def replica_id_in_sync_group(self):
3483 """Returns the id of the replica.
3485 This identifies the replica among all replicas that are kept in sync. The
3486 value of the replica id can range from 0 to
3487 `tf.distribute.ReplicaContext.num_replicas_in_sync` - 1.
3489 NOTE: This is not guaranteed to be the same ID as the XLA replica ID use
3490 for low-level operations such as collective_permute.
3492 Returns:
3493 a `Tensor`.
3494 """
3495 # It's important to prefer making the Tensor at call time whenever possible.
3496 # Keeping Tensors in global states doesn't work well with nested
3497 # tf.function, since it's possible that the tensor is generated in one func
3498 # graph, and gets captured by another, which will result in a subtle "An op
3499 # outside of the function building code is being passed a Graph tensor"
3500 # error. Making the tensor at call time to ensure it is the same graph where
3501 # it's used. However to be compatible with tpu.replicate(),
3502 # self._replica_id_in_sync_group can also be a Tensor.
3503 if tensor_util.is_tf_type(self._replica_id_in_sync_group):
3504 return self._replica_id_in_sync_group
3505 return constant_op.constant(
3506 self._replica_id_in_sync_group,
3507 dtypes.int32,
3508 name="replica_id_in_sync_group")
3510 @property
3511 def _replica_id(self):
3512 """This is the local replica id in a given sync group."""
3513 return self._local_replica_id
3515 @property
3516 def strategy(self):
3517 """The current `tf.distribute.Strategy` object."""
3518 return self._strategy
3520 @property
3521 @deprecation.deprecated(None, "Please avoid relying on devices property.")
3522 def devices(self):
3523 """Returns the devices this replica is to be executed on, as a tuple of strings.
3525 NOTE: For `tf.distribute.MirroredStrategy` and
3526 `tf.distribute.experimental.MultiWorkerMirroredStrategy`, this returns a
3527 nested
3528 list of device strings, e.g, [["GPU:0"]].
3529 """
3530 require_replica_context(self)
3531 return (device_util.current(),)
3533 def all_reduce(self, reduce_op, value, options=None):
3534 """All-reduces `value` across all replicas.
3536 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
3537 >>> def step_fn():
3538 ... ctx = tf.distribute.get_replica_context()
3539 ... value = tf.identity(1.)
3540 ... return ctx.all_reduce(tf.distribute.ReduceOp.SUM, value)
3541 >>> strategy.experimental_local_results(strategy.run(step_fn))
3542 (<tf.Tensor: shape=(), dtype=float32, numpy=2.0>,
3543 <tf.Tensor: shape=(), dtype=float32, numpy=2.0>)
3545 It supports batched operations. You can pass a list of values and it
3546 attempts to batch them when possible. You can also specify `options`
3547 to indicate the desired batching behavior, e.g. batch the values into
3548 multiple packs so that they can better overlap with computations.
3550 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
3551 >>> def step_fn():
3552 ... ctx = tf.distribute.get_replica_context()
3553 ... value1 = tf.identity(1.)
3554 ... value2 = tf.identity(2.)
3555 ... return ctx.all_reduce(tf.distribute.ReduceOp.SUM, [value1, value2])
3556 >>> strategy.experimental_local_results(strategy.run(step_fn))
3557 ([<tf.Tensor: shape=(), dtype=float32, numpy=2.0>,
3558 <tf.Tensor: shape=(), dtype=float32, numpy=4.0>],
3559 [<tf.Tensor: shape=(), dtype=float32, numpy=2.0>,
3560 <tf.Tensor: shape=(), dtype=float32, numpy=4.0>])
3562 Note that all replicas need to participate in the all-reduce, otherwise this
3563 operation hangs. Note that if there're multiple all-reduces, they need to
3564 execute in the same order on all replicas. Dispatching all-reduce based on
3565 conditions is usually error-prone.
3567 Known limitation: if `value` contains `tf.IndexedSlices`, attempting to
3568 compute gradient w.r.t `value` would result in an error.
3570 This API currently can only be called in the replica context. Other
3571 variants to reduce values across replicas are:
3572 * `tf.distribute.StrategyExtended.reduce_to`: the reduce and all-reduce API
3573 in the cross-replica context.
3574 * `tf.distribute.StrategyExtended.batch_reduce_to`: the batched reduce and
3575 all-reduce API in the cross-replica context.
3576 * `tf.distribute.Strategy.reduce`: a more convenient method to reduce
3577 to the host in cross-replica context.
3579 Args:
3580 reduce_op: a `tf.distribute.ReduceOp` value specifying how values should
3581 be combined. Allows using string representation of the enum such as
3582 "SUM", "MEAN".
3583 value: a potentially nested structure of `tf.Tensor` or `tf.IndexedSlices` which
3584 `tf.nest.flatten` accepts. The structure and the shapes of `value` need to be
3585 same on all replicas.
3586 options: a `tf.distribute.experimental.CommunicationOptions`. Options to
3587 perform collective operations. This overrides the default options if the
3588 `tf.distribute.Strategy` takes one in the constructor. See
3589 `tf.distribute.experimental.CommunicationOptions` for details of the
3590 options.
3592 Returns:
3593 A nested structure of `tf.Tensor` with the reduced values. The structure
3594 is the same as `value`.
3595 """
3596 flattened_value = nest.flatten(value)
3597 has_indexed_slices = False
3599 for v in flattened_value:
3600 if isinstance(v, indexed_slices.IndexedSlices):
3601 has_indexed_slices = True
3603 if isinstance(reduce_op, six.string_types):
3604 reduce_op = reduce_util.ReduceOp(reduce_op.upper())
3605 if options is None:
3606 options = collective_util.Options()
3608 def batch_all_reduce(strategy, *value_flat):
3609 return strategy.extended.batch_reduce_to(
3610 reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat],
3611 options)
3613 # Due to the use of `capture_call_time_value` in collective ops, we have
3614 # to maintain two branches: one w/ merge_call and one w/o. Details can be
3615 # found in b/184009754.
3616 if self._strategy.extended._use_merge_call(): # pylint: disable=protected-access
3617 # TODO(cjfj): Work out why `batch_reduce` doesn't return the correct grad.
3618 if has_indexed_slices:
3619 return nest.pack_sequence_as(
3620 value,
3621 self.merge_call(batch_all_reduce, args=flattened_value))
3623 @custom_gradient.custom_gradient
3624 def grad_wrapper(*xs):
3625 ys = self.merge_call(batch_all_reduce, args=xs)
3626 # The gradient of an all-sum is itself an all-sum (all-mean, likewise).
3627 return ys, lambda *dy_s: self.all_reduce(reduce_op, dy_s)
3628 return nest.pack_sequence_as(value, grad_wrapper(*flattened_value))
3629 else:
3630 if has_indexed_slices:
3631 return nest.pack_sequence_as(
3632 value,
3633 self._strategy.extended._replica_ctx_all_reduce( # pylint: disable=protected-access
3634 reduce_op, flattened_value, options))
3636 @custom_gradient.custom_gradient
3637 def grad_wrapper(*xs):
3638 ys = self._strategy.extended._replica_ctx_all_reduce( # pylint: disable=protected-access
3639 reduce_op, xs, options)
3640 # The gradient of an all-sum is itself an all-sum (all-mean, likewise).
3641 return ys, lambda *dy_s: self.all_reduce(reduce_op, dy_s)
3643 return nest.pack_sequence_as(value, grad_wrapper(*flattened_value))
3645 # TODO(josh11b): Implement `start_all_reduce(method, t)` for efficient
3646 # all-reduce. It would return a function returning the result of reducing `t`
3647 # across all replicas. The caller would wait to call this function until they
3648 # needed the reduce result, allowing an efficient implementation:
3649 # * With eager execution, the reduction could be performed asynchronously
3650 # in the background, not blocking until the result was needed.
3651 # * When constructing a graph, it could batch up all reduction requests up
3652 # to that point that the first result is needed. Most likely this can be
3653 # implemented in terms of `merge_call()` and `batch_reduce_to()`.
3656@tf_export("distribute.ReplicaContext", v1=[])
3657class ReplicaContext(ReplicaContextBase):
3659 __doc__ = ReplicaContextBase.__doc__
3661 def all_gather(self, value, axis, options=None):
3662 """All-gathers `value` across all replicas along `axis`.
3664 Note: An `all_gather` method can only be called in replica context. For
3665 a cross-replica context counterpart, see `tf.distribute.Strategy.gather`.
3666 All replicas need to participate in the all-gather, otherwise this
3667 operation hangs. So if `all_gather` is called in any replica, it must be
3668 called in all replicas.
3670 Note: If there are multiple `all_gather` calls, they need to be executed in
3671 the same order on all replicas. Dispatching `all_gather` based on conditions
3672 is usually error-prone.
3674 For all strategies except `tf.distribute.TPUStrategy`, the input
3675 `value` on different replicas must have the same rank, and their shapes must
3676 be the same in all dimensions except the `axis`-th dimension. In other
3677 words, their shapes cannot be different in a dimension `d` where `d` does
3678 not equal to the `axis` argument. For example, given a
3679 `tf.distribute.DistributedValues` with component tensors of shape
3680 `(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can call
3681 `all_gather(..., axis=1, ...)` on it, but not `all_gather(..., axis=0, ...)`
3682 or `all_gather(..., axis=2, ...)`. However, with
3683 `tf.distribute.TPUStrategy`, all tensors must have exactly the same rank and
3684 same shape.
3686 Note: The input `value` must have a non-zero rank. Otherwise, consider using
3687 `tf.expand_dims` before gathering them.
3689 You can pass in a single tensor to all-gather:
3691 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
3692 >>> @tf.function
3693 ... def gather_value():
3694 ... ctx = tf.distribute.get_replica_context()
3695 ... local_value = tf.constant([1, 2, 3])
3696 ... return ctx.all_gather(local_value, axis=0)
3697 >>> result = strategy.run(gather_value)
3698 >>> result
3699 PerReplica:{
3700 0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
3701 1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>
3702 }
3703 >>> strategy.experimental_local_results(result)
3704 (<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3],
3705 dtype=int32)>,
3706 <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3],
3707 dtype=int32)>)
3710 You can also pass in a nested structure of tensors to all-gather, say, a
3711 list:
3713 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
3714 >>> @tf.function
3715 ... def gather_nest():
3716 ... ctx = tf.distribute.get_replica_context()
3717 ... value_1 = tf.constant([1, 2, 3])
3718 ... value_2 = tf.constant([[1, 2], [3, 4]])
3719 ... # all_gather a nest of `tf.distribute.DistributedValues`
3720 ... return ctx.all_gather([value_1, value_2], axis=0)
3721 >>> result = strategy.run(gather_nest)
3722 >>> result
3723 [PerReplica:{
3724 0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
3725 1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>
3726 }, PerReplica:{
3727 0: <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
3728 array([[1, 2],
3729 [3, 4],
3730 [1, 2],
3731 [3, 4]], dtype=int32)>,
3732 1: <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
3733 array([[1, 2],
3734 [3, 4],
3735 [1, 2],
3736 [3, 4]], dtype=int32)>
3737 }]
3738 >>> strategy.experimental_local_results(result)
3739 ([<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
3740 <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
3741 array([[1, 2],
3742 [3, 4],
3743 [1, 2],
3744 [3, 4]], dtype=int32)>],
3745 [<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
3746 <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
3747 array([[1, 2],
3748 [3, 4],
3749 [1, 2],
3750 [3, 4]], dtype=int32)>])
3753 What if you are all-gathering tensors with different shapes on different
3754 replicas? Consider the following example with two replicas, where you have
3755 `value` as a nested structure consisting of two items to all-gather, `a` and
3756 `b`.
3758 * On Replica 0, `value` is `{'a': [0], 'b': [[0, 1]]}`.
3759 * On Replica 1, `value` is `{'a': [1], 'b': [[2, 3], [4, 5]]}`.
3760 * Result for `all_gather` with `axis=0` (on each of the replicas) is:
3762 ```
3763 {'a': [1, 2], 'b': [[0, 1], [2, 3], [4, 5]]}
3764 ```
3766 Args:
3767 value: a nested structure of `tf.Tensor` which `tf.nest.flatten` accepts,
3768 or a `tf.distribute.DistributedValues` instance. The structure of the
3769 `tf.Tensor` need to be same on all replicas. The underlying tensor
3770 constructs can only be dense tensors with non-zero rank, NOT
3771 `tf.IndexedSlices`.
3772 axis: 0-D int32 Tensor. Dimension along which to gather.
3773 options: a `tf.distribute.experimental.CommunicationOptions`. Options to
3774 perform collective operations. This overrides the default options if the
3775 `tf.distribute.Strategy` takes one in the constructor. See
3776 `tf.distribute.experimental.CommunicationOptions` for details of the
3777 options.
3779 Returns:
3780 A nested structure of `tf.Tensor` with the gathered values. The structure
3781 is the same as `value`.
3782 """
3783 for v in nest.flatten(value):
3784 if isinstance(v, indexed_slices.IndexedSlices):
3785 raise NotImplementedError("all_gather does not support IndexedSlices")
3787 if options is None:
3788 options = collective_util.Options()
3790 def batch_all_gather(strategy, *value_flat):
3791 return strategy.extended._batch_gather_to( # pylint: disable=protected-access
3792 [(v, _batch_reduce_destination(v)) for v in value_flat], axis,
3793 options)
3795 @custom_gradient.custom_gradient
3796 def grad_wrapper(*xs):
3797 ys = self.merge_call(batch_all_gather, args=xs)
3799 def grad(*dy_s):
3800 grads = self.all_reduce(reduce_util.ReduceOp.SUM, dy_s)
3801 new_grads = []
3802 for i, grad in enumerate(grads):
3803 input_shape = array_ops.shape(xs[i])
3804 axis_dim = array_ops.reshape(input_shape[axis], [1])
3805 with ops.control_dependencies([array_ops.identity(grads)]):
3806 d = self.all_gather(axis_dim, axis=0)
3807 begin_dim = math_ops.reduce_sum(d[:self.replica_id_in_sync_group])
3808 end_dim = begin_dim + array_ops.shape(xs[i])[axis]
3809 new_grad = array_ops.gather(
3810 grad, axis=axis, indices=math_ops.range(begin_dim, end_dim))
3811 new_grads.append(new_grad)
3812 return new_grads
3814 return ys, grad
3816 return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value)))
3818 def _update(self, var, fn, args=(), kwargs=None, group=True):
3819 """Run `fn` to update `var` with `args` and `kwargs` in replica context.
3821 `tf.distribute.ReplicaContext.update` takes a (distributed) variable `var`
3822 to be updated, an update function `fn`, and `args` and `kwargs` for `fn`.
3823 `fn` applies to each component variable of `var` with corresponding input
3824 values from `args` and `kwargs`.
3826 Example usage:
3828 >>> strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # 2 replicas
3829 >>> with strategy.scope():
3830 ... distributed_variable = tf.Variable(5.0)
3831 >>> distributed_variable
3832 MirroredVariable:{
3833 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>,
3834 1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=5.0>
3835 }
3836 >>> def replica_fn(v):
3837 ... value = tf.identity(1.0)
3838 ... replica_context = tf.distribute.get_replica_context()
3839 ... update_fn = lambda var, value: var.assign(value)
3840 ... replica_context._update(v, update_fn, args=(value,))
3841 >>> strategy.run(replica_fn, args=(distributed_variable,))
3842 >>> distributed_variable
3843 MirroredVariable:{
3844 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>,
3845 1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0>
3846 }
3848 This API must be called in a replica context.
3850 Note that if `var` is a MirroredVariable (i.e., the type of variable created
3851 under the scope of a synchronous strategy, and is synchronized on-write, see
3852 `tf.VariableSynchronization` for more information) and `args`/`kwargs`
3853 contains different values for different replicas, `var` will be dangerously
3854 out of synchronization. Thus we recommend using `variable.assign(value)` as
3855 long as you can, which under the hood aggregates the updates and guarantees
3856 the synchronization. The case where you actually want this API instead of
3857 `variable.assign(value)` is that before assigning `value` to the `variable`,
3858 you'd like to conduct some pre-`assign` computation colocated with the
3859 variable devices (i.e. where variables reside, for MirroredStrategy they are
3860 the same as the compute device, for ParameterServerStrategy they refer to
3861 parameter servers). E.g.,
3863 ```python
3864 strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # 2 replicas
3865 with strategy.scope():
3866 v = tf.Variable(5.0, aggregation=tf.VariableAggregation.SUM)
3867 def replica_fn(inputs):
3868 value = computation(inputs)
3869 replica_context = tf.distribute.get_replica_context()
3870 reduced_value = replica_context.all_reduce(value)
3872 def update_fn(var, value):
3873 # this computation will colocate with `var`'s device
3874 updated_value = post_reduce_pre_update_computation(value)
3875 var.assign(value)
3877 replica_context._update(v, update_fn, args=(reduced_value,))
3879 strategy.run(replica_fn, args=(inputs,))
3880 ```
3882 This code snippet is consistent across all strategies. If you directly
3883 compute and use `assign` in the replica context instead of wrapping it with
3884 `update`, for strategies with fewer variable devices than compute devices
3885 (e.g., parameter server strategy, usually), the
3886 `post_reduce_pre_update_computation` will happen
3887 N==number_of_compute_devices times which is less performant.
3890 Args:
3891 var: Variable, possibly distributed to multiple devices, to operate on.
3892 fn: Function to call. Should take the variable as the first argument.
3893 args: Tuple or list. Additional positional arguments to pass to `fn()`.
3894 kwargs: Dict with keyword arguments to pass to `fn()`.
3895 group: Boolean. Defaults to True. Most strategies enter a merge_call to
3896 conduct update in cross-replica context, and group=True guarantees updates
3897 on all replicas is executed.
3899 Returns:
3900 The return value of `fn` for the local replica.
3901 """
3902 if kwargs is None:
3903 kwargs = {}
3904 return self._strategy.extended._replica_ctx_update(var, fn, args=args, kwargs=kwargs, group=group) # pylint: disable=protected-access
3907@tf_export(v1=["distribute.ReplicaContext"])
3908class ReplicaContextV1(ReplicaContextBase):
3909 __doc__ = ReplicaContextBase.__doc__
3912def _batch_reduce_destination(x):
3913 """Returns the destinations for batch all-reduce."""
3914 if isinstance(x, ops.Tensor):
3915 # If this is a one device strategy.
3916 return x.device
3917 else:
3918 return x
3919# ------------------------------------------------------------------------------
3922class _DefaultDistributionStrategyV1(StrategyV1):
3923 """Default `tf.distribute.Strategy` if none is explicitly selected."""
3925 def __init__(self):
3926 if not _creating_default_strategy_singleton:
3927 raise RuntimeError("Should only create a single instance of "
3928 "_DefaultDistributionStrategy")
3929 super(_DefaultDistributionStrategyV1,
3930 self).__init__(_DefaultDistributionExtended(self))
3932 def __deepcopy__(self, memo):
3933 del memo
3934 raise RuntimeError("Should only create a single instance of "
3935 "_DefaultDistributionStrategy")
3938class _DefaultDistributionStrategy(Strategy):
3939 """Default `tf.distribute.Strategy` if none is explicitly selected."""
3941 def __init__(self):
3942 if not _creating_default_strategy_singleton:
3943 raise RuntimeError("Should only create a single instance of "
3944 "_DefaultDistributionStrategy")
3945 super(_DefaultDistributionStrategy, self).__init__(
3946 _DefaultDistributionExtended(self))
3948 def __deepcopy__(self, memo):
3949 del memo
3950 raise RuntimeError("Should only create a single instance of "
3951 "_DefaultDistributionStrategy")
3954class _DefaultDistributionContext(object):
3955 """Context manager setting the default `tf.distribute.Strategy`."""
3957 __slots__ = ["_var_creator_scope", "_strategy", "_nested_count"]
3959 def __init__(self, strategy):
3961 def creator(next_creator, **kwargs):
3962 _require_strategy_scope_strategy(strategy)
3963 return next_creator(**kwargs)
3965 self._var_creator_scope = variable_scope.variable_creator_scope(creator)
3966 self._strategy = strategy
3967 self._nested_count = 0
3969 def __enter__(self):
3970 # Allow this scope to be entered if this strategy is already in scope.
3971 if has_strategy():
3972 raise RuntimeError("Must not nest tf.distribute.Strategy scopes.")
3973 if self._nested_count == 0:
3974 self._var_creator_scope.__enter__()
3975 self._nested_count += 1
3976 return self._strategy
3978 def __exit__(self, exception_type, exception_value, traceback):
3979 self._nested_count -= 1
3980 if self._nested_count == 0:
3981 try:
3982 self._var_creator_scope.__exit__(
3983 exception_type, exception_value, traceback)
3984 except RuntimeError as e:
3985 six.raise_from(
3986 RuntimeError("Variable creator scope nesting error: move call to "
3987 "tf.distribute.set_strategy() out of `with` scope."),
3988 e)
3991class _DefaultDistributionExtended(StrategyExtendedV1):
3992 """Implementation of _DefaultDistributionStrategy."""
3994 def __init__(self, container_strategy):
3995 super(_DefaultDistributionExtended, self).__init__(container_strategy)
3996 self._retrace_functions_for_each_device = False
3998 def _scope(self, strategy):
3999 """Context manager setting a variable creator and `self` as current."""
4000 return _DefaultDistributionContext(strategy)
4002 def colocate_vars_with(self, colocate_with_variable):
4003 """Does not require `self.scope`."""
4004 _require_strategy_scope_extended(self)
4005 return ops.colocate_with(colocate_with_variable)
4007 def variable_created_in_scope(self, v):
4008 return v._distribute_strategy is None # pylint: disable=protected-access
4010 def _experimental_distribute_dataset(self, dataset, options):
4011 return dataset
4013 def _distribute_datasets_from_function(self, dataset_fn, options):
4014 return dataset_fn(InputContext())
4016 def _experimental_distribute_values_from_function(self, value_fn):
4017 return value_fn(ValueContext())
4019 def _make_dataset_iterator(self, dataset):
4020 return _DefaultDistributionExtended.DefaultInputIterator(dataset)
4022 def _make_input_fn_iterator(self,
4023 input_fn,
4024 replication_mode=InputReplicationMode.PER_WORKER):
4025 dataset = input_fn(InputContext())
4026 return _DefaultDistributionExtended.DefaultInputIterator(dataset)
4028 def _experimental_make_numpy_dataset(self, numpy_input, session):
4029 numpy_flat = nest.flatten(numpy_input)
4030 vars_flat = tuple(
4031 variable_v1.VariableV1(array_ops.zeros(i.shape, i.dtype),
4032 trainable=False, use_resource=True)
4033 for i in numpy_flat
4034 )
4035 for v, i in zip(vars_flat, numpy_flat):
4036 numpy_dataset.init_var_from_numpy(v, i, session)
4037 vars_nested = nest.pack_sequence_as(numpy_input, vars_flat)
4038 return dataset_ops.Dataset.from_tensor_slices(vars_nested)
4040 def _broadcast_to(self, tensor, destinations):
4041 if destinations is None:
4042 return tensor
4043 else:
4044 raise NotImplementedError("TODO")
4046 def _call_for_each_replica(self, fn, args, kwargs):
4047 with ReplicaContext(self._container_strategy(), replica_id_in_sync_group=0):
4048 return fn(*args, **kwargs)
4050 def _reduce_to(self, reduce_op, value, destinations, options):
4051 # TODO(josh11b): Use destinations?
4052 del reduce_op, destinations, options
4053 return value
4055 def _gather_to_implementation(self, value, destinations, axis, options):
4056 del destinations, axis, options
4057 return value
4059 def _update(self, var, fn, args, kwargs, group):
4060 # The implementations of _update() and _update_non_slot() are identical
4061 # except _update() passes `var` as the first argument to `fn()`.
4062 return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group)
4064 def _update_non_slot(self, colocate_with, fn, args, kwargs, should_group):
4065 # TODO(josh11b): Figure out what we should be passing to UpdateContext()
4066 # once that value is used for something.
4067 with UpdateContext(colocate_with):
4068 result = fn(*args, **kwargs)
4069 if should_group:
4070 return result
4071 else:
4072 return nest.map_structure(self._local_results, result)
4074 def read_var(self, replica_local_var):
4075 return array_ops.identity(replica_local_var)
4077 def _local_results(self, distributed_value):
4078 return (distributed_value,)
4080 def value_container(self, value):
4081 return value
4083 @property
4084 def _num_replicas_in_sync(self):
4085 return 1
4087 @property
4088 def worker_devices(self):
4089 raise RuntimeError("worker_devices() method unsupported by default "
4090 "tf.distribute.Strategy.")
4092 @property
4093 def parameter_devices(self):
4094 raise RuntimeError("parameter_devices() method unsupported by default "
4095 "tf.distribute.Strategy.")
4097 def non_slot_devices(self, var_list):
4098 return min(var_list, key=lambda x: x.name)
4100 def _in_multi_worker_mode(self):
4101 """Whether this strategy indicates working in multi-worker settings."""
4102 # Default strategy doesn't indicate multi-worker training.
4103 return False
4105 @property
4106 def should_checkpoint(self):
4107 return True
4109 @property
4110 def should_save_summary(self):
4111 return True
4113 def _get_local_replica_id(self, replica_id_in_sync_group):
4114 return replica_id_in_sync_group
4116 def _get_replica_id_in_sync_group(self, replica_id):
4117 return replica_id
4119 # TODO(priyag): This should inherit from `InputIterator`, once dependency
4120 # issues have been resolved.
4121 class DefaultInputIterator(object):
4122 """Default implementation of `InputIterator` for default strategy."""
4124 def __init__(self, dataset):
4125 self._dataset = dataset
4126 if eager_context.executing_eagerly():
4127 self._iterator = dataset_ops.make_one_shot_iterator(dataset)
4128 else:
4129 self._iterator = dataset_ops.make_initializable_iterator(dataset)
4131 def get_next(self):
4132 return self._iterator.get_next()
4134 def get_next_as_optional(self):
4135 return self._iterator.get_next_as_optional()
4137 @deprecated(None, "Use the iterator's `initializer` property instead.")
4138 def initialize(self):
4139 """Initialize underlying iterators.
4141 Returns:
4142 A list of any initializer ops that should be run.
4143 """
4144 if eager_context.executing_eagerly():
4145 self._iterator = self._dataset.make_one_shot_iterator()
4146 return []
4147 else:
4148 return [self._iterator.initializer]
4150 @property
4151 def initializer(self):
4152 """Returns a list of ops that initialize the iterator."""
4153 return self.initialize()
4155 # TODO(priyag): Delete this once all strategies use global batch size.
4156 @property
4157 def _global_batch_size(self):
4158 """Global and per-replica batching are equivalent for this strategy."""
4159 return True
4162class _DefaultReplicaContext(ReplicaContext):
4163 """ReplicaContext for _DefaultDistributionStrategy."""
4165 @property
4166 def replica_id_in_sync_group(self):
4167 # Return 0 instead of a constant tensor to avoid creating a new node for
4168 # users who don't use distribution strategy.
4169 return 0
4172# ------------------------------------------------------------------------------
4173# We haven't yet implemented deserialization for DistributedVariables.
4174# So here we catch any attempts to deserialize variables
4175# when using distribution strategies.
4176# pylint: disable=protected-access
4177_original_from_proto = ref_variable._from_proto_fn
4180def _from_proto_fn(v, import_scope=None):
4181 if has_strategy():
4182 raise NotImplementedError(
4183 "Deserialization of variables is not yet supported when using a "
4184 "tf.distribute.Strategy.")
4185 else:
4186 return _original_from_proto(v, import_scope=import_scope)
4188ref_variable._from_proto_fn = _from_proto_fn
4189# pylint: enable=protected-access
4192def get_local_results_or_value_container(variable):
4193 strategy, context = get_strategy_and_replica_context()
4194 if context:
4195 return [strategy.extended.value_container(variable)]
4196 else:
4197 return strategy.experimental_local_results(variable)
4200tape.register_watched_variable_resolver(get_local_results_or_value_container)
4203# ------------------------------------------------------------------------------
4204# Metrics to track which distribution strategy is being called
4205distribution_strategy_gauge = monitoring.StringGauge(
4206 "/tensorflow/api/distribution_strategy",
4207 "Gauge to track the type of distribution strategy used.", "TFVersion")
4208distribution_strategy_replica_gauge = monitoring.IntGauge(
4209 "/tensorflow/api/distribution_strategy/replica",
4210 "Gauge to track the number of replica each distribution strategy used.",
4211 "CountType")
4212distribution_strategy_input_api_counter = monitoring.Counter(
4213 "/tensorflow/api/distribution_strategy/input_api",
4214 "Counter to track the usage of the input APIs", "strategy", "api")
4215distributed_variable_creation_time_counter = monitoring.Counter(
4216 "/tensorflow/api/distribution_strategy/distributed_variable_creation_time_usecs",
4217 "Time to create distributed variables (us).", "strategy", "if_graph_building")