Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ragged/ragged_batch_gather_ops.py: 88%

8 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Batch gather operations for RaggedTensors.""" 

16 

17from tensorflow.python.ops import array_ops 

18from tensorflow.python.ops.ragged import ragged_gather_ops 

19from tensorflow.python.ops.ragged import ragged_tensor 

20from tensorflow.python.util import dispatch 

21 

22 

23#=============================================================================== 

24# ragged.batch_gather 

25#=============================================================================== 

26@dispatch.dispatch_for_api(array_ops.batch_gather) 

27def batch_gather(params: ragged_tensor.RaggedOrDense, 

28 indices: ragged_tensor.RaggedOrDense, 

29 name=None): 

30 """Gathers slices from `params` according to `indices` with batch dims. 

31 

32 This operation is similar to `gather`, but it assumes that the leading `N` 

33 dimensions of `indices` and `params` are batch dimensions, and performs a 

34 gather within each batch. In particular, when using this operation with `N` 

35 batch dimensions `B1...BN`: 

36 

37 * `indices` has shape `[B1...BN, I]` 

38 * `params` has shape `[B1...BN, P1...PM]`. 

39 * `result` has shape `[B1...BN, I, P2...PM]`. 

40 * `result[b1...bN, i, p2...pM] = 

41 params[b1...bN, indices[b1...bN, i], p2...pM]` 

42 

43 Args: 

44 params: A potentially ragged tensor with shape `[B1...BN, P1...PM]` (`N>=0`, 

45 `M>0`). 

46 indices: A potentially ragged tensor with shape `[B1...BN, I]` (`N>=0`). 

47 name: A name for the operation (optional). 

48 

49 Returns: 

50 A potentially ragged tensor with shape `[B1...BN, I, P2...PM]`. 

51 `result.ragged_rank = max(indices.ragged_rank, params.ragged_rank)`. 

52 

53 #### Example: 

54 

55 >>> params = tf.ragged.constant([['a', 'b', 'c'], ['d'], [], ['e']]) 

56 >>> indices = tf.ragged.constant([[1, 2, 0], [], [], [0, 0]]) 

57 >>> tf.compat.v1.batch_gather(params, indices) 

58 <tf.RaggedTensor [[b'b', b'c', b'a'], [], [], [b'e', b'e']]> 

59 """ 

60 return ragged_gather_ops.gather(params, indices, batch_dims=-1, name=name)