Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/tpu/tpu.py: 19%
500 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ======================================
16"""Library of TPU helper functions."""
18import collections
19import enum
20from typing import Any, Callable, Iterable, List, Optional, Text, Tuple, Union
22from absl import logging
23import numpy as np
25from tensorflow.compiler.tf2xla.python import xla as tf2xla
26from tensorflow.core.framework import attr_value_pb2
27from tensorflow.core.protobuf.tpu import dynamic_padding_pb2 as dynamic_padding
28from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 as embedding_pb2
29from tensorflow.python import tf2
30from tensorflow.python.compiler.xla import xla
31from tensorflow.python.framework import auto_control_deps
32from tensorflow.python.framework import composite_tensor
33from tensorflow.python.framework import config
34from tensorflow.python.framework import constant_op
35from tensorflow.python.framework import dtypes
36from tensorflow.python.framework import func_graph
37from tensorflow.python.framework import function
38from tensorflow.python.framework import ops
39from tensorflow.python.framework import tensor_shape
40from tensorflow.python.ops import array_ops
41from tensorflow.python.ops import array_ops_stack
42from tensorflow.python.ops import cond
43from tensorflow.python.ops import control_flow_ops
44from tensorflow.python.ops import math_ops
45from tensorflow.python.ops import variable_scope
46from tensorflow.python.tpu import device_assignment as device_assignment_lib
47from tensorflow.python.tpu import tensor_tracer
48from tensorflow.python.tpu import tpu_feed
49from tensorflow.python.tpu import tpu_function
50from tensorflow.python.tpu import tpu_name_util
51from tensorflow.python.tpu import tpu_replication
52from tensorflow.python.tpu.ops import tpu_ops
53from tensorflow.python.types import core as core_types
54from tensorflow.python.util import compat
55from tensorflow.python.util import nest
56from tensorflow.python.util import object_identity
57from tensorflow.python.util import traceback_utils
58from tensorflow.python.util import variable_utils
59from tensorflow.python.util.tf_export import tf_export
62ops.NotDifferentiable("TPUReplicatedInput")
64# Ops which can be safely pruned from XLA compile if they have no consumers.
65# These ops should also have no inputs.
66_UNCONNECTED_OPS_TO_PRUNE = set(["Placeholder", "VarHandleOp"])
68_POST_DEVICE_REWRITE_ATTR = "_post_device_rewrite"
69_TPU_COMPILATION_STATUS_ATTR = "_tpu_compilation_status"
70_PIVOT_FOR_CLUSTER = "_pivot_for_cluster"
73core = tpu_name_util.core
76def _tpu_system_device_name(job: Optional[Text]) -> Text:
77 """Returns the device name for the TPU_SYSTEM device of `job`."""
78 if job is None:
79 return "/device:TPU_SYSTEM:0"
80 else:
81 return "/job:%s/device:TPU_SYSTEM:0" % job
84@tf_export(v1=["tpu.initialize_system"])
85def initialize_system(
86 embedding_config: Optional[embedding_pb2.TPUEmbeddingConfiguration] = None,
87 job: Optional[Text] = None,
88 compilation_failure_closes_chips: bool = True,
89 tpu_cancellation_closes_chips: Optional[bool] = None,
90) -> core_types.Tensor:
91 """Initializes a distributed TPU system for use with TensorFlow.
93 Args:
94 embedding_config: If not None, a `TPUEmbeddingConfiguration` proto
95 describing the desired configuration of the hardware embedding lookup
96 tables. If embedding_config is None, no hardware embeddings can be used.
97 job: The job (the XXX in TensorFlow device specification /job:XXX) that
98 contains the TPU devices that will be initialized. If job=None it is
99 assumed there is only one job in the TensorFlow flock, and an error will
100 be returned if this assumption does not hold.
101 compilation_failure_closes_chips: Set the configuration whether
102 we want to close TPU chips when there is a compilation failure.
103 tpu_cancellation_closes_chips: Set the configuration whether
104 we want to close TPU chips when a TPU execution is cancelled. If the value
105 is None, the behavior will be determined by the command line flag
106 `tpu_cancellation_closes_chips` for the TPU worker. WARNING: this argument
107 only applies to TFRT TPU runtime.
108 Returns:
109 A serialized `TopologyProto` that describes the TPU system. Note:
110 the topology must be evaluated using `Session.run` before it can be used.
111 """
112 config_string = ("" if embedding_config is None else
113 embedding_config.SerializeToString())
115 # The enum is defined in core/tpu/kernels/tpu_execute_op_options.h.
116 tpu_cancellation_closes_chips_enum = 0
117 if tpu_cancellation_closes_chips is not None:
118 if tpu_cancellation_closes_chips:
119 tpu_cancellation_closes_chips_enum = 1
120 else:
121 tpu_cancellation_closes_chips_enum = 2
123 with ops.device(_tpu_system_device_name(job)):
124 topology = tpu_ops.configure_distributed_tpu(
125 compilation_failure_closes_chips=compilation_failure_closes_chips,
126 tpu_cancellation_closes_chips=tpu_cancellation_closes_chips_enum,
127 )
129 if embedding_config is None:
130 return topology
132 # This set of control dependencies is needed as this function is expected to
133 # return an op which will return the topology when executed, but we need to
134 # call the embedding initialization op between initializing the TPU and
135 # returning the topology.
136 with ops.control_dependencies([topology]):
137 embedding_init = tpu_ops.configure_tpu_embedding(config=config_string)
138 with ops.control_dependencies([embedding_init]):
139 return array_ops.identity(topology, name="tpu_init_identity")
142def initialize_system_for_tpu_embedding(
143 embedding_config: embedding_pb2.TPUEmbeddingConfiguration,
144 job: Optional[Text] = None,
145) -> ops.Operation:
146 """Initializes a distributed TPU Embedding system for use with TensorFlow.
148 The following two are equivalent:
149 1. initialize_system() with embedding_config.
150 2. initialize_system() without embedding_config, then
151 initialize_system_for_tpu_embedding().
152 initialize_system() should not be called with embedding_config if
153 initialize_system_for_tpu_embedding() is meant to be called later.
155 Args:
156 embedding_config: a `TPUEmbeddingConfiguration` proto describing the desired
157 configuration of the hardware embedding lookup tables.
158 job: The job (the XXX in TensorFlow device specification /job:XXX) that
159 contains the TPU devices that will be initialized. If job=None it is
160 assumed there is only one job in the TensorFlow flock, and an error will
161 be returned if this assumption does not hold.
163 Returns:
164 A no-op.
165 """
166 config_string = embedding_config.SerializeToString()
167 with ops.device(_tpu_system_device_name(job)):
168 return tpu_ops.configure_tpu_embedding(config=config_string)
171@tf_export(v1=["tpu.shutdown_system"])
172def shutdown_system(job: Optional[Text] = None) -> ops.Operation:
173 """Shuts down a running a distributed TPU system.
175 Args:
176 job: The job (the XXX in TensorFlow device specification /job:XXX) that
177 contains the TPU devices that will be shutdown. If job=None it is
178 assumed there is only one job in the TensorFlow flock, and an error will
179 be returned if this assumption does not hold.
180 """
181 with ops.device(_tpu_system_device_name(job)):
182 shutdown_distributed_tpu = tpu_ops.shutdown_distributed_tpu()
183 return shutdown_distributed_tpu
186@auto_control_deps.register_acd_resource_resolver
187def tpu_replicated_input_resolver(
188 op: ops.Operation,
189 resource_reads: object_identity.ObjectIdentitySet,
190 resource_writes: object_identity.ObjectIdentitySet) -> bool:
191 """Replaces TPUReplicatedInput outputs with its inputs in resource_inputs."""
192 # Ignore TPUReplicatedInput for ACD purposes since we will be directly adding
193 # control deps on the replicated inputs.
194 if op.type == "TPUReplicatedInput":
195 if resource_reads or resource_writes:
196 resource_reads.clear()
197 resource_writes.clear()
198 return True
199 else:
200 return False
201 # Replace tensors in `resource_inputs` which are outputs of TPUReplicatedInput
202 # with the actual replicated inputs. This allows ACD to correct add control
203 # deps when there are multiple calls to `run` in a
204 # `tf.function`.
205 def replace_with_unreplicated_resources(resource_inputs):
206 """Replaces handles in `resource_inputs` with their unreplicated inputs."""
207 to_remove = []
208 to_add = []
209 for resource in resource_inputs:
210 if resource.op.type == "TPUReplicatedInput":
211 to_remove.append(resource)
212 to_add.extend(resource.op.inputs)
213 for t in to_remove:
214 resource_inputs.discard(t)
215 resource_inputs.update(to_add)
216 return to_add or to_remove
218 return bool(replace_with_unreplicated_resources(resource_reads) or
219 replace_with_unreplicated_resources(resource_writes))
222@tf_export(v1=["tpu.PaddingSpec"])
223class PaddingSpec(enum.IntEnum):
224 """Represents the type of padding policies for tpu.replicate."""
225 # By default the policy is set to AUTO, the dynamic input shape dimension will
226 # be pad to maximum of all the replicas.
227 AUTO = 0
228 # Bucketize the dynamic input shape dimension into a power of 2.
229 POWER_OF_TWO = 1
232@tf_export("tpu.XLAOptions")
233class XLAOptions(
234 collections.namedtuple("XLAOptions", [
235 "use_spmd_for_xla_partitioning",
236 "enable_xla_dynamic_padder",
237 ])):
238 """XLA compilation options.
240 Attributes:
241 use_spmd_for_xla_partitioning: Boolean. Whether to use XLA's SPMD
242 partitioner instead of MPMD partitioner when compiler partitioning is
243 requested.
244 enable_xla_dynamic_padder: Boolean. Whether to enable XLA dynamic padder
245 infrastructure to handle dynamic shapes inputs inside XLA. True by
246 default. Disabling this may cause correctness issues with dynamic shapes
247 inputs, as XLA will just assume the inputs are with padded shapes. However
248 users can optionally set it to False to improve device time if masking is
249 already handled in the user side.
250 """
252 def __new__(cls,
253 use_spmd_for_xla_partitioning=True,
254 enable_xla_dynamic_padder=True):
255 return super(XLAOptions, cls).__new__(cls, use_spmd_for_xla_partitioning,
256 enable_xla_dynamic_padder)
259@tf_export(v1=["tpu.replicate"])
260@traceback_utils.filter_traceback
261def replicate(
262 computation: Callable[..., Any],
263 inputs: Optional[List[List[core_types.Tensor]]] = None,
264 infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
265 device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
266 name: Optional[Text] = None,
267 maximum_shapes: Optional[Any] = None,
268 padding_spec: Optional[PaddingSpec] = None,
269 xla_options: Optional[XLAOptions] = None) -> List[Any]:
270 """Builds a graph operator that runs a replicated TPU computation.
272 Example for the basic usage that `inputs` has static shape:
274 ```python
276 def computation(x):
277 x = x + 1
278 return tf.math.reduce_mean(x)
280 x = tf.convert_to_tensor([1., 2., 3.])
281 y = tf.convert_to_tensor([4., 5., 6.])
282 tf.compat.v1.tpu.replicate(computation, inputs=[[x], [y]])
283 ```
285 If the `inputs` has dynamic shapes and you would like to automatically
286 bucketize the inputs to avoid XLA recompilation. See the advanced example
287 below:
289 ```python
291 def computation(x):
292 x = x + 1
293 return tf.math.reduce_mean(x)
295 # Assume input tensors in two replicas `x` and `y` both have dynamic shape
296 # ([None, 2]).
297 tf.compat.v1.tpu.replicate(
298 computation,
299 inputs=[x, y],
300 maximum_shapes=[tf.TensorShape([None, None])],
301 padding_spec=tf.compat.v1.tpu.PaddingSpec.POWER_OF_TWO)
302 ```
304 Args:
305 computation: A Python function that builds the computation to replicate.
306 inputs: A list of lists of input tensors or `None` (equivalent to
307 `[[]]`), indexed by `[replica_num][input_num]`. All replicas must
308 have the same number of inputs. Each input can be a nested structure
309 containing values that are convertible to tensors. Note that passing an
310 N-dimension list of compatible values will result in a N-dimension list of
311 scalar tensors rather than a single Rank-N tensors. If you need different
312 behavior, convert part of inputs to tensors with `tf.convert_to_tensor`.
313 infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
314 of arguments as inputs to computation.
315 device_assignment: If not `None`, a `DeviceAssignment` describing the
316 mapping between logical cores in the computation with physical cores in
317 the TPU topology. Uses a default device assignment if `None`. The
318 `DeviceAssignment` may be omitted if each replica of the computation uses
319 only one core, and there is either only one replica, or the number of
320 replicas is equal to the number of cores in the TPU system.
321 name: (Deprecated) Does nothing.
322 maximum_shapes: A nested structure of tf.TensorShape representing the shape
323 to which the respective component of each input element in each replica
324 should be padded. Any unknown dimensions (e.g.
325 tf.compat.v1.Dimension(None) in a tf.TensorShape or -1 in a tensor-like
326 object) will be padded to the maximum size of that dimension over all
327 replicas. The structure of `maximum_shapes` needs to be the same as
328 `inputs[0]`.
329 padding_spec: An enum specified by `tpu.PaddingSpec`. This describes the
330 padding policy when the `inputs` to `tpu.replicate` is dynamic.
331 One usage is to enable automatic bucketizing on the inputs by setting the
332 value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the
333 recompilation in the XLA side.
334 xla_options: An instance of `tpu.XLAOptions` which indicates the options
335 passed to XLA compiler. Use `None` for default options.
336 Returns:
337 A list of outputs, indexed by `[replica_num]` each output can be a nested
338 structure same as what computation() returns with a few exceptions.
340 Exceptions include:
341 1) None output: a NoOp would be returned which control-depends on
342 computation.
343 2) Single value output: A tuple containing the value would be returned.
344 3) Operation-only outputs: a NoOp would be returned which
345 control-depends on computation.
346 TODO(b/121383831): Investigate into removing these special cases.
348 Raises:
349 ValueError: If all replicas do not have equal numbers of input tensors.
350 ValueError: If the number of inputs per replica does not match
351 the number of formal parameters to `computation`.
352 ValueError: If the static `inputs` dimensions don't match with the values
353 given in `maximum_shapes`.
354 ValueError: If the structure of inputs per replica does not match
355 the structure of `maximum_shapes`.
356 """
357 return split_compile_and_replicate(
358 computation,
359 inputs,
360 infeed_queue,
361 device_assignment,
362 name,
363 maximum_shapes=maximum_shapes,
364 padding_spec=padding_spec,
365 xla_options=xla_options)[1]
368def _ceil_to_pow_of_n(x, n):
369 """Ceil input `x` to power of `n`."""
370 x = math_ops.cast(x, dtypes.float32)
371 lognx = math_ops.log(x) / math_ops.log(n * 1.0)
372 lognx = math_ops.ceil(lognx)
373 result = math_ops.pow(n * 1.0, lognx)
374 result = math_ops.cast(result, dtypes.int32)
375 return result
378def _pad_all_input(
379 inputs: Iterable[core_types.Tensor],
380 padded_shapes: List[Optional[tensor_shape.TensorShape]],
381 padding_spec: PaddingSpec
382) -> Tuple[List[List[Any]], List[dynamic_padding.PaddingMap]]:
383 """Pad all input tensors given padded_shapes.
385 The real shape tensors will be concatenated with the padded original inputs.
387 Args:
388 inputs: The original inputs.
389 padded_shapes: A list of padded shapes for each input. If an entry is None,
390 no padding is performed.
391 padding_spec: An enum specified by `tpu.PaddingSpec`. This describes the
392 padding policy when the `inputs` to `tf.tpu.replicate` is dynamic.
393 One usage is to enable automatic bucketizing on the inputs by setting the
394 value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the
395 recompilation in the XLA side.
397 Returns:
398 The padded inputs and a PaddingMap list which maps the padded input
399 dimension to the real shape argument index.
400 """
401 # maximum_static_shapes[idx][i] indicates the maximum static size of ith
402 # dimension of the idx input among all the replicas.
403 maximum_static_shapes = []
404 # need_padding[idx][i] indicates whether the ith dimension of the idx input
405 # needs padding.
406 need_padding = []
407 input_shape_tensors = []
408 for core_idx, inputs_per_core in enumerate(inputs):
409 for idx, input_tensor in enumerate(inputs_per_core):
410 input_shape = input_tensor.get_shape().as_list()
411 if core_idx == 0:
412 input_shape_tensors.append([])
413 maximum_static_shapes.append(input_shape)
414 need_padding.append(np.full_like(input_shape, False, dtype=bool))
415 else:
416 for i, s in enumerate(input_shape):
417 if s is None or s != maximum_static_shapes[idx][i]:
418 need_padding[idx][i] = True
419 maximum_static_shapes[idx] = max(input_shape,
420 maximum_static_shapes[idx])
422 # Append _POST_DEVICE_REWRITE_ATTR attributes to the real shape ops.
423 real_input_shape = array_ops.shape(input_tensor)
424 real_input_shape.op._set_attr( # pylint: disable=protected-access
425 _POST_DEVICE_REWRITE_ATTR,
426 attr_value_pb2.AttrValue(b=True))
427 input_shape_tensors[idx].append(real_input_shape)
429 maximum_shapes = []
430 for shapes_per_input in input_shape_tensors:
431 maximum_shapes.append(
432 math_ops.reduce_max(array_ops_stack.stack(shapes_per_input), axis=0))
434 padded_inputs = []
435 real_shapes = []
436 padding_maps = []
437 for core_idx, inputs_per_core in enumerate(inputs):
438 padded_inputs.append([])
439 real_shapes.append([])
440 real_shape_idx = len(inputs_per_core) - 1
441 for idx, input_tensor in enumerate(inputs_per_core):
442 input_shape_tensor = input_shape_tensors[idx][core_idx]
443 input_shape = input_tensor.get_shape().as_list()
444 padded_shape = padded_shapes[idx]
446 # If we have no padded_shape, then skip padding.
447 if any(need_padding[idx]) and padded_shape is not None:
448 for i, s in enumerate(input_shape):
449 if need_padding[idx][i]:
450 if core_idx == 0:
451 real_shape_idx += 1
452 padding_map = dynamic_padding.PaddingMap()
453 padding_map.arg_index = idx
454 padding_map.shape_index = i
455 padding_map.padding_arg_index = real_shape_idx
456 padding_maps.append(padding_map)
457 real_shapes[core_idx].append(
458 math_ops.cast(input_shape_tensor[i], dtypes.int32))
460 paddings = []
461 for i, s in enumerate(padded_shape.dims):
462 if need_padding[idx][i]:
463 # The minimum padded dimension size is 2 as XLA doesn't support size
464 # 1 dynamic size.
465 minimum_dynamic_dim_size = 2
466 if s.value is not None:
467 # Pad to the given maximum value.
468 max_dim_size = max(s.value, minimum_dynamic_dim_size)
469 else:
470 # If maximum value is not given, then pad to the maximum dimension
471 # among all the cores.
472 max_dim_size = math_ops.maximum(maximum_shapes[idx][i],
473 minimum_dynamic_dim_size)
474 if padding_spec == PaddingSpec.POWER_OF_TWO:
475 max_dim_size = _ceil_to_pow_of_n(max_dim_size, 2)
476 # Pad to the given maximum value.
477 padding = [0, max_dim_size - input_shape_tensor[i]]
478 else:
479 padding = [0, 0]
480 paddings.append(padding)
482 if input_tensor.get_shape().is_fully_defined():
483 # TODO(rxsang): This is a hack to make sure padded_input has dynamic
484 # shapes, so any tf.size/tf.shape op performed on it won't be constant
485 # folded. Do we have better ways to do it?
486 padded_input = cond.cond(
487 array_ops.constant(True),
488 lambda: array_ops.pad(input_tensor, paddings), # pylint: disable=cell-var-from-loop
489 lambda: input_tensor)
490 else:
491 padded_input = array_ops.pad(input_tensor, paddings)
493 # Append _POST_DEVICE_REWRITE_ATTR attributes to all padded inputs.
494 padded_input.op._set_attr( # pylint: disable=protected-access
495 _POST_DEVICE_REWRITE_ATTR,
496 attr_value_pb2.AttrValue(b=True))
498 padded_inputs[core_idx].append(padded_input)
499 else:
500 padded_inputs[core_idx].append(input_tensor)
502 num_replicas = len(padded_inputs)
503 for i in range(num_replicas):
504 padded_inputs[i].extend(real_shapes[i])
506 return padded_inputs, padding_maps
509def _flatten_and_filter_composite(maybe_composite, non_composite_output,
510 composite_output=None):
511 """For an input, replaced the input by a tuple if the input is composite.
513 If `maybe_composite` is not composite, return the parameter
514 `non_composite_output` otherwise return a tuple which consists of the value of
515 the parameter `composite_output` the same number of times as there are
516 components of the composite tensor.
518 This is useful for computing a mask when flattening nested data with
519 `expand_composites=True`. For example
521 ```python
522 nest.flatten(data, expand_composites=True)
523 ```
525 and
527 ```python
528 nest.flatten(nest.map(
529 data, lambda x: _flatten_and_filter_composite(x, False, True)))
530 ```
532 will have the same length and second will be True if the tensor in the first
533 is derived from a expanding a composite tensor.
535 Args:
536 maybe_composite: A value to test for being a composite tensor.
537 non_composite_output: The value to return when `maybe_composite` is not a
538 composite.
539 composite_output: the value to fill the output tuple with if
540 `maybe_composite` is a composite.
542 Returns:
543 `non_composite_output` or a tuple with multiple copies of
544 `composite_output`.
545 """
547 if isinstance(maybe_composite, composite_tensor.CompositeTensor):
548 num_components = len(nest.flatten(maybe_composite, expand_composites=True))
549 return (composite_output,) * num_components
550 return non_composite_output
553def split_compile_and_replicate(
554 computation: Callable[..., Any],
555 inputs: Optional[List[List[core_types.Tensor]]] = None,
556 infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
557 device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
558 name: Optional[Text] = None,
559 use_tpu: bool = True,
560 maximum_shapes: Optional[Any] = None,
561 padding_spec: Optional[PaddingSpec] = None,
562 xla_options: Optional[XLAOptions] = None,
563) -> List[List[core_types.Tensor]]:
564 """Builds graph operators that runs compilation and replicated computation.
566 This is a lower level interface than replicate that returns a separate compile
567 and execute output tensor. In the generated graph the compile op feeds into
568 the execute op and no additional compilation is incurred when running the
569 compile op before the execute op. The compile op returns additional
570 information about the compilation but does not return the compiled program.
572 Args:
573 computation: A Python function that builds the computation to replicate.
574 inputs: A list of lists of input tensors or `None` (equivalent to
575 `[[]]`), indexed by `[replica_num][input_num]`. All replicas must
576 have the same number of inputs. Each input can be a nested structure
577 containing values that are convertible to tensors. Note that passing an
578 N-dimension list of compatible values will result in a N-dimension list of
579 scalar tensors rather than a single Rank-N tensors. If you need different
580 behavior, convert part of inputs to tensors with `tf.convert_to_tensor`.
581 infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
582 of arguments as inputs to computation.
583 device_assignment: If not `None`, a `DeviceAssignment` describing the
584 mapping between logical cores in the computation with physical cores in
585 the TPU topology. Uses a default device assignment if `None`. The
586 `DeviceAssignment` may be omitted if each replica of the computation uses
587 only one core, and there is either only one replica, or the number of
588 replicas is equal to the number of cores in the TPU system.
589 name: (Deprecated) Does nothing.
590 use_tpu: When false, the input `computation` is executed on the XLA CPU/GPU
591 backends. Currently, only supports a default placement (computation is
592 placed on GPU if one is available, and on CPU if not).
593 maximum_shapes: A nested structure of tf.TensorShape representing the shape
594 to which the respective component of each input element in each replica
595 should be padded. Any unknown dimensions (e.g.
596 tf.compat.v1.Dimension(None) in a tf.TensorShape or -1 in a tensor-like
597 object) will be padded to the maximum size of that dimension over all
598 replicas. The structure of `maximum_shapes` needs to be the same as
599 `inputs[0]`.
600 padding_spec: An enum specified by `tf.tpu.PaddingSpec`. This describes the
601 padding policy when the `inputs` to `tf.tpu.replicate` is dynamic.
602 One usage is to enable automatic bucketizing on the inputs by setting the
603 value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the
604 recompilation in the XLA side.
605 xla_options: An instance of `tpu.XLAOptions` which indicates the options
606 passed to XLA compiler. Use `None` for default options.
608 Returns:
609 A list of lists with the first list corresponding to the compile op and the
610 second a list of output tensors, indexed by `[replica_num][output_num]`.
611 Raises:
612 ValueError: If all replicas do not have equal numbers of input tensors.
613 ValueError: If the number of inputs per replica does not match
614 the number of formal parameters to `computation`.
615 ValueError: If the static `inputs` dimensions don't match with the values
616 given in `maximum_shapes`.
617 ValueError: If the structure of inputs per replica does not match
618 the structure of `maximum_shapes`.
619 """
620 del name
621 inputs = [[]] if inputs is None else inputs
622 xla_options = xla_options or XLAOptions()
624 metadata_kwargs = {}
625 if device_assignment is not None:
626 # Turn the Numpy array into a flattened list so we can pass it as an
627 # operator attribute.
628 metadata_kwargs = {
629 "topology":
630 device_assignment.topology.serialized(),
631 "device_assignment":
632 device_assignment.core_assignment.flatten().tolist()
633 }
634 metadata_kwargs["num_cores_per_replica"] = (
635 device_assignment.num_cores_per_replica)
637 # This entry is used for enabling automatic outside compilation.
638 metadata_kwargs["allow_soft_placement"] = config.get_soft_device_placement()
639 if config.get_soft_device_placement():
640 logging.info("Automatic outside compilation is enabled. "
641 "Ops without XLA kernels will be automatically "
642 "placed on CPU.")
644 if not isinstance(inputs, list):
645 raise TypeError("tpu.replicate() inputs must be a list of lists/tuples, "
646 f"received {type(inputs)}")
647 if any(not isinstance(inp, (list, tuple)) for inp in inputs):
648 raise TypeError(
649 "tpu.replicate() inputs must be a list of lists/tuples, "
650 f"received types: {[type(inp) for inp in inputs]}")
652 num_replicas = len(inputs)
654 # No replicas? Nothing to do.
655 if num_replicas == 0:
656 return []
658 # Checks all replicas have the same structure.
659 for i in range(1, num_replicas):
660 nest.assert_same_structure(inputs[0], inputs[i])
662 # Explicitly read variables.
663 inputs = variable_utils.convert_variables_to_tensors(inputs)
664 # Flatten inputs. This structure may contain None values, which will be
665 # handled later.
666 flat_inputs_with_nones = [
667 nest.flatten(per_replica_input, expand_composites=True)
668 for per_replica_input in inputs
669 ]
670 # Mask parallel to one replica's inputs with True for tensors coming from
671 # composites.
672 is_composite = nest.flatten(nest.map_structure(
673 lambda x: _flatten_and_filter_composite(x, False, True), inputs[0]))
675 # Converts inputs to Tensors, replacing Nones with a placeholder 0 since
676 # tpu_ops.tpu_replicated_input() can't handle non-Tensor values.
677 flat_inputs = []
678 for inp in flat_inputs_with_nones:
679 flat_inputs.append([
680 constant_op.constant(0) if x is None else ops.convert_to_tensor(x)
681 for x in inp
682 ])
684 # Verifies that all replicas have matching numbers and types of inputs
685 flat_input_types = [x.dtype for x in flat_inputs[0]]
686 input_arity = len(inputs[0])
687 flat_input_arity = len(flat_input_types)
688 for i in range(num_replicas):
689 if len(inputs[i]) != input_arity:
690 raise ValueError("Replicas must have the same number of inputs. "
691 "Replica 0 had {} inputs, replica {} had {} "
692 "inputs.".format(input_arity, i, len(inputs[i])))
694 types = [x.dtype for x in flat_inputs[i]]
695 if types != flat_input_types:
696 raise ValueError("Replicas must have matching input types. Replica 0 had "
697 "input types {}, replica {} had input types {}".format(
698 flat_input_types, i, types))
700 arg_error = xla.check_function_argument_count(
701 computation, input_arity, infeed_queue)
702 if arg_error is not None:
703 if infeed_queue is None:
704 raise TypeError(
705 "Supplied computation cannot be called with the specified inputs. "
706 f"You specified {input_arity} inputs: {[i.name for i in inputs[0]]}, "
707 f"but the computation needs {arg_error}")
708 else:
709 raise TypeError(
710 "Supplied computation cannot be called with the specified inputs. "
711 f"You specified {input_arity} inputs: {[i.name for i in inputs[0]]} ",
712 f"and {infeed_queue.number_of_tuple_elements} additional inputs "
713 f"from infeed, but the computation needs {arg_error}")
715 dynamic_shape_inputs = False
716 if maximum_shapes:
717 if infeed_queue:
718 raise ValueError(
719 "Dynamic input shapes are not supported with infeed queues")
721 # Make sure maximum_shapes has the same structure as inputs.
722 nest.assert_same_structure(inputs[0], maximum_shapes, check_types=False)
724 # Flatten padded shapes:
725 # For composite tensor components, we don't want to pad them. For each
726 # entry of maximum_shapes that corresponds to a composite tensor, replace it
727 # by a tuple of Nones of the same length as the number of components of the
728 # composite tensor. When we flatten a second time, this makes
729 # flat_maximum_shapes have the same length as flat_inputs[i]. We can then
730 # avoid padding these tensors. The assumption is that they will be used by
731 # outside compilation or that the components are statically shaped and will
732 # be used by tpu compatible ops.
733 flat_maximum_shapes = nest.flatten(
734 [_flatten_and_filter_composite(x, y)
735 for x, y in zip(nest.flatten(inputs[0]),
736 nest.flatten(maximum_shapes))])
737 flat_maximum_shapes = [
738 tensor_shape.TensorShape(s) if s is not None else None
739 for s in flat_maximum_shapes
740 ]
741 nest.assert_same_structure(flat_inputs[0], flat_maximum_shapes,
742 check_types=False)
744 unpadded_inputs = flat_inputs
745 flat_inputs, padding_maps = _pad_all_input(unpadded_inputs,
746 flat_maximum_shapes,
747 padding_spec)
748 if padding_maps:
749 dynamic_shape_inputs = True
750 logging.info("TPU has inputs with dynamic shapes: %s", inputs[0])
752 metadata_kwargs["step_marker_location"] = getattr(
753 computation, "step_marker_location", "STEP_MARK_AT_ENTRY")
754 metadata_kwargs["use_spmd_for_xla_partitioning"] = \
755 xla_options.use_spmd_for_xla_partitioning
757 graph = ops.get_default_graph()
759 # Fan-in: Builds a TPUReplicatedInput node for each input.
760 flat_replicated_inputs = []
761 for i in range(0, len(flat_inputs[0])):
762 replicas = [flat_inputs[replica][i] for replica in range(num_replicas)]
763 flat_replicated_inputs.append(
764 tpu_ops.tpu_replicated_input(
765 replicas, name="input{}".format(i)))
766 if isinstance(graph, func_graph.FuncGraph):
767 # When we are in Tensorflow 2.0 function, 'graph' will be a FuncGraph
768 # object. If both outside graph and this function have a TPU cluster,
769 # they will have the same cluster name and it will cause problems (because
770 # we lower functional ops in Tensorflow 2.0). Append function name to
771 # 'cluster_name' to avoid cluster name collision.
772 cluster_name = graph.unique_name("cluster_" + graph.name)
773 else:
774 cluster_name = graph.unique_name("cluster")
775 pivot = control_flow_ops.no_op(name=cluster_name + "/pivot")
776 pivot._set_attr(_PIVOT_FOR_CLUSTER, # pylint: disable=protected-access
777 attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name)))
778 context = tpu_replication.TPUReplicateContext(
779 name=cluster_name, num_replicas=num_replicas, pivot=pivot)
780 try:
781 context.Enter()
783 metadata = tpu_ops.tpu_replicate_metadata(
784 num_replicas=num_replicas, use_tpu=use_tpu, **metadata_kwargs)
786 with tpu_function.tpu_shard_context(
787 num_replicas), ops.control_dependencies([metadata]):
789 if dynamic_shape_inputs and xla_options.enable_xla_dynamic_padder:
790 for padding_map in padding_maps:
791 input_shape = flat_replicated_inputs[padding_map.arg_index].shape
792 flat_replicated_inputs[
793 padding_map.arg_index] = tf2xla.set_dynamic_dimension_size(
794 flat_replicated_inputs[padding_map.arg_index],
795 padding_map.shape_index,
796 flat_replicated_inputs[padding_map.padding_arg_index])
797 flat_replicated_inputs[padding_map.arg_index].set_shape(input_shape)
799 # Add identity ops so even unused inputs are "consumed" by the
800 # computation. This is to avoid orphaned TPUReplicatedInput nodes.
801 # TODO(phawkins): consider instead pruning unused TPUReplicatedInput
802 # and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs.
803 flat_replicated_inputs = [
804 array_ops.identity(x, name="replicated_input_{}".format(i))
805 for i, x in enumerate(flat_replicated_inputs)
806 ]
807 for i, composite in zip(flat_replicated_inputs, is_composite):
808 # pylint: disable=protected-access
809 # Add an attribute to the identity node so that they could be removed in
810 # encapsulate TPU computation pass if unused. However we don't remove
811 # inputs when dynamic padding is enabled.
812 # TODO(rxsang): Use other ways except argument index in padding_map so
813 # outside compilation can work with dynamic padding correctly.
814 if not dynamic_shape_inputs or composite:
815 i.op._set_attr("_tpu_input_identity",
816 attr_value_pb2.AttrValue(b=True))
817 # pylint: enable=protected-access
819 # Clobber replicated placeholders with Nones.
820 computation_inputs = [
821 None if inp is None else replicated for replicated, inp in zip(
822 flat_replicated_inputs, flat_inputs_with_nones[0])
823 ]
825 # Unflatten the computation inputs to match original input structure.
826 computation_inputs = nest.pack_sequence_as(
827 structure=inputs[0],
828 flat_sequence=computation_inputs[:flat_input_arity],
829 expand_composites=True)
831 # If there is an infeed queue, adds the dequeued values to the
832 # computation's inputs.
833 if infeed_queue is not None:
834 infeed_queue.set_number_of_shards(num_replicas)
835 for t in infeed_queue.generate_dequeue_op():
836 computation_inputs.append(t)
838 # Only resource variables work inside a TPU computation, so turn on
839 # resource variables for the computation.
840 # TODO(phawkins): consider removing this code. It will
841 # be less confusing to clients if they knowingly choose to use resource
842 # variables.
843 # Partitioned variables is not supported (b/112311320).
844 vscope = variable_scope.get_variable_scope()
845 saved_use_resource = vscope.use_resource
846 saved_custom_getter = vscope.custom_getter
848 def custom_getter(getter, name, *args, **kwargs):
849 """Variables on TPU have a few restrictions."""
850 partitioner = kwargs.get("partitioner", None)
851 if partitioner is not None:
852 kwargs["partitioner"] = None
853 logging.warning(
854 "Partitioned variables are not supported on TPU. Got "
855 "`partitioner` that is %s for variable %s. "
856 "Setting `partitioner` to `None`.", partitioner, name)
857 if saved_custom_getter is None:
858 return getter(name, *args, **kwargs)
859 else:
860 return saved_custom_getter(getter, name, *args, **kwargs)
862 vscope.set_use_resource(True)
863 vscope.set_custom_getter(custom_getter)
865 outputs = computation(*computation_inputs)
867 vscope.set_use_resource(saved_use_resource)
868 vscope.set_custom_getter(saved_custom_getter)
870 outputs = variable_utils.convert_variables_to_tensors(outputs)
872 need_spmd_partitioning = (
873 xla_options.use_spmd_for_xla_partitioning and
874 device_assignment is not None and
875 device_assignment.num_cores_per_replica > 1)
876 outputs_is_flat = xla.is_flat(outputs)
877 if outputs_is_flat:
878 output_tensors, control_deps, pack_template = _postprocess_flat_outputs(
879 outputs, need_spmd_partitioning)
880 else:
881 output_tensors, control_deps, pack_template = (
882 _postprocess_non_flat_outputs(outputs, need_spmd_partitioning))
884 if tensor_tracer.TensorTracer.is_enabled():
885 if tf2.enabled():
886 logging.warn("TF API ver >= 2.0 detected. "
887 "Tensor Tracer v1 is not enabled.")
888 else:
889 tt = tensor_tracer.TensorTracer()
890 output_tensors = tt.trace_tpu(ops.get_default_graph(),
891 output_tensors, control_deps,
892 num_replicas)
894 context.ExitResult(output_tensors)
895 finally:
896 context.report_unsupported_operations()
897 context.Exit()
898 host_compute_core = context.HostComputeCore()
900 if host_compute_core:
901 attr_value = attr_value_pb2.AttrValue()
902 attr_value.list.s.extend(compat.as_bytes(x) for x in host_compute_core)
903 metadata._set_attr("host_compute_core", attr_value) # pylint: disable=protected-access
905 with ops.control_dependencies([metadata]):
906 if use_tpu:
907 compile_status = tpu_ops.tpu_compilation_result()
908 op = compile_status.op
909 attr_value = attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name))
910 op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value) # pylint: disable=protected-access
911 else:
912 compile_status = control_flow_ops.no_op(name="compilation_status")
914 if not output_tensors:
915 # Returns a list of NoOps dependent on the replication Op, indexed by
916 # [replica_num].
917 return [
918 compile_status,
919 [
920 control_flow_ops.group(control_deps, name="shard_%d" % i)
921 for i in range(num_replicas)
922 ]
923 ]
925 # Fan-out: Builds a TPUReplicatedOutput node for each output.
926 replicated_outputs = [[] for i in range(num_replicas)]
927 for i, t in enumerate(output_tensors):
929 # None values returned by the computation can't be sent to
930 # tpu_ops.tpu_replicated_output(), we handle them specially here. We can
931 # avoid the placeholder 0 routine required on the inputs since outputs are
932 # replicated per-tensor, not per-replica, so we can skip replication.
933 if t is None:
934 for replica in range(num_replicas):
935 replicated_outputs[replica].append(None)
936 continue
938 # Fan-out: Builds a TPUReplicatedOutput node for each output.
939 ys = tpu_ops.tpu_replicated_output(
940 t, num_replicas, name="output{}".format(i))
942 # Wraps the outputs in identity operators so the names of any possible
943 # `fetch` nodes are preserved by the replication rewrite.
944 with ops.control_dependencies(control_deps):
945 for replica in range(num_replicas):
946 replicated_outputs[replica].append(
947 array_ops.identity(
948 ys[replica], name="output_%d_shard_%d" % (i, replica)))
950 replicated_outputs = [
951 nest.pack_sequence_as(pack_template, replica_outs, expand_composites=True)
952 for replica_outs in replicated_outputs
953 ]
955 return [compile_status, replicated_outputs]
958def _postprocess_flat_outputs(
959 outputs: Any,
960 need_spmd_partitioning: bool
961) -> Tuple[List[Optional[core_types.Tensor]], List[ops.Operation], List[Any]]:
962 """Validates non-flat outputs, add backs device assignments and other attrs.
964 Args:
965 outputs: Output from `computation` inside `tpu.rewrite`.
966 need_spmd_partitioning: Whether XLA SPMD partitioning is needed.
968 Returns:
969 - Tensors extracted from outputs.
970 - Operations extracted from outputs.
971 - A pack template for use with nest.pack_sequence_as to pack the tensors.
972 """
973 # Following code segment is to preserve legacy behavior. Previously we only
974 # supported flat outputs and thus for consistency it was nice to convert even
975 # single element into a tuple. But now that we support arbitrary output
976 # structure, this is no longer necessary.
977 # TODO(b/121383831): Migrate all legacy use cases and delete this special
978 # case.
979 # If the computation returns `None`, make it an empty tuple.
980 if outputs is None:
981 outputs = tuple()
983 # For legacy / backwards compatibility reasons we return a list for "flat"
984 # output values (even if the user's flat return value was a different type or
985 # even just a scalar value) so use nest.flatten to compute a flat list pack
986 # template.
987 pack_template = nest.flatten(outputs, expand_composites=False)
989 # Even though outputs is already "flat", we flatten any composites so their
990 # component tensors can be tagged and replicated. The pack_template will be
991 # used by the caller to repack the composite tensors.
992 outputs = nest.flatten(outputs, expand_composites=True)
994 # Append `no_op` here so that fetching any return value of this function
995 # will trigger TPUExecute node.
996 outputs += (control_flow_ops.no_op(),)
998 maybe_convert = lambda x: None if x is None else ops.convert_to_tensor(x)
999 try:
1000 if need_spmd_partitioning:
1001 outputs = [
1002 o if isinstance(o, ops.Operation) else maybe_convert(o)
1003 for o in outputs
1004 ]
1005 else:
1006 with ops.device(core(0)):
1007 outputs = [
1008 o if isinstance(o, ops.Operation) else maybe_convert(o)
1009 for o in outputs
1010 ]
1011 except Exception as e:
1012 raise ValueError(
1013 "TPU function return values must all either be Operations or "
1014 f"convertible to Tensors. Got error: {e}")
1016 # Separates the returned Operations and Tensors.
1017 output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
1018 output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)]
1020 if outputs != output_tensors + output_operations:
1021 raise ValueError(
1022 "TPU functions must return zero-or more Tensor values followed by "
1023 "zero or more Operations.")
1025 # Trim operations off the end of the pack template. output_operations has 1
1026 # extra element due to the no-op that is added.
1027 if len(output_operations) > 1:
1028 pack_template = pack_template[:1 - len(output_operations)]
1030 # Wraps outputs in Identity ops. Otherwise a replicated input copied
1031 # straight to an output would bypass the replicate(). This would be bad
1032 # because the TPUReplicatedInput/TPUReplicatedOutput operator would not
1033 # be rewritten away, leading to a runtime error.
1034 # TODO(phawkins): extend the rewrite to elide these nodes instead.
1035 new_output_tensors = []
1036 for t in output_tensors:
1037 if t is None:
1038 new_output_tensors.append(None)
1039 elif need_spmd_partitioning:
1040 o = array_ops.identity(t)
1041 # pylint: disable=protected-access
1042 o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True))
1043 # pylint: enable=protected-access
1044 new_output_tensors.append(o)
1045 else:
1046 with ops.device(t.device if t.device else core(0)):
1047 o = array_ops.identity(t)
1048 # pylint: disable=protected-access
1049 o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True))
1050 # pylint: enable=protected-access
1051 new_output_tensors.append(o)
1052 return new_output_tensors, output_operations, pack_template
1055def _postprocess_non_flat_outputs(
1056 outputs: Any,
1057 need_spmd_partitioning: bool
1058) -> Tuple[List[Optional[core_types.Tensor]], List[ops.Operation], List[Any]]:
1059 """Validates non-flat outputs, add backs device assignments and other attrs.
1061 Args:
1062 outputs: Output from `computation` inside `tpu.rewrite`.
1063 need_spmd_partitioning: Whether XLA SPMD partitioning is needed.
1065 Returns:
1066 - Tensors extracted from outputs.
1067 - An empty Operations list because Operations are not allowed in non-flat
1068 outputs.
1069 - A pack template for use with nest.pack_sequence_as to pack the tensors.
1070 """
1072 # Flatten output items.
1073 flat_outputs = nest.flatten(outputs, expand_composites=True)
1075 # Convert all non-None non-Operation outputs to Tensors.
1076 for i, o in enumerate(flat_outputs):
1077 if o is None:
1078 flat_outputs[i] = None
1079 continue
1081 if isinstance(o, ops.Operation):
1082 raise ValueError(
1083 "tpu.rewrite does not support Operation as return value in non-flat "
1084 "output structure. You can set returned Operations as control "
1085 "dependencies of returned Tensors so Operations are triggered when "
1086 f'Tensors are evaluated. Operation found: "{o.name}"')
1088 try:
1089 o = ops.convert_to_tensor(o)
1090 except Exception as e:
1091 raise ValueError(
1092 "TPU function return values must all either be Operations or "
1093 f'convertible to Tensors. Got error: "{e}"')
1095 # Wraps outputs in Identity ops. Otherwise a replicated input copied
1096 # straight to an output would bypass the replicate(). This would be bad
1097 # because the TPUReplicatedInput/TPUReplicatedOutput operator would not
1098 # be rewritten away, leading to a runtime error.
1099 # TODO(phawkins): extend the rewrite to elide these nodes instead.
1100 if need_spmd_partitioning:
1101 o = array_ops.identity(o)
1102 # pylint: disable=protected-access
1103 o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True))
1104 # pylint: enable=protected-access
1105 flat_outputs[i] = array_ops.identity(o)
1106 else:
1107 with ops.device(o.device if o.device else core(0)):
1108 o = array_ops.identity(o)
1109 # pylint: disable=protected-access
1110 o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True))
1111 # pylint: enable=protected-access
1112 flat_outputs[i] = array_ops.identity(o)
1114 # All flat_outputs are Tensors, and no Operations.
1115 return flat_outputs, [], outputs
1118def split_compile_and_shard(
1119 computation: Callable[..., Any],
1120 inputs: Optional[List[List[Optional[core_types.Tensor]]]] = None,
1121 num_shards: int = 1,
1122 input_shard_axes: Optional[List[int]] = None,
1123 outputs_from_all_shards: Union[bool, List[bool]] = True,
1124 output_shard_axes: Optional[List[int]] = None,
1125 infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
1126 device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
1127 name: Optional[Text] = None,
1128 xla_options: Optional[XLAOptions] = None,
1129 ) -> Tuple[ops.Operation, List[core_types.Tensor]]:
1130 """Shards `computation` for parallel execution.
1132 `inputs` must be a list of Tensors or None (equivalent to an empty list), each
1133 of which has a corresponding split axis (from `input_shard_axes`). Each input
1134 is split into `num_shards` pieces along the corresponding axis, and
1135 computation is applied to each shard in parallel.
1137 Tensors are broadcast to all shards if they are lexically captured by
1138 `computation`. e.g.,
1140 x = tf.constant(7)
1141 def computation():
1142 return x + 3
1143 ... = shard(computation, ...)
1145 If `outputs_from_all_shards` is true, the outputs from all shards of
1146 `computation` are concatenated back together along their `output_shard_axes`.
1147 Otherwise, each output is taken from an arbitrary shard.
1149 Inputs and outputs of the computation must be at least rank-1 Tensors.
1151 Args:
1152 computation: A Python function that builds a computation to apply to each
1153 shard of the input.
1154 inputs: A list of input tensors or None (equivalent to an empty list). Each
1155 input tensor has a corresponding shard axes, given by `input_shard_axes`,
1156 which must have size divisible by `num_shards`.
1157 num_shards: The number of shards.
1158 input_shard_axes: A list of dimensions along which to shard `inputs`, or
1159 `None`. `None` means "shard all inputs along dimension 0". If not `None`,
1160 there must be one dimension per input.
1161 outputs_from_all_shards: Boolean or list of boolean. For each output, if
1162 `True`, outputs from all shards are concatenated along the corresponding
1163 `output_shard_axes` entry. Otherwise, each output is taken
1164 from an arbitrary shard. If the argument is a boolean, the argument's
1165 value is used for each output.
1166 output_shard_axes: A list of dimensions along which to concatenate the
1167 outputs of `computation`, or `None`. `None` means "concatenate all outputs
1168 along dimension 0". If not `None`, there must be one dimension per output.
1169 Ignored if `outputs_from_all_shards` is False.
1170 infeed_queue: If not `None`, the `InfeedQueue` to use to augment the inputs
1171 of `computation`.
1172 device_assignment: If not `None`, a `DeviceAssignment` describing the
1173 mapping between logical cores in the computation with physical cores in
1174 the TPU topology. Uses a default device assignment if `None`. The
1175 `DeviceAssignment` may be omitted if each shard of the computation uses
1176 only one core, and there is either only one shard, or the number of shards
1177 is equal to the number of cores in the TPU system.
1178 name: (Deprecated) Does nothing.
1179 xla_options: An instance of `tpu.XLAOptions` which indicates the options
1180 passed to XLA compiler. Use `None` for default options.
1181 Returns:
1182 A tuple of (compile op, [output tensors]).
1183 Raises:
1184 ValueError: If num_shards <= 0
1185 ValueError: If len(input_shard_axes) != len(inputs)
1186 ValueError: If len(output_shard_axes) != len(outputs from `computation`)
1187 """
1188 # TODO(phawkins): consider adding support for broadcasting Tensors passed as
1189 # inputs.
1191 if num_shards <= 0:
1192 raise ValueError(
1193 f"num_shards must be a positive integer. Received {num_shards}")
1195 inputs = [] if inputs is None else inputs
1196 if not isinstance(inputs, list):
1197 raise TypeError("tpu.shard()'s inputs must be a list of Tensors or None. "
1198 f"Received {type(inputs)}")
1200 # Converts inputs to Tensors.
1201 inputs = [ops.convert_to_tensor(x) for x in inputs]
1203 if input_shard_axes is None:
1204 input_shard_axes = [0] * len(inputs)
1205 if len(inputs) != len(input_shard_axes):
1206 raise ValueError("Length of input_shard_axes must be equal to the number "
1207 f"of inputs. Received {len(inputs)} inputs and "
1208 f"{len(input_shard_axes)} input_shard_axes.")
1210 if inputs:
1211 # Splits the `inputs` along the corresponding `input_shard_axes`, giving
1212 # lists with layout [input][shard]
1213 split_inputs = [
1214 array_ops.split(x, num_shards, axis=axis)
1215 for (axis, x) in zip(input_shard_axes, inputs)]
1217 # Transposes the input lists to have layout [shard][input]
1218 transposed_inputs = [list(i) for i in zip(*split_inputs)]
1219 else:
1220 transposed_inputs = [[]] * num_shards
1222 compile_op, outputs = split_compile_and_replicate(
1223 computation,
1224 transposed_inputs,
1225 infeed_queue=infeed_queue,
1226 device_assignment=device_assignment,
1227 name=name,
1228 xla_options=xla_options)
1230 # There must be at least one shard since num_shards > 0.
1231 # TODO(b/36647078) remove disable when pylint bug is fixed.
1232 # pylint: disable=indexing-exception
1233 if isinstance(outputs[0], ops.Operation):
1234 # pylint: enable=indexing-exception
1235 # There were no outputs from the computation and replicate returned a list
1236 # of NoOps with control dependencies on the computation. Return the first
1237 # one so it can be used as a control dependency or fetch node.
1238 # TODO(b/36647078) remove disable when pylint bug is fixed.
1239 # pylint: disable=indexing-exception
1240 return compile_op, [outputs[0]]
1241 # pylint: enable=indexing-exception
1243 # TODO(b/36647078) remove disable when pylint bug is fixed.
1244 # pylint: disable=indexing-exception
1245 num_outputs = len(outputs[0])
1246 # pylint: enable=indexing-exception
1248 if output_shard_axes is None:
1249 output_shard_axes = [0] * num_outputs
1250 if num_outputs != len(output_shard_axes):
1251 raise ValueError("Length of output_shard_axes must be equal to the number "
1252 f"of outputs. Received {num_outputs} outputs "
1253 f"and {len(output_shard_axes)} output_shard_axes.")
1255 if isinstance(outputs_from_all_shards, bool):
1256 outputs_from_all_shards = [outputs_from_all_shards] * num_outputs
1258 if num_outputs != len(outputs_from_all_shards):
1259 raise ValueError(
1260 "Length of outputs_from_all_shards must be equal to the number of "
1261 f"outputs. Received {num_outputs} outputs and "
1262 f"{len(outputs_from_all_shards)} outputs_from_all_shards.")
1264 results = []
1265 for (axis, all_shards, x) in zip(output_shard_axes, outputs_from_all_shards,
1266 zip(*outputs)):
1267 if all_shards:
1268 # Concatenate all of the outputs together (use stack for scalars).
1269 shape = x[0].shape
1270 is_scalar = shape is not None and (shape.ndims == 0)
1271 results.append((array_ops_stack.stack(list(x)) if is_scalar
1272 else array_ops.concat(list(x), axis=axis)))
1273 else:
1274 # TODO(phawkins): use a smarter policy, e.g., round-robin across shards.
1275 results.append(x[0])
1277 return compile_op, results
1280@tf_export(v1=["tpu.shard"])
1281@traceback_utils.filter_traceback
1282def shard(
1283 computation: Callable[..., Any],
1284 inputs: Optional[List[core_types.Tensor]] = None,
1285 num_shards: int = 1,
1286 input_shard_axes: Optional[List[int]] = None,
1287 outputs_from_all_shards: Union[bool, List[bool]] = True,
1288 output_shard_axes: Optional[List[int]] = None,
1289 infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
1290 device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
1291 name: Optional[Text] = None,
1292 xla_options: Optional[XLAOptions] = None) -> List[core_types.Tensor]:
1293 """Shards `computation` for parallel execution.
1295 `inputs` must be a list of Tensors or None (equivalent to an empty list), each
1296 of which has a corresponding split axis (from `input_shard_axes`). Each input
1297 is split into `num_shards` pieces along the corresponding axis, and
1298 computation is applied to each shard in parallel.
1300 Tensors are broadcast to all shards if they are lexically captured by
1301 `computation`. e.g.,
1303 x = tf.constant(7)
1304 def computation():
1305 return x + 3
1306 ... = shard(computation, ...)
1308 TODO(phawkins): consider adding support for broadcasting Tensors passed
1309 as inputs.
1311 If `outputs_from_all_shards` is true, the outputs from all shards of
1312 `computation` are concatenated back together along their `output_shard_axes`.
1313 Otherwise, each output is taken from an arbitrary shard.
1315 Inputs and outputs of the computation must be at least rank-1 Tensors.
1317 Args:
1318 computation: A Python function that builds a computation to apply to each
1319 shard of the input.
1320 inputs: A list of input tensors or None (equivalent to an empty list). Each
1321 input tensor has a corresponding shard axes, given by `input_shard_axes`,
1322 which must have size divisible by `num_shards`.
1323 num_shards: The number of shards.
1324 input_shard_axes: A list of dimensions along which to shard `inputs`, or
1325 `None`. `None` means "shard all inputs along dimension 0". If not `None`,
1326 there must be one dimension per input.
1327 outputs_from_all_shards: Boolean or list of boolean. For each output, if
1328 `True`, outputs from all shards are concatenated along the corresponding
1329 `output_shard_axes` entry. Otherwise, each output is taken
1330 from an arbitrary shard. If the argument is a boolean, the argument's
1331 value is used for each output.
1332 output_shard_axes: A list of dimensions along which to concatenate the
1333 outputs of `computation`, or `None`. `None` means "concatenate all outputs
1334 along dimension 0". If not `None`, there must be one dimension per output.
1335 Ignored if `outputs_from_all_shards` is False.
1336 infeed_queue: If not `None`, the `InfeedQueue` to use to augment the inputs
1337 of `computation`.
1338 device_assignment: If not `None`, a `DeviceAssignment` describing the
1339 mapping between logical cores in the computation with physical cores in
1340 the TPU topology. Uses a default device assignment if `None`. The
1341 `DeviceAssignment` may be omitted if each shard of the computation uses
1342 only one core, and there is either only one shard, or the number of shards
1343 is equal to the number of cores in the TPU system.
1344 name: (Deprecated) Does nothing.
1345 xla_options: An instance of `tpu.XLAOptions` which indicates the options
1346 passed to XLA compiler. Use `None` for default options.
1347 Returns:
1348 A list of output tensors.
1349 Raises:
1350 ValueError: If num_shards <= 0
1351 ValueError: If len(input_shard_axes) != len(inputs)
1352 ValueError: If len(output_shard_axes) != len(outputs from `computation`)
1353 """
1354 return split_compile_and_shard(
1355 computation,
1356 inputs=inputs,
1357 num_shards=num_shards,
1358 input_shard_axes=input_shard_axes,
1359 outputs_from_all_shards=outputs_from_all_shards,
1360 output_shard_axes=output_shard_axes,
1361 infeed_queue=infeed_queue,
1362 device_assignment=device_assignment,
1363 name=name,
1364 xla_options=xla_options)[1]
1367@tf_export(v1=["tpu.batch_parallel"])
1368@traceback_utils.filter_traceback
1369def batch_parallel(
1370 computation: Callable[..., Any],
1371 inputs: Optional[List[List[Optional[core_types.Tensor]]]] = None,
1372 num_shards: int = 1,
1373 infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
1374 device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
1375 name: Optional[Text] = None,
1376 xla_options: Optional[XLAOptions] = None):
1377 """Shards `computation` along the batch dimension for parallel execution.
1379 Convenience wrapper around shard().
1381 `inputs` must be a list of Tensors or None (equivalent to an empty list).
1382 Each input is split into `num_shards` pieces along the 0-th dimension, and
1383 computation is applied to each shard in parallel.
1385 Tensors are broadcast to all shards if they are lexically captured by
1386 `computation`. e.g.,
1388 x = tf.constant(7)
1389 def computation():
1390 return x + 3
1391 ... = shard(computation, ...)
1393 The outputs from all shards are concatenated back together along their 0-th
1394 dimension.
1396 Inputs and outputs of the computation must be at least rank-1 Tensors.
1398 Args:
1399 computation: A Python function that builds a computation to apply to each
1400 shard of the input.
1401 inputs: A list of input tensors or None (equivalent to an empty list). The
1402 0-th dimension of each Tensor must have size divisible by `num_shards`.
1403 num_shards: The number of shards.
1404 infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
1405 of arguments as inputs to `computation`.
1406 device_assignment: If not `None`, a `DeviceAssignment` describing the
1407 mapping between logical cores in the computation with physical cores in
1408 the TPU topology. Uses a default device assignment if `None`. The
1409 `DeviceAssignment` may be omitted if each shard of the computation uses
1410 only one core, and there is either only one shard, or the number of shards
1411 is equal to the number of cores in the TPU system.
1412 name: (Deprecated) Does nothing.
1413 xla_options: An instance of `tpu.XLAOptions` which indicates the options
1414 passed to XLA compiler. Use `None` for default options.
1415 Returns:
1416 A list of output tensors.
1417 Raises:
1418 ValueError: If `num_shards <= 0`
1419 """
1420 return shard(
1421 computation,
1422 inputs,
1423 num_shards=num_shards,
1424 infeed_queue=infeed_queue,
1425 device_assignment=device_assignment,
1426 name=name,
1427 xla_options=xla_options)
1430@tf_export(v1=["tpu.rewrite"])
1431@traceback_utils.filter_traceback
1432def rewrite(
1433 computation: Callable[..., Any],
1434 inputs: Optional[List[List[Optional[core_types.Tensor]]]] = None,
1435 infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
1436 device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
1437 name: Optional[Text] = None,
1438 xla_options: Optional[XLAOptions] = None) -> Any:
1439 """Rewrites `computation` for execution on a TPU system.
1441 Args:
1442 computation: A Python function that builds a computation to apply to the
1443 input. If the function takes n inputs, 'inputs' should be a list of n
1444 tensors.
1446 `computation` may return a list of operations and tensors. Tensors must
1447 come before operations in the returned list. The return value of
1448 `rewrite` is a list of tensors corresponding to the tensors from the
1449 output of `computation`.
1451 All `Operation`s constructed during `computation` will be executed when
1452 evaluating any of the returned output tensors, not just the ones returned.
1453 inputs: A list of input tensors or `None` (equivalent to an empty list).
1454 Each input can be a nested structure containing values that are
1455 convertible to tensors. Note that passing an N-dimension list of
1456 compatible values will result in a N-dimension list of scalar tensors
1457 rather than a single Rank-N tensors. If you need different behavior,
1458 convert part of inputs to tensors with `tf.convert_to_tensor`.
1459 infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
1460 of arguments as inputs to `computation`.
1461 device_assignment: if not `None`, a `DeviceAssignment` describing the
1462 mapping between logical cores in the computation with physical cores in
1463 the TPU topology. May be omitted for a single-core computation, in which
1464 case the core attached to task 0, TPU device 0 is used.
1465 name: (Deprecated) Does nothing.
1466 xla_options: An instance of `tpu.XLAOptions` which indicates the options
1467 passed to XLA compiler. Use `None` for default options.
1468 Returns:
1469 Same data structure as if computation(*inputs) is called directly with some
1470 exceptions for correctness. Exceptions include:
1471 1) None output: a NoOp would be returned which control-depends on
1472 computation.
1473 2) Single value output: A tuple containing the value would be returned.
1474 3) Operation-only outputs: a NoOp would be returned which
1475 control-depends on computation.
1476 TODO(b/121383831): Investigate into removing these special cases.
1477 """
1478 # TODO(b/36647078) remove disable when pylint bug is fixed.
1479 # pylint: disable=indexing-exception
1480 return replicate(
1481 computation,
1482 None if inputs is None else [inputs],
1483 infeed_queue=infeed_queue,
1484 device_assignment=device_assignment,
1485 name=name,
1486 xla_options=xla_options)[0]
1487 # pylint: enable=indexing-exception
1489 # Operations that indicate some error in the user's inference graph.
1492_DENYLISTED_INFERENCE_OPS = set([
1493 "ReadVariableOp",
1494 "AssignVariableOp",
1495 "AssignAddVariableOp",
1496 "AssignSubVariableOp",
1497 "VarHandleOp",
1498 "Variable",
1499 "VariableV2",
1500])
1503def under_tpu_inference_context() -> bool:
1504 """Check if it is currently under `_TPUInferenceContext`."""
1505 graph = ops.get_default_graph()
1506 while graph:
1507 context = graph._get_control_flow_context() # pylint: disable=protected-access
1508 while context:
1509 if isinstance(context, _TPUInferenceContext):
1510 return True
1511 context = context.outer_context
1512 if isinstance(graph, function._FuncGraph): # pylint: disable=protected-access
1513 graph = graph._outer_graph # pylint: disable=protected-access
1514 elif isinstance(graph, func_graph.FuncGraph):
1515 graph = graph.outer_graph
1516 else:
1517 return False
1520class _TPUInferenceContext(control_flow_ops.XLAControlFlowContext):
1521 """A `ControlFlowContext` for nodes inside a TPU inference computation.
1523 The primary role of `_TPUInferenceContext` is to indicate the mode of
1524 operation and possibly sanity check operators inside a
1525 tpu.rewrite_for_inference() computation.
1526 """
1528 def __init__(self, name: Text, check_ops: bool = True):
1529 super(_TPUInferenceContext, self).__init__()
1530 self._name = name
1531 self._check_ops = check_ops
1533 def AddOp(self, op):
1534 self._AddOpInternal(op)
1536 def _AddOpInternal(self, op):
1537 # pylint: disable=protected-access
1538 if self._check_ops and op.type in _DENYLISTED_INFERENCE_OPS:
1539 raise NotImplementedError(
1540 f"Operation of type {op.type} ({op.name}) is not supported on the "
1541 "TPU for inference. Execution will fail if this op is used in the "
1542 "graph. Make sure your variables are using variable_scope.")
1543 if self._outer_context:
1544 self._outer_context.AddInnerOp(op)
1546 def AddValue(self, val):
1547 result = val
1548 if self._outer_context:
1549 result = self._outer_context.AddValue(val)
1550 return result
1552 def AddInnerOp(self, op):
1553 self._AddOpInternal(op)
1555 @property
1556 def grad_state(self):
1557 return None
1560def validate_inference_rewrite_for_variables(graph: ops.Graph):
1561 """Validates whether rewrite_for_inference() 'worked' for variables.
1563 The rewrite_for_inference() method is supposed to append GuaranteeConstOps
1564 after ReadVariableOps, but this mechanism works only if you are using
1565 tf.compat.v1.get_variable() to create and access variables in your tpu
1566 computation. This validation method can be called immediately after calling
1567 tpu.rewrite_for_inference() to check whether GuaranteeConstOps where added
1568 to the graph.
1570 Typical usages:
1571 tpu.validate_inference_rewrite_for_variables(
1572 tf.compat.v1.get_default_graph())
1574 tpu.validate_inference_rewrite_for_variables(sess.graph)
1576 Args:
1577 graph: The graph which needs to be validated.
1578 Raises:
1579 RuntimeError: if validation failed.
1580 """
1581 if not any(x.type == "GuaranteeConst" for x in graph.get_operations()):
1582 raise RuntimeError(
1583 "No GuaranteeConst ops found in the graph after running "
1584 "tpu.rewrite_for_inference(...). Please check that you are using "
1585 "tf.get_variable() to create and access variables in your tpu "
1586 "computation.")
1589def rewrite_for_inference(
1590 computation: Callable[..., Any],
1591 inputs: Optional[List[core_types.Tensor]] = None,
1592 infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
1593 device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
1594 name: Optional[Text] = None) -> List[core_types.Tensor]:
1595 """Rewrites `computation` for inference on a TPU system.
1597 Other than 'rewriting' the computation to run on a TPU, if using variables
1598 in your computation, it moves the ReadVariableOps outside the TPU
1599 computation, and adds GuaranteeConst ops just after the ReadVariableOps.
1600 This mechanism works only if you are using tf.compat.v1.get_variable() to
1601 create and access variables in your tpu computation. You can validate
1602 whether this worked, by calling validate_inference_rewrite_for_variables()
1603 method immediately after this method to check whether GuaranteeConstOps
1604 where added to the graph.
1606 Args:
1607 computation: A Python function that builds a computation to apply to the
1608 input. If the function takes n inputs, 'inputs' should be a list of n
1609 tensors. If the function returns m outputs, rewrite will return a list of
1610 m tensors.
1611 inputs: A list of input tensors or `None` (equivalent to an empty list).
1612 infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
1613 of arguments as inputs to `computation`.
1614 device_assignment: if not `None`, a `DeviceAssignment` describing the
1615 mapping between logical cores in the computation with physical cores in
1616 the TPU topology. May be omitted for a single-core computation, in which
1617 case the core attached to task 0, TPU device 0 is used.
1618 name: The name of the operator.
1619 Returns:
1620 A list of output tensors.
1621 """
1623 def guarantee_const_getter(getter, name, *args, **kwargs):
1624 with ops.control_dependencies(None):
1625 return array_ops.guarantee_const(
1626 getter(name, *args, **kwargs), name=name + "/GuaranteeConst")
1628 def wrapped_computation(*args, **kwargs):
1629 """Execute computation under `_TPUInferenceContext`."""
1630 context = _TPUInferenceContext(
1631 name=ops.get_default_graph().unique_name("rewrite_for_inference"))
1632 try:
1633 context.Enter()
1635 vscope = variable_scope.get_variable_scope()
1636 prev_custom_getter = vscope.custom_getter
1637 prev_caching_device = vscope.caching_device
1638 vscope.set_custom_getter(guarantee_const_getter)
1639 vscope.set_caching_device(lambda op: op.device)
1641 result = computation(*args, **kwargs)
1643 vscope.set_custom_getter(prev_custom_getter)
1644 vscope.set_caching_device(prev_caching_device)
1645 finally:
1646 context.Exit()
1647 return result
1649 # pylint: disable=undefined-variable
1650 return rewrite(
1651 wrapped_computation,
1652 inputs=inputs,
1653 infeed_queue=infeed_queue,
1654 device_assignment=device_assignment,
1655 name=name)
1656 # pylint: enable=undefined-variable
1659def prune_unconnected_ops_from_xla(prune_graph: ops.Graph):
1660 """Prunes unconnected ops as listed in _UNCONNECTED_OPS_TO_PRUNE.
1662 Args:
1663 prune_graph: A tensorflow graph from which we wish to prune unconnected ops
1664 as listed in _UNCONNECTED_OPS_TO_PRUNE. In general, these ops should have
1665 no inputs and no consumers. These can often be left behind due to graph
1666 construction rewiring (for instance TF-Hub). While they never execute,
1667 they will cause XLA compile to fail so we strip them from XLA compile by
1668 removing the tpu_replicate attribute.
1669 """
1670 # Scan over the top level graph and all function graphs.
1671 for graph in [prune_graph] + [
1672 f for f in prune_graph._functions.values() # pylint: disable=protected-access
1673 ]:
1674 if not isinstance(graph, ops.Graph):
1675 continue
1676 for op in graph.get_operations():
1677 if op.type not in _UNCONNECTED_OPS_TO_PRUNE:
1678 continue
1679 outputs_consumed = False
1680 for output in op.outputs:
1681 if output.consumers():
1682 outputs_consumed = True
1683 break
1684 if not outputs_consumed:
1685 logging.info(
1686 "Pruning OP %s of type %s from XLA Compile due to "
1687 "it being disconnected.", op.name, op.type)
1688 op._clear_attr(tpu_replication._TPU_REPLICATE_ATTR) # pylint: disable=protected-access