Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/cross_device_utils.py: 18%
294 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for cross_device_ops."""
17import copy
18import threading
19from typing import Callable, List, Optional, Union
21from tensorflow.python.distribute import collective_util
22from tensorflow.python.distribute import values as value_lib
23from tensorflow.python.eager import backprop_util
24from tensorflow.python.eager import context
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import indexed_slices
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import tensor_spec
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import collective_ops
31from tensorflow.python.ops import cond
32from tensorflow.python.ops import math_ops
33from tensorflow.python.ops import nccl_ops
34from tensorflow.python.ops import resource_variable_ops
35from tensorflow.python.platform import tf_logging as logging
36from tensorflow.python.types import core
38INSTANCE_KEY_START_NUMBER = 100
41def aggregate_gradients_using_nccl(replica_grads):
42 """Aggregate gradients using nccl allreduce."""
43 agg_all_g_and_v = []
44 for single_g_and_v in zip(*replica_grads):
45 single_grads = [g for g, _ in single_g_and_v]
46 agg_grads = nccl_ops.all_sum(single_grads)
47 agg_all_g_and_v.append(
48 [(g, v) for g, (_, v) in zip(agg_grads, single_g_and_v)])
50 agg_all_g_and_v = list(zip(*agg_all_g_and_v))
52 return agg_all_g_and_v
55def aggregate_gradients_using_hierarchical_copy(avail_devices, replica_grads):
56 """Aggregate gradients using hierarchical copies.
58 Args:
59 avail_devices: available GPU devices.
60 replica_grads: List of lists of (gradient, variable) tuples. The outer list
61 is over replicas. The inner list is over individual gradients.
63 Returns:
64 The list of (aggregated_gradient, variable), where the gradient has been
65 summed across all replicas and the variable is chosen from the first
66 replica.
67 """
68 # This only works for DGX-1 type of machine topology
69 # Device peer to peer matrix
70 # DMA: 0 1 2 3 4 5 6 7
71 # 0: Y Y Y Y Y N N N
72 # 1: Y Y Y Y N Y N N
73 # 2: Y Y Y Y N N Y N
74 # 3: Y Y Y Y N N N Y
75 # 4: Y N N N Y Y Y Y
76 # 5: N Y N N Y Y Y Y
77 # 6: N N Y N Y Y Y Y
78 # 7: N N N Y Y Y Y Y
79 agg_grads = []
80 num_devices = len(avail_devices)
81 # In the special case of DGX-1 machine topology, the two groups have equal
82 # size.
83 group_size = num_devices // 2
84 for i, single_grads in enumerate(zip(*replica_grads)):
85 group_0_main_device = i % num_devices
86 group_1_main_device = (group_0_main_device + group_size) % num_devices
87 if group_0_main_device < group_size:
88 group_0_begin = 0
89 group_1_begin = group_size
90 else:
91 group_0_begin = group_size
92 group_1_begin = 0
94 # Aggregate the first group.
95 group_0_device_grads = single_grads[group_0_begin:
96 group_0_begin + group_size]
97 with ops.device(avail_devices[group_0_main_device]):
98 group_0_agg_grads, _ = aggregate_single_gradient_using_copy(
99 group_0_device_grads, False, False)
101 # Aggregate the second group.
102 group_1_device_grads = single_grads[group_1_begin:
103 group_1_begin + group_size]
104 with ops.device(avail_devices[group_1_main_device]):
105 group_1_agg_grads, _ = aggregate_single_gradient_using_copy(
106 group_1_device_grads, False, False)
108 # Aggregate between the groups.
109 with ops.device(avail_devices[group_0_main_device]):
110 (agg_total_grads, _), _ = aggregate_single_gradient_using_copy(
111 [group_0_agg_grads, group_1_agg_grads], False, False)
113 # Broadcast the result back into the root of each group.
114 with ops.device(avail_devices[group_0_main_device]):
115 group_0_agg_grads_bcast = array_ops.identity(agg_total_grads)
116 with ops.device(avail_devices[group_1_main_device]):
117 group_1_agg_grads_bcast = array_ops.identity(agg_total_grads)
119 agg_grads_bcast = []
120 for j in range(len(single_grads)):
121 with ops.device(avail_devices[j]):
122 # Broadcast the result back to each member in the group from the root.
123 if (group_0_main_device < group_size) == (j < group_size):
124 src_device_grad = group_0_agg_grads_bcast
125 else:
126 src_device_grad = group_1_agg_grads_bcast
127 agg_grads_bcast.append(array_ops.identity(src_device_grad))
129 agg_grads.append(
130 [(g, v) for g, (_, v) in zip(agg_grads_bcast, single_grads)])
132 agg_grads = list(zip(*agg_grads))
134 return agg_grads
137def aggregate_single_gradient_using_copy(grad_and_vars, use_mean,
138 check_inf_nan):
139 """Calculate the average gradient for a shared variable across all replicas.
141 Note that this function provides a synchronization point across all replicas.
143 Args:
144 grad_and_vars: A list or tuple of (gradient, variable) tuples. Each
145 (gradient, variable) pair within the outer list represents the gradient
146 of the variable calculated for a single replica, and the number of pairs
147 equals the number of replicas.
148 use_mean: if True, mean is taken, else sum of gradients is taken.
149 check_inf_nan: check grads for nans and infs.
151 Returns:
152 The tuple ([(average_gradient, variable),], has_nan_or_inf) where the
153 gradient has been averaged across all replicas. The variable is chosen
154 from the first replica. The has_nan_or_inf indicates the grads has nan or
155 inf.
156 """
157 grads = [g for g, _ in grad_and_vars]
158 grad = math_ops.add_n(grads)
160 if use_mean and len(grads) > 1:
161 grad = array_ops.multiply(grad, 1.0 / len(grads))
163 v = grad_and_vars[0][1]
164 if check_inf_nan:
165 has_nan_or_inf = array_ops.logical_not(
166 array_ops.reduce_all(array_ops.is_finite(grads)))
167 return (grad, v), has_nan_or_inf
168 else:
169 return (grad, v), None
172# TODO(yuefengz): use random key starts to avoid reusing keys?
173class CollectiveKeys(object):
174 """Class that manages collective keys.
176 We need to manage three different keys for collective:
178 *Group key*: an integer key to identify the set of cooperative devices.
179 Collective ops work under the same set of devices must using the same group
180 key.
182 *Instance key*: an integer key to identify the set of same counterpart of
183 tensors on different devices in a device group that need to be all-reduced.
185 This class is thread safe.
186 """
188 def __init__(self, group_key_start=1):
189 """Initializes the object.
191 Args:
192 group_key_start: the starting integer of group key.
193 """
194 self._group_key = group_key_start
195 self._instance_key_table = {}
196 self._lock = threading.Lock()
197 self._known_groups = {}
199 def get_group_key(self, devices):
200 """Returns a group key for the list of local devices.
202 The same group key is returned if the list of local devices is the same.
204 Args:
205 devices: a list of local canonical device strings in a collective group.
207 Returns:
208 a group key.
209 """
210 with self._lock:
211 devices_key = ','.join(devices)
212 if devices_key not in self._known_groups:
213 self._known_groups[devices_key] = self._get_new_group_key(devices)
214 return self._known_groups[devices_key]
216 def _get_new_group_key(self, devices):
217 """Returns a new group key.
219 The caller should store and reuse the same group key for the same set of
220 devices. Calling this method always returns a new group key.
222 This method is not thread-safe.
224 Args:
225 devices: a list of canonical device strings in a collective group.
227 Returns:
228 a new group key.
229 """
230 new_key = self._group_key
231 self._group_key += 1
232 self._instance_key_table[new_key] = {}
233 for device in devices:
234 self._instance_key_table[new_key][device] = INSTANCE_KEY_START_NUMBER
235 return new_key
237 def get_instance_key(self, group_key, device):
238 """Returns a new instance key for use in defining a collective op.
240 You should call this once per each collective op of a collective instance.
242 Args:
243 group_key: the group key returned by get_group_key(). You should not
244 assign the group key yourself.
245 device: a canonical device string. It should be the device this collective
246 op is on.
248 Returns:
249 a new instance key.
251 Raises:
252 ValueError: when the group key is invalid or the device is not in the
253 group.
254 """
255 with self._lock:
256 group = self._instance_key_table.get(group_key, None)
257 if group is None:
258 raise ValueError(f'Group {group_key} is not found.')
259 if device not in group:
260 raise ValueError(f'Device {device} is not present in group {group_key}')
261 v = group[device]
262 group[device] += 1
263 return v
265 def __deepcopy__(self, memo):
266 # distribute_coordinator deep-copies the strategy object, so
267 # CollectiveKeys needs to support deep copy as well.
268 copied = CollectiveKeys()
269 copied._group_key = self._group_key
270 copied._instance_key_table = copy.deepcopy(self._instance_key_table, memo)
271 return copied
274class CollectiveReplicaLauncher(object):
275 """Launch collectives on one replica."""
277 _prefer_unique_instance_key = True
278 _prefer_ordering_token = True
280 def __init__(self, group_key: int, group_size: int,
281 collective_keys: CollectiveKeys, device: str,
282 options: collective_util.Options):
283 self._group_key = group_key
284 self._group_size = group_size
285 self._collective_keys = collective_keys
286 self._device = device
287 self._options = options
288 if self._use_ordering_token():
289 with ops.init_scope(), ops.device(device):
290 self._ordering_token = resource_variable_ops.ResourceVariable(0.)
291 else:
292 self._ordering_token = None
294 def _control_input(self, control_input: Union[core.TensorLike,
295 ops.Operation]):
296 if control_input is not None and not self._use_ordering_token():
297 return ops.control_dependencies([control_input])
298 return ops.NullContextmanager()
300 def _use_unique_instance_key(self):
301 if not ops.executing_eagerly_outside_functions():
302 return False
303 return CollectiveReplicaLauncher._prefer_unique_instance_key
305 def _use_ordering_token(self):
306 # We rely on auto control dep to insert control edges between NCCL calls,
307 # but for tf1 graph mode auto control dep is not used.
308 if not ops.executing_eagerly_outside_functions():
309 return False
310 return CollectiveReplicaLauncher._prefer_ordering_token
312 def _next_instance_key(self):
313 """Returns the next instance key."""
314 if self._use_unique_instance_key():
315 # Assigning instance keys at function building time have issues since
316 # different workers may retrace the function at different times. With
317 # collective V2 we can use capture_call_time_value to use a placeholder as
318 # the instance key and feed it at function call time. In this way we also
319 # don't reuse instance keys, which allows for per-instance cancellation.
320 graph = ops.get_default_graph()
321 # Control flow ops don't work with capture_call_time_value, so we put the
322 # capture in the function graph of that control flow op.
323 while getattr(graph, 'is_control_flow_graph', False):
324 graph = graph.outer_graph
325 if not context.executing_eagerly() and graph.building_function:
326 with graph.as_default():
327 # Capture self._next_instance_key so that when building a function
328 # that calls another tf.function, the instance key assignment is
329 # further delayed until we actually call the function in eager. Note
330 # that capture_call_time_value doesn't automatically propagate the
331 # deferred capture to the outer function.
332 return graph.capture_call_time_value(
333 self._next_instance_key, tensor_spec.TensorSpec([], dtypes.int32))
334 else:
335 instance_key = self._collective_keys.get_instance_key(
336 self._group_key, self._device)
337 with ops.device('CPU:0'):
338 return ops.convert_to_tensor(instance_key, dtype=dtypes.int32)
339 else:
340 return self._collective_keys.get_instance_key(self._group_key,
341 self._device)
343 def _get_ordering_token(self):
344 if self._use_ordering_token():
345 return self._ordering_token.handle # pytype: disable=attribute-error
347 def can_order_nccl(self):
348 """Whether this launcher can order NCCL operations."""
349 return self._use_ordering_token()
351 def all_reduce(
352 self,
353 input_tensor: core.TensorLike,
354 control_input: Optional[Union[core.TensorLike, ops.Operation]] = None,
355 options: Optional[collective_util.Options] = None) -> core.Tensor:
356 """All-reduce a dense tensor.
358 Args:
359 input_tensor: a dense tensor. It must have the same shape on all replicas.
360 control_input: if not None, add control edges between control_input and
361 the all-reduce.
362 options: an optional tf.distribute.experimental.CommunicationOptions. If
363 provided, it overrides the default options.
365 Returns:
366 The reduced tensor.
367 """
368 instance_key = self._next_instance_key()
369 options = self._options.merge(options)
370 ordering_token = self._get_ordering_token()
371 with ops.device(self._device), \
372 self._control_input(control_input):
373 return collective_ops.all_reduce_v2(
374 input_tensor,
375 self._group_size,
376 self._group_key,
377 instance_key,
378 communication_hint=options.implementation.value,
379 timeout=options.timeout_seconds,
380 ordering_token=ordering_token)
382 def _all_gather(self, input_tensor: core.TensorLike,
383 options: Optional[collective_util.Options]) -> core.Tensor:
384 """All-gather a dense tensor.
386 Args:
387 input_tensor: a dense tensor. It must have the same shape on all replicas.
388 options: an optional tf.distribute.experimental.CommunicationOptions. If
389 provided, it overrides the default options.
391 Returns:
392 The reduced tensor.
393 """
394 instance_key = self._next_instance_key()
395 options = self._options.merge(options)
396 ordering_token = self._get_ordering_token()
397 with ops.device(self._device):
398 return collective_ops.all_gather_v2(
399 input_tensor,
400 self._group_size,
401 self._group_key,
402 instance_key,
403 communication_hint=options.implementation.value,
404 timeout=options.timeout_seconds,
405 ordering_token=ordering_token)
407 def batch_all_reduce(
408 self,
409 input_tensor_packs: List[List[core.TensorLike]],
410 options: Optional[collective_util.Options] = None) -> core.Tensor:
411 """Batch all-reduce dense tensors.
413 This takes a list of batches of tensors. Using multiple batches have the
414 benefit that it doesn't need to wait for all inputs to be ready to start the
415 all-reduce.
417 Args:
418 input_tensor_packs: a list of lists of dense tensors.
419 options: an optional tf.distribute.experimental.CommunicationOptions. If
420 provided, it overrides the default options.
422 Returns:
423 A flat list of reduced tensors.
424 """
425 options = self._options.merge(options)
426 outputs = []
427 for pack in input_tensor_packs:
428 if context.executing_eagerly():
429 # We don't batch in eager as it sometimes makes the performance worse
430 # due the concat/split ops.
431 for input_tensor in pack:
432 outputs.append(self.all_reduce(input_tensor, None, options))
433 else:
434 # TODO(b/169168846): inserts a parallel all_gather to verify packings
435 # are the same on each replica.
436 with ops.device(self._device):
437 flat_tensors = [array_ops.reshape(t, [-1]) for t in pack]
438 shapes = [array_ops.shape(t) for t in pack]
439 if (options.implementation
440 == collective_util.CommunicationImplementation.NCCL and outputs):
441 control_input = outputs[-1]
442 else:
443 control_input = None
444 reduced = self.all_reduce(
445 array_ops.concat(flat_tensors, axis=0), control_input, options)
446 num_elements = [math_ops.reduce_prod(s) for s in shapes]
447 flat_outputs = array_ops.split(reduced, num_elements, axis=0)
448 for shape, flat_output in zip(shapes, flat_outputs):
449 outputs.append(array_ops.reshape(flat_output, shape))
451 return outputs
453 def all_gather(
454 self,
455 input_tensor: core.TensorLike,
456 axis: core.TensorLike,
457 options: Optional[collective_util.Options] = None) -> core.Tensor:
458 """All-gather a dense tensor.
460 This method must be called inside a tf.function.
462 Args:
463 input_tensor: a dense tensor. It must have the same rank on all replicas,
464 and dimensions other than `axis` need to be the same as well.
465 axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the
466 range [0, rank(value)).
467 options: an optional tf.distribute.experimental.CommunicationOptions. If
468 provided, it overrides the default options.
470 Returns:
471 The gathered Tensor.
473 Raises:
474 RuntimeError: if called in eager mode.
475 """
476 if context.executing_eagerly():
477 raise RuntimeError('all_gather is not supported in eager mode.')
479 with ops.device(self._device), \
480 ops.control_dependencies([array_ops.identity(input_tensor)]):
481 # 1. Transpose
482 # E.g. Given an input_tensor with shape [2,2,5,1] and axis to gather is 3,
483 # we use perm_pre=[3 0 1 2] to reshape it to [1,2,2,5], which
484 # brings the 3rd dim first; afterwards we use perm_after=[1,2,3,0] to
485 # place it back.
486 perm_pre = array_ops.concat(
487 ([axis], math_ops.range(axis),
488 math_ops.range(axis + 1, array_ops.rank(input_tensor))),
489 axis=0)
490 input_tensor_t = array_ops.transpose(input_tensor, perm=perm_pre)
491 # 2. Pad
492 gathered_shape = self._all_gather(
493 array_ops.expand_dims_v2(array_ops.shape_v2(input_tensor_t), axis=0),
494 options)
495 first_dims = gathered_shape[:, 0]
496 full_axis_dim = math_ops.reduce_max(first_dims)
497 padded_input_tensor = _pad_util(input_tensor_t, full_axis_dim)
499 # 3. Gather
500 gather_padded_out_tensor = self._all_gather(padded_input_tensor, options)
501 # 4. Unpad
502 split_tensors = []
503 for i in range(self._group_size):
504 start_pos = i * full_axis_dim
505 split_tensors.append(gather_padded_out_tensor[start_pos:start_pos +
506 first_dims[i]])
507 out_tensor_t = array_ops.concat(split_tensors, 0)
509 # 5. Transpose back
510 perm_after = array_ops.concat(
511 (math_ops.range(1, axis + 1), [0],
512 math_ops.range(axis + 1, array_ops.rank(input_tensor_t))),
513 axis=0)
514 return array_ops.transpose(out_tensor_t, perm=perm_after)
516 def all_reduce_indexed_slices(
517 self,
518 input_slices: indexed_slices.IndexedSlices,
519 options: Optional[collective_util.Options] = None
520 ) -> indexed_slices.IndexedSlices:
521 """All-reduce an IndexedSlices.
523 This method can be called outside tf.function.
525 Args:
526 input_slices: an IndexedSlices.
527 options: an optional tf.distribute.experimental.CommunicationOptions. If
528 provided, it overrides the default options.
530 Returns:
531 The reduced IndexedSlices.
532 """
534 # Current CollectiveAllGather implementations require input IndexedSlices to
535 # have consistent length across the board, we handle the reduction of
536 # IndexedSlices as follows:
537 # 1. Gather the lengths of IndexedSlices from all participants.
538 # 2. If they have consistent length, apply all_gather.
539 # 3. Otherwise pad IndexedSlices to be the same length across all
540 # participants and apply_gather.
541 options = self._options.merge(options)
542 with ops.device(self._device):
544 def all_gather_indexed_slices(
545 all_gather_fn: Callable[
546 [core.TensorLike, Optional[collective_util.Options]], core.Tensor]
547 ) -> indexed_slices.IndexedSlices:
548 """Use all_gather_fn to aggregate `IndexedSlices`."""
549 all_values = all_gather_fn(input_slices.values, options)
550 # Add control dependency to order the all-gather.
551 if (options.implementation ==
552 collective_util.CommunicationImplementation.NCCL):
553 control = [all_values]
554 else:
555 control = []
556 with ops.control_dependencies(control):
557 all_indices = all_gather_fn(input_slices.indices, options)
558 return indexed_slices.IndexedSlices(
559 values=all_values,
560 indices=all_indices,
561 dense_shape=input_slices.dense_shape)
563 length = array_ops.shape(input_slices.indices)
564 all_lengths = self._all_gather(length, options)
566 def all_gather_with_padding(
567 input_tensor: core.TensorLike,
568 options: Optional[collective_util.Options]) -> core.Tensor:
569 """all_gather tensors of different sizes using padding."""
570 max_length = math_ops.reduce_max(all_lengths)
571 padded_tensor = _pad_util(input_tensor, max_length)
572 all_padded_tensors = self._all_gather(padded_tensor, options)
573 split_tensors = []
574 for i in range(self._group_size):
575 start_pos = i * max_length
576 split_tensors.append(all_padded_tensors[start_pos:start_pos +
577 all_lengths[i]])
578 return array_ops.concat(split_tensors, 0)
580 return cond.cond(
581 math_ops.equal(
582 math_ops.reduce_max(all_lengths),
583 math_ops.reduce_min(all_lengths)),
584 lambda: all_gather_indexed_slices(self._all_gather),
585 lambda: all_gather_indexed_slices(all_gather_with_padding))
588def aggregate_tensors_or_indexed_slices(values, accumulation_fn=math_ops.add_n):
589 """Aggregate tensors using `accumulation_fn` and IndexedSlices via concat."""
590 if any(isinstance(v, indexed_slices.IndexedSlices) for v in values):
591 return backprop_util.AggregateIndexedSlicesGradients(values)
592 else:
593 return accumulation_fn(values)
596def divide_by_n_tensors_or_indexed_slices(value, n):
597 if isinstance(value, indexed_slices.IndexedSlices):
598 value = backprop_util.FlattenNestedIndexedSlices(value)
599 return indexed_slices.IndexedSlices(value.values / n, value.indices,
600 value.dense_shape)
601 else:
602 return value / n
605def copy_tensor_or_indexed_slices_to_device(value, device):
606 """Copies a tensor or IndexedSlices to a device."""
607 with ops.device(device):
608 if isinstance(value, indexed_slices.IndexedSlices):
609 copied_values = array_ops.identity(value.values)
610 copied_indices = array_ops.identity(value.indices)
611 if value.dense_shape is not None:
612 copied_shape = array_ops.identity(value.dense_shape)
613 else:
614 copied_shape = None
615 result = indexed_slices.IndexedSlices(copied_values, copied_indices,
616 copied_shape)
617 else:
618 result = array_ops.identity(value)
619 return result
622def is_indexed_slices(value):
623 if isinstance(value, indexed_slices.IndexedSlices):
624 return True
625 if isinstance(value, value_lib.DistributedValues):
626 return all(
627 isinstance(v, indexed_slices.IndexedSlices) for v in value.values)
628 return False
631def split_by_sparsity(values):
632 """Split values into dense and sparse values.
634 Args:
635 values: a list of tensors or `PerReplica`s.
637 Returns:
638 Four lists:
639 a list of dense values, a list of their indices in `values` and
640 a list of sparse values, a list of their indices in `values`.
641 """
642 dense_values = []
643 dense_indices = []
644 sparse_values = []
645 sparse_indices = []
646 for i, v in enumerate(values):
647 if is_indexed_slices(v):
648 sparse_values.append(v)
649 sparse_indices.append(i)
650 else:
651 dense_values.append(v)
652 dense_indices.append(i)
653 return dense_values, dense_indices, sparse_values, sparse_indices
656def stitch_values(values_and_indices_list):
657 """Stitch values together according to their indices.
659 Args:
660 values_and_indices_list: a list of tuples of values and indices indicating
661 the values and positions in the returned list.
663 Returns:
664 a stitched list of values.
665 """
666 length = 0
667 for values_and_indices in values_and_indices_list:
668 length += len(values_and_indices[0])
670 result = [None] * length
671 for values_and_indices in values_and_indices_list:
672 if values_and_indices and values_and_indices[0]:
673 for v, i in zip(*values_and_indices):
674 assert result[i] is None
675 result[i] = v
676 return result
679def group_by_size(input_tensors, bytes_per_pack):
680 """Groups `input_tensors` into chunks of `bytes_per_pack`.
682 The method preserves the original order of `input_tensors`. The grouping is
683 best effort, each pack could have more or less bytes than `bytes_per_pack`.
684 It only groups values with known shape.
686 Args:
687 input_tensors: a list of Tensor.
688 bytes_per_pack: an integer.
690 Returns:
691 A list of packs of Tensor. All values are grouped into one pack if
692 `bytes_per_pack` is zero or any of the value has unknown shape.
693 """
695 if bytes_per_pack == 0:
696 return [input_tensors]
697 packs = []
698 last_pack_size = 0
699 for value in input_tensors:
700 num_elements = value.shape.num_elements()
701 if num_elements is None:
702 # Can't pack values with unknown shape.
703 logging.warning(
704 'not packing values due to the unknown or inconsistent shape of %s',
705 value)
706 return [input_tensors]
707 size = num_elements * value.dtype.size
708 # Try to keep each pack as close to bytes_per_pack as possible, while each
709 # pack is at least bytes_per_pack large. I.E. we err on the side of having
710 # few but large packs.
711 if not packs or last_pack_size > bytes_per_pack:
712 packs.append([])
713 last_pack_size = 0
714 packs[-1].append(value)
715 last_pack_size += size
716 return packs
719def _pad_util(input_tensor, full_axis_dim):
720 """Pad the `input_tensor`'s first dimension to be `full_axis_dim`."""
721 missing_axis_dim = full_axis_dim - array_ops.shape_v2(input_tensor)[0]
722 tensor_rank = array_ops.rank(input_tensor)
723 paddings_axis = [[0, missing_axis_dim]]
724 paddings = array_ops.concat([
725 paddings_axis,
726 array_ops.zeros(shape=(tensor_rank - 1, 2), dtype=dtypes.int32)
727 ],
728 axis=0)
729 padded_input_tensor = array_ops.pad(input_tensor, paddings)
730 return padded_input_tensor