Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ragged/ragged_gather_ops.py: 15%

187 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Gather operations for RaggedTensors.""" 

16 

17from tensorflow.python.framework import dtypes 

18from tensorflow.python.framework import indexed_slices 

19from tensorflow.python.framework import ops 

20from tensorflow.python.framework import tensor_shape 

21from tensorflow.python.ops import array_ops 

22from tensorflow.python.ops import gen_ragged_array_ops 

23from tensorflow.python.ops import math_ops 

24from tensorflow.python.ops.ragged import ragged_array_ops 

25from tensorflow.python.ops.ragged import ragged_math_ops 

26from tensorflow.python.ops.ragged import ragged_tensor 

27from tensorflow.python.util import dispatch 

28 

29 

30#=============================================================================== 

31# ragged_gather 

32#=============================================================================== 

33@dispatch.dispatch_for_api(array_ops.gather_v2) 

34def gather(params: ragged_tensor.RaggedOrDense, 

35 indices: ragged_tensor.RaggedOrDense, 

36 validate_indices=None, 

37 axis=None, 

38 batch_dims=0, 

39 name=None): 

40 """Gathers ragged slices from `params` axis `0` according to `indices`. 

41 

42 See `tf.gather` for full documentation. (This version has the same API 

43 as `tf.gather`, but supports ragged `params` and `indices`.) 

44 

45 Examples: 

46 

47 >>> params = tf.constant(['a', 'b', 'c', 'd', 'e']) 

48 >>> indices = tf.constant([3, 1, 2, 1, 0]) 

49 >>> ragged_params = tf.ragged.constant([['a', 'b', 'c'], ['d'], [], ['e']]) 

50 >>> ragged_indices = tf.ragged.constant([[3, 1, 2], [1], [], [0]]) 

51 

52 >>> tf.gather(params, ragged_indices) 

53 <tf.RaggedTensor [[b'd', b'b', b'c'], [b'b'], [], [b'a']]> 

54 

55 >>> tf.gather(ragged_params, indices) 

56 <tf.RaggedTensor [[b'e'], [b'd'], [], [b'd'], [b'a', b'b', b'c']]> 

57 

58 >>> tf.gather(ragged_params, ragged_indices) 

59 <tf.RaggedTensor [[[b'e'], [b'd'], []], [[b'd']], [], [[b'a', b'b', b'c']]]> 

60 

61 Args: 

62 params: The potentially ragged tensor from which to gather values. Must be 

63 at least rank 1. 

64 indices: The potentially ragged tensor indicating which values to gather. 

65 Must have dtype `int32` or `int64`. Values must be in the range `[0, 

66 params.shape[0]]`. 

67 validate_indices: Ignored. 

68 axis: The axis in `params` to gather `indices` from. 

69 batch_dims: The number of batch dimensions. 

70 name: A name for the operation (optional). 

71 

72 Returns: 

73 A `RaggedTensor`, where `output.dtype=params.dtype` and 

74 `output.shape=indices.shape + params.shape[1:]` and 

75 `output.ragged_rank=indices.shape.ndims + params.ragged_rank`. 

76 

77 Raises: 

78 ValueError: If indices.shape.ndims is not known statically. 

79 """ 

80 del validate_indices 

81 

82 with ops.name_scope(name, 'RaggedGather', [params, indices]): 

83 params = ragged_tensor.convert_to_tensor_or_ragged_tensor( 

84 params, name='params') 

85 indices = ragged_tensor.convert_to_tensor_or_ragged_tensor( 

86 indices, name='indices') 

87 params, indices = ragged_tensor.match_row_splits_dtypes(params, indices) 

88 

89 if batch_dims != indices.shape.rank: 

90 batch_dims = array_ops.get_positive_axis( 

91 batch_dims, 

92 indices.shape.rank, 

93 axis_name='batch_dims', 

94 ndims_name='rank(indices)') 

95 if params.shape.rank is not None and batch_dims >= params.shape.rank: 

96 raise ValueError('batch_dims must be less than rank(params)') 

97 if axis is None: 

98 axis = batch_dims 

99 axis = array_ops.get_positive_axis( 

100 axis, params.shape.rank, ndims_name='rank(params)') 

101 if axis < batch_dims: 

102 raise ValueError('axis must be greater than or equal to batch_dims') 

103 if indices.shape.rank is not None: 

104 if not 0 <= batch_dims <= indices.shape.rank: 

105 raise ValueError( 

106 'batch_dims=%s must be between 0 and rank(indices)=%s' % 

107 (batch_dims, indices.shape.rank)) 

108 

109 return _gather(params, indices, axis, batch_dims) 

110 

111 

112def _gather(params, indices, axis, batch_dims): 

113 """Helper that implements the body for ragged gather(). 

