Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ragged/ragged_squeeze_op.py: 28%
58 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Operator Squeeze for RaggedTensors."""
17from tensorflow.python.framework import constant_op
18from tensorflow.python.framework import dtypes
19from tensorflow.python.framework import ops
20from tensorflow.python.ops import array_ops
21from tensorflow.python.ops import control_flow_assert
22from tensorflow.python.ops import control_flow_ops
23from tensorflow.python.ops import math_ops
24from tensorflow.python.ops.ragged import ragged_tensor
25from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor
26from tensorflow.python.util import deprecation
27from tensorflow.python.util import dispatch
30@dispatch.dispatch_for_api(array_ops.squeeze_v2)
31def squeeze(input: ragged_tensor.Ragged, axis=None, name=None): # pylint: disable=redefined-builtin
32 """Ragged compatible squeeze.
34 If `input` is a `tf.Tensor`, then this calls `tf.squeeze`.
36 If `input` is a `tf.RaggedTensor`, then this operation takes `O(N)` time,
37 where `N` is the number of elements in the squeezed dimensions.
39 Args:
40 input: A potentially ragged tensor. The input to squeeze.
41 axis: An optional list of ints. Defaults to `None`. If the `input` is
42 ragged, it only squeezes the dimensions listed. It fails if `input` is
43 ragged and axis is []. If `input` is not ragged it calls tf.squeeze. Note
44 that it is an error to squeeze a dimension that is not 1. It must be in
45 the range of [-rank(input), rank(input)).
46 name: A name for the operation (optional).
48 Returns:
49 A potentially ragged tensor. Contains the same data as input,
50 but has one or more dimensions of size 1 removed.
51 """
52 with ops.name_scope(name, 'RaggedSqueeze', [input]):
53 input = ragged_tensor.convert_to_tensor_or_ragged_tensor(input)
54 if isinstance(input, ops.Tensor):
55 return array_ops.squeeze(input, axis, name)
57 if axis is None:
58 raise ValueError('Ragged.squeeze must have an axis argument.')
59 if isinstance(axis, int):
60 axis = [axis]
61 elif ((not isinstance(axis, (list, tuple))) or
62 (not all(isinstance(d, int) for d in axis))):
63 raise TypeError('Axis must be a list or tuple of integers.')
65 dense_dims = []
66 ragged_dims = []
67 # Normalize all the dims in axis to be positive
68 axis = [
69 array_ops.get_positive_axis(d, input.shape.ndims, 'axis[%d]' % i,
70 'rank(input)') for i, d in enumerate(axis)
71 ]
72 for dim in axis:
73 if dim > input.ragged_rank:
74 dense_dims.append(dim - input.ragged_rank)
75 else:
76 ragged_dims.append(dim)
78 # Make sure the specified ragged dimensions are squeezable.
79 assertion_list = []
80 scalar_tensor_one = constant_op.constant(1, dtype=input.row_splits.dtype)
81 for i, r in enumerate(input.nested_row_lengths()):
82 if i + 1 in ragged_dims:
83 assertion_list.append(
84 control_flow_assert.Assert(
85 math_ops.reduce_all(math_ops.equal(r, scalar_tensor_one)),
86 ['the given axis (axis = %d) is not squeezable!' % (i + 1)]))
87 if 0 in ragged_dims:
88 scalar_tensor_two = constant_op.constant(2, dtype=dtypes.int32)
89 assertion_list.append(
90 control_flow_assert.Assert(
91 math_ops.equal(
92 array_ops.size(input.row_splits), scalar_tensor_two),
93 ['the given axis (axis = 0) is not squeezable!']))
95 # Till now, we are sure that the ragged dimensions are squeezable.
96 squeezed_rt = None
97 squeezed_rt = control_flow_ops.with_dependencies(assertion_list,
98 input.flat_values)
100 if dense_dims:
101 # Gives error if the dense dimension is not squeezable.
102 squeezed_rt = array_ops.squeeze(squeezed_rt, dense_dims)
104 remaining_row_splits = []
105 remaining_row_splits = list()
106 for i, row_split in enumerate(input.nested_row_splits):
107 # each row_splits tensor is for dimension #(i+1) .
108 if (i + 1) not in ragged_dims:
109 remaining_row_splits.append(row_split)
110 # Take care of the first row if it is to be squeezed.
111 if remaining_row_splits and 0 in ragged_dims:
112 remaining_row_splits.pop(0)
114 squeezed_rt = RaggedTensor.from_nested_row_splits(squeezed_rt,
115 remaining_row_splits)
117 # Corner case: when removing all the ragged dimensions and the output is
118 # a scalar tensor e.g. ragged.squeeze(ragged.constant([[[1]]])).
119 if set(range(0, input.ragged_rank + 1)).issubset(set(ragged_dims)):
120 squeezed_rt = array_ops.squeeze(squeezed_rt, [0], name)
122 return squeezed_rt
125@dispatch.dispatch_for_api(array_ops.squeeze)
126def _ragged_squeeze_v1(input: ragged_tensor.Ragged, # pylint: disable=redefined-builtin
127 axis=None,
128 name=None,
129 squeeze_dims=None):
130 axis = deprecation.deprecated_argument_lookup('axis', axis, 'squeeze_dims',
131 squeeze_dims)
132 return squeeze(input, axis, name)