Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ragged/ragged_where_op.py: 24%

80 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""where operation for RaggedTensors.""" 

16 

17import typing 

18 

19from tensorflow.python.framework import ops 

20from tensorflow.python.ops import array_ops 

21from tensorflow.python.ops import math_ops 

22from tensorflow.python.ops.ragged import ragged_concat_ops 

23from tensorflow.python.ops.ragged import ragged_functional_ops 

24from tensorflow.python.ops.ragged import ragged_gather_ops 

25from tensorflow.python.ops.ragged import ragged_tensor 

26from tensorflow.python.ops.ragged import ragged_tensor_shape 

27from tensorflow.python.util import dispatch 

28 

29 

30@dispatch.dispatch_for_api(array_ops.where_v2) 

31def where_v2(condition: ragged_tensor.RaggedOrDense, 

32 x: typing.Optional[ragged_tensor.RaggedOrDense] = None, 

33 y: typing.Optional[ragged_tensor.RaggedOrDense] = None, 

34 name=None): 

35 """Return the elements where `condition` is `True`. 

36 

37 : If both `x` and `y` are None: Retrieve indices of true elements. 

38 

39 Returns the coordinates of true elements of `condition`. The coordinates 

40 are returned in a 2-D tensor with shape 

41 `[num_true_values, dim_size(condition)]`, where `result[i]` is the 

42 coordinates of the `i`th true value (in row-major order). 

43 

44 : If both `x` and `y` are non-`None`: Multiplex between `x` and `y`. 

45 

46 Choose an output shape from the shapes of `condition`, `x`, and `y` that 

47 all three shapes are broadcastable to; and then use the broadcasted 

48 `condition` tensor as a mask that chooses whether the corredsponding element 

49 in the output should be taken from `x` (if `condition` is true) or `y` (if 

50 `condition` is false). 

51 

52 >>> # Example: retrieve indices of true elements 

53 >>> tf.where(tf.ragged.constant([[True, False], [True]])) 

54 <tf.Tensor: shape=(2, 2), dtype=int64, numpy= array([[0, 0], [1, 0]])> 

55 

56 >>> # Example: multiplex between `x` and `y` 

57 >>> tf.where(tf.ragged.constant([[True, False], [True, False, True]]), 

58 ... tf.ragged.constant([['A', 'B'], ['C', 'D', 'E']]), 

59 ... tf.ragged.constant([['a', 'b'], ['c', 'd', 'e']])) 

60 <tf.RaggedTensor [[b'A', b'b'], [b'C', b'd', b'E']]> 

61 

62 Args: 

63 condition: A potentially ragged tensor of type `bool` 

64 x: A potentially ragged tensor (optional). 

65 y: A potentially ragged tensor (optional). Must be specified if `x` is 

66 specified. Must have the same rank and type as `x`. 

67 name: A name of the operation (optional). 

68 

69 Returns: 

70 : If both `x` and `y` are `None`: 

71 A `Tensor` with shape `(num_true, rank(condition))`. 

72 : Otherwise: 

73 A potentially ragged tensor with the same type as `x` and `y`, and whose 

74 shape is broadcast-compatible with `x`, `y`, and `condition`. 

75 

76 Raises: 

77 ValueError: When exactly one of `x` or `y` is non-`None`; or when 

78 `condition`, `x`, and `y` have incompatible shapes. 

79 """ 

80 if (x is None) != (y is None): 

81 raise ValueError('x and y must be either both None or both non-None') 

82 

83 with ops.name_scope('RaggedWhere', name, [condition, x, y]): 

84 condition = ragged_tensor.convert_to_tensor_or_ragged_tensor( 

85 condition, name='condition') 

86 if x is None: 

87 return _coordinate_where(condition) 

88 else: 

89 x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name='x') 

90 y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y, name='y') 

91 condition, x, y = ragged_tensor.match_row_splits_dtypes(condition, x, y) 

92 return _elementwise_where_v2(condition, x, y) 

93 

94 

95@dispatch.dispatch_for_api(array_ops.where) 

96def where(condition: ragged_tensor.RaggedOrDense, 

97 x: typing.Optional[ragged_tensor.RaggedOrDense] = None, 

98 y: typing.Optional[ragged_tensor.RaggedOrDense] = None, 

99 name=None): 

100 """Return the elements, either from `x` or `y`, depending on the `condition`. 