114 

115 Assumes that `params` and `indices` have been converted to tensors or 

116 ragged tensors, and that `axis` and `batch_dims` have been normalized to 

117 be positive. (So these conversions & normalizations can be skipped in 

118 recursive calls to _gather). 

119 

120 Args: 

121 params: The tensor from which to gather values. 

122 indices: The indices of values to gather. 

123 axis: The axis in `params` to gather `indices` from. 

124 batch_dims: The number of batch dimensions. 

125 

126 Returns: 

127 A potentially ragged tensor. 

128 """ 

129 params_is_ragged = ragged_tensor.is_ragged(params) 

130 indices_is_ragged = ragged_tensor.is_ragged(indices) 

131 

132 if not (params_is_ragged or indices_is_ragged): 

133 return array_ops.gather(params, indices, axis=axis, batch_dims=batch_dims) 

134 

135 if batch_dims > 0: 

136 return _batch_gather(params, indices, axis, batch_dims) 

137 

138 if axis > 0: 

139 return _axis_gather(params, indices, axis) 

140 

141 if indices_is_ragged: 

142 return indices.with_values(_gather(params, indices.values, 0, 0)) 

143 

144 if indices.shape.ndims is None: 

145 raise ValueError('rank(indices) must be known statically') 

146 

147 out_ragged_rank = indices.shape.ndims + len(params.nested_row_splits) - 1 

148 result = gen_ragged_array_ops.ragged_gather( 

149 indices=indices, 

150 params_dense_values=params.flat_values, 

151 params_nested_splits=params.nested_row_splits, 

152 OUTPUT_RAGGED_RANK=out_ragged_rank) 

153 

154 result = ragged_tensor.RaggedTensor.from_nested_row_splits( 

155 result.output_dense_values, result.output_nested_splits, validate=False) 

156 

157 # Inject uniform_row_lengths into the result RaggedTensors for dimensions 

158 # corresponding to dense outer dimensions of `indices`. 

159 # TODO(edloper): Change this to construct the result using RowPartition 

160 # objects instead, so we don't need to modify private variables. 

161 if indices.shape.ndims > 1: 

162 target = result 

163 indices_shape = array_ops.shape(indices, out_type=params.row_splits.dtype) 

164 shape_cumprod = math_ops.cumprod(indices_shape) 

165 for dim in range(indices.shape.ndims - 1): 

166 # pylint: disable=protected-access 

167 target._cached_nrows = shape_cumprod[dim] 

168 target._uniform_row_length = indices_shape[dim + 1] 

169 target = target.values 

170 

171 return result 

172 

173 

174def _batch_gather(params, indices, axis, batch_dims): 

175 """Helper that implements the body for ragged gather() when batch_dims>0. 

176 

177 Args: 

178 params: The tensor from which to gather values. 

179 indices: The indices of values to gather. 

180 axis: The axis in `params` to gather `indices` from. 

181 batch_dims: The number of batch dimensions. 

182 

183 Returns: 

184 A potentially ragged tensor. 

