Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ragged/ragged_squeeze_op.py: 28%

58 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Operator Squeeze for RaggedTensors.""" 

16 

17from tensorflow.python.framework import constant_op 

18from tensorflow.python.framework import dtypes 

19from tensorflow.python.framework import ops 

20from tensorflow.python.ops import array_ops 

21from tensorflow.python.ops import control_flow_assert 

22from tensorflow.python.ops import control_flow_ops 

23from tensorflow.python.ops import math_ops 

24from tensorflow.python.ops.ragged import ragged_tensor 

25from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor 

26from tensorflow.python.util import deprecation 

27from tensorflow.python.util import dispatch 

28 

29 

30@dispatch.dispatch_for_api(array_ops.squeeze_v2) 

31def squeeze(input: ragged_tensor.Ragged, axis=None, name=None): # pylint: disable=redefined-builtin 

32 """Ragged compatible squeeze. 

33 

34 If `input` is a `tf.Tensor`, then this calls `tf.squeeze`. 

35 

36 If `input` is a `tf.RaggedTensor`, then this operation takes `O(N)` time, 

37 where `N` is the number of elements in the squeezed dimensions. 

38 

39 Args: 

40 input: A potentially ragged tensor. The input to squeeze. 

41 axis: An optional list of ints. Defaults to `None`. If the `input` is 

42 ragged, it only squeezes the dimensions listed. It fails if `input` is 

43 ragged and axis is []. If `input` is not ragged it calls tf.squeeze. Note 

44 that it is an error to squeeze a dimension that is not 1. It must be in 

45 the range of [-rank(input), rank(input)). 

46 name: A name for the operation (optional). 

47 

48 Returns: 

49 A potentially ragged tensor. Contains the same data as input, 

50 but has one or more dimensions of size 1 removed. 

51 """ 

52 with ops.name_scope(name, 'RaggedSqueeze', [input]): 

53 input = ragged_tensor.convert_to_tensor_or_ragged_tensor(input) 

54 if isinstance(input, ops.Tensor): 

55 return array_ops.squeeze(input, axis, name) 

56 

57 if axis is None: 

58 raise ValueError('Ragged.squeeze must have an axis argument.') 

59 if isinstance(axis, int): 

60 axis = [axis] 

61 elif ((not isinstance(axis, (list, tuple))) or 

62 (not all(isinstance(d, int) for d in axis))): 

63 raise TypeError('Axis must be a list or tuple of integers.') 

64 

65 dense_dims = [] 

66 ragged_dims = [] 

67 # Normalize all the dims in axis to be positive 

68 axis = [ 

69 array_ops.get_positive_axis(d, input.shape.ndims, 'axis[%d]' % i, 

70 'rank(input)') for i, d in enumerate(axis) 

71 ] 

72 for dim in axis: 

73 if dim > input.ragged_rank: 

74 dense_dims.append(dim - input.ragged_rank) 

75 else: 

76 ragged_dims.append(dim) 

77 

78 # Make sure the specified ragged dimensions are squeezable. 

79 assertion_list = [] 

80 scalar_tensor_one = constant_op.constant(1, dtype=input.row_splits.dtype) 

81 for i, r in enumerate(input.nested_row_lengths()): 

82 if i + 1 in ragged_dims: 

83 assertion_list.append( 

84 control_flow_assert.Assert( 

85 math_ops.reduce_all(math_ops.equal(r, scalar_tensor_one)), 

86 ['the given axis (axis = %d) is not squeezable!' % (i + 1)])) 

87 if 0 in ragged_dims: 

88 scalar_tensor_two = constant_op.constant(2, dtype=dtypes.int32) 

89 assertion_list.append( 

90 control_flow_assert.Assert( 

91 math_ops.equal( 

92 array_ops.size(input.row_splits), scalar_tensor_two), 

93 ['the given axis (axis = 0) is not squeezable!'])) 

94 

95 # Till now, we are sure that the ragged dimensions are squeezable. 

96 squeezed_rt = None 

97 squeezed_rt = control_flow_ops.with_dependencies(assertion_list, 

98 input.flat_values) 

99 

100 if dense_dims: 

101 # Gives error if the dense dimension is not squeezable. 

102 squeezed_rt = array_ops.squeeze(squeezed_rt, dense_dims) 

103 

104 remaining_row_splits = [] 

105 remaining_row_splits = list() 

106 for i, row_split in enumerate(input.nested_row_splits): 

107 # each row_splits tensor is for dimension #(i+1) . 

108 if (i + 1) not in ragged_dims: 

109 remaining_row_splits.append(row_split) 

110 # Take care of the first row if it is to be squeezed. 

111 if remaining_row_splits and 0 in ragged_dims: 

112 remaining_row_splits.pop(0) 

113 

114 squeezed_rt = RaggedTensor.from_nested_row_splits(squeezed_rt, 

115 remaining_row_splits) 

116 

117 # Corner case: when removing all the ragged dimensions and the output is 

118 # a scalar tensor e.g. ragged.squeeze(ragged.constant([[[1]]])). 

119 if set(range(0, input.ragged_rank + 1)).issubset(set(ragged_dims)): 

120 squeezed_rt = array_ops.squeeze(squeezed_rt, [0], name) 

121 

122 return squeezed_rt 

123 

124 

125@dispatch.dispatch_for_api(array_ops.squeeze) 

126def _ragged_squeeze_v1(input: ragged_tensor.Ragged, # pylint: disable=redefined-builtin 

127 axis=None, 

128 name=None, 

129 squeeze_dims=None): 

130 axis = deprecation.deprecated_argument_lookup('axis', axis, 'squeeze_dims', 

131 squeeze_dims) 

132 return squeeze(input, axis, name)