Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ragged/ragged_where_op.py: 24%
80 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""where operation for RaggedTensors."""
17import typing
19from tensorflow.python.framework import ops
20from tensorflow.python.ops import array_ops
21from tensorflow.python.ops import math_ops
22from tensorflow.python.ops.ragged import ragged_concat_ops
23from tensorflow.python.ops.ragged import ragged_functional_ops
24from tensorflow.python.ops.ragged import ragged_gather_ops
25from tensorflow.python.ops.ragged import ragged_tensor
26from tensorflow.python.ops.ragged import ragged_tensor_shape
27from tensorflow.python.util import dispatch
30@dispatch.dispatch_for_api(array_ops.where_v2)
31def where_v2(condition: ragged_tensor.RaggedOrDense,
32 x: typing.Optional[ragged_tensor.RaggedOrDense] = None,
33 y: typing.Optional[ragged_tensor.RaggedOrDense] = None,
34 name=None):
35 """Return the elements where `condition` is `True`.
37 : If both `x` and `y` are None: Retrieve indices of true elements.
39 Returns the coordinates of true elements of `condition`. The coordinates
40 are returned in a 2-D tensor with shape
41 `[num_true_values, dim_size(condition)]`, where `result[i]` is the
42 coordinates of the `i`th true value (in row-major order).
44 : If both `x` and `y` are non-`None`: Multiplex between `x` and `y`.
46 Choose an output shape from the shapes of `condition`, `x`, and `y` that
47 all three shapes are broadcastable to; and then use the broadcasted
48 `condition` tensor as a mask that chooses whether the corredsponding element
49 in the output should be taken from `x` (if `condition` is true) or `y` (if
50 `condition` is false).
52 >>> # Example: retrieve indices of true elements
53 >>> tf.where(tf.ragged.constant([[True, False], [True]]))
54 <tf.Tensor: shape=(2, 2), dtype=int64, numpy= array([[0, 0], [1, 0]])>
56 >>> # Example: multiplex between `x` and `y`
57 >>> tf.where(tf.ragged.constant([[True, False], [True, False, True]]),
58 ... tf.ragged.constant([['A', 'B'], ['C', 'D', 'E']]),
59 ... tf.ragged.constant([['a', 'b'], ['c', 'd', 'e']]))
60 <tf.RaggedTensor [[b'A', b'b'], [b'C', b'd', b'E']]>
62 Args:
63 condition: A potentially ragged tensor of type `bool`
64 x: A potentially ragged tensor (optional).
65 y: A potentially ragged tensor (optional). Must be specified if `x` is
66 specified. Must have the same rank and type as `x`.
67 name: A name of the operation (optional).
69 Returns:
70 : If both `x` and `y` are `None`:
71 A `Tensor` with shape `(num_true, rank(condition))`.
72 : Otherwise:
73 A potentially ragged tensor with the same type as `x` and `y`, and whose
74 shape is broadcast-compatible with `x`, `y`, and `condition`.
76 Raises:
77 ValueError: When exactly one of `x` or `y` is non-`None`; or when
78 `condition`, `x`, and `y` have incompatible shapes.
79 """
80 if (x is None) != (y is None):
81 raise ValueError('x and y must be either both None or both non-None')
83 with ops.name_scope('RaggedWhere', name, [condition, x, y]):
84 condition = ragged_tensor.convert_to_tensor_or_ragged_tensor(
85 condition, name='condition')
86 if x is None:
87 return _coordinate_where(condition)
88 else:
89 x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name='x')
90 y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y, name='y')
91 condition, x, y = ragged_tensor.match_row_splits_dtypes(condition, x, y)
92 return _elementwise_where_v2(condition, x, y)
95@dispatch.dispatch_for_api(array_ops.where)
96def where(condition: ragged_tensor.RaggedOrDense,
97 x: typing.Optional[ragged_tensor.RaggedOrDense] = None,
98 y: typing.Optional[ragged_tensor.RaggedOrDense] = None,
99 name=None):
100 """Return the elements, either from `x` or `y`, depending on the `condition`.
102 : If both `x` and `y` are `None`:
103 Returns the coordinates of true elements of `condition`. The coordinates
104 are returned in a 2-D tensor with shape
105 `[num_true_values, dim_size(condition)]`, where `result[i]` is the
106 coordinates of the `i`th true value (in row-major order).
108 : If both `x` and `y` are non-`None`:
109 Returns a tensor formed by selecting values from `x` where condition is
110 true, and from `y` when condition is false. In particular:
112 : If `condition`, `x`, and `y` all have the same shape:
114 * `result[i1...iN] = x[i1...iN]` if `condition[i1...iN]` is true.
115 * `result[i1...iN] = y[i1...iN]` if `condition[i1...iN]` is false.
117 : Otherwise:
119 * `condition` must be a vector.
120 * `x` and `y` must have the same number of dimensions.
121 * The outermost dimensions of `condition`, `x`, and `y` must all have the
122 same size.
123 * `result[i] = x[i]` if `condition[i]` is true.
124 * `result[i] = y[i]` if `condition[i]` is false.
126 Args:
127 condition: A potentially ragged tensor of type `bool`
128 x: A potentially ragged tensor (optional).
129 y: A potentially ragged tensor (optional). Must be specified if `x` is
130 specified. Must have the same rank and type as `x`.
131 name: A name of the operation (optional)
133 Returns:
134 : If both `x` and `y` are `None`:
135 A `Tensor` with shape `(num_true, dim_size(condition))`.
136 : Otherwise:
137 A potentially ragged tensor with the same type, rank, and outermost
138 dimension size as `x` and `y`.
139 `result.ragged_rank = max(x.ragged_rank, y.ragged_rank)`.
141 Raises:
142 ValueError: When exactly one of `x` or `y` is non-`None`; or when
143 `condition`, `x`, and `y` have incompatible shapes.
145 #### Examples:
147 >>> # Coordinates where condition is true.
148 >>> condition = tf.ragged.constant([[True, False, True], [False, True]])
149 >>> print(where(condition))
150 tf.Tensor( [[0 0] [0 2] [1 1]], shape=(3, 2), dtype=int64)
152 >>> # Elementwise selection between x and y, based on condition.
153 >>> condition = tf.ragged.constant([[True, False, True], [False, True]])
154 >>> x = tf.ragged.constant([['A', 'B', 'C'], ['D', 'E']])
155 >>> y = tf.ragged.constant([['a', 'b', 'c'], ['d', 'e']])
156 >>> print(where(condition, x, y))
157 <tf.RaggedTensor [[b'A', b'b', b'C'], [b'd', b'E']]>
159 >>> # Row selection between x and y, based on condition.
160 >>> condition = [True, False]
161 >>> x = tf.ragged.constant([['A', 'B', 'C'], ['D', 'E']])
162 >>> y = tf.ragged.constant([['a', 'b', 'c'], ['d', 'e']])
163 >>> print(where(condition, x, y))
164 <tf.RaggedTensor [[b'A', b'B', b'C'], [b'd', b'e']]>
165 """
166 if (x is None) != (y is None):
167 raise ValueError('x and y must be either both None or both non-None')
168 with ops.name_scope('RaggedWhere', name, [condition, x, y]):
169 condition = ragged_tensor.convert_to_tensor_or_ragged_tensor(
170 condition, name='condition')
171 if x is None:
172 return _coordinate_where(condition)
173 else:
174 x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name='x')
175 y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y, name='y')
176 condition, x, y = ragged_tensor.match_row_splits_dtypes(condition, x, y)
177 return _elementwise_where(condition, x, y)
180def _elementwise_where(condition, x, y):
181 """Ragged version of tf.where(condition, x, y)."""
182 condition_is_ragged = isinstance(condition, ragged_tensor.RaggedTensor)
183 x_is_ragged = isinstance(x, ragged_tensor.RaggedTensor)
184 y_is_ragged = isinstance(y, ragged_tensor.RaggedTensor)
186 if not (condition_is_ragged or x_is_ragged or y_is_ragged):
187 return array_ops.where(condition, x, y)
189 elif condition_is_ragged and x_is_ragged and y_is_ragged:
190 return ragged_functional_ops.map_flat_values(array_ops.where, condition, x,
191 y)
192 elif not condition_is_ragged:
193 # Concatenate x and y, and then use `gather` to assemble the selected rows.
194 condition.shape.assert_has_rank(1)
195 x_and_y = ragged_concat_ops.concat([x, y], axis=0)
196 x_nrows = _nrows(x, out_type=x_and_y.row_splits.dtype)
197 y_nrows = _nrows(y, out_type=x_and_y.row_splits.dtype)
198 indices = array_ops.where(condition, math_ops.range(x_nrows),
199 x_nrows + math_ops.range(y_nrows))
200 return ragged_gather_ops.gather(x_and_y, indices)
202 else:
203 raise ValueError('Input shapes do not match.')
206def _elementwise_where_v2(condition, x, y):
207 """Ragged version of tf.where_v2(condition, x, y)."""
208 # Broadcast x, y, and condition to have the same shape.
209 if not (condition.shape.is_fully_defined() and x.shape.is_fully_defined() and
210 y.shape.is_fully_defined() and x.shape == y.shape and
211 condition.shape == x.shape):
212 shape_c = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(
213 condition)
214 shape_x = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(x)
215 shape_y = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(y)
216 shape = ragged_tensor_shape.broadcast_dynamic_shape(
217 shape_c, ragged_tensor_shape.broadcast_dynamic_shape(shape_x, shape_y))
218 condition = ragged_tensor_shape.broadcast_to(condition, shape)
219 x = ragged_tensor_shape.broadcast_to(x, shape)
220 y = ragged_tensor_shape.broadcast_to(y, shape)
222 condition_is_ragged = isinstance(condition, ragged_tensor.RaggedTensor)
223 x_is_ragged = isinstance(x, ragged_tensor.RaggedTensor)
224 y_is_ragged = isinstance(y, ragged_tensor.RaggedTensor)
225 if not (condition_is_ragged or x_is_ragged or y_is_ragged):
226 return array_ops.where_v2(condition, x, y)
228 return ragged_functional_ops.map_flat_values(array_ops.where_v2, condition, x,
229 y)
232def _coordinate_where(condition):
233 """Ragged version of tf.where(condition)."""
234 if not isinstance(condition, ragged_tensor.RaggedTensor):
235 return array_ops.where(condition)
237 # The coordinate for each `true` value in condition.values.
238 selected_coords = _coordinate_where(condition.values)
240 # Convert the first index in each coordinate to a row index and column index.
241 condition = condition.with_row_splits_dtype(selected_coords.dtype)
242 first_index = selected_coords[:, 0]
243 selected_rows = array_ops.gather(condition.value_rowids(), first_index)
244 selected_row_starts = array_ops.gather(condition.row_splits, selected_rows)
245 selected_cols = first_index - selected_row_starts
247 # Assemble the row & column index with the indices for inner dimensions.
248 return array_ops.concat([
249 array_ops.expand_dims(selected_rows, 1),
250 array_ops.expand_dims(selected_cols, 1), selected_coords[:, 1:]
251 ],
252 axis=1)
255def _nrows(rt_input, out_type):
256 if isinstance(rt_input, ragged_tensor.RaggedTensor):
257 return rt_input.nrows(out_type=out_type)
258 else:
259 return array_ops.shape(rt_input, out_type=out_type)[0]