185 """ 

186 # Perform static checks that `params` and `indices` have compatible batch 

187 # dimensions. Note: we do not perform *runtime* checks that `params` and 

188 # `indices` actually have the same row-splits (because we wish to avoid the 

189 # runtime cost of those checks). If `params` and `indices` are 

190 # incompatible, the resulting `RaggedTensor` may be nonsensical. 

191 if not params.shape[:batch_dims].is_compatible_with( 

192 indices.shape[:batch_dims]): 

193 raise ValueError('batch shape from indices %s does not match params ' 

194 'shape %s' % (indices.shape[:batch_dims], params.shape)) 

195 

196 if batch_dims > 1: 

197 # Convert params & indices to ragged tensors. 

198 if not isinstance(params, ragged_tensor.RaggedTensor): 

199 if indices.uniform_row_length is None: 

200 raise ValueError( 

201 'batch shape from indices does not match params shape: ragged ' 

202 'indices dimension corresponds to uniform params dimension') 

203 params = ragged_tensor.RaggedTensor.from_tensor( 

204 params, ragged_rank=1, row_splits_dtype=indices.row_splits.dtype) 

205 if not isinstance(indices, ragged_tensor.RaggedTensor): 

206 if params.uniform_row_length is None: 

207 raise ValueError( 

208 'batch shape from indices does not match params shape: ragged ' 

209 'params dimension corresponds to uniform indices dimension') 

210 indices = ragged_tensor.RaggedTensor.from_tensor( 

211 indices, ragged_rank=1, row_splits_dtype=params.row_splits.dtype) 

212 # Flatten the two outer batch dimensions into a single batch dimension, 

213 # and recurse. 

214 return params.with_values( 

215 _gather(params.values, indices.values, axis - 1, batch_dims - 1)) 

216 

217 if axis > 1: 

218 # Convert an axis dimension into a batch dimension, by adding a dimension 

219 # to `indices`, and tiling it to match `params`. E.g., if `params` 

220 # had shape `[B, P1, P2]`, and `indices` had shape `[B, I1, I2]`, then we 

221 # tile `indices` to have shape `[B, P1, I1, I2]`. That way, we can treat 

222 # the `P1` dimension as a batch dimension. 

223 if not isinstance(indices, ragged_tensor.RaggedTensor): 

224 adjusted_indices = params.with_values( 

225 array_ops.repeat(indices, params.row_lengths(), 0)) 

226 else: 

227 if not isinstance(params, ragged_tensor.RaggedTensor): 

228 params = ragged_tensor.RaggedTensor.from_tensor( 

229 params, ragged_rank=1, row_splits_dtype=indices.row_splits.dtype) 

230 adjusted_indices = _gather( 

231 indices, 

232 params.with_values( 

233 array_ops.repeat( 

234 math_ops.range(params.nrows()), params.row_lengths())), 0, 0) 

235 return _batch_gather(params, adjusted_indices, axis, batch_dims + 1) 

236 

237 if indices.shape.rank is None: 

238 raise ValueError('rank(indices) must be known statically') 

239 

240 assert batch_dims == 1 

241 # If params.shape=[B, P1...PN] and indices.shape=[B, I1...IM], then: 

242 # 

243 # output[b, i1...im, p2...pn] = 

244 # params[b, indices[b, i1...im], p2...pn] 

245 # 

246 # We construct `output` by flattening `params`, adjusting the `indices` to 

247 # point into that flattened list, and recursively calling `gather`. 

248 flat_params = _flatten_dims_0_and_1(params) 

249 adjustments = _row_starts(params, indices.dtype) # offset for each batch 

250 # increase adjustments's rank so it broadcasts w/ the outer dim of indices 

251 adjustments = _increase_rank_to(adjustments, indices.shape.ndims) 

252 adjusted_indices = indices + adjustments 

253 return _gather(flat_params, adjusted_indices, axis - 1, 0) 

254 

255 

256def _axis_gather(params, indices, axis): 

257 """Helper that implements ragged gather when axis>0 and batch_dims==0. 

258 

259 Args: 

260 params: The tensor from which to gather values. 

261 indices: The indices of values to gather. 

262 axis: The axis in `params` to gather `indices` from. 

263 

264 Returns: 

265 A potentially ragged tensor. 

