Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/tpu_strategy.py: 22%
667 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TPU Strategy."""
17import atexit
18import collections
19import contextlib
20import copy
21import functools
22import weakref
24from absl import logging
25import numpy as np
27from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
28from tensorflow.python.autograph.impl import api as autograph
29from tensorflow.python.compiler.xla.experimental import xla_sharding
30from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
31from tensorflow.python.distribute import device_util
32from tensorflow.python.distribute import distribute_lib
33from tensorflow.python.distribute import distribute_utils
34from tensorflow.python.distribute import input_lib
35from tensorflow.python.distribute import input_util
36from tensorflow.python.distribute import numpy_dataset
37from tensorflow.python.distribute import reduce_util
38from tensorflow.python.distribute import tpu_replicated_variable
39from tensorflow.python.distribute import tpu_util
40from tensorflow.python.distribute import tpu_values
41from tensorflow.python.distribute import values
42from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver
43from tensorflow.python.distribute.v1 import input_lib as input_lib_v1
44from tensorflow.python.eager import context
45from tensorflow.python.eager import def_function
46from tensorflow.python.eager import function
47from tensorflow.python.framework import constant_op
48from tensorflow.python.framework import device as tf_device
49from tensorflow.python.framework import device_spec
50from tensorflow.python.framework import dtypes
51from tensorflow.python.framework import indexed_slices
52from tensorflow.python.framework import ops
53from tensorflow.python.framework import sparse_tensor
54from tensorflow.python.framework import tensor_shape
55from tensorflow.python.framework import tensor_util
56from tensorflow.python.ops import array_ops
57from tensorflow.python.ops import control_flow_ops
58from tensorflow.python.ops import math_ops
59from tensorflow.python.ops import resource_variable_ops
60from tensorflow.python.ops import variables as variables_lib
61from tensorflow.python.ops.ragged import ragged_tensor
62from tensorflow.python.tpu import device_assignment as device_assignment_lib # pylint: disable=unused-import
63from tensorflow.python.tpu import tpu
64from tensorflow.python.tpu import tpu_hardware_feature
65from tensorflow.python.tpu import tpu_strategy_util
66from tensorflow.python.tpu import training_loop
67from tensorflow.python.tpu.ops import tpu_ops
68from tensorflow.python.util import deprecation
69from tensorflow.python.util import nest
70from tensorflow.python.util import tf_inspect
71from tensorflow.python.util.tf_export import tf_export
74_XLA_OP_BY_OP_INPUTS_LIMIT = 200
77@contextlib.contextmanager
78def maybe_init_scope():
79 if ops.executing_eagerly_outside_functions():
80 yield
81 else:
82 with ops.init_scope():
83 yield
86def validate_run_function(fn):
87 """Validate the function passed into strategy.run."""
89 # We allow three types of functions/objects passed into TPUStrategy
90 # run in eager mode:
91 # 1. a user annotated tf.function
92 # 2. a ConcreteFunction, this is mostly what you get from loading a saved
93 # model.
94 # 3. a callable object and the `__call__` method itself is a tf.function.
95 #
96 # Otherwise we return an error, because we don't support eagerly running
97 # run in TPUStrategy.
99 if context.executing_eagerly() \
100 and not isinstance(fn, def_function.Function) \
101 and not isinstance(fn, function.ConcreteFunction) \
102 and not (callable(fn) and isinstance(fn.__call__, def_function.Function)):
103 raise NotImplementedError(
104 "TPUStrategy.run(fn, ...) does not support pure eager "
105 "execution. please make sure the function passed into "
106 "`strategy.run` is a `tf.function` or "
107 "`strategy.run` is called inside a `tf.function` if "
108 "eager behavior is enabled.")
111def _maybe_partial_apply_variables(fn, args, kwargs):
112 """Inspects arguments to partially apply any DistributedVariable.
114 This avoids an automatic cast of the current variable value to tensor.
116 Note that a variable may be captured implicitly with Python scope instead of
117 passing it to run(), but supporting run() keeps behavior consistent
118 with MirroredStrategy.
120 Since positional arguments must be applied from left to right, this function
121 does some tricky function inspection to move variable positional arguments
122 into kwargs. As a result of this, we can't support passing Variables as *args,
123 nor as args to functions which combine both explicit positional arguments and
124 *args.
126 Args:
127 fn: The function to run, as passed to run().
128 args: Positional arguments to fn, as passed to run().
129 kwargs: Keyword arguments to fn, as passed to run().
131 Returns:
132 A tuple of the function (possibly wrapped), args, kwargs (both
133 possibly filtered, with members of args possibly moved to kwargs).
134 If no variables are found, this function is a noop.
136 Raises:
137 ValueError: If the function signature makes unsupported use of *args, or if
138 too many arguments are passed.
139 """
141 def is_distributed_var(x):
142 flat = nest.flatten(x)
143 return flat and isinstance(flat[0], values.DistributedVariable)
145 # We will split kwargs into two dicts, one of which will be applied now.
146 var_kwargs = {}
147 nonvar_kwargs = {}
149 if kwargs:
150 var_kwargs = {k: v for k, v in kwargs.items() if is_distributed_var(v)}
151 if var_kwargs:
152 nonvar_kwargs = {
153 k: v for k, v in kwargs.items() if not is_distributed_var(v)
154 }
156 # Dump the argument names of `fn` to a list. This will include both positional
157 # and keyword arguments, but since positional arguments come first we can
158 # look up names of positional arguments by index.
159 positional_args = []
160 index_of_star_args = None
161 for i, p in enumerate(tf_inspect.signature(fn).parameters.values()):
162 # Class methods define "self" as first argument, but we don't pass "self".
163 # Note that this is a heuristic, as a method can name its first argument
164 # something else, and a function can define a first argument "self" as well.
165 # In both of these cases, using a Variable will fail with an unfortunate
166 # error about the number of arguments.
167 # inspect.is_method() seems not to work here, possibly due to the use of
168 # tf.function().
169 if i == 0 and p.name == "self":
170 continue
172 if p.kind == tf_inspect.Parameter.POSITIONAL_OR_KEYWORD:
173 positional_args.append(p.name)
175 elif p.kind == tf_inspect.Parameter.VAR_POSITIONAL:
176 # We'll raise an error later if a variable is passed to *args, since we
177 # can neither pass it by name nor partially apply it. This case only
178 # happens once at most.
179 index_of_star_args = i
181 elif p.kind == tf_inspect.Parameter.POSITIONAL_ONLY:
182 # This is a rare Python feature, indicating a / in the arg list.
183 if var_kwargs or any(is_distributed_var(a) for a in args):
184 raise ValueError(
185 "Mixing Variables and positional-only parameters not supported by "
186 f"TPUStrategy. Received {len(var_kwargs)} DistributedVariables in "
187 f"**kwargs and {sum(is_distributed_var(a) for a in args)} in *args,"
188 " expected zero for both."
189 )
190 return fn, args, kwargs
192 star_args = []
193 have_seen_var_arg = False
195 for i, a in enumerate(args):
196 if is_distributed_var(a):
197 if index_of_star_args is not None and i >= index_of_star_args:
198 raise ValueError(
199 "TPUStrategy.run() cannot handle Variables passed to *args. "
200 "Either name the function argument, or capture the Variable "
201 "implicitly.")
202 if len(positional_args) <= i:
203 raise ValueError(
204 "Too many positional arguments passed to call to TPUStrategy.run()."
205 )
206 var_kwargs[positional_args[i]] = a
207 have_seen_var_arg = True
208 else:
209 if index_of_star_args is not None and i >= index_of_star_args:
210 if have_seen_var_arg:
211 raise ValueError(
212 "TPUStrategy.run() cannot handle both Variables and a mix of "
213 "positional args and *args. Either remove the *args, or capture "
214 "the Variable implicitly.")
215 else:
216 star_args.append(a)
217 continue
219 if len(positional_args) <= i:
220 raise ValueError(
221 "Too many positional arguments passed to call to TPUStrategy.run()."
222 )
223 nonvar_kwargs[positional_args[i]] = a
225 if var_kwargs:
226 return functools.partial(fn, **var_kwargs), star_args, nonvar_kwargs
227 return fn, args, kwargs
230@tf_export("distribute.TPUStrategy", v1=[])
231class TPUStrategyV2(distribute_lib.Strategy):
232 """Synchronous training on TPUs and TPU Pods.
234 To construct a TPUStrategy object, you need to run the
235 initialization code as below:
237 >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
238 >>> tf.config.experimental_connect_to_cluster(resolver)
239 >>> tf.tpu.experimental.initialize_tpu_system(resolver)
240 >>> strategy = tf.distribute.TPUStrategy(resolver)
242 While using distribution strategies, the variables created within the
243 strategy's scope will be replicated across all the replicas and can be kept in
244 sync using all-reduce algorithms.
246 To run TF2 programs on TPUs, you can either use `.compile` and
247 `.fit` APIs in `tf.keras` with TPUStrategy, or write your own customized
248 training loop by calling `strategy.run` directly. Note that
249 TPUStrategy doesn't support pure eager execution, so please make sure the
250 function passed into `strategy.run` is a `tf.function` or
251 `strategy.run` is called inside a `tf.function` if eager
252 behavior is enabled. See more details in https://www.tensorflow.org/guide/tpu.
254 `distribute_datasets_from_function` and
255 `experimental_distribute_dataset` APIs can be used to distribute the dataset
256 across the TPU workers when writing your own training loop. If you are using
257 `fit` and `compile` methods available in `tf.keras.Model`, then Keras will
258 handle the distribution for you.
260 An example of writing customized training loop on TPUs:
262 >>> with strategy.scope():
263 ... model = tf.keras.Sequential([
264 ... tf.keras.layers.Dense(2, input_shape=(5,)),
265 ... ])
266 ... optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
268 >>> def dataset_fn(ctx):
269 ... x = np.random.random((2, 5)).astype(np.float32)
270 ... y = np.random.randint(2, size=(2, 1))
271 ... dataset = tf.data.Dataset.from_tensor_slices((x, y))
272 ... return dataset.repeat().batch(1, drop_remainder=True)
273 >>> dist_dataset = strategy.distribute_datasets_from_function(
274 ... dataset_fn)
275 >>> iterator = iter(dist_dataset)
277 >>> @tf.function()
278 ... def train_step(iterator):
279 ...
280 ... def step_fn(inputs):
281 ... features, labels = inputs
282 ... with tf.GradientTape() as tape:
283 ... logits = model(features, training=True)
284 ... loss = tf.keras.losses.sparse_categorical_crossentropy(
285 ... labels, logits)
286 ...
287 ... grads = tape.gradient(loss, model.trainable_variables)
288 ... optimizer.apply_gradients(zip(grads, model.trainable_variables))
289 ...
290 ... strategy.run(step_fn, args=(next(iterator),))
292 >>> train_step(iterator)
294 For the advanced use cases like model parallelism, you can set
295 `experimental_device_assignment` argument when creating TPUStrategy to specify
296 number of replicas and number of logical devices. Below is an example to
297 initialize TPU system with 2 logical devices and 1 replica.
299 >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
300 >>> tf.config.experimental_connect_to_cluster(resolver)
301 >>> topology = tf.tpu.experimental.initialize_tpu_system(resolver)
302 >>> device_assignment = tf.tpu.experimental.DeviceAssignment.build(
303 ... topology,
304 ... computation_shape=[1, 1, 1, 2],
305 ... num_replicas=1)
306 >>> strategy = tf.distribute.TPUStrategy(
307 ... resolver, experimental_device_assignment=device_assignment)
309 Then you can run a `tf.add` operation only on logical device 0.
311 >>> @tf.function()
312 ... def step_fn(inputs):
313 ... features, _ = inputs
314 ... output = tf.add(features, features)
315 ...
316 ... # Add operation will be executed on logical device 0.
317 ... output = strategy.experimental_assign_to_logical_device(output, 0)
318 ... return output
319 >>> dist_dataset = strategy.distribute_datasets_from_function(
320 ... dataset_fn)
321 >>> iterator = iter(dist_dataset)
322 >>> strategy.run(step_fn, args=(next(iterator),))
324 `experimental_spmd_xla_partitioning` enables the experimental XLA SPMD feature
325 for model parallelism. This flag can reduce the compilation time and HBM
326 requirements. When running in this mode, every input tensor must either be
327 partitioned (via `strategy.experimental_split_to_logical_devices`) or fully
328 replicated (via `strategy.experimental_replicate_to_logical_devices`) to all
329 logical devices. And calling `strategy.experimental_assign_to_logical_device`
330 will result in a ValueError in this mode.
331 """
333 def __init__(self,
334 tpu_cluster_resolver=None,
335 experimental_device_assignment=None,
336 experimental_spmd_xla_partitioning=False):
337 """Synchronous training in TPU donuts or Pods.
339 Args:
340 tpu_cluster_resolver: A
341 `tf.distribute.cluster_resolver.TPUClusterResolver` instance, which
342 provides information about the TPU cluster. If None, it will assume
343 running on a local TPU worker.
344 experimental_device_assignment: Optional
345 `tf.tpu.experimental.DeviceAssignment` to specify the placement of
346 replicas on the TPU cluster.
347 experimental_spmd_xla_partitioning: If True, enable the SPMD (Single
348 Program Multiple Data) mode in XLA compiler. This flag only affects the
349 performance of XLA compilation and the HBM requirement of the compiled
350 TPU program. Ceveat: if this flag is True, calling
351 `tf.distribute.TPUStrategy.experimental_assign_to_logical_device` will
352 result in a ValueError.
353 """
354 super(TPUStrategyV2, self).__init__(
355 TPUExtended(
356 self,
357 tpu_cluster_resolver,
358 device_assignment=experimental_device_assignment,
359 use_spmd_for_xla_partitioning=experimental_spmd_xla_partitioning,
360 enable_data_reorder=experimental_device_assignment is not None,
361 )
362 )
363 distribute_lib.distribution_strategy_gauge.get_cell("V2").set("TPUStrategy")
364 distribute_lib.distribution_strategy_replica_gauge.get_cell(
365 "num_workers").set(self.extended.num_hosts)
366 distribute_lib.distribution_strategy_replica_gauge.get_cell(
367 "num_replicas_per_worker").set(self.extended.num_replicas_per_host)
368 # Packed variable is used to reduce the overhead of function execution.
369 # For a DistributedVariable, only one variable handle is captured into a
370 # function graph. It's only supported in eager mode.
371 self._enable_packed_variable_in_eager_mode = True
373 def run(self, fn, args=(), kwargs=None, options=None):
374 """Run the computation defined by `fn` on each TPU replica.
376 Executes ops specified by `fn` on each replica. If `args` or `kwargs` have
377 `tf.distribute.DistributedValues`, such as those produced by a
378 `tf.distribute.DistributedDataset` from
379 `tf.distribute.Strategy.experimental_distribute_dataset` or
380 `tf.distribute.Strategy.distribute_datasets_from_function`,
381 when `fn` is executed on a particular replica, it will be executed with the
382 component of `tf.distribute.DistributedValues` that correspond to that
383 replica.
385 `fn` may call `tf.distribute.get_replica_context()` to access members such
386 as `all_reduce`.
388 All arguments in `args` or `kwargs` should either be nest of tensors or
389 `tf.distribute.DistributedValues` containing tensors or composite tensors.
391 Example usage:
393 >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
394 >>> tf.config.experimental_connect_to_cluster(resolver)
395 >>> tf.tpu.experimental.initialize_tpu_system(resolver)
396 >>> strategy = tf.distribute.TPUStrategy(resolver)
397 >>> @tf.function
398 ... def run():
399 ... def value_fn(value_context):
400 ... return value_context.num_replicas_in_sync
401 ... distributed_values = (
402 ... strategy.experimental_distribute_values_from_function(value_fn))
403 ... def replica_fn(input):
404 ... return input * 2
405 ... return strategy.run(replica_fn, args=(distributed_values,))
406 >>> result = run()
408 Args:
409 fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
410 args: (Optional) Positional arguments to `fn`.
411 kwargs: (Optional) Keyword arguments to `fn`.
412 options: (Optional) An instance of `tf.distribute.RunOptions` specifying
413 the options to run `fn`.
415 Returns:
416 Merged return value of `fn` across replicas. The structure of the return
417 value is the same as the return value from `fn`. Each element in the
418 structure can either be `tf.distribute.DistributedValues`, `Tensor`
419 objects, or `Tensor`s (for example, if running on a single replica).
420 """
421 validate_run_function(fn)
423 fn, args, kwargs = _maybe_partial_apply_variables(fn, args, kwargs)
425 # Note: the target function is converted to graph even when in Eager mode,
426 # so autograph is on by default here.
427 fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
428 options = options or distribute_lib.RunOptions()
429 return self.extended.tpu_run(fn, args, kwargs, options)
431 @property
432 def cluster_resolver(self):
433 """Returns the cluster resolver associated with this strategy.
435 `tf.distribute.TPUStrategy` provides the associated
436 `tf.distribute.cluster_resolver.ClusterResolver`. If the user provides one
437 in `__init__`, that instance is returned; if the user does not, a default
438 `tf.distribute.cluster_resolver.TPUClusterResolver` is provided.
439 """
440 return self.extended._tpu_cluster_resolver # pylint: disable=protected-access
442 def experimental_assign_to_logical_device(self, tensor, logical_device_id):
443 """Adds annotation that `tensor` will be assigned to a logical device.
445 This adds an annotation to `tensor` specifying that operations on
446 `tensor` will be invoked on logical core device id `logical_device_id`.
447 When model parallelism is used, the default behavior is that all ops
448 are placed on zero-th logical device.
450 ```python
452 # Initializing TPU system with 2 logical devices and 4 replicas.
453 resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
454 tf.config.experimental_connect_to_cluster(resolver)
455 topology = tf.tpu.experimental.initialize_tpu_system(resolver)
456 device_assignment = tf.tpu.experimental.DeviceAssignment.build(
457 topology,
458 computation_shape=[1, 1, 1, 2],
459 num_replicas=4)
460 strategy = tf.distribute.TPUStrategy(
461 resolver, experimental_device_assignment=device_assignment)
462 iterator = iter(inputs)
464 @tf.function()
465 def step_fn(inputs):
466 output = tf.add(inputs, inputs)
468 # Add operation will be executed on logical device 0.
469 output = strategy.experimental_assign_to_logical_device(output, 0)
470 return output
472 strategy.run(step_fn, args=(next(iterator),))
473 ```
475 Args:
476 tensor: Input tensor to annotate.
477 logical_device_id: Id of the logical core to which the tensor will be
478 assigned.
480 Raises:
481 ValueError: The logical device id presented is not consistent with total
482 number of partitions specified by the device assignment or the TPUStrategy
483 is constructed with `experimental_spmd_xla_partitioning=True`.
485 Returns:
486 Annotated tensor with identical value as `tensor`.
487 """
488 if self.extended._use_spmd_for_xla_partitioning: # pylint: disable=protected-access
489 raise ValueError(
490 "Cannot assign a tensor to a logical device in SPMD mode. To disable "
491 "SPMD, Please construct the TPUStrategy with "
492 "`experimental_spmd_xla_partitioning=False`")
494 num_logical_devices_per_replica = self.extended._tpu_devices.shape[1] # pylint: disable=protected-access
495 if (logical_device_id < 0 or
496 logical_device_id >= num_logical_devices_per_replica):
497 raise ValueError("`logical_core_id` to assign must be lower then total "
498 "number of logical devices per replica. Received "
499 "logical device id {} but there are only total of {} "
500 "logical devices in replica.".format(
501 logical_device_id, num_logical_devices_per_replica))
502 return xla_sharding.assign_device(
503 tensor, logical_device_id, use_sharding_op=True)
505 def experimental_split_to_logical_devices(self, tensor, partition_dimensions):
506 """Adds annotation that `tensor` will be split across logical devices.
508 This adds an annotation to tensor `tensor` specifying that operations on
509 `tensor` will be split among multiple logical devices. Tensor `tensor` will
510 be split across dimensions specified by `partition_dimensions`.
511 The dimensions of `tensor` must be divisible by corresponding value in
512 `partition_dimensions`.
514 For example, for system with 8 logical devices, if `tensor` is an image
515 tensor with shape (batch_size, width, height, channel) and
516 `partition_dimensions` is [1, 2, 4, 1], then `tensor` will be split
517 2 in width dimension and 4 way in height dimension and the split
518 tensor values will be fed into 8 logical devices.
520 ```python
521 # Initializing TPU system with 8 logical devices and 1 replica.
522 resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
523 tf.config.experimental_connect_to_cluster(resolver)
524 topology = tf.tpu.experimental.initialize_tpu_system(resolver)
525 device_assignment = tf.tpu.experimental.DeviceAssignment.build(
526 topology,
527 computation_shape=[1, 2, 2, 2],
528 num_replicas=1)
529 # Construct the TPUStrategy. Since we are going to split the image across
530 # logical devices, here we set `experimental_spmd_xla_partitioning=True`
531 # so that the partitioning can be compiled in SPMD mode, which usually
532 # results in faster compilation and smaller HBM requirement if the size of
533 # input and activation tensors are much bigger than that of the model
534 # parameters. Note that this flag is suggested but not a hard requirement
535 # for `experimental_split_to_logical_devices`.
536 strategy = tf.distribute.TPUStrategy(
537 resolver, experimental_device_assignment=device_assignment,
538 experimental_spmd_xla_partitioning=True)
540 iterator = iter(inputs)
542 @tf.function()
543 def step_fn(inputs):
544 inputs = strategy.experimental_split_to_logical_devices(
545 inputs, [1, 2, 4, 1])
547 # model() function will be executed on 8 logical devices with `inputs`
548 # split 2 * 4 ways.
549 output = model(inputs)
550 return output
552 strategy.run(step_fn, args=(next(iterator),))
553 ```
554 Args:
555 tensor: Input tensor to annotate.
556 partition_dimensions: An unnested list of integers with the size equal to
557 rank of `tensor` specifying how `tensor` will be partitioned. The
558 product of all elements in `partition_dimensions` must be equal to the
559 total number of logical devices per replica.
561 Raises:
562 ValueError: 1) If the size of partition_dimensions does not equal to rank
563 of `tensor` or 2) if product of elements of `partition_dimensions` does
564 not match the number of logical devices per replica defined by the
565 implementing DistributionStrategy's device specification or
566 3) if a known size of `tensor` is not divisible by corresponding
567 value in `partition_dimensions`.
569 Returns:
570 Annotated tensor with identical value as `tensor`.
571 """
572 num_logical_devices_per_replica = self.extended._tpu_devices.shape[1] # pylint: disable=protected-access
573 num_partition_splits = np.prod(partition_dimensions)
574 input_shape = tensor.shape
575 tensor_rank = len(input_shape)
577 if tensor_rank != len(partition_dimensions):
578 raise ValueError("Length of `partition_dimensions` must equal to the "
579 "rank of `tensor.shape` ({}). Received "
580 "len(partition_dimensions)={}.".format(
581 tensor_rank, len(partition_dimensions)))
583 for dim_index, dim_size in enumerate(input_shape):
584 if dim_size is None:
585 continue
587 split_size = partition_dimensions[dim_index]
588 if dim_size % split_size != 0:
589 raise ValueError("Tensor shape at `partition_dimensions[{}]` must be "
590 "divisible by corresponding value specified "
591 "by `partition_dimensions` ({}). Received: {}.".format(
592 dim_index, split_size, dim_size))
594 if num_partition_splits != num_logical_devices_per_replica:
595 raise ValueError(
596 "The product of `partition_dimensions` should be the same as the "
597 "number of logical devices (={}). Received `partition_dimensions`={},"
598 "and their product is {}.".format(num_logical_devices_per_replica,
599 partition_dimensions,
600 num_partition_splits))
602 tile_assignment = np.arange(num_partition_splits).reshape(
603 partition_dimensions)
604 return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True)
606 def experimental_replicate_to_logical_devices(self, tensor):
607 """Adds annotation that `tensor` will be replicated to all logical devices.
609 This adds an annotation to tensor `tensor` specifying that operations on
610 `tensor` will be invoked on all logical devices.
612 ```python
613 # Initializing TPU system with 2 logical devices and 4 replicas.
614 resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
615 tf.config.experimental_connect_to_cluster(resolver)
616 topology = tf.tpu.experimental.initialize_tpu_system(resolver)
617 device_assignment = tf.tpu.experimental.DeviceAssignment.build(
618 topology,
619 computation_shape=[1, 1, 1, 2],
620 num_replicas=4)
621 strategy = tf.distribute.TPUStrategy(
622 resolver, experimental_device_assignment=device_assignment)
624 iterator = iter(inputs)
626 @tf.function()
627 def step_fn(inputs):
628 images, labels = inputs
629 images = strategy.experimental_split_to_logical_devices(
630 inputs, [1, 2, 4, 1])
632 # model() function will be executed on 8 logical devices with `inputs`
633 # split 2 * 4 ways.
634 output = model(inputs)
636 # For loss calculation, all logical devices share the same logits
637 # and labels.
638 labels = strategy.experimental_replicate_to_logical_devices(labels)
639 output = strategy.experimental_replicate_to_logical_devices(output)
640 loss = loss_fn(labels, output)
642 return loss
644 strategy.run(step_fn, args=(next(iterator),))
645 ```
646 Args:
647 tensor: Input tensor to annotate.
649 Returns:
650 Annotated tensor with identical value as `tensor`.
651 """
652 return xla_sharding.replicate(tensor, use_sharding_op=True)
655@tf_export("distribute.experimental.TPUStrategy", v1=[])
656@deprecation.deprecated_endpoints("distribute.experimental.TPUStrategy")
657class TPUStrategy(distribute_lib.Strategy):
658 """Synchronous training on TPUs and TPU Pods.
660 To construct a TPUStrategy object, you need to run the
661 initialization code as below:
663 >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
664 >>> tf.config.experimental_connect_to_cluster(resolver)
665 >>> tf.tpu.experimental.initialize_tpu_system(resolver)
666 >>> strategy = tf.distribute.experimental.TPUStrategy(resolver)
668 While using distribution strategies, the variables created within the
669 strategy's scope will be replicated across all the replicas and can be kept in
670 sync using all-reduce algorithms.
672 To run TF2 programs on TPUs, you can either use `.compile` and
673 `.fit` APIs in `tf.keras` with TPUStrategy, or write your own customized
674 training loop by calling `strategy.run` directly. Note that
675 TPUStrategy doesn't support pure eager execution, so please make sure the
676 function passed into `strategy.run` is a `tf.function` or
677 `strategy.run` is called inside a `tf.function` if eager
678 behavior is enabled.
679 """
681 def __init__(self,
682 tpu_cluster_resolver=None,
683 device_assignment=None):
684 """Synchronous training in TPU donuts or Pods.
686 Args:
687 tpu_cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
688 which provides information about the TPU cluster.
689 device_assignment: Optional `tf.tpu.experimental.DeviceAssignment` to
690 specify the placement of replicas on the TPU cluster.
691 """
692 logging.warning(
693 "`tf.distribute.experimental.TPUStrategy` is deprecated, please use "
694 "the non-experimental symbol `tf.distribute.TPUStrategy` instead.")
696 super(TPUStrategy, self).__init__(
697 TPUExtended(
698 self,
699 tpu_cluster_resolver,
700 device_assignment=device_assignment,
701 enable_data_reorder=device_assignment is not None,
702 )
703 )
704 distribute_lib.distribution_strategy_gauge.get_cell("V2").set("TPUStrategy")
705 distribute_lib.distribution_strategy_replica_gauge.get_cell(
706 "num_workers").set(self.extended.num_hosts)
707 distribute_lib.distribution_strategy_replica_gauge.get_cell(
708 "num_replicas_per_worker").set(self.extended.num_replicas_per_host)
709 # Packed variable is used to reduce the overhead of function execution.
710 # For a DistributedVariable, only one variable handle is captured into a
711 # function graph. It's only supported in eager mode.
712 self._enable_packed_variable_in_eager_mode = True
714 # TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this
715 # can use the default implementation.
716 # This implementation runs a single step. It does not use infeed or outfeed.
717 def run(self, fn, args=(), kwargs=None, options=None):
718 """See base class."""
719 validate_run_function(fn)
721 fn, args, kwargs = _maybe_partial_apply_variables(fn, args, kwargs)
723 # Note: the target function is converted to graph even when in Eager mode,
724 # so autograph is on by default here.
725 fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
726 options = options or distribute_lib.RunOptions()
727 return self.extended.tpu_run(fn, args, kwargs, options)
729 @property
730 def cluster_resolver(self):
731 """Returns the cluster resolver associated with this strategy.
733 `tf.distribute.experimental.TPUStrategy` provides the
734 associated `tf.distribute.cluster_resolver.ClusterResolver`. If the user
735 provides one in `__init__`, that instance is returned; if the user does
736 not, a default
737 `tf.distribute.cluster_resolver.TPUClusterResolver` is provided.
738 """
739 return self.extended._tpu_cluster_resolver # pylint: disable=protected-access
742@tf_export(v1=["distribute.experimental.TPUStrategy"])
743class TPUStrategyV1(distribute_lib.StrategyV1):
744 """TPU distribution strategy implementation."""
746 def __init__(self,
747 tpu_cluster_resolver=None,
748 steps_per_run=None,
749 device_assignment=None):
750 """Initializes the TPUStrategy object.
752 Args:
753 tpu_cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
754 which provides information about the TPU cluster.
755 steps_per_run: Number of steps to run on device before returning to the
756 host. Note that this can have side-effects on performance, hooks,
757 metrics, summaries etc.
758 This parameter is only used when Distribution Strategy is used with
759 estimator or keras.
760 device_assignment: Optional `tf.tpu.experimental.DeviceAssignment` to
761 specify the placement of replicas on the TPU cluster. Currently only
762 supports the usecase of using a single core within a TPU cluster.
763 """
764 super(TPUStrategyV1, self).__init__(TPUExtended(
765 self, tpu_cluster_resolver, steps_per_run, device_assignment))
766 distribute_lib.distribution_strategy_gauge.get_cell("V1").set("TPUStrategy")
767 distribute_lib.distribution_strategy_replica_gauge.get_cell(
768 "num_workers").set(self.extended.num_hosts)
769 distribute_lib.distribution_strategy_replica_gauge.get_cell(
770 "num_replicas_per_worker").set(self.extended.num_replicas_per_host)
771 # Packed variable is used to reduce the overhead of function execution.
772 # For a DistributedVariable, only one variable handle is captured into a
773 # function graph. It's only supported in eager mode.
774 self._enable_packed_variable_in_eager_mode = True
776 @property
777 def steps_per_run(self):
778 """DEPRECATED: use .extended.steps_per_run instead."""
779 return self._extended.steps_per_run
781 # TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this
782 # can use the default implementation.
783 # This implementation runs a single step. It does not use infeed or outfeed.
784 def run(self, fn, args=(), kwargs=None, options=None):
785 """Run `fn` on each replica, with the given arguments.
787 Executes ops specified by `fn` on each replica. If `args` or `kwargs` have
788 "per-replica" values, such as those produced by a "distributed `Dataset`",
789 when `fn` is executed on a particular replica, it will be executed with the
790 component of those "per-replica" values that correspond to that replica.
792 `fn` may call `tf.distribute.get_replica_context()` to access members such
793 as `all_reduce`.
795 All arguments in `args` or `kwargs` should either be nest of tensors or
796 per-replica objects containing tensors or composite tensors.
798 Users can pass strategy specific options to `options` argument. An example
799 to enable bucketizing dynamic shapes in `TPUStrategy.run`
800 is:
802 >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
803 >>> tf.config.experimental_connect_to_cluster(resolver)
804 >>> tf.tpu.experimental.initialize_tpu_system(resolver)
805 >>> strategy = tf.distribute.experimental.TPUStrategy(resolver)
807 >>> options = tf.distribute.RunOptions(
808 ... experimental_bucketizing_dynamic_shape=True)
810 >>> dataset = tf.data.Dataset.range(
811 ... strategy.num_replicas_in_sync, output_type=dtypes.float32).batch(
812 ... strategy.num_replicas_in_sync, drop_remainder=True)
813 >>> input_iterator = iter(strategy.experimental_distribute_dataset(dataset))
815 >>> @tf.function()
816 ... def step_fn(inputs):
817 ... output = tf.reduce_sum(inputs)
818 ... return output
820 >>> strategy.run(step_fn, args=(next(input_iterator),), options=options)
822 Args:
823 fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
824 args: (Optional) Positional arguments to `fn`.
825 kwargs: (Optional) Keyword arguments to `fn`.
826 options: (Optional) An instance of `tf.distribute.RunOptions` specifying
827 the options to run `fn`.
829 Returns:
830 Merged return value of `fn` across replicas. The structure of the return
831 value is the same as the return value from `fn`. Each element in the
832 structure can either be "per-replica" `Tensor` objects or `Tensor`s
833 (for example, if running on a single replica).
834 """
835 validate_run_function(fn)
837 fn, args, kwargs = _maybe_partial_apply_variables(fn, args, kwargs)
839 fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
840 options = options or distribute_lib.RunOptions()
841 return self.extended.tpu_run(fn, args, kwargs, options)
844# TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1.
845class TPUExtended(distribute_lib.StrategyExtendedV1):
846 """Implementation of TPUStrategy."""
848 def __init__(
849 self,
850 container_strategy,
851 tpu_cluster_resolver=None,
852 steps_per_run=None,
853 device_assignment=None,
854 use_spmd_for_xla_partitioning=False,
855 enable_data_reorder=False,
856 ):
857 super(TPUExtended, self).__init__(container_strategy)
859 if tpu_cluster_resolver is None:
860 tpu_cluster_resolver = TPUClusterResolver("")
862 if steps_per_run is None:
863 # TODO(frankchn): Warn when we are being used by DS/Keras and this is
864 # not specified.
865 steps_per_run = 1
867 # `self._tpu_function_cache` is a dict of `tf.function`s, thus if a
868 # `tf.function` is passed into `strategy.run` in eager mode, the
869 # `tf.function` won't get retraced.
870 self._tpu_function_cache = weakref.WeakKeyDictionary()
872 self._tpu_cluster_resolver = tpu_cluster_resolver
873 self._tpu_metadata = self._tpu_cluster_resolver.get_tpu_system_metadata()
874 self._device_assignment = device_assignment
876 tpu_devices_flat = [
877 d.name for d in self._tpu_metadata.devices if "device:TPU:" in d.name]
879 # `self._tpu_devices` is a two-dimensional NumPy array of strings. It is
880 # indexed using `[replica_id][logical_device_id]`.
881 if device_assignment is None:
882 self._tpu_devices = np.array(
883 [[d] for d in tpu_devices_flat], dtype=object)
884 else:
885 job_name = device_spec.DeviceSpecV2.from_string(tpu_devices_flat[0]).job
887 tpu_devices = []
888 for replica_id in range(device_assignment.num_replicas):
889 replica_devices = []
891 for logical_core in range(device_assignment.num_cores_per_replica):
892 replica_devices.append(
893 device_util.canonicalize(
894 device_assignment.tpu_device(
895 replica=replica_id,
896 logical_core=logical_core,
897 job=job_name)))
899 tpu_devices.append(replica_devices)
900 self._tpu_devices = np.array(tpu_devices, dtype=object)
902 self._host_device = device_util.get_host_for_device(self._tpu_devices[0][0])
904 # Preload the data onto the TPUs. Currently we always preload onto logical
905 # device 0 for each replica.
906 # TODO(cjfj): Create `InputWorkers` lazily, allowing users to place the
907 # input onto a different logical device?
908 self._device_input_worker_devices = collections.OrderedDict()
909 self._host_input_worker_devices = collections.OrderedDict()
910 for tpu_device in self._tpu_devices[:, 0]:
911 host_device = device_util.get_host_for_device(tpu_device)
912 self._device_input_worker_devices.setdefault(host_device, [])
913 self._device_input_worker_devices[host_device].append(tpu_device)
914 self._host_input_worker_devices.setdefault(host_device, [])
915 self._host_input_worker_devices[host_device].append(host_device)
917 # Create the replica order based on the assigned device order.
918 # This replica order will be used to match the IteratorGetNext ops
919 # with the device assigment.
920 self._replica_order = (
921 self._get_replica_order(self._tpu_devices[:, 0])
922 if enable_data_reorder
923 else None
924 )
926 # TODO(sourabhbajaj): Remove this once performance of running one step
927 # at a time is comparable to multiple steps.
928 self.steps_per_run = steps_per_run
929 self._require_static_shapes = True
931 self.experimental_enable_get_next_as_optional = True
933 self._logical_device_stack = [0]
935 if context.executing_eagerly():
936 # In async remote eager, we want to sync the executors before exiting the
937 # program.
938 atexit.register(context.async_wait)
940 # Flag to turn on VariablePolicy. Var policy is deprecated because there is
941 # another effort unifying DistributedVariables (see values_v2.py). SPMD XLA
942 # partitioning is not implemented for var policies.
943 # TODO(b/202048882): remove var policy from TPUStrategy.
944 self._use_var_policy = not use_spmd_for_xla_partitioning
946 # Flag to enable XLA SPMD partitioning.
947 self._use_spmd_for_xla_partitioning = use_spmd_for_xla_partitioning
949 def _get_replica_order(self, tpu_devices):
950 """Get the replica order based on the tpu device order.
952 For example, if the tpu_devices are:
953 '/job:worker/replica:0/task:0/device:TPU:0',
954 '/job:worker/replica:0/task:0/device:TPU:2',
955 '/job:worker/replica:0/task:1/device:TPU:0',
956 '/job:worker/replica:0/task:1/device:TPU:2',
957 '/job:worker/replica:0/task:1/device:TPU:6',
958 '/job:worker/replica:0/task:1/device:TPU:4',
959 '/job:worker/replica:0/task:0/device:TPU:6',
960 '/job:worker/replica:0/task:0/device:TPU:4',
962 the returned replica order will be:
963 [0, 1, 7, 6, 2, 3, 5, 4]
965 This replica order will be used to reorder the data returned by the
966 iterators,
967 so that they can be placed on the same node as their computation graphs.
969 Args:
970 tpu_devices (List[str]): A list of tpu device names in the order of
971 replicas.
973 Returns:
974 A list containing the order ids of corresponding TPU devices.
975 """
976 devices_with_ids = []
977 for i, tpu_device in enumerate(tpu_devices):
978 spec = tf_device.DeviceSpec.from_string(tpu_device)
979 devices_with_ids.append((
980 (
981 spec.job,
982 spec.replica,
983 spec.device_type,
984 spec.task,
985 spec.device_index,
986 ),
987 i,
988 ))
989 return [i for _, i in sorted(devices_with_ids)]
991 def _validate_colocate_with_variable(self, colocate_with_variable):
992 distribute_utils.validate_colocate(colocate_with_variable, self)
994 def _make_dataset_iterator(self, dataset):
995 """Make iterators for each of the TPU hosts."""
996 input_workers = input_lib.InputWorkers(
997 tuple(self._device_input_worker_devices.items()))
998 return input_lib_v1.DatasetIterator(
999 dataset,
1000 input_workers,
1001 self._container_strategy(),
1002 num_replicas_in_sync=self._num_replicas_in_sync)
1004 def _make_input_fn_iterator(
1005 self,
1006 input_fn,
1007 replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
1008 input_contexts = []
1009 input_workers = input_lib.InputWorkers(
1010 tuple(self._device_input_worker_devices.items()))
1011 num_workers = input_workers.num_workers
1012 for i in range(num_workers):
1013 input_contexts.append(
1014 distribute_lib.InputContext(
1015 num_input_pipelines=num_workers,
1016 input_pipeline_id=i,
1017 num_replicas_in_sync=self._num_replicas_in_sync))
1018 return input_lib_v1.InputFunctionIterator(input_fn, input_workers,
1019 input_contexts,
1020 self._container_strategy())
1022 def _experimental_make_numpy_dataset(self, numpy_input, session):
1023 return numpy_dataset.one_host_numpy_dataset(
1024 numpy_input, numpy_dataset.SingleDevice(self._host_device),
1025 session)
1027 def _get_input_workers(self, options):
1028 if not options or options.experimental_fetch_to_device:
1029 return input_lib.InputWorkers(
1030 tuple(self._device_input_worker_devices.items()))
1031 else:
1032 return input_lib.InputWorkers(
1033 tuple(self._host_input_worker_devices.items()))
1035 def _check_spec(self, element_spec):
1036 if isinstance(element_spec, values.PerReplicaSpec):
1037 element_spec = element_spec._component_specs # pylint: disable=protected-access
1038 specs = nest.flatten_with_joined_string_paths(element_spec)
1039 for path, spec in specs:
1040 if isinstance(spec, (sparse_tensor.SparseTensorSpec,
1041 ragged_tensor.RaggedTensorSpec)):
1042 raise ValueError(
1043 "Found tensor {} with spec {}. TPUStrategy does not support "
1044 "distributed datasets with device prefetch when using sparse or "
1045 "ragged tensors. If you intend to use sparse or ragged tensors, "
1046 "please pass a tf.distribute.InputOptions object with "
1047 "experimental_fetch_to_device set to False to your dataset "
1048 "distribution function.".format(path, type(spec)))
1050 def _experimental_distribute_dataset(self, dataset, options):
1051 if (options and options.experimental_replication_mode ==
1052 distribute_lib.InputReplicationMode.PER_REPLICA):
1053 raise NotImplementedError(
1054 "InputReplicationMode.PER_REPLICA "
1055 "is only supported in "
1056 "`experimental_distribute_datasets_from_function`."
1057 )
1058 if options is None or options.experimental_fetch_to_device:
1059 self._check_spec(dataset.element_spec)
1061 return input_util.get_distributed_dataset(
1062 dataset,
1063 self._get_input_workers(options),
1064 self._container_strategy(),
1065 num_replicas_in_sync=self._num_replicas_in_sync,
1066 options=options,
1067 replica_order=self._replica_order,
1068 )
1070 def _distribute_datasets_from_function(self, dataset_fn, options):
1071 if (options and options.experimental_replication_mode ==
1072 distribute_lib.InputReplicationMode.PER_REPLICA):
1073 raise NotImplementedError(
1074 "InputReplicationMode.PER_REPLICA "
1075 "is only supported in "
1076 " `experimental_distribute_datasets_from_function` "
1077 "of tf.distribute.MirroredStrategy")
1078 input_workers = self._get_input_workers(options)
1079 input_contexts = []
1080 num_workers = input_workers.num_workers
1081 for i in range(num_workers):
1082 input_contexts.append(distribute_lib.InputContext(
1083 num_input_pipelines=num_workers,
1084 input_pipeline_id=i,
1085 num_replicas_in_sync=self._num_replicas_in_sync))
1087 distributed_dataset = input_util.get_distributed_datasets_from_function(
1088 dataset_fn,
1089 input_workers,
1090 input_contexts,
1091 self._container_strategy(),
1092 options=options,
1093 replica_order=self._replica_order,
1094 )
1096 # We can only check after the dataset_fn is called.
1097 if options is None or options.experimental_fetch_to_device:
1098 self._check_spec(distributed_dataset.element_spec)
1099 return distributed_dataset
1101 def _experimental_distribute_values_from_function(self, value_fn):
1102 per_replica_values = []
1103 for replica_id in range(self._num_replicas_in_sync):
1104 per_replica_values.append(
1105 value_fn(distribute_lib.ValueContext(replica_id,
1106 self._num_replicas_in_sync)))
1107 return distribute_utils.regroup(per_replica_values, always_wrap=True)
1109 # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
1110 # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have
1111 # a mechanism to infer the outputs of `fn`. Pending b/110550782.
1112 def _experimental_run_steps_on_iterator(
1113 self, fn, multi_worker_iterator, iterations, initial_loop_values=None):
1114 # Wrap `fn` for repeat.
1115 if initial_loop_values is None:
1116 initial_loop_values = {}
1117 initial_loop_values = nest.flatten(initial_loop_values)
1118 ctx = input_lib.MultiStepContext()
1120 def run_fn(inputs):
1121 """Single step on the TPU device."""
1122 fn_result = fn(ctx, inputs)
1123 flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
1124 if flat_last_step_outputs:
1125 with ops.control_dependencies([fn_result]):
1126 return [array_ops.identity(f) for f in flat_last_step_outputs]
1127 else:
1128 return fn_result
1130 # We capture the control_flow_context at this point, before we run `fn`
1131 # inside a while_loop and TPU replicate context. This is useful in cases
1132 # where we might need to exit these contexts and get back to the outer
1133 # context to do some things, for e.g. create an op which should be
1134 # evaluated only once at the end of the loop on the host. One such usage
1135 # is in creating metrics' value op.
1136 self._outer_control_flow_context = (
1137 ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access
1139 def rewrite_fn(*args):
1140 """The rewritten step fn running on TPU."""
1141 del args
1143 per_replica_inputs = multi_worker_iterator.get_next()
1144 replicate_inputs = []
1145 for replica_id in range(self._num_replicas_in_sync):
1146 select_replica = lambda x: distribute_utils.select_replica( # pylint: disable=g-long-lambda
1147 replica_id, x) # pylint: disable=cell-var-from-loop
1148 replicate_inputs.append((nest.map_structure(
1149 select_replica, per_replica_inputs),))
1151 replicate_outputs = tpu.replicate(
1152 run_fn,
1153 replicate_inputs,
1154 device_assignment=self._device_assignment,
1155 xla_options=tpu.XLAOptions(use_spmd_for_xla_partitioning=self
1156 ._use_spmd_for_xla_partitioning))
1157 # If run_fn has tensor outputs, tpu.replicate returns a list of list. We
1158 # will flatten it in this case. If run_fn has no tensor outputs,
1159 # tpu.replicate returns a list of no_ops, we will keep the output as it
1160 # is.
1161 if isinstance(replicate_outputs[0], list):
1162 replicate_outputs = nest.flatten(replicate_outputs)
1164 return replicate_outputs
1166 # TODO(sourabhbajaj): The input to while loop should be based on the
1167 # output type of the step_fn
1168 assert isinstance(initial_loop_values, list)
1169 initial_loop_values = initial_loop_values * self._num_replicas_in_sync
1171 # Put the while loop op on TPU host 0.
1172 with ops.device(self._host_device):
1173 if self.steps_per_run == 1:
1174 replicate_outputs = rewrite_fn()
1175 else:
1176 replicate_outputs = training_loop.repeat(iterations, rewrite_fn,
1177 initial_loop_values)
1179 del self._outer_control_flow_context
1180 ctx.run_op = control_flow_ops.group(replicate_outputs)
1182 if isinstance(replicate_outputs, list):
1183 # Filter out any ops from the outputs, typically this would be the case
1184 # when there were no tensor outputs.
1185 last_step_tensor_outputs = [
1186 x for x in replicate_outputs if not isinstance(x, ops.Operation)
1187 ]
1189 # Outputs are currently of the structure (flattened)
1190 # [output0_device0, output1_device0, output2_device0,
1191 # output0_device1, output1_device1, output2_device1,
1192 # ...]
1193 # Convert this to the following structure instead: (grouped by output)
1194 # [[output0_device0, output0_device1],
1195 # [output1_device0, output1_device1],
1196 # [output2_device0, output2_device1]]
1197 output_num = len(last_step_tensor_outputs) // self._num_replicas_in_sync
1198 last_step_tensor_outputs = [
1199 last_step_tensor_outputs[i::output_num] for i in range(output_num)
1200 ]
1201 else:
1202 # no tensors returned.
1203 last_step_tensor_outputs = []
1205 _set_last_step_outputs(ctx, last_step_tensor_outputs)
1206 return ctx
1208 def _call_for_each_replica(self, fn, args, kwargs):
1209 # TODO(jhseu): Consider making it so call_for_each_replica implies that
1210 # we're in a tpu.rewrite(), and update TPUMirroredVariable accordingly.
1211 with _TPUReplicaContext(self._container_strategy()):
1212 return fn(*args, **kwargs)
1214 @contextlib.contextmanager
1215 def experimental_logical_device(self, logical_device_id):
1216 """Places variables and ops on the specified logical device."""
1217 num_logical_devices_per_replica = self._tpu_devices.shape[1]
1218 if logical_device_id >= num_logical_devices_per_replica:
1219 raise ValueError(
1220 "`logical_device_id` not in range (was {}, but there are only {} "
1221 "logical devices per replica).".format(
1222 logical_device_id, num_logical_devices_per_replica))
1224 self._logical_device_stack.append(logical_device_id)
1225 try:
1226 if tpu_util.enclosing_tpu_context() is None:
1227 yield
1228 else:
1229 with ops.device(tpu.core(logical_device_id)):
1230 yield
1231 finally:
1232 self._logical_device_stack.pop()
1234 def _experimental_initialize_system(self):
1235 """Experimental method added to be used by Estimator.
1237 This is a private method only to be used by Estimator. Other frameworks
1238 should directly be calling `tf.tpu.experimental.initialize_tpu_system`
1239 """
1240 tpu_strategy_util.initialize_tpu_system(self._tpu_cluster_resolver)
1242 def _create_variable(self, next_creator, **kwargs):
1243 """Create a TPUMirroredVariable. See `DistributionStrategy.scope`."""
1244 if kwargs.pop("skip_mirrored_creator", False):
1245 return next_creator(**kwargs)
1247 colocate_with = kwargs.pop("colocate_with", None)
1248 if colocate_with is None:
1249 devices = self._tpu_devices[:, self._logical_device_stack[-1]]
1250 elif isinstance(colocate_with, numpy_dataset.SingleDevice):
1251 with ops.device(colocate_with.device):
1252 return next_creator(**kwargs)
1253 else:
1254 devices = colocate_with._devices # pylint: disable=protected-access
1256 num_replicas, num_cores_per_replica = self._tpu_devices.shape
1258 def _create_mirrored_tpu_variables(**kwargs):
1259 """Returns a list of `tf.Variable`s.
1261 The list contains `number_replicas` `tf.Variable`s and can be used to
1262 initialize a `TPUMirroredVariable`.
1264 Args:
1265 **kwargs: the keyword arguments for creating a variable
1266 """
1267 initial_value = None
1268 value_list = []
1269 for i, d in enumerate(devices):
1270 with ops.device(d):
1271 if i == 0:
1272 initial_value = kwargs["initial_value"]
1273 # Note: some v1 code expects variable initializer creation to happen
1274 # inside a init_scope.
1275 with maybe_init_scope():
1276 initial_value = initial_value() if callable(
1277 initial_value) else initial_value
1279 if i > 0:
1280 # Give replicas meaningful distinct names:
1281 var0name = value_list[0].name.split(":")[0]
1282 # We append a / to variable names created on replicas with id > 0 to
1283 # ensure that we ignore the name scope and instead use the given
1284 # name as the absolute name of the variable.
1285 kwargs["name"] = "%s/replica_%d/" % (var0name, i)
1286 kwargs["initial_value"] = initial_value
1288 with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
1289 v = next_creator(**kwargs)
1291 assert not isinstance(v, tpu_values.TPUMirroredVariable)
1292 value_list.append(v)
1293 return value_list
1295 def _create_mirrored_tpu_replicated_variables(**kwargs):
1296 """Returns a list of `TPUReplicatedVariable`s.
1298 The list consists of `num_replicas` `TPUReplicatedVariable`s and can be
1299 used to initialize a `TPUMirroredVariable`. Each `TPUReplicatedVariable`
1300 contains a list of `tf.Variable`s which are replicated to
1301 `num_cores_per_replica` logical cores to enable XLA SPMD compilation.
1303 Args:
1304 **kwargs: the keyword arguments for creating a variable
1305 """
1306 initial_value = kwargs["initial_value"]
1307 # Note: some v1 code expects variable initializer creation to happen
1308 # inside a init_scope.
1309 with maybe_init_scope():
1310 initial_value = initial_value() if callable(
1311 initial_value) else initial_value
1313 mirrored_replicated_var_list = []
1315 for replica_id in range(num_replicas):
1316 replicated_var_list = []
1317 for logic_core_id in range(num_cores_per_replica):
1318 with ops.device(self._tpu_devices[replica_id][logic_core_id]):
1319 kwargs["initial_value"] = initial_value
1320 v = next_creator(**kwargs)
1321 replicated_var_list.append(v)
1322 replica_name = "{}/r:{}".format(kwargs["name"], replica_id)
1323 tpu_replicated_var = tpu_replicated_variable.TPUReplicatedVariable(
1324 variables=replicated_var_list, name=replica_name)
1326 mirrored_replicated_var_list.append(tpu_replicated_var)
1327 return mirrored_replicated_var_list
1329 if self._use_spmd_for_xla_partitioning and num_cores_per_replica > 1:
1330 real_creator = _create_mirrored_tpu_replicated_variables
1331 else:
1332 real_creator = _create_mirrored_tpu_variables
1334 return distribute_utils.create_mirrored_variable(
1335 self._container_strategy(), real_creator,
1336 distribute_utils.TPU_VARIABLE_CLASS_MAPPING,
1337 distribute_utils.TPU_VARIABLE_POLICY_MAPPING, **kwargs)
1339 def _resource_creator_scope(self):
1341 def lookup_creator(next_creator, *args, **kwargs):
1342 host_to_table = collections.OrderedDict()
1343 for host_device in self._device_input_worker_devices.keys():
1344 with ops.device(host_device):
1345 host_to_table[host_device] = next_creator(*args, **kwargs)
1347 return values.PerWorkerResource(self._container_strategy(), host_to_table)
1349 # TODO(b/194362531): Define creator(s) for other resources.
1350 return ops.resource_creator_scope("StaticHashTable", lookup_creator)
1352 def _gather_to_implementation(self, value, destinations, axis, options):
1353 if not isinstance(value, values.DistributedValues):
1354 return value
1356 value_list = list(value.values)
1357 # pylint: disable=protected-access
1358 if isinstance(
1359 value,
1360 values.DistributedVariable) and value._packed_variable is not None:
1361 value_list = list(
1362 value._packed_variable.on_device(d)
1363 for d in value._packed_variable.devices)
1364 # pylint: enable=protected-access
1366 # Currently XLA op by op mode has a limit for the number of inputs for a
1367 # single op, thus we break one `add_n` op into a group of `add_n` ops to
1368 # work around the constraint.
1369 if len(value.values) <= _XLA_OP_BY_OP_INPUTS_LIMIT:
1370 output = array_ops.concat(value_list, axis=axis)
1371 else:
1372 output = array_ops.concat(
1373 value_list[:_XLA_OP_BY_OP_INPUTS_LIMIT], axis=axis)
1374 for i in range(_XLA_OP_BY_OP_INPUTS_LIMIT, len(value_list),
1375 _XLA_OP_BY_OP_INPUTS_LIMIT - 1):
1376 output = array_ops.concat(
1377 [output] + value_list[i:i + _XLA_OP_BY_OP_INPUTS_LIMIT - 1],
1378 axis=axis)
1380 output = self._broadcast_output(destinations, output)
1381 return output
1383 def _broadcast_output(self, destinations, output):
1384 devices = cross_device_ops_lib.get_devices_from(destinations)
1386 if len(devices) == 1:
1387 # If necessary, copy to requested destination.
1388 dest_canonical = device_util.canonicalize(devices[0])
1389 host_canonical = device_util.canonicalize(self._host_device)
1391 if dest_canonical != host_canonical:
1392 with ops.device(dest_canonical):
1393 output = array_ops.identity(output)
1394 else:
1395 output = cross_device_ops_lib.simple_broadcast(output, destinations)
1397 return output
1399 def _reduce_to(self, reduce_op, value, destinations, options):
1400 if (isinstance(value, values.DistributedValues) or
1401 tensor_util.is_tf_type(value)
1402 ) and tpu_util.enclosing_tpu_context() is not None:
1403 if reduce_op == reduce_util.ReduceOp.MEAN:
1404 # TODO(jhseu): Revisit once we support model-parallelism.
1405 # scalar_mul maintains the type of value: tensor or IndexedSlices.
1406 value = math_ops.scalar_mul((1./self._num_replicas_in_sync), value)
1407 elif reduce_op != reduce_util.ReduceOp.SUM:
1408 raise NotImplementedError(
1409 f"`reduce_op`={reduce_op} is not supported. Currently we only "
1410 "support ReduceOp.SUM and ReduceOp.MEAN in TPUStrategy.")
1411 return tpu_ops.cross_replica_sum(value)
1413 if not isinstance(value, values.DistributedValues):
1414 # This function handles reducing values that are not PerReplica or
1415 # Mirrored values. For example, the same value could be present on all
1416 # replicas in which case `value` would be a single value or value could
1417 # be 0.
1418 return cross_device_ops_lib.reduce_non_distributed_value(
1419 reduce_op, value, destinations, self._num_replicas_in_sync)
1421 value_list = value.values
1422 # pylint: disable=protected-access
1423 if isinstance(
1424 value,
1425 values.DistributedVariable) and value._packed_variable is not None:
1426 value_list = tuple(
1427 value._packed_variable.on_device(d)
1428 for d in value._packed_variable.devices)
1429 # pylint: enable=protected-access
1431 # Currently XLA op by op mode has a limit for the number of inputs for a
1432 # single op, thus we break one `add_n` op into a group of `add_n` ops to
1433 # work around the constraint.
1434 # TODO(cjfj): Detect when it is possible to use `cross_replica_sum`.
1435 if len(value.values) <= _XLA_OP_BY_OP_INPUTS_LIMIT:
1436 output = math_ops.add_n(value_list)
1437 else:
1438 output = array_ops.zeros_like(value_list[0], dtype=value_list[0].dtype)
1439 for i in range(0, len(value_list), _XLA_OP_BY_OP_INPUTS_LIMIT):
1440 output += math_ops.add_n(value_list[i:i + _XLA_OP_BY_OP_INPUTS_LIMIT])
1442 if reduce_op == reduce_util.ReduceOp.MEAN:
1443 output *= (1. / len(value_list))
1445 output = self._broadcast_output(destinations, output)
1446 return output
1448 def _update(self, var, fn, args, kwargs, group):
1449 assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance(
1450 var, resource_variable_ops.BaseResourceVariable)
1451 if tpu_util.enclosing_tpu_context() is not None:
1452 if group:
1453 return fn(var, *args, **kwargs)
1454 else:
1455 return (fn(var, *args, **kwargs),)
1457 # Inside `tf.function`, we don't expand PackedVariable in python as it will
1458 # be expanded later during function instantiation in the runtime.
1459 packed_var = var._packed_variable # pylint: disable=protected-access
1460 if packed_var is not None and not context.executing_eagerly():
1461 if group:
1462 return fn(packed_var, *args, **kwargs)
1463 else:
1464 return (fn(packed_var, *args, **kwargs),)
1466 # Otherwise, we revert to MirroredStrategy behavior and update the variable
1467 # on each replica directly.
1468 updates = []
1469 values_and_devices = []
1470 if packed_var is not None:
1471 for device in packed_var.devices:
1472 values_and_devices.append((packed_var, device))
1473 else:
1474 for value in var.values:
1475 values_and_devices.append((value, value.device))
1477 if (var.synchronization != variables_lib.VariableSynchronization.ON_READ and
1478 var.aggregation != variables_lib.VariableAggregation.NONE):
1479 distribute_utils.assert_mirrored(args)
1480 distribute_utils.assert_mirrored(kwargs)
1481 for i, value_and_device in enumerate(values_and_devices):
1482 value = value_and_device[0]
1483 device = value_and_device[1]
1484 name = "update_%d" % i
1485 with ops.device(device), \
1486 distribute_lib.UpdateContext(i), \
1487 ops.name_scope(name):
1488 # If args and kwargs are not mirrored, the value is returned as is.
1489 updates.append(
1490 fn(value, *distribute_utils.select_replica(i, args),
1491 **distribute_utils.select_replica(i, kwargs)))
1492 return distribute_utils.update_regroup(self, updates, group)
1494 def read_var(self, var):
1495 assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance(
1496 var, resource_variable_ops.BaseResourceVariable)
1497 return var.read_value()
1499 def value_container(self, value):
1500 return value
1502 def _broadcast_to(self, tensor, destinations):
1503 del destinations
1504 # This is both a fast path for Python constants, and a way to delay
1505 # converting Python values to a tensor until we know what type it
1506 # should be converted to. Otherwise we have trouble with:
1507 # global_step.assign_add(1)
1508 # since the `1` gets broadcast as an int32 but global_step is int64.
1509 if isinstance(tensor, (float, int)):
1510 return tensor
1511 if tpu_util.enclosing_tpu_context() is not None:
1512 broadcast_tensor = [tensor for _ in range(self._num_replicas_in_sync)]
1513 result = tpu_ops.all_to_all(
1514 broadcast_tensor,
1515 concat_dimension=0,
1516 split_dimension=0,
1517 split_count=self._num_replicas_in_sync)
1519 # This uses the broadcasted value from the first replica because the only
1520 # caller of this is for ONLY_FIRST_REPLICA variables aggregation.
1521 return result[0]
1522 return tensor
1524 @property
1525 def num_hosts(self):
1526 if self._device_assignment is None:
1527 return self._tpu_metadata.num_hosts
1529 return len(set([self._device_assignment.host_device(r)
1530 for r in range(self._device_assignment.num_replicas)]))
1532 @property
1533 def num_replicas_per_host(self):
1534 if self._device_assignment is None:
1535 return self._tpu_metadata.num_of_cores_per_host
1537 # TODO(sourabhbajaj): Remove this method we use inputs and remove infeed
1538 # as the computation of num_replicas_per_host is not a constant
1539 # when using device_assignment. This is a temporary workaround to support
1540 # StatefulRNN as everything is 1 in that case.
1541 # This method needs to take host_id as input for correct computation.
1542 max_models_per_host = (self._tpu_metadata.num_of_cores_per_host //
1543 self._device_assignment.num_cores_per_replica)
1544 return min(self._device_assignment.num_replicas, max_models_per_host)
1546 @property
1547 def _num_replicas_in_sync(self):
1548 if self._device_assignment is None:
1549 return self._tpu_metadata.num_cores
1550 return self._device_assignment.num_replicas
1552 @property
1553 def experimental_between_graph(self):
1554 return False
1556 @property
1557 def experimental_should_init(self):
1558 return True
1560 @property
1561 def should_checkpoint(self):
1562 return True
1564 @property
1565 def should_save_summary(self):
1566 return True
1568 @property
1569 def worker_devices(self):
1570 return tuple(self._tpu_devices[:, self._logical_device_stack[-1]])
1572 @property
1573 def parameter_devices(self):
1574 return self.worker_devices
1576 @property
1577 def tpu_hardware_feature(self):
1578 """Return the `tf.tpu.experimental.HardwareFeature` class."""
1579 return tpu_hardware_feature.HardwareFeature(
1580 self._tpu_cluster_resolver.tpu_hardware_feature)
1582 def non_slot_devices(self, var_list):
1583 return self._host_device
1585 def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
1586 del colocate_with
1587 with ops.device(self._host_device), distribute_lib.UpdateContext(None):
1588 result = fn(*args, **kwargs)
1589 if group:
1590 return result
1591 else:
1592 return nest.map_structure(self._local_results, result)
1594 def _configure(self,
1595 session_config=None,
1596 cluster_spec=None,
1597 task_type=None,
1598 task_id=None):
1599 del cluster_spec, task_type, task_id
1600 if session_config:
1601 session_config.CopyFrom(self._update_config_proto(session_config))
1603 def _update_config_proto(self, config_proto):
1604 updated_config = copy.deepcopy(config_proto)
1605 updated_config.isolate_session_state = True
1606 cluster_spec = self._tpu_cluster_resolver.cluster_spec()
1607 if cluster_spec:
1608 updated_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
1609 return updated_config
1611 # TODO(priyag): Delete this once all strategies use global batch size.
1612 @property
1613 def _global_batch_size(self):
1614 """`make_dataset_iterator` and `make_numpy_iterator` use global batch size.
1616 `make_input_fn_iterator` assumes per-replica batching.
1618 Returns:
1619 Boolean.
1620 """
1621 return True
1623 def tpu_run(self, fn, args, kwargs, options=None):
1624 func = self._tpu_function_creator(fn, options)
1625 return func(args, kwargs)
1627 def _tpu_function_creator(self, fn, options):
1628 if context.executing_eagerly() and fn in self._tpu_function_cache:
1629 return self._tpu_function_cache[fn]
1631 strategy = self._container_strategy()
1633 def tpu_function(args, kwargs):
1634 """TF Function used to replicate the user computation."""
1635 logging.vlog(1,
1636 "`TPUStrategy.run` is called with [args: %s] [kwargs: %s]",
1637 args, kwargs)
1639 if kwargs is None:
1640 kwargs = {}
1642 # Used to re-structure flattened output tensors from `tpu.replicate()`
1643 # into a structured format.
1644 result = [[]]
1646 def replicated_fn(replica_id, replica_args, replica_kwargs):
1647 """Wraps user function to provide replica ID and `Tensor` inputs."""
1648 with _TPUReplicaContext(strategy, replica_id_in_sync_group=replica_id):
1649 result[0] = fn(*replica_args, **replica_kwargs)
1650 return result[0]
1652 replicate_inputs = [] # By replica.
1653 for i in range(strategy.num_replicas_in_sync):
1654 replicate_inputs.append(
1655 [constant_op.constant(i, dtype=dtypes.int32),
1656 distribute_utils.select_replica(i, args),
1657 distribute_utils.select_replica(i, kwargs)])
1659 # Construct and pass `maximum_shapes` so that we could support dynamic
1660 # shapes using dynamic padder.
1661 if options.experimental_enable_dynamic_batch_size and replicate_inputs:
1662 maximum_shapes = []
1663 flattened_list = nest.flatten(replicate_inputs[0])
1664 for input_tensor in flattened_list:
1665 if tensor_util.is_tf_type(input_tensor):
1666 rank = input_tensor.shape.rank
1667 else:
1668 rank = np.ndim(input_tensor)
1669 if rank is None:
1670 raise ValueError(
1671 "input tensor {} to TPUStrategy.run() has unknown rank, "
1672 "which is not allowed".format(input_tensor))
1673 maximum_shape = tensor_shape.TensorShape([None] * rank)
1674 maximum_shapes.append(maximum_shape)
1675 maximum_shapes = nest.pack_sequence_as(replicate_inputs[0],
1676 maximum_shapes)
1677 else:
1678 maximum_shapes = None
1680 if options.experimental_bucketizing_dynamic_shape:
1681 padding_spec = tpu.PaddingSpec.POWER_OF_TWO
1682 else:
1683 padding_spec = None
1685 with strategy.scope():
1686 xla_options = options.experimental_xla_options or tpu.XLAOptions(
1687 use_spmd_for_xla_partitioning=self._use_spmd_for_xla_partitioning)
1688 replicate_outputs = tpu.replicate(
1689 replicated_fn,
1690 replicate_inputs,
1691 device_assignment=self._device_assignment,
1692 maximum_shapes=maximum_shapes,
1693 padding_spec=padding_spec,
1694 xla_options=xla_options)
1696 # Remove all no ops that may have been added during 'tpu.replicate()'
1697 filter_ops = lambda x: [o for o in x if not isinstance(o, ops.Operation)]
1698 if isinstance(result[0], list):
1699 result[0] = filter_ops(result[0])
1701 # Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
1702 if result[0] is None or isinstance(result[0], ops.Operation):
1703 replicate_outputs = [None] * len(replicate_outputs)
1704 else:
1705 replicate_outputs = [
1706 nest.pack_sequence_as(result[0], filter_ops(nest.flatten(output)))
1707 for output in replicate_outputs
1708 ]
1709 return distribute_utils.regroup(replicate_outputs)
1711 if context.executing_eagerly():
1712 tpu_function = def_function.function(tpu_function)
1713 self._tpu_function_cache[fn] = tpu_function
1714 return tpu_function
1716 def _in_multi_worker_mode(self):
1717 """Whether this strategy indicates working in multi-worker settings."""
1718 # TPUStrategy has different distributed training structure that the whole
1719 # cluster should be treated as single worker from higher-level (e.g. Keras)
1720 # library's point of view.
1721 # TODO(rchao): Revisit this as we design a fault-tolerance solution for
1722 # TPUStrategy.
1723 return False
1725 def _get_local_replica_id(self, replica_id_in_sync_group):
1726 return replica_id_in_sync_group
1729def _make_axis_nonnegative(axis, rank):
1730 # Convert a potentially negative `axis` to a non-negative one.
1731 if isinstance(axis, int):
1732 if axis >= 0:
1733 return axis
1734 else:
1735 return axis + rank
1736 else:
1737 return array_ops.where_v2(
1738 math_ops.greater_equal(axis, 0),
1739 axis,
1740 axis + rank)
1743# List of Tensor dtypes supported by cross_replica_sum().
1744_DTYPES_SUPPORTED_BY_CROSS_REPLICA_SUM = (
1745 dtypes.bfloat16,
1746 dtypes.float16,
1747 dtypes.float32,
1748 dtypes.float64,
1749 dtypes.int32,
1750 dtypes.uint32,
1751)
1754class _TPUReplicaContext(distribute_lib.ReplicaContext):
1755 """Replication Context class for TPU Strategy."""
1757 # TODO(sourabhbajaj): Call for each replica should be updating this.
1758 # TODO(b/118385803): Always properly initialize replica_id.
1759 def __init__(self, strategy, replica_id_in_sync_group=0):
1760 distribute_lib.ReplicaContext.__init__(
1761 self, strategy, replica_id_in_sync_group=replica_id_in_sync_group)
1763 @property
1764 def devices(self):
1765 distribute_lib.require_replica_context(self)
1766 ds = self._strategy
1767 replica_id = tensor_util.constant_value(self.replica_id_in_sync_group)
1769 if replica_id is None: # Non-constant `Tensor` inside `tpu.replicate`.
1770 # TODO(cjfj): Return other devices when model parallelism is supported.
1771 return (tpu.core(0),)
1772 else:
1773 return (ds.extended.worker_devices[replica_id],)
1775 def experimental_logical_device(self, logical_device_id):
1776 """Places variables and ops on the specified logical device."""
1777 return self.strategy.extended.experimental_logical_device(logical_device_id)
1779 def _compute_all_gather_output_shape(self, value_shape, value_rank, axis):
1780 if isinstance(value_rank, int):
1781 output_shape = list(value_shape)
1782 output_shape[axis] *= self.num_replicas_in_sync
1783 else:
1784 output_shape = array_ops.where_v2(
1785 math_ops.equal(math_ops.range(value_rank), axis),
1786 value_shape * context.num_replicas_in_sync,
1787 value_shape)
1788 return output_shape
1790 def all_gather(self, value, axis, experimental_hints=None):
1791 del experimental_hints
1792 for v in nest.flatten(value):
1793 if isinstance(v, indexed_slices.IndexedSlices):
1794 raise NotImplementedError("all_gather does not support IndexedSlices")
1796 def _all_gather_tensor(value, axis):
1797 value = ops.convert_to_tensor(value)
1799 # Compute the shape and rank and rank of the input tensor. Use static
1800 # shapes when possible to help with shape inference in graph mode, but
1801 # fall back on dynamic shapes when necessary.
1802 if value.shape.rank is None:
1803 value_rank = array_ops.rank(value)
1804 value_shape = array_ops.shape(value)
1805 else:
1806 value_rank = value.shape.rank
1807 value_shape = value.shape.as_list()
1808 value_shape_tensor = array_ops.shape(value)
1809 for i in range(len(value_shape)):
1810 if value_shape[i] is None:
1811 value_shape[i] = value_shape_tensor[i]
1813 # In the code below, we will insert a new "replica" dimension immediately
1814 # *before* `axis`. To ensure that it's inserted before and not after, we
1815 # must make `axis` non-negative.
1816 axis = _make_axis_nonnegative(axis, value_rank)
1818 # Create a list or 1D int Tensor such as
1819 # [1, 1, ..., 1, num_replicas_in_sync, 1, ..., 1],
1820 # which is equal to `num_replicas_in_sync` at index `axis`
1821 # and is equal to 1 everywhere else.
1822 if isinstance(value_rank, int):
1823 replica_broadcast_shape = [1] * (value_rank + 1)
1824 replica_broadcast_shape[axis] = self.num_replicas_in_sync
1825 else:
1826 replica_broadcast_shape = array_ops.where_v2(
1827 math_ops.equal(math_ops.range(value_rank+1), axis),
1828 self.num_replicas_in_sync,
1829 1)
1831 output_shape = self._compute_all_gather_output_shape(
1832 value_shape, value_rank, axis)
1834 if value.dtype in _DTYPES_SUPPORTED_BY_CROSS_REPLICA_SUM:
1835 # optimized all_gather implementation based on cross_replica_sum().
1836 replica_id_mask = array_ops.one_hot(
1837 self.replica_id_in_sync_group, self.num_replicas_in_sync)
1838 replica_id_mask = array_ops.reshape(
1839 replica_id_mask, replica_broadcast_shape)
1840 replica_id_mask = math_ops.cast(replica_id_mask, value.dtype)
1842 gathered_value = array_ops.expand_dims(value, axis) * replica_id_mask
1843 gathered_value = self.all_reduce(
1844 reduce_util.ReduceOp.SUM, gathered_value)
1845 return array_ops.reshape(gathered_value, output_shape)
1846 else:
1847 # value.dtype isn't supported by cross_replica_sum(), so we fall back
1848 # on a less efficient implementation based on all_to_all().
1850 # The underlying AllToAllOp first do a split of the input value and then
1851 # cross-replica communication and concatenation of the result. So we
1852 # concatenate the local tensor here first.
1853 inputs = array_ops.expand_dims(value, axis=axis)
1854 inputs = array_ops.tile(inputs, replica_broadcast_shape)
1855 unordered_output = tpu_ops.all_to_all(
1856 inputs,
1857 concat_dimension=axis,
1858 split_dimension=axis,
1859 split_count=self.num_replicas_in_sync)
1861 # Re-order since xla.replica_id and ReplicaContext.replica_id mismatch.
1862 # Start by computing a permutation -- a 1D Tensor which maps
1863 # tensor[xla.replica_id] = ReplicaContext.replica_id
1864 concat_replica_id = array_ops.reshape(
1865 self.replica_id_in_sync_group, [1])
1866 concat_replica_id = array_ops.tile(
1867 concat_replica_id, [self.num_replicas_in_sync])
1868 xla_to_replica_context_id = tpu_ops.all_to_all(
1869 concat_replica_id,
1870 concat_dimension=0,
1871 split_dimension=0,
1872 split_count=self.num_replicas_in_sync)
1874 # Now invert the mapping to get
1875 # tensor[ReplicaContext.replica_id] = xla.replica_id
1876 replica_context_to_xla_id = math_ops.argmax(
1877 array_ops.one_hot(xla_to_replica_context_id,
1878 self.num_replicas_in_sync),
1879 axis=0)
1881 # Reorder the output elements so that they're sorted based on
1882 # ReplicaContext.replica_id instead of xla.replica_id.
1883 sorted_with_extra_dim = array_ops.gather(
1884 unordered_output, replica_context_to_xla_id, axis=axis)
1885 return array_ops.reshape(sorted_with_extra_dim, output_shape)
1887 ys = [_all_gather_tensor(t, axis=axis) for t in nest.flatten(value)]
1888 return nest.pack_sequence_as(value, ys)
1891def _set_last_step_outputs(ctx, last_step_tensor_outputs):
1892 """Sets the last step outputs on the given context."""
1893 # Convert replicate_outputs to the original dict structure of
1894 # last_step_outputs.
1895 last_step_tensor_outputs_dict = nest.pack_sequence_as(
1896 ctx.last_step_outputs, last_step_tensor_outputs)
1898 for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access
1899 output = last_step_tensor_outputs_dict[name]
1900 # For outputs that aren't reduced, return a PerReplica of all values. Else
1901 # take the first value from the list as each value should be the same.
1902 if reduce_op is None:
1903 last_step_tensor_outputs_dict[name] = values.PerReplica(output)
1904 else:
1905 # TODO(priyag): Should this return the element or a list with 1 element
1906 last_step_tensor_outputs_dict[name] = output[0]
1907 ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access