101 

102 : If both `x` and `y` are `None`: 

103 Returns the coordinates of true elements of `condition`. The coordinates 

104 are returned in a 2-D tensor with shape 

105 `[num_true_values, dim_size(condition)]`, where `result[i]` is the 

106 coordinates of the `i`th true value (in row-major order). 

107 

108 : If both `x` and `y` are non-`None`: 

109 Returns a tensor formed by selecting values from `x` where condition is 

110 true, and from `y` when condition is false. In particular: 

111 

112 : If `condition`, `x`, and `y` all have the same shape: 

113 

114 * `result[i1...iN] = x[i1...iN]` if `condition[i1...iN]` is true. 

115 * `result[i1...iN] = y[i1...iN]` if `condition[i1...iN]` is false. 

116 

117 : Otherwise: 

118 

119 * `condition` must be a vector. 

120 * `x` and `y` must have the same number of dimensions. 

121 * The outermost dimensions of `condition`, `x`, and `y` must all have the 

122 same size. 

123 * `result[i] = x[i]` if `condition[i]` is true. 

124 * `result[i] = y[i]` if `condition[i]` is false. 

125 

126 Args: 

127 condition: A potentially ragged tensor of type `bool` 

128 x: A potentially ragged tensor (optional). 

129 y: A potentially ragged tensor (optional). Must be specified if `x` is 

130 specified. Must have the same rank and type as `x`. 

131 name: A name of the operation (optional) 

132 

133 Returns: 

134 : If both `x` and `y` are `None`: 

135 A `Tensor` with shape `(num_true, dim_size(condition))`. 

136 : Otherwise: 

137 A potentially ragged tensor with the same type, rank, and outermost 

138 dimension size as `x` and `y`. 

139 `result.ragged_rank = max(x.ragged_rank, y.ragged_rank)`. 

140 

141 Raises: 

142 ValueError: When exactly one of `x` or `y` is non-`None`; or when 

143 `condition`, `x`, and `y` have incompatible shapes. 

144 

145 #### Examples: 

146 

147 >>> # Coordinates where condition is true. 

148 >>> condition = tf.ragged.constant([[True, False, True], [False, True]]) 

149 >>> print(where(condition)) 

150 tf.Tensor( [[0 0] [0 2] [1 1]], shape=(3, 2), dtype=int64) 

151 

152 >>> # Elementwise selection between x and y, based on condition. 

153 >>> condition = tf.ragged.constant([[True, False, True], [False, True]]) 

154 >>> x = tf.ragged.constant([['A', 'B', 'C'], ['D', 'E']]) 

155 >>> y = tf.ragged.constant([['a', 'b', 'c'], ['d', 'e']]) 

156 >>> print(where(condition, x, y)) 

157 <tf.RaggedTensor [[b'A', b'b', b'C'], [b'd', b'E']]> 

158 

159 >>> # Row selection between x and y, based on condition. 

160 >>> condition = [True, False] 

161 >>> x = tf.ragged.constant([['A', 'B', 'C'], ['D', 'E']]) 

162 >>> y = tf.ragged.constant([['a', 'b', 'c'], ['d', 'e']]) 

163 >>> print(where(condition, x, y)) 

164 <tf.RaggedTensor [[b'A', b'B', b'C'], [b'd', b'e']]> 