266 """ 

267 if axis > 1: 

268 if not isinstance(params, ragged_tensor.RaggedTensor): 

269 params = ragged_tensor.RaggedTensor.from_tensor( 

270 params, ragged_rank=1, row_splits_dtype=indices.row_splits.dtype) 

271 # Recurse, using the flattened params (but do not flatten indices). 

272 return params.with_values(_gather(params.values, indices, axis - 1, 0)) 

273 

274 if indices.shape.rank is None: 

275 raise ValueError('rank(indices) must be known statically') 

276 

277 # Note: there is no checking of indices. If there is some index 

278 # out of bounds, the results may be nonsensical. 

279 

280 assert axis == 1 

281 # If params.shape=[P1...PN] and indices.shape=[I1...IM], then: 

282 # 

283 # output[p1, i1...im, p3...pn] = 

284 # params[p1, indices[i1...im], p3...pn] 

285 # 

286 # We construct `output` by flattening `params`, adjusting the `indices` to 

287 # have one additional dimension, and to point into that flattened list, and 

288 # recursively calling `gather`. 

289 flat_params = _flatten_dims_0_and_1(params) 

290 adjustments = _row_starts(params, indices.dtype) # offset for each batch 

291 adjustments = _increase_rank_to(adjustments, indices.shape.ndims + 1) 

292 adjusted_indices = indices + adjustments 

293 return _gather(flat_params, adjusted_indices, axis - 1, 0) 

294 

295 

296def _flatten_dims_0_and_1(t): 

297 """Returns a copy of `t` with the outer two dimensions merged.""" 

298 if isinstance(t, ragged_tensor.RaggedTensor): 

299 return t.values 

300 else: 

301 t_shape = array_ops.shape(t) 

302 return array_ops.reshape(t, array_ops.concat([[-1], t_shape[2:]], axis=0)) 

303 

304 

305def _row_starts(t, dtype): 

306 """Returns the start indices for the rows in `t`.""" 

307 if isinstance(t, ragged_tensor.RaggedTensor): 

308 return math_ops.cast(t.row_starts(), dtype) 

309 else: 

310 t_shape = array_ops.shape(t, out_type=dtype) 

311 return math_ops.range(t_shape[0]) * t_shape[1] 

312 

313 

314def _increase_rank_to(t, rank): 

315 """Adds *trailing* size-1 dimensions to `t` until it has the given rank.""" 

316 if isinstance(t, ragged_tensor.RaggedTensor): 

317 return t.with_values(_increase_rank_to(t, rank - 1)) 

318 else: 

319 old_dims = array_ops.shape(t) 

320 new_dims = array_ops.ones([rank - array_ops.rank(t)], old_dims.dtype) 

321 new_shape = array_ops.concat([old_dims, new_dims], axis=0) 

322 return array_ops.reshape(t, new_shape) 

323 

324 

325@dispatch.dispatch_for_api(array_ops.gather) 

326def _ragged_gather_v1(params: ragged_tensor.RaggedOrDense, 

327 indices: ragged_tensor.RaggedOrDense, 

328 validate_indices=None, 

329 name=None, 

330 axis=0, 

331 batch_dims=0): 

332 return gather(params, indices, validate_indices, axis, batch_dims, name) 

333 

334 

335#=============================================================================== 

336# ragged.gather_nd 

337#=============================================================================== 

338@dispatch.dispatch_for_api(array_ops.gather_nd_v2) 

339def gather_nd(params: ragged_tensor.RaggedOrDense, 

340 indices: ragged_tensor.RaggedOrDense, 

341 batch_dims=0, 

342 name=None): 

343 """Gather slices from `params` using `n`-dimensional indices. 

344 

345 This operation is similar to `gather`, but it uses the innermost dimension 

346 of `indices` to define a slice into `params`. In particular, if: 

347 

348 * `indices` has shape `[A1...AN, I]` 

349 * `params` has shape `[B1...BM]` 

350 

351 Then: 

352 

353 * `result` has shape `[A1...AN, B_{I+1}...BM]`. 

354 * `result[a1...aN] = params[indices[a1...aN, :]]` 

355 

356 Args: 

357 params: A potentially ragged tensor with shape `[A1...AN, I]`. 

358 indices: A potentially ragged tensor with shape `[B1...BM]`. 

359 batch_dims: Must be zero. 

360 name: A name for the operation (optional). 

