Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ragged/ragged_batch_gather_ops.py: 88%
8 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Batch gather operations for RaggedTensors."""
17from tensorflow.python.ops import array_ops
18from tensorflow.python.ops.ragged import ragged_gather_ops
19from tensorflow.python.ops.ragged import ragged_tensor
20from tensorflow.python.util import dispatch
23#===============================================================================
24# ragged.batch_gather
25#===============================================================================
26@dispatch.dispatch_for_api(array_ops.batch_gather)
27def batch_gather(params: ragged_tensor.RaggedOrDense,
28 indices: ragged_tensor.RaggedOrDense,
29 name=None):
30 """Gathers slices from `params` according to `indices` with batch dims.
32 This operation is similar to `gather`, but it assumes that the leading `N`
33 dimensions of `indices` and `params` are batch dimensions, and performs a
34 gather within each batch. In particular, when using this operation with `N`
35 batch dimensions `B1...BN`:
37 * `indices` has shape `[B1...BN, I]`
38 * `params` has shape `[B1...BN, P1...PM]`.
39 * `result` has shape `[B1...BN, I, P2...PM]`.
40 * `result[b1...bN, i, p2...pM] =
41 params[b1...bN, indices[b1...bN, i], p2...pM]`
43 Args:
44 params: A potentially ragged tensor with shape `[B1...BN, P1...PM]` (`N>=0`,
45 `M>0`).
46 indices: A potentially ragged tensor with shape `[B1...BN, I]` (`N>=0`).
47 name: A name for the operation (optional).
49 Returns:
50 A potentially ragged tensor with shape `[B1...BN, I, P2...PM]`.
51 `result.ragged_rank = max(indices.ragged_rank, params.ragged_rank)`.
53 #### Example:
55 >>> params = tf.ragged.constant([['a', 'b', 'c'], ['d'], [], ['e']])
56 >>> indices = tf.ragged.constant([[1, 2, 0], [], [], [0, 0]])
57 >>> tf.compat.v1.batch_gather(params, indices)
58 <tf.RaggedTensor [[b'b', b'c', b'a'], [], [], [b'e', b'e']]>
59 """
60 return ragged_gather_ops.gather(params, indices, batch_dims=-1, name=name)