165 """ 

166 if (x is None) != (y is None): 

167 raise ValueError('x and y must be either both None or both non-None') 

168 with ops.name_scope('RaggedWhere', name, [condition, x, y]): 

169 condition = ragged_tensor.convert_to_tensor_or_ragged_tensor( 

170 condition, name='condition') 

171 if x is None: 

172 return _coordinate_where(condition) 

173 else: 

174 x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name='x') 

175 y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y, name='y') 

176 condition, x, y = ragged_tensor.match_row_splits_dtypes(condition, x, y) 

177 return _elementwise_where(condition, x, y) 

178 

179 

180def _elementwise_where(condition, x, y): 

181 """Ragged version of tf.where(condition, x, y).""" 

182 condition_is_ragged = isinstance(condition, ragged_tensor.RaggedTensor) 

183 x_is_ragged = isinstance(x, ragged_tensor.RaggedTensor) 

184 y_is_ragged = isinstance(y, ragged_tensor.RaggedTensor) 

185 

186 if not (condition_is_ragged or x_is_ragged or y_is_ragged): 

187 return array_ops.where(condition, x, y) 

188 

189 elif condition_is_ragged and x_is_ragged and y_is_ragged: 

190 return ragged_functional_ops.map_flat_values(array_ops.where, condition, x, 

191 y) 

192 elif not condition_is_ragged: 

193 # Concatenate x and y, and then use `gather` to assemble the selected rows. 

194 condition.shape.assert_has_rank(1) 

195 x_and_y = ragged_concat_ops.concat([x, y], axis=0) 

196 x_nrows = _nrows(x, out_type=x_and_y.row_splits.dtype) 

197 y_nrows = _nrows(y, out_type=x_and_y.row_splits.dtype) 

198 indices = array_ops.where(condition, math_ops.range(x_nrows), 

199 x_nrows + math_ops.range(y_nrows)) 

200 return ragged_gather_ops.gather(x_and_y, indices) 

201 

202 else: 

203 raise ValueError('Input shapes do not match.') 

204 

205 

206def _elementwise_where_v2(condition, x, y): 

207 """Ragged version of tf.where_v2(condition, x, y).""" 

208 # Broadcast x, y, and condition to have the same shape. 

209 if not (condition.shape.is_fully_defined() and x.shape.is_fully_defined() and 

210 y.shape.is_fully_defined() and x.shape == y.shape and 

211 condition.shape == x.shape): 

212 shape_c = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor( 

213 condition) 

214 shape_x = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(x) 

215 shape_y = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(y) 

216 shape = ragged_tensor_shape.broadcast_dynamic_shape( 

217 shape_c, ragged_tensor_shape.broadcast_dynamic_shape(shape_x, shape_y)) 

218 condition = ragged_tensor_shape.broadcast_to(condition, shape) 

219 x = ragged_tensor_shape.broadcast_to(x, shape) 

220 y = ragged_tensor_shape.broadcast_to(y, shape) 

221 

222 condition_is_ragged = isinstance(condition, ragged_tensor.RaggedTensor) 

223 x_is_ragged = isinstance(x, ragged_tensor.RaggedTensor) 

224 y_is_ragged = isinstance(y, ragged_tensor.RaggedTensor) 

225 if not (condition_is_ragged or x_is_ragged or y_is_ragged): 

226 return array_ops.where_v2(condition, x, y) 

227 

228 return ragged_functional_ops.map_flat_values(array_ops.where_v2, condition, x, 

229 y) 

230 

231 

232def _coordinate_where(condition): 

233 """Ragged version of tf.where(condition).""" 

234 if not isinstance(condition, ragged_tensor.RaggedTensor): 

235 return array_ops.where(condition) 

236 

237 # The coordinate for each `true` value in condition.values. 

238 selected_coords = _coordinate_where(condition.values) 

239 

240 # Convert the first index in each coordinate to a row index and column index. 

241 condition = condition.with_row_splits_dtype(selected_coords.dtype) 

242 first_index = selected_coords[:, 0] 

243 selected_rows = array_ops.gather(condition.value_rowids(), first_index) 

244 selected_row_starts = array_ops.gather(condition.row_splits, selected_rows) 

245 selected_cols = first_index - selected_row_starts 

246 

247 # Assemble the row & column index with the indices for inner dimensions. 

248 return array_ops.concat([ 

249 array_ops.expand_dims(selected_rows, 1), 

250 array_ops.expand_dims(selected_cols, 1), selected_coords[:, 1:] 

251 ], 

252 axis=1) 

253 

254 

255def _nrows(rt_input, out_type): 

256 if isinstance(rt_input, ragged_tensor.RaggedTensor): 

257 return rt_input.nrows(out_type=out_type) 

258 else: 

259 return array_ops.shape(rt_input, out_type=out_type)[0]