361 

362 Returns: 

363 A potentially ragged tensor with shape `[A1...AN, B_{I+1}...BM]`. 

364 

365 #### Examples: 

366 

367 >>> params = tf.ragged.constant( 

368 ... [ [ ['000', '001'], ['010' ] ], 

369 ... [ ['100' ], ['110', '111', '112'], ['120'] ], 

370 ... [ [ ], ['210' ] ] ]) 

371 

372 >>> # Gather 2D slices from a 3D tensor 

373 >>> tf.gather_nd(params, [[2], [0]]) 

374 <tf.RaggedTensor [[[], [b'210']], [[b'000', b'001'], [b'010']]]> 

375 

376 >>> # Gather 1D slices from a 3D tensor 

377 >>> tf.gather_nd(params, [[2, 1], [0, 0]]) 

378 <tf.RaggedTensor [[b'210'], [b'000', b'001']]> 

379 

380 >>> # Gather scalars from a 3D tensor 

381 >>> tf.gather_nd(params, [[0, 0, 1], [1, 1, 2]]).numpy() 

382 array([b'001', b'112'], dtype=object) 

383 """ 

384 if not isinstance(batch_dims, int) or batch_dims != 0: 

385 raise ValueError('batch_dims != 0 is not supported for ragged gather yet.') 

386 if not (ragged_tensor.is_ragged(params) or ragged_tensor.is_ragged(indices)): 

387 return array_ops.gather_nd(params, indices, name) 

388 

389 with ops.name_scope(name, 'RaggedGatherNd', [params, indices]): 

390 

391 params = ragged_tensor.convert_to_tensor_or_ragged_tensor( 

392 params, name='params') 

393 indices = ragged_tensor.convert_to_tensor_or_ragged_tensor( 

394 indices, name='indices') 

395 params, indices = ragged_tensor.match_row_splits_dtypes(params, indices) 

396 indices_shape = indices.shape 

397 indices_ndims = indices_shape.ndims 

398 if indices_ndims is None: 

399 raise ValueError('indices.rank be statically known.') 

400 if indices_ndims == 0: 

401 raise ValueError('indices.rank must be at least 1.') 

402 if (ragged_tensor.is_ragged(indices) and 

403 indices_ndims == indices.ragged_rank + 1): 

404 raise ValueError('The innermost dimension of indices may not be ragged') 

405 

406 # `index_size` is the "n" in "gather_nd" -- i.e., the number of dimensions 

407 # that each index slices into. 

408 index_size = tensor_shape.dimension_value(indices_shape[-1]) 

409 if index_size is None: 

410 raise ValueError('indices.shape[-1] must be statically known.') 

411 

412 # If `indices` has more than 2 dimensions, then recurse. If `indices` is 

413 # dense, then we convert it to ragged before recursing, and then convert 

414 # the result back to `dense` if appropriate. 

415 if indices_ndims > 2: 

416 indices_is_dense = not ragged_tensor.is_ragged(indices) 

417 if indices_is_dense: 

418 indices = ragged_tensor.RaggedTensor.from_tensor( 

419 indices, ragged_rank=indices_ndims - 2, 

420 row_splits_dtype=params.row_splits.dtype) 

421 result = indices.with_flat_values(gather_nd(params, indices.flat_values)) 

422 if (indices_is_dense and ragged_tensor.is_ragged(result) and 

423 result.ragged_rank == indices_ndims - 2): 

424 result = ragged_tensor.RaggedTensor.to_tensor(result) 

425 return result 

426 

427 # indices_ndims <= 2, and the innermost dimension of indices may not be 

428 # ragged, so `indices` must not be ragged. 

429 assert not ragged_tensor.is_ragged(indices) 

430 assert ragged_tensor.is_ragged(params) 

431 

432 # Handle corner case: An empty index tuple selects the entire `params` 

433 # value. So if `index_size` is zero, then tile `params`. 

434 if index_size == 0: 

435 params_ndims = params.ragged_rank + array_ops.rank(params.flat_values) 

436 for dim in range(indices_ndims - 1): 

437 params = ragged_array_ops.expand_dims(params, axis=0) 

438 multiples = array_ops.concat([ 

439 array_ops.shape(indices)[:-1], 

440 array_ops.ones([params_ndims], dtypes.int32) 

441 ], 

442 axis=0) 

443 return ragged_array_ops.tile(params, multiples) 

444 

445 # When index_size=1, we can just flatten the index tuples and use gather. 

446 elif index_size == 1: 

447 flattened_index_tuples = array_ops.reshape(indices, [-1]) 

448 return gather(params, flattened_index_tuples) 

449 

450 # Otherwise, params is a RaggedTensor, and indices is a 1D or 2D Tensor. 

451 # Flatten both the index tuples and the params, such that the flattened 

452 # index tuples point to the correct values in the flattened params; and 

453 # then use ragged.gather on the flattened index tuples & params. 

454 else: 

455 indices = math_ops.cast(indices, params.row_splits.dtype) 

456 

457 # Flatten the outermost 2 dimensions of the index tuples & params. 

458 flattened_index_tuples = array_ops.gather(params.row_splits, 

459 indices[..., 0]) 

460 flattened_index_tuples += indices[..., 1] 

461 flattened_params = params.values 

462 

463 # Flatten any remaining dimensions. 

464 for dim in range(2, index_size): 

465 if not ragged_tensor.is_ragged(flattened_params): 

466 flattened_index_tuples = array_ops.expand_dims( 

467 flattened_index_tuples, axis=1) 

468 flattened_index_tuples = array_ops.concat( 

469 [flattened_index_tuples, indices[..., dim:]], axis=1) 

470 return array_ops.gather_nd(flattened_params, flattened_index_tuples) 

471 

472 flattened_index_tuples = array_ops.gather( 

473 flattened_params.row_starts(), flattened_index_tuples) 

474 flattened_index_tuples += indices[..., dim] 

475 flattened_params = flattened_params.values 

476 

477 # Gather using the flattened index tuples and params. 

478 return gather(flattened_params, flattened_index_tuples) 

479 

480 

481@dispatch.dispatch_for_api(array_ops.gather_nd) 

482def _ragged_gather_nd_v1(params: ragged_tensor.RaggedOrDense, 

483 indices: ragged_tensor.RaggedOrDense, 

484 name=None, 

485 batch_dims=0): 

486 return gather_nd(params, indices, batch_dims, name) 

487 

488 

489#=============================================================================== 

490# Gradient for the RaggedGather kernel 

491#=============================================================================== 

492@ops.RegisterGradient('RaggedGather') 

493def _ragged_gather_grad(op, *grads): 

494 """Gradient for RaggedGather op.""" 

495 param_nested_splits = op.inputs[:-2] 

496 param_inner_values = op.inputs[-2] 

497 indices = op.inputs[-1] 

498 grad_inner_values = grads[-1] 

499 

500 # For each row in `params`, find the range of values in `params.inner_values` 

501 # that is covered by that row. In particular, the values in row `i` are 

502 # `param_inner_values[combined_splits[i]:combined_splits[i+1]`. 

503 combined_splits = param_nested_splits[0] 

504 for row_splits in param_nested_splits[1:]: 

505 combined_splits = array_ops.gather(row_splits, combined_splits) 

506 

507 # The outer dimensions of `indices` correspond 1:1 with the outer dimensions 

508 # of `ragged_grad` that are encoded by `grad_nested_splits`. Thus, the 

509 # flattened `indices` correspond 1:1 with `grad_inner_values`. 

510 flat_indices = array_ops.reshape(indices, [-1]) 

511 

512 # Build an IndexedSlices where the values are taken from `flat_grad`. 

513 grad_indices = ragged_math_ops.range( 

514 array_ops.gather(combined_splits, flat_indices), 

515 array_ops.gather(combined_splits[1:], flat_indices)).values 

516 

517 param_inner_values_grad = indexed_slices.IndexedSlices( 

518 values=grad_inner_values, indices=grad_indices, 

519 dense_shape=array_ops.shape(param_inner_values)) 

520 return [None for _ in param_nested_splits] + [param_inner_values_grad, None]