Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/tpu/ops/tpu_ops.py: 41%
91 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Operations for TPUs."""
17from tensorflow.python.framework import dtypes
18from tensorflow.python.framework import ops
19from tensorflow.python.ops import array_ops
20# pylint: disable=wildcard-import,unused-import
21from tensorflow.python.ops import gen_tpu_ops
22from tensorflow.python.ops.gen_tpu_ops import *
23# pylint: enable=wildcard-import,unused-import
24from tensorflow.python.platform import tf_logging as logging
25from tensorflow.python.tpu import tpu_function
26from tensorflow.python.util.tf_export import tf_export
29def _create_default_group_assignment():
30 num_shards = tpu_function.get_tpu_context().number_of_shards
31 if num_shards is None:
32 logging.warning(
33 "cross_replica_sum should be used within a tpu_shard_context, but "
34 "got unset number_of_shards. Assuming 1.")
35 num_shards = 1
36 group_assignment = [list(range(num_shards))]
37 return group_assignment
40def all_to_all(x,
41 concat_dimension,
42 split_dimension,
43 split_count,
44 group_assignment=None,
45 name=None):
46 """Exchange data across TPU replicas.
48 Args:
49 x: The local tensor.
50 concat_dimension: The dimension number to concatenate.
51 split_dimension: The dimension number to split.
52 split_count: The number of splits, this number must equal to the sub-group
53 size(group_assignment.get_shape()[1])
54 group_assignment: Optional 2d int32 lists with shape [num_groups,
55 num_replicas_per_group]. `group_assignment[i]` represents the replica ids
56 in the ith subgroup.
57 name: Optional op name.
59 Returns:
60 A `Tensor` which is concatenated by data from different replicas.
61 """
62 if group_assignment is None:
63 group_assignment = _create_default_group_assignment()
64 return gen_tpu_ops.all_to_all(
65 x,
66 group_assignment,
67 concat_dimension=concat_dimension,
68 split_dimension=split_dimension,
69 split_count=split_count,
70 name=name)
73@ops.RegisterGradient("AllToAll")
74def _all_to_all_grad(op, grad):
75 # The gradient of a all-to-all is also a all-to-all but the
76 # split_dimension and concat_dimension is swapped.
77 # The gradient with respect to group_assignment is None.
78 return [
79 gen_tpu_ops.all_to_all(
80 grad,
81 op.inputs[1],
82 concat_dimension=op.get_attr("split_dimension"),
83 split_dimension=op.get_attr("concat_dimension"),
84 split_count=op.get_attr("split_count")), None
85 ]
88@tf_export(v1=["tpu.cross_replica_sum"])
89def cross_replica_sum(x, group_assignment=None, name=None):
90 """Sum the input tensor across replicas according to group_assignment.
92 Args:
93 x: The local tensor to the sum.
94 group_assignment: Optional 2d int32 lists with shape [num_groups,
95 num_replicas_per_group]. `group_assignment[i]` represents the replica ids
96 in the ith subgroup.
97 name: Optional op name.
99 Returns:
100 A `Tensor` which is summed across replicas.
101 """
102 if group_assignment is None:
103 group_assignment = _create_default_group_assignment()
105 return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name)
108def collective_permute(x, source_target_pairs, name=None):
109 """Permute the input tensor across replicas given source_target_pairs.
111 For each source_target_pair <a, b>, we send replica a's input to replica b.
112 Each replica id must only appear once in the source column. Also it must
113 only appear once in the target column.
114 For the replica id not in the target column, this op returns a zero tensor
115 with the same shape and dtype of the input x.
117 For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing
118 source_target_pairs=`[[0,1],[1,2],[2,3]]` gets the outputs:
119 `[0, A, B, C]`.
121 Args:
122 x: The local tensor to be permuted.
123 source_target_pairs: 2d int lists with shape [num_pairs, 2].
124 source_target_pairs[i][0] represents the source replica id and
125 source_target_pairs[i][1] represents the target replica id.
126 name: Optional op name.
128 Returns:
129 A `Tensor` which is permuted.
130 """
131 return gen_tpu_ops.collective_permute(x, source_target_pairs, name=name)
134@ops.RegisterGradient("CollectivePermute")
135def _collective_permute_grad(op, grad):
136 # The gradient of a collective permute operation is also a collective
137 # permute, but with source/target pairs reversed. The gradient with respect
138 # to input argument `source_target_pairs` is `None`.
139 source_target_pairs = op.inputs[1][:, ::-1]
140 return [gen_tpu_ops.collective_permute(grad, source_target_pairs), None]
143@ops.RegisterGradient("CrossReplicaSum")
144def _cross_replica_sum_grad(op, grad):
145 # The gradient of a cross replica sum is also a cross-replica sum.
146 # The gradient with respect to group_assignment is None.
147 return [gen_tpu_ops.cross_replica_sum(grad, op.inputs[1]), None]
150# This extra type checking exists to give a more helpful error message.
151_SUPPORTED_INFEED_DTYPES = frozenset([
152 dtypes.bool, dtypes.int32, dtypes.int64, dtypes.bfloat16, dtypes.float32,
153 dtypes.complex64, dtypes.uint32, dtypes.uint8, dtypes.int8
154])
157@ops.RegisterGradient("TPUEmbeddingActivations")
158def _embedding_activations_grad(activations_op, grad_wrt_activations):
159 """Saves the gradient of embedding activations ops in a graph collection."""
160 g = ops.get_default_graph()
161 table_id = activations_op.get_attr("table_id")
162 lookup_id = activations_op.get_attr("lookup_id")
163 table_gradients = g.get_collection_ref("tpu_embedding_gradients_table_%d" %
164 table_id)
166 if not table_gradients:
167 raise RuntimeError(
168 "Gradients for TPUEmbedding have been generated in non-training mode."
169 "This is not expected. Consider putting your Optimizer.minimize code "
170 "behind the training mode condition check. For Estimator, you can "
171 "do \n\n"
172 " if mode == tf.estimator.ModeKeys.TRAIN:\n"
173 " train_op = opt.minimize(loss)\n"
174 "\n")
176 if lookup_id < 0 or lookup_id >= len(table_gradients):
177 raise RuntimeError(
178 "Gradients (w.r.t. TPUEmbedding activations) generated for table_id {} "
179 "and lookup_id {}. The lookup_id attribute is outside the expected "
180 "range [0, {}).".format(table_id, lookup_id, len(table_gradients)))
182 if table_gradients[lookup_id] is not None:
183 raise RuntimeError(
184 "Duplicate gradients (w.r.t. TPUEmbedding activations) generated for "
185 "table_id {} and lookup_id {}. This happens when there are multiple "
186 "calls to tf.gradients in a graph containing TPU embeddings. "
187 "TF cannot identify which gradient to use for updating the embedding "
188 "variables. Consider placing tf.StopGradient around tensors where "
189 "variable update is not required. Previous gradients were generated by "
190 "the following callstack: {}.".format(
191 table_id, lookup_id, table_gradients[lookup_id].op.traceback))
193 table_gradients[lookup_id] = array_ops.identity(grad_wrt_activations)
194 return [
195 # RegisterGradient requires that value be returned for all inputs. Since
196 # the first argument (tpu_gradient_variable_{table_name}) has shape [1],
197 # we will return zeros(shape=[1]). The actual gradient w.r.t. the
198 # embedding activations (grad_wrt_activations) has the same shape as the
199 # activations returned by embedding_activations.
200 array_ops.zeros(arg.shape, dtype=dtypes.float32)
201 for arg in activations_op.inputs
202 ]
205def infeed_dequeue(dtype, shape, name=None):
206 """A placeholder op for a value that will be fed into the computation.
208 Args:
209 dtype: A `tf.DType`. The type of elements in the tensor.
210 shape: A `tf.TensorShape` or list of `ints`. The shape of the tensor.
211 name: A name for the operation (optional).
213 Returns:
214 A `Tensor` of type `dtype`.
215 A tensor that will be provided using the infeed mechanism.
217 Raises:
218 TypeError: If 'dtype` is not a supported infeed type.
219 """
220 if dtype not in _SUPPORTED_INFEED_DTYPES:
221 raise TypeError(
222 "Operation '{}' has type {} which is not a supported TPU infeed type. "
223 "Supported types are: {}".format(name, dtype,
224 list(_SUPPORTED_INFEED_DTYPES)))
226 return gen_tpu_ops.infeed_dequeue(dtype, shape, name=name)
229# pylint: disable=redefined-outer-name
230def infeed_dequeue_tuple(dtypes, shapes, name=None):
231 """A placeholder op for values fed into the TPU simultaneously as a tuple.
233 Args:
234 dtypes: A list of `tf.DType`s that has length `>= 1`. The element types of
235 each element in `outputs`.
236 shapes: A list of shapes (each a `tf.TensorShape` or list of `ints`). The
237 shapes of each tensor in `outputs`.
238 name: A name for the operation (optional).
240 Returns:
241 A list of `Tensor` objects of type `dtypes`.
242 A list of tensors that will be provided using the infeed mechanism.
244 Raises:
245 TypeError: If a type in 'dtypes` is not a supported infeed type.
246 """
247 for dtype in dtypes:
248 if dtype not in _SUPPORTED_INFEED_DTYPES:
249 raise TypeError(
250 "{} is not a supported TPU infeed type. Supported types are: "
251 "{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES)))
252 return gen_tpu_ops.infeed_dequeue_tuple(dtypes, shapes, name=name)
255# pylint: enable=redefined-outer-name
258# pylint: disable=protected-access
259def send_tpu_embedding_gradients(inputs,
260 config,
261 learning_rates=None,
262 name=None):
263 """A placeholder op for feeding per-sample gradients to the embedding layer.
265 Args:
266 inputs: A TensorList of gradients with which to update embedding tables.
267 This argument has the same length and shapes as the return value of
268 RecvTPUEmbeddingActivations, but contains gradients of the model's loss
269 with respect to the embedding activations. The embedding tables are
270 updated from these gradients via the optimizers specified in the TPU
271 embedding configuration given to tpu.initialize_system.
272 config: Serialized TPUEmbeddingConfiguration proto.
273 learning_rates: A TensorList of float32 scalars, one for each dynamic
274 learning rate tag: see the comments in
275 //third_party/tensorflow/core/protobuf/tpu/
276 optimization_parameters.proto. Multiple tables can share the same
277 dynamic learning rate tag as specified in the configuration. If the
278 learning rates for all tables are constant, this list should be empty.
279 name: A name for the operation (optional).
281 Returns:
282 A SendTPUEmbeddingGradients operation.
283 """
284 if learning_rates is None:
285 learning_rates = []
286 return gen_tpu_ops.send_tpu_embedding_gradients(
287 inputs=inputs, learning_rates=learning_rates, config=config, name=name)
290send_tpu_embedding_gradients.__doc__ = (
291 gen_tpu_ops.send_tpu_embedding_gradients.__doc__)
294# pylint: disable=protected-access
295def enqueue_tpu_embedding_integer_batch(batch,
296 device_ordinal,
297 mode_override=None,
298 name=None):
299 """A placeholder op for enqueueing embedding IDs to the TPU.
301 Args:
302 batch: A list of 1D tensors, one for each embedding table, containing the
303 indices into the tables.
304 device_ordinal: The TPU device to use. Should be >= 0 and less than the
305 number of TPU cores in the task on which the node is placed.
306 mode_override: A string input that overrides the mode specified in the
307 TPUEmbeddingConfiguration. Supported values are {'unspecified',
308 'inference', 'train', 'backward_pass_only'}. When set to 'unspecified',
309 the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override
310 is used (optional).
311 name: A name for the operation (optional).
313 Returns:
314 An EnqueueTPUEmbeddingIntegerBatch operation.
315 """
316 if mode_override is None:
317 mode_override = "unspecified"
318 return gen_tpu_ops.enqueue_tpu_embedding_integer_batch(
319 batch=batch,
320 device_ordinal=device_ordinal,
321 mode_override=mode_override,
322 name=name)
325enqueue_tpu_embedding_integer_batch.__doc__ = (
326 gen_tpu_ops.enqueue_tpu_embedding_integer_batch.__doc__)
329# pylint: disable=protected-access
330def enqueue_tpu_embedding_sparse_batch(sample_indices,
331 embedding_indices,
332 aggregation_weights,
333 device_ordinal,
334 combiners=None,
335 mode_override=None,
336 name=None):
337 """A placeholder op for enqueueing embedding IDs to the TPU.
339 Args:
340 sample_indices: A list of rank 1 Tensors specifying the training example and
341 feature to which the corresponding embedding_indices and
342 aggregation_weights values belong. sample_indices[i] must equal b * nf +
343 f, where nf is the number of features from the corresponding table, f is
344 in [0, nf), and b is in [0, batch size). Both int32 and int64 are allowed,
345 and will be converted to int32 internally.
346 embedding_indices: A list of rank 1 Tensors, indices into the embedding
347 tables. Both int32 and int64 are allowed and will be converted to int32
348 internally.
349 aggregation_weights: A list of rank 1 Tensors containing per sample -- i.e.,
350 per (training example, feature) -- aggregation weights. Both float32 and
351 float64 are allowed and will be converted to float32 internally.
352 device_ordinal: The TPU device to use. Should be >= 0 and less than the
353 number of TPU cores in the task on which the node is placed.
354 combiners: A list of string scalars, one for each embedding table that
355 specify how to normalize the embedding activations after weighted
356 summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
357 invalid to have the sum of the weights be 0 for 'mean' or the sum of the
358 squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default
359 is to use 'sum' for all tables (optional).
360 mode_override: A string input that overrides the mode specified in the
361 TPUEmbeddingConfiguration. Supported values are {'unspecified',
362 'inference', 'train', 'backward_pass_only'}. When set to 'unspecified',
363 the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override
364 is used (optional).
365 name: A name for the operation (optional).
367 Returns:
368 An EnqueueTPUEmbeddingSparseBatch operation.
369 """
370 if mode_override is None:
371 mode_override = "unspecified"
372 return gen_tpu_ops.enqueue_tpu_embedding_sparse_batch(
373 sample_indices=sample_indices,
374 embedding_indices=embedding_indices,
375 aggregation_weights=aggregation_weights,
376 device_ordinal=device_ordinal,
377 combiners=combiners,
378 mode_override=mode_override,
379 name=name)
382enqueue_tpu_embedding_sparse_batch.__doc__ = (
383 gen_tpu_ops.enqueue_tpu_embedding_sparse_batch.__doc__)
386# pylint: disable=protected-access
387def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices,
388 embedding_indices,
389 aggregation_weights,
390 table_ids,
391 device_ordinal,
392 max_sequence_lengths=None,
393 num_features=None,
394 combiners=None,
395 mode_override=None,
396 name=None):
397 """A placeholder op for enqueueing embedding IDs to the TPU.
399 Args:
400 sample_indices: A list of rank 2 Tensors specifying the training example to
401 which the corresponding embedding_indices and aggregation_weights values
402 belong. It corresponds to sp_ids.indices in embedding_lookup_sparse(). If
403 the size of its first dimension is 0, we assume each embedding_indices
404 belongs to a different sample. Both int32 and int64 are allowed and will
405 be converted to int32 internally.
406 embedding_indices: A list of rank 1 Tensors, indices into the embedding
407 tables. It corresponds to sp_ids.values in embedding_lookup_sparse(). Both
408 int32 and int64 are allowed and will be converted to int32 internally.
409 aggregation_weights: A list of rank 1 Tensors containing per training
410 example aggregation weights. It corresponds to sp_weights.values in
411 embedding_lookup_sparse(). If the size of its first dimension is 0, we
412 assume all weights are 1. Both float32 and float64 are allowed and will be
413 converted to float32 internally.
414 table_ids: A list of integers specifying the identifier of the embedding
415 table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to
416 lookup the corresponding input. The ith input is looked up using
417 table_ids[i]. The size of the table_ids list must be equal to that of
418 sample_indices, embedding_indices and aggregation_weights.
419 device_ordinal: The TPU device to use. Should be >= 0 and less than the
420 number of TPU cores in the task on which the node is placed.
421 max_sequence_lengths: A list of integers, the size of which is equal to
422 sample_indices. If equal to 0, the corresponding feature is considered to
423 be a non-sequence feature, If greater than 0, the corresponding feature is
424 a sequence feature with the given maximal length. If None, then we assume
425 a list of all zeroes.
426 num_features: A list of integers, the size of which is equal to
427 sample_indices. If non-empty, entries in this list must be at least 1. For
428 each batch element, we will take num_features rows of the input tensor for
429 embedding lookup. E.g., when sample_indices is empty, the embedding
430 indices must be of shape (batch_size*num_features).
431 combiners: A list of string scalars, one for each embedding table that
432 specify how to normalize the embedding activations after weighted
433 summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
434 invalid to have the sum of the weights be 0 for 'mean' or the sum of the
435 squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default
436 is to use 'sum' for all tables (optional).
437 mode_override: A string input that overrides the mode specified in the
438 TPUEmbeddingConfiguration. Supported values are {'unspecified',
439 'inference', 'train', 'backward_pass_only'}. When set to 'unspecified',
440 the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override
441 is used (optional).
442 name: A name for the operation (optional).
444 Returns:
445 An EnqueueTPUEmbeddingSparseTensorBatch operation.
446 """
447 if mode_override is None:
448 mode_override = "unspecified"
449 return gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch(
450 sample_indices=sample_indices,
451 embedding_indices=embedding_indices,
452 aggregation_weights=aggregation_weights,
453 table_ids=table_ids,
454 device_ordinal=device_ordinal,
455 max_sequence_lengths=max_sequence_lengths,
456 combiners=combiners,
457 mode_override=mode_override,
458 num_features=num_features,
459 name=name)
462enqueue_tpu_embedding_sparse_tensor_batch.__doc__ = (
463 gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch.__doc__)
466# pylint: disable=protected-access
467def enqueue_tpu_embedding_ragged_tensor_batch(sample_splits,
468 embedding_indices,
469 aggregation_weights,
470 table_ids,
471 device_ordinal,
472 max_sequence_lengths=None,
473 num_features=None,
474 combiners=None,
475 mode_override=None,
476 name=None):
477 """A placeholder op for enqueueing embedding IDs to the TPU.
479 Args:
480 sample_splits: A list of rank 1 Tensors specifying the break points for
481 splitting embedding_indices and aggregation_weights into rows. It
482 corresponds to ids.row_splits in embedding_lookup(), when ids is a
483 RaggedTensor. Both int32 and int64 are allowed and will be converted to
484 int32 internally.
485 embedding_indices: A list of rank 1 Tensors, indices into the embedding
486 tables. It corresponds to ids.values in embedding_lookup(), when ids is a
487 RaggedTensor. Both int32 and int64 are allowed and will be converted to
488 int32 internally.
489 aggregation_weights: A list of rank 1 Tensors containing per training
490 example aggregation weights. It corresponds to the values field of a
491 RaggedTensor with the same row_splits as ids in embedding_lookup(), when
492 ids is a RaggedTensor. Both float32 and float64 are allowed and will be
493 converted to float32 internally.
494 table_ids: A list of integers specifying the identifier of the embedding
495 table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to
496 lookup the corresponding input. The ith input is looked up using
497 table_ids[i]. The size of the table_ids list must be equal to that of
498 sample_indices, embedding_indices and aggregation_weights.
499 device_ordinal: The TPU device to use. Should be >= 0 and less than the
500 number of TPU cores in the task on which the node is placed.
501 max_sequence_lengths: A list of integers, the size of which is equal to
502 sample_indices. If equal to 0, the corresponding feature is considered to
503 be a non-sequence feature, If greater than 0, the corresponding feature is
504 a sequence feature with the given maximal length. If None, then we assume
505 a list of all zeroes.
506 num_features: A list of integers, the size of which must be equal to
507 sample_indices. If non-empty, entries in this list must be at least 1. For
508 each batch element, we will take num_features rows of the input tensor for
509 embedding lookup. E.g., when sample_indices is empty, the embedding
510 indices must be of shape (batch_size*num_features).
511 combiners: A list of string scalars, one for each embedding table that
512 specify how to normalize the embedding activations after weighted
513 summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
514 invalid to have the sum of the weights be 0 for 'mean' or the sum of the
515 squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default
516 is to use 'sum' for all tables (optional).
517 mode_override: A string input that overrides the mode specified in the
518 TPUEmbeddingConfiguration. Supported values are {'unspecified',
519 'inference', 'training', 'backward_pass_only'}. When set to 'unspecified',
520 the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override
521 is used (optional).
522 name: A name for the operation (optional).
524 Returns:
525 An EnqueueTPUEmbeddingRaggedTensorBatch operation.
526 """
527 if mode_override is None:
528 mode_override = "unspecified"
529 return gen_tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch(
530 sample_splits=sample_splits,
531 embedding_indices=embedding_indices,
532 aggregation_weights=aggregation_weights,
533 table_ids=table_ids,
534 device_ordinal=device_ordinal,
535 max_sequence_lengths=max_sequence_lengths,
536 combiners=combiners,
537 mode_override=mode_override,
538 num_features=num_features,
539 name=name)
542enqueue_tpu_embedding_ragged_tensor_batch.__doc__ = (
543 gen_tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch.__doc__)
546def enqueue_tpu_embedding_arbitrary_tensor_batch(sample_indices_or_row_splits,
547 embedding_indices,
548 aggregation_weights,
549 device_ordinal,
550 combiners=None,
551 mode_override=None,
552 name=None):
553 """A placeholder op for enqueueing embedding IDs to the TPU.
555 Args:
556 sample_indices_or_row_splits: A list of rank 1 or 2 Tensors. When rank 2,
557 the tensors specify the training example to which the corresponding
558 embedding_indices and aggregation_weights values belong. If the size of
559 its first dimension is 0, we assume each embedding_indices belongs to a
560 different sample. Both int32 and int64 are allowed and will be converted
561 to int32 internally. When rank 1, the tensors specify the row splits for
562 splitting embedding_indices and aggregation_weights into rows. It
563 corresponds to ids.row_splits in embedding_lookup(), when ids is a
564 RaggedTensor. When enqueuing N-D ragged tensor, only the last dimension is
565 allowed to be ragged. the row splits is 1-D dense tensor. When empty, we
566 assume a dense tensor is passed to the op. Both int32 and int64 are
567 allowed and will be converted to int32 internally.
568 embedding_indices: A list of rank 1 Tensors, indices into the embedding
569 tables. Both int32 and int64 are allowed and will be converted to int32
570 internally.
571 aggregation_weights: A list of rank 1 Tensors containing per training
572 example aggregation weights. Both float32 and float64 are allowed and will
573 be converted to float32 internally.
574 device_ordinal: The TPU device to use. Should be >= 0 and less than the
575 number of TPU cores in the task on which the node is placed.
576 combiners: A list of string scalars, one for each embedding table that
577 specify how to normalize the embedding activations after weighted
578 summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
579 invalid to have the sum of the weights be 0 for 'mean' or the sum of the
580 squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default
581 is to use 'sum' for all tables (optional).
582 mode_override: A string input that overrides the mode specified in the
583 TPUEmbeddingConfiguration. Supported values are {'unspecified',
584 'inference', 'training', 'backward_pass_only'}. When set to 'unspecified',
585 the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override
586 is used (optional).
587 name: A name for the operation (optional).
589 Returns:
590 An EnqueueTPUEmbeddingArbitraryTensorBatch operation.
591 """
592 if mode_override is None:
593 mode_override = "unspecified"
594 return gen_tpu_ops.enqueue_tpu_embedding_arbitrary_tensor_batch(
595 sample_indices_or_row_splits=sample_indices_or_row_splits,
596 embedding_indices=embedding_indices,
597 aggregation_weights=aggregation_weights,
598 device_ordinal=device_ordinal,
599 combiners=combiners,
600 mode_override=mode_override,
601 name=name)
604enqueue_tpu_embedding_arbitrary_tensor_batch.__doc__ = (
605 gen_tpu_ops.enqueue_tpu_embedding_arbitrary_tensor_batch.__doc__)