Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ragged/ragged_array_ops.py: 22%

363 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Array operations for RaggedTensors.""" 

16 

17from typing import Optional 

18from typing import Union 

19 

20from tensorflow.python.framework import dtypes 

21from tensorflow.python.framework import ops 

22from tensorflow.python.framework import sparse_tensor 

23from tensorflow.python.framework import tensor_shape 

24from tensorflow.python.framework import tensor_util 

25from tensorflow.python.ops import array_ops 

26from tensorflow.python.ops import array_ops_stack 

27from tensorflow.python.ops import check_ops 

28from tensorflow.python.ops import control_flow_ops 

29from tensorflow.python.ops import data_flow_ops 

30from tensorflow.python.ops import gen_ragged_array_ops 

31from tensorflow.python.ops import math_ops 

32from tensorflow.python.ops import sort_ops 

33from tensorflow.python.ops.ragged import dynamic_ragged_shape 

34from tensorflow.python.ops.ragged import ragged_functional_ops 

35from tensorflow.python.ops.ragged import ragged_math_ops 

36from tensorflow.python.ops.ragged import ragged_tensor 

37from tensorflow.python.ops.ragged import ragged_util 

38from tensorflow.python.ops.ragged import segment_id_ops 

39from tensorflow.python.types import core as core_types 

40from tensorflow.python.util import dispatch 

41from tensorflow.python.util.tf_export import tf_export 

42 

43# =============================================================================== 

44# Masking 

45# =============================================================================== 

46 

47 

48@tf_export('ragged.boolean_mask') 

49@dispatch.add_dispatch_support 

50def boolean_mask(data, mask, name=None): 

51 """Applies a boolean mask to `data` without flattening the mask dimensions. 

52 

53 Returns a potentially ragged tensor that is formed by retaining the elements 

54 in `data` where the corresponding value in `mask` is `True`. 

55 

56 * `output[a1...aA, i, b1...bB] = data[a1...aA, j, b1...bB]` 

57 

58 Where `j` is the `i`th `True` entry of `mask[a1...aA]`. 

59 

60 Note that `output` preserves the mask dimensions `a1...aA`; this differs 

61 from `tf.boolean_mask`, which flattens those dimensions. 

62 

63 Args: 

64 data: A potentially ragged tensor. 

65 mask: A potentially ragged boolean tensor. `mask`'s shape must be a prefix 

66 of `data`'s shape. `rank(mask)` must be known statically. 

67 name: A name prefix for the returned tensor (optional). 

68 

69 Returns: 

70 A potentially ragged tensor that is formed by retaining the elements in 

71 `data` where the corresponding value in `mask` is `True`. 

72 

73 * `rank(output) = rank(data)`. 

74 * `output.ragged_rank = max(data.ragged_rank, rank(mask) - 1)`. 

75 

76 Raises: 

77 ValueError: if `rank(mask)` is not known statically; or if `mask.shape` is 

78 not a prefix of `data.shape`. 

79 

80 #### Examples: 

81 

82 >>> # Aliases for True & False so data and mask line up. 

83 >>> T, F = (True, False) 

84 

85 >>> tf.ragged.boolean_mask( # Mask a 2D Tensor. 

86 ... data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], 

87 ... mask=[[T, F, T], [F, F, F], [T, F, F]]).to_list() 

88 [[1, 3], [], [7]] 

89 

90 >>> tf.ragged.boolean_mask( # Mask a 2D RaggedTensor. 

91 ... tf.ragged.constant([[1, 2, 3], [4], [5, 6]]), 

92 ... tf.ragged.constant([[F, F, T], [F], [T, T]])).to_list() 

93 [[3], [], [5, 6]] 

94 

95 >>> tf.ragged.boolean_mask( # Mask rows of a 2D RaggedTensor. 

96 ... tf.ragged.constant([[1, 2, 3], [4], [5, 6]]), 

97 ... tf.ragged.constant([True, False, True])).to_list() 

98 [[1, 2, 3], [5, 6]] 

99 """ 

100 with ops.name_scope(name, 'RaggedMask', [data, mask]): 

101 # Convert inputs to tensors. 

102 data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data') 

103 mask = ragged_tensor.convert_to_tensor_or_ragged_tensor( 

104 mask, dtypes.bool, name='mask') 

105 row_splits_dtype, (data, mask) = ragged_tensor.match_row_splits_dtypes( 

106 data, mask, return_dtype=True) 

107 

108 # Get static rank of mask. 

109 if mask.shape.ndims is None: 

110 raise ValueError('mask.shape.ndims must be known statically.') 

111 elif mask.shape.ndims == 0: 

112 raise ValueError('mask cannot be scalar.') 

113 

114 # If mask is ragged, then recurse with a non-ragged mask. 

115 if ragged_tensor.is_ragged(mask): 

116 if not ragged_tensor.is_ragged(data): 

117 data = ragged_tensor.RaggedTensor.from_tensor( 

118 data, 

119 ragged_rank=mask.ragged_rank, 

120 row_splits_dtype=mask.row_splits.dtype) 

121 # Check that mask.nested_row_splits is a prefix of 

122 # data.nested_row_splits. 

123 splits_list = [ 

124 mask.nested_row_splits, data.nested_row_splits[:mask.ragged_rank] 

125 ] 

126 with ops.control_dependencies( 

127 ragged_util.assert_splits_match(splits_list)): 

128 # Strip off ragged `splits` until `mask` is non-ragged. Keep the splits 

129 # that we strip off in `splits`, so we can add them back on after 

130 # we recursively mask the non-ragged data. 

131 splits = [] 

132 while ragged_tensor.is_ragged(mask): 

133 if mask.shape.ndims > 2: 

134 splits.append(mask.row_splits) 

135 else: 

136 # Count the number of True mask values in each row to find the 

137 # lengths of the filtered rows; then convert to splits. 

138 int_mask = ragged_functional_ops.map_flat_values( 

139 math_ops.cast, mask, dtype=row_splits_dtype) 

140 masked_row_lengths = ragged_math_ops.reduce_sum(int_mask, axis=1) 

141 splits.append(ragged_util.lengths_to_splits(masked_row_lengths)) 

142 mask = mask.values 

143 data = data.values 

144 

145 # Recursively apply the nested non-ragged mask to the nested data. 

146 masked_values = boolean_mask(data, mask) 

147 

148 # Add the ragged `splits` back to the result. 

149 masked_values = ragged_tensor.RaggedTensor.from_nested_row_splits( 

150 masked_values, splits, validate=False) 

151 

152 return masked_values 

153 

154 # If mask is non-ragged and has rank 1, and data is ragged, then build a 

155 # ragged tensor with the indicated rows. 

156 elif ragged_tensor.is_ragged(data) and mask.shape.ndims == 1: 

157 # Get the masked splits: first get the length of each row, then filter 

158 # out the rows that we are deleting, and convert that filtered set of 

159 # masks back to a splits tensor. 

160 lengths = data.row_lengths() 

161 masked_lengths = array_ops.boolean_mask(lengths, mask) 

162 masked_splits = ragged_util.lengths_to_splits(masked_lengths) 

163 

164 # Get the masked values: first get row ids corresponding to each 

165 # value, then use tf.gather to build a boolean mask that's false for 

166 # values that come from rows that we are deleting, and use that mask to 

167 # construct the masked values tensor. 

168 segment_ids = segment_id_ops.row_splits_to_segment_ids(data.row_splits) 

169 segment_mask = array_ops.gather(mask, segment_ids) 

170 masked_values = boolean_mask(data.values, segment_mask) 

171 

172 return ragged_tensor.RaggedTensor.from_row_splits( 

173 masked_values, masked_splits, validate=False) 

174 

175 # If mask is non-ragged and has rank>1, then convert it to be ragged, 

176 # with a ragged rank matching data. 

177 if ragged_tensor.is_ragged(data): 

178 mask = ragged_tensor.RaggedTensor.from_tensor( 

179 mask, 

180 ragged_rank=min(data.ragged_rank, mask.shape.ndims - 1), 

181 row_splits_dtype=data.row_splits.dtype) 

182 return boolean_mask(data, mask) 

183 

184 # Otherwise, data and mask are both `Tensor`s. 

185 else: 

186 # Apply `boolean_mask` to get the masked values. 

187 masked_values = array_ops.boolean_mask(data, mask) 

188 

189 if mask.shape.ndims >= 2: 

190 # Add the innermost ragged dimension. For each innermost cell, get the 

191 # number of values it contains. Then flatten that to get a list of 

192 # cell lengths, and convert it to splits. Finally, combine the splits 

193 # and values to get the innermost ragged tensor. 

194 masked_lengths = math_ops.count_nonzero( 

195 mask, axis=-1, dtype=row_splits_dtype) 

196 flattened_masked_lengths = array_ops.reshape(masked_lengths, [-1]) 

197 masked_values = ragged_tensor.RaggedTensor.from_row_lengths( 

198 masked_values, flattened_masked_lengths, validate=False) 

199 

200 # Wrap remaining ragged dimensions. 

201 if mask.shape.ndims > 2: 

202 mask_shape = array_ops.shape(mask, out_type=row_splits_dtype) 

203 split_size = math_ops.cumprod(mask_shape) + 1 

204 for dim in range(mask.shape.ndims - 3, -1, -1): 

205 elt_size = mask_shape[dim + 1] 

206 masked_splits = math_ops.range(split_size[dim]) * elt_size 

207 masked_values = ragged_tensor.RaggedTensor.from_row_splits( 

208 masked_values, masked_splits, validate=False) 

209 

210 return masked_values 

211 

212 

213# =============================================================================== 

214# Tiling 

215# =============================================================================== 

216@dispatch.dispatch_for_api(array_ops.tile) 

217def tile(input: ragged_tensor.Ragged, multiples, name=None): # pylint: disable=redefined-builtin 

218 """Constructs a `RaggedTensor` by tiling a given `RaggedTensor`. 

219 

220 The values of `input` are replicated `multiples[i]` times along the 

221 `i`th dimension (for each dimension `i`). For every dimension `axis` in 

222 `input`, the length of each output element in that dimension is the 

223 length of corresponding input element multiplied by `multiples[axis]`. 

224 

225 Args: 

226 input: A `RaggedTensor`. 

227 multiples: A 1-D integer `Tensor`. Length must be the same as the number of 

228 dimensions in `input`. 

229 name: A name for the operation (optional). 

230 

231 Returns: 

232 A `RaggedTensor` with the same type, rank, and ragged_rank as `input`. 

233 

234 #### Example: 

235 

236 >>> rt = tf.ragged.constant([[1, 2], [3]]) 

237 >>> tf.tile(rt, [3, 2]).to_list() 

238 [[1, 2, 1, 2], [3, 3], [1, 2, 1, 2], [3, 3], [1, 2, 1, 2], [3, 3]] 

239 """ 

240 with ops.name_scope(name, 'RaggedTile', [input, multiples]): 

241 input = ragged_tensor.convert_to_tensor_or_ragged_tensor( 

242 input, name='input') 

243 if not ragged_tensor.is_ragged(input): 

244 return array_ops.tile(input, multiples, name) 

245 multiples = ragged_util.convert_to_int_tensor( 

246 multiples, name='multiples', dtype=input.row_splits.dtype) 

247 multiples.shape.assert_has_rank(1) 

248 

249 # If the constant value of `multiples` is available, then we can use it 

250 # to skip tiling dimensions where `multiples=1`. 

251 const_multiples = tensor_util.constant_value(multiples) 

252 

253 return ragged_tensor.RaggedTensor.from_nested_row_splits( 

254 _tile_ragged_values(input, multiples, const_multiples), 

255 _tile_ragged_splits(input, multiples, const_multiples), 

256 validate=False) 

257 

258 

259def _tile_ragged_values(rt_input, multiples, const_multiples=None): 

260 """Builds flat_values tensor for a tiled `RaggedTensor`. 

261 

262 Returns a tensor that repeats the values in 

263 `rt_input.flat_values` in the 

264 appropriate pattern to construct a `RaggedTensor` that tiles `rt_input` as 

265 specified by `multiples`. 

266 

267 Args: 

268 rt_input: The `RaggedTensor` whose values should be repeated. 

269 multiples: A 1-D integer `tensor`, indicating how many times each dimension 

270 should be repeated. 

271 const_multiples: Optional constant value for multiples. Used to skip tiling 

272 dimensions where `multiples=1`. 

273 

274 Returns: 

275 A `Tensor` with the same type and rank as `rt_input.flat_values`. 

276 

277 #### Example: 

278 

279 >>> rt = tf.ragged.constant([[1, 2], [3]]) 

280 >>> _tile_ragged_values(rt, tf.constant([3, 2])).numpy() 

281 array([1, 2, 1, 2, 3, 3, 1, 2, 1, 2, 3, 3, 1, 2, 1, 2, 3, 3], dtype=int32) 

282 """ 

283 ragged_rank = rt_input.ragged_rank 

284 nested_splits = rt_input.nested_row_splits 

285 

286 # Pointers to the values in `rt_input.flat_values`. 

287 inner_value_ids = math_ops.range(nested_splits[-1][-1]) 

288 

289 # For each ragged dimension (working from the innermost to outermost), 

290 # expand `inner_value_ids` as necessary to tile that dimension. 

291 prev_splits = None 

292 for axis in range(ragged_rank, 0, -1): 

293 # Ragged splits for this dimension. 

294 splits = nested_splits[axis - 1] 

295 

296 # Adjust splits so they point into `inner_value_ids` (instead of just 

297 # pointing into the next dimension's values). 

298 if prev_splits is not None: # Not the first pass through the loop. 

299 splits = array_ops.gather(prev_splits * multiples[axis + 1], splits) 

300 

301 # Repeat each element in this ragged dimension `multiples[axis]` times. 

302 if const_multiples is None or const_multiples[axis] != 1: 

303 inner_value_ids = ragged_util.repeat_ranges(inner_value_ids, splits, 

304 multiples[axis]) 

305 

306 prev_splits = splits 

307 

308 # Gather the tiled inner values. 

309 ragged_tiled_values = array_ops.gather(rt_input.flat_values, inner_value_ids) 

310 

311 # Tile the flat_values for the uniform dimensions (i.e., for `axis=0` plus 

312 # `axis=range(ragged_rank, rank)`). 

313 inner_repeats = array_ops.concat([multiples[:1], multiples[ragged_rank + 1:]], 

314 axis=0) 

315 return array_ops.tile(ragged_tiled_values, inner_repeats) 

316 

317 

318def _tile_ragged_splits(rt_input, multiples, const_multiples=None): 

319 """Builds nested_split tensors for a tiled `RaggedTensor`. 

320 

321 Returns a list of split tensors that can be used to construct the 

322 `RaggedTensor` that tiles `rt_input` as specified by `multiples`. 

323 

324 Args: 

325 rt_input: The `RaggedTensor` that is being tiled. 

326 multiples: A 1-D integer `tensor`, indicating how many times each dimension 

327 should be repeated. 

328 const_multiples: Optional constant value for multiples. Used to skip tiling 

329 dimensions where `multiples=1`. 

330 

331 Returns: 

332 A list of 1-D integer `Tensor`s (one for each ragged dimension in 

333 `rt_input`). 

334 

335 #### Example: 

336 

337 >>> rt = tf.ragged.constant([[1, 2], [3]]) 

338 >>> _tile_ragged_splits(rt, [3, 2]) 

339 [<tf.Tensor: shape=(7,), dtype=int64, 

340 numpy=array([ 0, 4, 6, 10, 12, 16, 18])>] 

341 """ 

342 ragged_rank = rt_input.ragged_rank 

343 nested_splits = rt_input.nested_row_splits 

344 

345 # projected_splits[src_axis, dst_axis] contains the split points that divide 

346 # the rows from src_axis in the list of dst_axis values. E.g., 

347 # projected_splits[i, i] = nested_splits[i], and 

348 # projected_splits[i, i+1] = gather(nested_splits[i+1], nested_splits[i]). 

349 projected_splits = [{i: nested_splits[i]} for i in range(ragged_rank)] 

350 for src_axis in range(ragged_rank): 

351 for dst_axis in range(src_axis + 1, ragged_rank - 1): 

352 projected_splits[src_axis][dst_axis] = array_ops.gather( 

353 nested_splits[dst_axis], projected_splits[src_axis][dst_axis - 1]) 

354 

355 # For each ragged dimension: nested_splits[axis] -> result_splits[axis]. 

356 result_splits = [] 

357 for axis in range(ragged_rank): 

358 # Get the length of each row for the input tensor for this dimension. 

359 input_lengths = nested_splits[axis][1:] - nested_splits[axis][:-1] 

360 

361 # Multiply those lengths by the `multiples` of dimension axis+1, since 

362 # each value will be repeated that number of times. 

363 output_lengths = input_lengths * multiples[axis + 1] 

364 

365 # Repeat ranges of the row lengths as necessary for them to be tiled in 

366 # each ragged dimension `d < axis`. (Start with dimension d=axis-1, and 

367 # work our way up to dimension d=0.) 

368 repeats = 1 

369 for d in range(axis - 1, -1, -1): 

370 if const_multiples is None or const_multiples[d + 1] != 1: 

371 splits = projected_splits[d][axis - 1] * repeats 

372 output_lengths = ragged_util.repeat_ranges(output_lengths, splits, 

373 multiples[d + 1]) 

374 repeats *= multiples[d + 1] 

375 

376 # Tile splits for the outermost (uniform) dimension. 

377 output_lengths = array_ops.tile(output_lengths, multiples[:1]) 

378 

379 # Convert to splits. 

380 result_splits.append(ragged_util.lengths_to_splits(output_lengths)) 

381 

382 return result_splits 

383 

384 

385# =============================================================================== 

386# Reshaping 

387# =============================================================================== 

388 

389 

390@dispatch.dispatch_for_api(array_ops.expand_dims_v2) 

391def expand_dims(input: ragged_tensor.Ragged, axis, name=None): # pylint: disable=redefined-builtin 

392 """Inserts a dimension with shape 1 into a potentially ragged tensor's shape. 

393 

394 Given a potentially ragged tenor `input`, this operation inserts a 

395 dimension with size 1 at the dimension `axis` of `input`'s shape. 

396 

397 The following table gives some examples showing how `ragged.expand_dims` 

398 impacts the shapes of different input tensors. Ragged dimensions are 

399 indicated by enclosing them in parentheses. 

400 

401 input.shape | axis | result.shape 

402 ----------------------- | ---- | ----------------------------- 

403 `[D1, D2]` | `0` | `[1, D1, D2]` 

404 `[D1, D2]` | `1` | `[D1, 1, D2]` 

405 `[D1, D2]` | `2` | `[D1, D2, 1]` 

406 `[D1, (D2), (D3), D4]` | `0` | `[1, D1, (D2), (D3), D4]` 

407 `[D1, (D2), (D3), D4]` | `1` | `[D1, 1, (D2), (D3), D4]` 

408 `[D1, (D2), (D3), D4]` | `2` | `[D1, (D2), 1, (D3), D4]` 

409 `[D1, (D2), (D3), D4]` | `3` | `[D1, (D2), (D3), 1, D4]` 

410 `[D1, (D2), (D3), D4]` | `4` | `[D1, (D2), (D3), D4, 1]` 

411 

412 Args: 

413 input: The potentially tensor that should be expanded with a new dimension. 

414 axis: An integer constant indicating where the new dimension should be 

415 inserted. 

416 name: A name for the operation (optional). 

417 

418 Returns: 

419 A tensor with the same values as `input`, with an added dimension of 

420 size 1 at `axis`. 

421 

422 #### Examples: 

423 

424 >>> rt = tf.ragged.constant([[1, 2], [3]]) 

425 >>> print(rt.shape) 

426 (2, None) 

427 

428 >>> expanded = tf.expand_dims(rt, axis=0) 

429 >>> print(expanded.shape, expanded) 

430 (1, 2, None) <tf.RaggedTensor [[[1, 2], [3]]]> 

431 

432 >>> expanded = tf.expand_dims(rt, axis=1) 

433 >>> print(expanded.shape, expanded) 

434 (2, 1, None) <tf.RaggedTensor [[[1, 2]], [[3]]]> 

435 

436 >>> expanded = tf.expand_dims(rt, axis=2) 

437 >>> print(expanded.shape, expanded) 

438 (2, None, 1) <tf.RaggedTensor [[[1], [2]], [[3]]]> 

439 """ 

440 with ops.name_scope(name, 'RaggedExpandDims', [input]): 

441 input = ragged_tensor.convert_to_tensor_or_ragged_tensor( 

442 input, name='input') 

443 

444 if not ragged_tensor.is_ragged(input): 

445 return array_ops.expand_dims(input, axis) 

446 

447 ndims = None if input.shape.ndims is None else input.shape.ndims + 1 

448 axis = array_ops.get_positive_axis(axis, ndims, ndims_name='rank(input)') 

449 

450 if axis == 0: 

451 return ragged_tensor.RaggedTensor.from_uniform_row_length( 

452 input, uniform_row_length=input.nrows(), nrows=1, validate=False) 

453 elif axis == 1: 

454 return ragged_tensor.RaggedTensor.from_uniform_row_length( 

455 input, uniform_row_length=1, nrows=input.nrows(), validate=False) 

456 else: 

457 if ragged_tensor.is_ragged(input.values): 

458 return input.with_values(expand_dims(input.values, axis - 1)) 

459 else: 

460 return input.with_values(array_ops.expand_dims(input.values, axis - 1)) 

461 

462 

463@dispatch.dispatch_for_api(array_ops.expand_dims) 

464def _ragged_expand_dims_v1( 

465 input: ragged_tensor.Ragged, # pylint: disable=redefined-builtin 

466 axis=None, 

467 name=None, 

468 dim=None): 

469 if dim is not None: 

470 axis = dim 

471 return expand_dims(input=input, axis=axis, name=name) 

472 

473 

474# =============================================================================== 

475# RaggedTensor Size 

476# =============================================================================== 

477 

478 

479@dispatch.dispatch_for_api(array_ops.size_v2) 

480def size(input: ragged_tensor.Ragged, out_type=dtypes.int32, name=None): # pylint: disable=redefined-builtin 

481 """Returns the size of a potentially ragged tensor. 

482 

483 The size of a ragged tensor is the size of its inner values. 

484 

485 #### Example: 

486 

487 >>> tf.size(tf.ragged.constant([[1, 2], [3]])).numpy() 

488 3 

489 

490 Args: 

491 input: A potentially ragged `Tensor`. 

492 out_type: The numeric output type for the operation. 

493 name: A name for the operation (optional). 

494 

495 Returns: 

496 A Tensor of type `out_type`. 

497 """ 

498 if ragged_tensor.is_ragged(input): 

499 return array_ops.size(input.flat_values, out_type=out_type, name=name) 

500 else: 

501 return array_ops.size(input, out_type=out_type, name=name) 

502 

503 

504@dispatch.dispatch_for_api(array_ops.size) 

505def _ragged_size_v1( 

506 input: ragged_tensor.Ragged, # pylint: disable=redefined-builtin 

507 name=None, 

508 out_type=dtypes.int32): 

509 return size(input=input, out_type=out_type, name=name) 

510 

511 

512# =============================================================================== 

513# ragged.rank 

514# =============================================================================== 

515@dispatch.dispatch_for_api(array_ops.rank) 

516def rank(input: ragged_tensor.Ragged, name=None): # pylint: disable=redefined-builtin 

517 """Returns the rank of a RaggedTensor. 

518 

519 Returns a 0-D `int32` `Tensor` representing the rank of `input`. 

520 

521 #### Example: 

522 

523 >>> # shape of tensor 't' is [2, None, None] 

524 >>> t = tf.ragged.constant([[[1], [2, 2]], [[3, 3, 3], [4, 4, 4, 4]]]) 

525 >>> tf.rank(t).numpy() 

526 3 

527 

528 Args: 

529 input: A `RaggedTensor` 

530 name: A name for the operation (optional). 

531 

532 Returns: 

533 A `Tensor` of type `int32`. 

534 """ 

535 with ops.name_scope(name, 'RaggedRank', [input]) as name: 

536 if not ragged_tensor.is_ragged(input): 

537 return array_ops.rank(input, name) 

538 

539 return input.ragged_rank + array_ops.rank(input.flat_values) 

540 

541 

542# =============================================================================== 

543# ragged.one_hot 

544# =============================================================================== 

545@dispatch.dispatch_for_api(array_ops.one_hot) 

546def ragged_one_hot(indices: ragged_tensor.Ragged, 

547 depth, 

548 on_value=None, 

549 off_value=None, 

550 axis=None, 

551 dtype=None, 

552 name=None): 

553 """Applies tf.one_hot along the values of a RaggedTensor.""" 

554 # Get the adjusted axis value for the call to array_ops.one_hot. 

555 # Note: the only negative `axis` value supported by array_ops.one_hot is -1. 

556 if isinstance(axis, int) and axis >= 0: 

557 if axis <= indices.ragged_rank: 

558 raise ValueError('axis (%d) must be greater than indices.ragged_rank ' 

559 '(%d).' % (axis, indices.ragged_rank)) 

560 axis -= indices.ragged_rank 

561 

562 with ops.name_scope(name, 'RaggedOneHot', 

563 [indices, depth, on_value, off_value, axis]): 

564 indices = ragged_tensor.convert_to_tensor_or_ragged_tensor( 

565 indices, name='indices') 

566 return indices.with_flat_values( 

567 array_ops.one_hot(indices.flat_values, depth, on_value, off_value, axis, 

568 dtype, name)) 

569 

570 

571# =============================================================================== 

572# ragged.stack_dynamic_partitions 

573# =============================================================================== 

574@tf_export('ragged.stack_dynamic_partitions') 

575@dispatch.add_dispatch_support 

576def stack_dynamic_partitions(data, partitions, num_partitions, name=None): 

577 """Stacks dynamic partitions of a Tensor or RaggedTensor. 

578 

579 Returns a RaggedTensor `output` with `num_partitions` rows, where the row 

580 `output[i]` is formed by stacking all slices `data[j1...jN]` such that 

581 `partitions[j1...jN] = i`. Slices of `data` are stacked in row-major 

582 order. 

583 

584 If `num_partitions` is an `int` (not a `Tensor`), then this is equivalent to 

585 `tf.ragged.stack(tf.dynamic_partition(data, partitions, num_partitions))`. 

586 

587 #### Example: 

588 

589 >>> data = ['a', 'b', 'c', 'd', 'e'] 

590 >>> partitions = [ 3, 0, 2, 2, 3] 

591 >>> num_partitions = 5 

592 >>> tf.ragged.stack_dynamic_partitions(data, partitions, num_partitions) 

593 <tf.RaggedTensor [[b'b'], [], [b'c', b'd'], [b'a', b'e'], []]> 

594 

595 Args: 

596 data: A `Tensor` or `RaggedTensor` containing the values to stack. 

597 partitions: An `int32` or `int64` `Tensor` or `RaggedTensor` specifying the 

598 partition that each slice of `data` should be added to. `partitions.shape` 

599 must be a prefix of `data.shape`. Values must be greater than or equal to 

600 zero, and less than `num_partitions`. `partitions` is not required to be 

601 sorted. 

602 num_partitions: An `int32` or `int64` scalar specifying the number of 

603 partitions to output. This determines the number of rows in `output`. 

604 name: A name prefix for the returned tensor (optional). 

605 

606 Returns: 

607 A `RaggedTensor` containing the stacked partitions. The returned tensor 

608 has the same dtype as `data`, and its shape is 

609 `[num_partitions, (D)] + data.shape[partitions.rank:]`, where `(D)` is a 

610 ragged dimension whose length is the number of data slices stacked for 

611 each `partition`. 

612 """ 

613 with ops.name_scope(name, 'SegmentStack', [data, partitions, num_partitions]): 

614 # Convert inputs to tensors. 

615 data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data') 

616 row_splits_dtype = ( 

617 data.row_splits.dtype 

618 if isinstance(data, ragged_tensor.RaggedTensor) else None) 

619 partitions = ragged_tensor.convert_to_tensor_or_ragged_tensor( 

620 partitions, name='partitions', preferred_dtype=row_splits_dtype) 

621 num_partitions = ops.convert_to_tensor( 

622 num_partitions, name='num_partitions', preferred_dtype=partitions.dtype) 

623 if row_splits_dtype is not None: 

624 partitions = math_ops.cast(partitions, row_splits_dtype) 

625 num_partitions = math_ops.cast(num_partitions, partitions.dtype) 

626 

627 # Sanity-checks for shapes. 

628 partitions_rank = partitions.shape.ndims 

629 if partitions_rank is None: 

630 raise ValueError('partitions must have known rank.') 

631 num_partitions.shape.assert_has_rank(0) 

632 partitions.shape.assert_is_compatible_with(data.shape[:partitions_rank]) 

633 

634 if partitions_rank == 0: 

635 # If partitions is a scalar, then just create a RaggedTensor containing 

636 # that single the complete `data` value in the specified row. 

637 return ragged_tensor.RaggedTensor.from_value_rowids( 

638 values=array_ops_stack.stack([data]), 

639 value_rowids=array_ops_stack.stack([partitions]), 

640 nrows=num_partitions, 

641 validate=False) 

642 

643 elif partitions_rank == 1: 

644 # If partitions is a vector (the typical case): we can just use data and 

645 # partitions as the `values` and `value_rowids` for `from_value_rowids`, 

646 # as long as we sort them first. 

647 permutation = sort_ops.argsort(partitions, stable=True) 

648 value_rowids = array_ops.gather(partitions, permutation) 

649 values = array_ops.gather(data, permutation) 

650 checks = [ 

651 check_ops.assert_less( 

652 value_rowids[-1:], num_partitions, 

653 message='partitions must be less than num_partitions'), 

654 check_ops.assert_non_negative( 

655 partitions, message='partitions must be non-negative.') 

656 ] 

657 with ops.control_dependencies(checks): 

658 return ragged_tensor.RaggedTensor.from_value_rowids( 

659 values, value_rowids, nrows=num_partitions, validate=False) 

660 

661 else: 

662 # Handle higher-dimensional partitions via recursion. 

663 if not isinstance(data, ragged_tensor.RaggedTensor): 

664 data = ragged_tensor.RaggedTensor.from_tensor( 

665 data, row_splits_dtype=partitions.dtype, ragged_rank=1) 

666 if not isinstance(partitions, ragged_tensor.RaggedTensor): 

667 partitions = ragged_tensor.RaggedTensor.from_tensor( 

668 partitions, 

669 row_splits_dtype=partitions.dtype, 

670 ragged_rank=max(data.ragged_rank, partitions_rank - 1)) 

671 check = check_ops.assert_equal( 

672 data.row_splits, 

673 partitions.row_splits, 

674 message='data and partitions have incompatible ragged shapes') 

675 with ops.control_dependencies([check]): 

676 return stack_dynamic_partitions(data.values, partitions.values, 

677 num_partitions) 

678 

679 

680# =============================================================================== 

681# Reverse 

682# =============================================================================== 

683@dispatch.dispatch_for_api(array_ops.reverse) 

684def reverse(tensor: ragged_tensor.Ragged, axis, name=None): 

685 """Reverses a RaggedTensor along the specified axes. 

686 

687 #### Example: 

688 

689 >>> data = tf.ragged.constant([ 

690 ... [[1, 2], [3, 4]], [[5, 6]], [[7, 8], [9, 10], [11, 12]]]) 

691 >>> tf.reverse(data, axis=[0, 2]) 

692 <tf.RaggedTensor [[[8, 7], [10, 9], [12, 11]], [[6, 5]], [[2, 1], [4, 3]]]> 

693 

694 Args: 

695 tensor: A 'RaggedTensor' to reverse. 

696 axis: A list or tuple of 'int' or a constant 1D 'tf.Tensor'. The indices of 

697 the axes to reverse. 

698 name: A name prefix for the returned tensor (optional). 

699 

700 Returns: 

701 A 'RaggedTensor'. 

702 """ 

703 type_error_msg = ('`axis` must be a list of int or a constant tensor' 

704 'when reversing axes in a ragged tensor') 

705 

706 with ops.name_scope(name, 'Reverse', [tensor, axis]): 

707 if isinstance(axis, ops.Tensor): 

708 axis = tensor_util.constant_value(axis) 

709 if axis is None: 

710 raise TypeError(type_error_msg) 

711 elif not (isinstance(axis, (list, tuple)) and 

712 all(isinstance(dim, int) for dim in axis)): 

713 raise TypeError(type_error_msg) 

714 

715 tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor( 

716 tensor, name='tensor') 

717 

718 # Allow usage of negative values to specify innermost axes. 

719 axis = [ 

720 array_ops.get_positive_axis(dim, tensor.shape.rank, 'axis[%d]' % i, 

721 'rank(tensor)') 

722 for i, dim in enumerate(axis) 

723 ] 

724 

725 # We only need to slice up to the max axis. If the axis list 

726 # is empty, it should be 0. 

727 slices = [slice(None)] * (max(axis) + 1 if axis else 0) 

728 

729 for dim in axis: 

730 slices[dim] = slice(None, None, -1) 

731 

732 return tensor[tuple(slices)] 

733 

734 

735# =============================================================================== 

736# Cross 

737# =============================================================================== 

738 

739 

740@tf_export('ragged.cross') 

741@dispatch.add_dispatch_support 

742def cross(inputs, name=None): 

743 """Generates feature cross from a list of tensors. 

744 

745 The input tensors must have `rank=2`, and must all have the same number of 

746 rows. The result is a `RaggedTensor` with the same number of rows as the 

747 inputs, where `result[row]` contains a list of all combinations of values 

748 formed by taking a single value from each input's corresponding row 

749 (`inputs[i][row]`). Values are combined by joining their strings with '_X_'. 

750 E.g.: 

751 

752 >>> tf.ragged.cross([tf.ragged.constant([['a'], ['b', 'c']]), 

753 ... tf.ragged.constant([['d'], ['e']]), 

754 ... tf.ragged.constant([['f'], ['g']])]) 

755 <tf.RaggedTensor [[b'a_X_d_X_f'], [b'b_X_e_X_g', b'c_X_e_X_g']]> 

756 

757 Args: 

758 inputs: A list of `RaggedTensor` or `Tensor` or `SparseTensor`. 

759 name: Optional name for the op. 

760 

761 Returns: 

762 A 2D `RaggedTensor` of type `string`. 

763 """ 

764 return _cross_internal(inputs=inputs, hashed_output=False, name=name) 

765 

766 

767@tf_export('ragged.cross_hashed') 

768@dispatch.add_dispatch_support 

769def cross_hashed(inputs, num_buckets=0, hash_key=None, name=None): 

770 """Generates hashed feature cross from a list of tensors. 

771 

772 The input tensors must have `rank=2`, and must all have the same number of 

773 rows. The result is a `RaggedTensor` with the same number of rows as the 

774 inputs, where `result[row]` contains a list of all combinations of values 

775 formed by taking a single value from each input's corresponding row 

776 (`inputs[i][row]`). Values are combined by hashing together their 

777 fingerprints. E.g.: 

778 

779 >>> tf.ragged.cross_hashed([tf.ragged.constant([['a'], ['b', 'c']]), 

780 ... tf.ragged.constant([['d'], ['e']]), 

781 ... tf.ragged.constant([['f'], ['g']])], 

782 ... num_buckets=100) 

783 <tf.RaggedTensor [[78], [66, 74]]> 

784 

785 Args: 

786 inputs: A list of `RaggedTensor` or `Tensor` or `SparseTensor`. 

787 num_buckets: A non-negative `int` that used to bucket the hashed values. If 

788 `num_buckets != 0`, then `output = hashed_value % num_buckets`. 

789 hash_key: Integer hash_key that will be used by the `FingerprintCat64` 

790 function. If not given, a default key is used. 

791 name: Optional name for the op. 

792 

793 Returns: 

794 A 2D `RaggedTensor` of type `int64`. 

795 """ 

796 return _cross_internal( 

797 inputs=inputs, 

798 hashed_output=True, 

799 num_buckets=num_buckets, 

800 hash_key=hash_key, 

801 name=name) 

802 

803 

804_DEFAULT_CROSS_HASH_KEY = 0xDECAFCAFFE 

805 

806 

807def _cross_internal(inputs, 

808 hashed_output=False, 

809 num_buckets=0, 

810 hash_key=None, 

811 name=None): 

812 """Generates feature cross from a list of ragged and dense tensors.""" 

813 if not isinstance(inputs, (tuple, list)): 

814 raise TypeError('Inputs must be a list') 

815 

816 if hash_key is None: 

817 hash_key = _DEFAULT_CROSS_HASH_KEY 

818 

819 ragged_inputs = [] 

820 sparse_inputs = [] 

821 dense_inputs = [] 

822 input_order = [] 

823 with ops.name_scope(name, 'RaggedCross', inputs): 

824 for i, t in enumerate(inputs): 

825 if sparse_tensor.is_sparse(t): 

826 t = sparse_tensor.SparseTensor.from_value(t) 

827 else: 

828 t = ragged_tensor.convert_to_tensor_or_ragged_tensor(t) 

829 if t.dtype.is_integer: 

830 t = math_ops.cast(t, dtypes.int64) 

831 elif t.dtype != dtypes.string: 

832 raise ValueError('Unexpected dtype for inputs[%d]: %s' % (i, t.dtype)) 

833 if isinstance(t, ragged_tensor.RaggedTensor): 

834 if t.ragged_rank != 1: 

835 raise ValueError('tf.ragged.cross only supports inputs with rank=2') 

836 ragged_inputs.append(t) 

837 input_order.append('R') 

838 elif isinstance(t, sparse_tensor.SparseTensor): 

839 sparse_inputs.append(t) 

840 input_order.append('S') 

841 else: 

842 dense_inputs.append(t) 

843 input_order.append('D') 

844 

845 out_values_type = dtypes.int64 if hashed_output else dtypes.string 

846 if ragged_inputs and all( 

847 t.row_splits.dtype == dtypes.int32 for t in ragged_inputs): 

848 out_row_splits_type = dtypes.int32 

849 else: 

850 out_row_splits_type = dtypes.int64 

851 

852 # Convert hash_key from uint64 -> int64, since we need to pass it via 

853 # an int64 attr. 

854 if hash_key > 2**63: 

855 hash_key -= 2**64 

856 

857 values_out, splits_out = gen_ragged_array_ops.ragged_cross( 

858 ragged_values=[rt.values for rt in ragged_inputs], 

859 ragged_row_splits=[rt.row_splits for rt in ragged_inputs], 

860 sparse_indices=[st.indices for st in sparse_inputs], 

861 sparse_values=[st.values for st in sparse_inputs], 

862 sparse_shape=[st.dense_shape for st in sparse_inputs], 

863 dense_inputs=dense_inputs, 

864 input_order=''.join(input_order), 

865 hashed_output=hashed_output, 

866 num_buckets=num_buckets, 

867 hash_key=hash_key, 

868 out_values_type=out_values_type.as_datatype_enum, 

869 out_row_splits_type=out_row_splits_type.as_datatype_enum, 

870 name=name) 

871 

872 return ragged_tensor.RaggedTensor.from_row_splits( 

873 values_out, splits_out, validate=False) 

874 

875 

876def fill_empty_rows(ragged_input, default_value, name=None): 

877 """Fills empty rows in the input `RaggedTensor` with rank 2 with a default 

878 

879 value. 

880 

881 This op adds entries with the specified `default_value` for any row in the 

882 input that does not already have a value. 

883 

884 The op also returns an indicator vector such that 

885 

886 empty_row_indicator[i] = True iff row i was an empty row. 

887 

888 Args: 

889 ragged_input: A `RaggedTensor` with rank 2. 

890 default_value: The value to fill for empty rows, with the same type as 

891 `ragged_input.` 

892 name: A name prefix for the returned tensors (optional) 

893 

894 Returns: 

895 ragged_ordered_output: A `RaggedTensor`with all empty rows filled in with 

896 `default_value`. 

897 empty_row_indicator: A bool vector indicating whether each input row was 

898 empty. 

899 

900 Raises: 

901 TypeError: If `ragged_input` is not a `RaggedTensor`. 

902 """ 

903 with ops.name_scope(name, 'RaggedFillEmptyRows', [ragged_input]): 

904 if not isinstance(ragged_input, ragged_tensor.RaggedTensor): 

905 raise TypeError( 

906 'ragged_input must be RaggedTensor, got' 

907 f' {type(ragged_input)}' 

908 ) 

909 default_value = ops.convert_to_tensor( 

910 default_value, dtype=ragged_input.dtype 

911 ) 

912 ( 

913 output_value_rowids, 

914 output_values, 

915 empty_row_indicator, 

916 unused_reverse_index_map, 

917 ) = gen_ragged_array_ops.ragged_fill_empty_rows( 

918 value_rowids=ragged_input.value_rowids(), 

919 values=ragged_input.values, 

920 nrows=ragged_input.nrows(), 

921 default_value=default_value, 

922 ) 

923 return ( 

924 ragged_tensor.RaggedTensor.from_value_rowids( 

925 values=output_values, 

926 value_rowids=output_value_rowids, 

927 validate=False, 

928 ), 

929 empty_row_indicator, 

930 ) 

931 

932 

933@ops.RegisterGradient('RaggedFillEmptyRows') 

934def _ragged_fill_empty_rows_grad( 

935 op, 

936 unused_grad_output_indices, 

937 output_grad_values, 

938 unused_grad_empty_row_indicator, 

939 unused_grad_reverse_index_map, 

940): 

941 """Gradients for RaggedFillEmptyRows.""" 

942 reverse_index_map = op.outputs[3] 

943 

944 d_values, d_default_value = gen_ragged_array_ops.ragged_fill_empty_rows_grad( 

945 reverse_index_map=reverse_index_map, grad_values=output_grad_values 

946 ) 

947 

948 # d_value_rowids, d_values, d_nrows, d_default_value. 

949 return [None, d_values, None, d_default_value] 

950 

951 

952# =============================================================================== 

953# dynamic_partition 

954# =============================================================================== 

955@dispatch.dispatch_for_api(data_flow_ops.dynamic_partition) 

956def dynamic_partition(data: ragged_tensor.RaggedOrDense, 

957 partitions: ragged_tensor.RaggedOrDense, 

958 num_partitions, 

959 name=None): 

960 """RaggedTensor dispatch override for tf.dynamic_partition.""" 

961 if not isinstance(num_partitions, int) or num_partitions < 0: 

962 raise TypeError('num_partitions must be a non-negative integer') 

963 result = stack_dynamic_partitions(data, partitions, num_partitions, name) 

964 return [result[i] for i in range(num_partitions)] 

965 

966 

967# =============================================================================== 

968# split 

969# =============================================================================== 

970@dispatch.dispatch_for_api(array_ops.split) 

971def split(value: ragged_tensor.Ragged, 

972 num_or_size_splits, 

973 axis=0, 

974 num=None, 

975 name=None): 

976 """Splits a RaggedTensor `value` into a list of sub RaggedTensors. 

977 

978 If `num_or_size_splits` is an `int`, then it splits `value` along the 

979 dimension `axis` into `num_or_size_splits` smaller RaggedTensors. This 

980 requires that `value.shape[axis]` is divisible by `num_or_size_splits`. 

981 

982 If `num_or_size_splits` is a 1-D Tensor (or list), then `value` is split into 

983 `len(num_or_size_splits)` elements. The shape of the `i`-th element has the 

984 same size as the `value` except along dimension `axis` where the size is 

985 `num_or_size_splits[i]`. 

986 

987 Splits along a ragged dimension is not allowed. 

988 

989 For example: 

990 

991 >>> rt = tf.RaggedTensor.from_row_lengths( 

992 ... np.arange(6 * 3).reshape(6, 3), row_lengths=[1, 2, 2, 1]) 

993 >>> rt.shape 

994 TensorShape([4, None, 3]) 

995 >>> 

996 >>> rt1, rt2 = tf.split(rt, 2) # uniform splits 

997 >>> rt1.shape 

998 TensorShape([2, None, 3]) 

999 >>> rt2.shape 

1000 TensorShape([2, None, 3]) 

1001 >>> 

1002 >>> rt3, rt4, rt5 = tf.split(rt, [1, 2, 1]) # ragged splits 

1003 >>> rt3.shape 

1004 TensorShape([1, None, 3]) 

1005 >>> rt4.shape 

1006 TensorShape([2, None, 3]) 

1007 >>> rt5.shape 

1008 TensorShape([1, None, 3]) 

1009 >>> 

1010 >>> rt6, rt7 = tf.split(rt, [1, 2], axis=2) # splits along axis 2 

1011 >>> rt6.shape 

1012 TensorShape([4, None, 1]) 

1013 >>> rt7.shape 

1014 TensorShape([4, None, 2]) 

1015 

1016 Args: 

1017 value: The `RaggedTensor` to split. 

1018 num_or_size_splits: Either an `int` indicating the number of splits 

1019 along `axis` or a 1-D integer `Tensor` or Python list containing the sizes 

1020 of each output tensor along `axis`. If a Python int, then it must evenly 

1021 divide `value.shape[axis]`; otherwise the sum of sizes along the split 

1022 axis must match that of the `value`. 

1023 axis: An `int` or scalar `int32` `Tensor`. The dimension along which 

1024 to split. Must be in the range `[-rank(value), rank(value))`. Defaults to 

1025 0. 

1026 num: An `int` used to specify the number of outputs when 

1027 `num_or_size_splits` is a 1-D list or `Tensor` and its length is 

1028 statically unknown, e.g., specifying `tf.TensorSepc(None)` with 

1029 the `input_signature` argument of `tf.function` (optional). 

1030 name: A name for the operation (optional). 

1031 

1032 Returns: 

1033 if `num_or_size_splits` is an `int` returns a list of `num_or_size_splits` 

1034 `RaggedTensor` objects; if `num_or_size_splits` is a 1-D Tensor returns 

1035 `num_or_size_splits.get_shape[0]` `RaggedTensor` objects resulting from 

1036 splitting `value`. 

1037 

1038 Raises: 

1039 ValueError: If the dimension `axis` of `value` is a ragged dimension. 

1040 ValueError: If `num` is unspecified and cannot be inferred. 

1041 ValueError: If `num` is specified but doesn't match the length of 

1042 `num_or_size_splits`. 

1043 ValueError: If `num_or_size_splits` is an `int` and less than 1. 

1044 TypeError: If `num_or_size_splits` is not an `int` or 1-D 

1045 list or 1-D `Tensor`. 

1046 InvalidArgumentError: If the `axis` of `value` cannot be exactly splitted 

1047 by `num_or_size_splits`. 

1048 InvalidArgumentError: If `num_or_size_splits` is contains negative integers. 

1049 InvalidArgumentError: If `num_or_size_splits`'s static shape is unknown and 

1050 its dynamic shape is inconsistent `num`. 

1051 InvalidArgumentError: If `num_or_size_splits`'s static rank is unknown and 

1052 `axis` is a negative integer. 

1053 """ 

1054 with ops.name_scope(name, 'RaggedSplit'): 

1055 value = ragged_tensor.convert_to_tensor_or_ragged_tensor( 

1056 value, name='value') 

1057 if isinstance(num_or_size_splits, int) and num_or_size_splits == 1: 

1058 return [value] 

1059 

1060 # static assert 

1061 check_ops.assert_integer_v2( 

1062 num_or_size_splits, 

1063 message=('`num_or_size_splits` must be an `int` or 1-D list or ' 

1064 '`Tensor` of integers.')) 

1065 value_shape = dynamic_ragged_shape.DynamicRaggedShape.from_tensor(value) 

1066 axis = array_ops.get_positive_axis(axis, value_shape.rank) 

1067 try: 

1068 dim_size = value_shape[axis] 

1069 except ValueError: 

1070 raise ValueError('Cannot split a ragged dimension. Got `value` with ' 

1071 f'shape {value_shape} and `axis` {axis}.') 

1072 if isinstance(num_or_size_splits, int): 

1073 # Uniform split 

1074 num_splits = num_or_size_splits 

1075 if num_splits < 1: 

1076 raise ValueError('`num_or_size_splits` must be >=1 if it is an `int`.' 

1077 f'Received {num_or_size_splits}.') 

1078 split_length = math_ops.floordiv(dim_size, num_splits) 

1079 split_lengths = array_ops.repeat(split_length, num_splits) 

1080 else: 

1081 # Ragged split 

1082 num_splits = None 

1083 split_lengths = ops.convert_to_tensor(num_or_size_splits) 

1084 if split_lengths.shape.ndims is not None: 

1085 if split_lengths.shape.ndims != 1: 

1086 raise TypeError('`num_or_size_splits` must be an `int` or 1-D list ' 

1087 f'or `Tensor`. Received {num_or_size_splits}.') 

1088 num_splits = tensor_shape.dimension_value(split_lengths.shape[0]) 

1089 

1090 if num_splits is None: 

1091 if num is None: 

1092 raise ValueError('`num` must be specified as an `int` when the ' 

1093 'size of `num_or_size_split` is statically ' 

1094 f'unknown. Received `num`: {num} and ' 

1095 f'`num_or_size_split`: {num_or_size_splits}.') 

1096 num_splits = num 

1097 else: 

1098 if num is not None and num != num_splits: 

1099 raise ValueError('`num` does not match the size of ' 

1100 f'`num_or_size_split`. Received `num`: {num} and ' 

1101 f'size of `num_or_size_split`: {num_splits}.') 

1102 

1103 splits = array_ops.concat([[0], math_ops.cumsum(split_lengths)], axis=0) 

1104 checks = [] 

1105 checks.append( 

1106 check_ops.assert_non_negative_v2( 

1107 num_or_size_splits, 

1108 message='`num_or_size_splits` must be non-negative.')) 

1109 checks.append( 

1110 check_ops.assert_equal_v2( 

1111 num_splits, 

1112 array_ops.shape(split_lengths)[0], 

1113 message='`num` is inconsistent with `num_or_size_split.shape[0]`.')) 

1114 checks.append( 

1115 check_ops.assert_equal_v2( 

1116 math_ops.cast(dim_size, splits.dtype), 

1117 splits[-1], 

1118 message=('Cannot exactly split the `axis` dimension of `value` ' 

1119 'with the given `num_or_size_split`.'))) 

1120 splits = control_flow_ops.with_dependencies(checks, splits) 

1121 splited_rts = [] 

1122 slices = [slice(None)] * (axis + 1) 

1123 for i in range(num_splits): 

1124 slices[-1] = slice(splits[i], splits[i + 1]) 

1125 splited_rts.append(value[tuple(slices)]) 

1126 return splited_rts 

1127 

1128 

1129# =============================================================================== 

1130# RaggedTensor shape operations 

1131# =============================================================================== 

1132 

1133 

1134@dispatch.dispatch_for_api(array_ops.reshape) 

1135def ragged_reshape( 

1136 tensor: ragged_tensor.RaggedOrDense, 

1137 shape: dynamic_ragged_shape.DenseOrRaggedShape 

1138) -> Union[ragged_tensor.RaggedTensor, ops.Tensor]: 

1139 """Reshapes a tensor or ragged tensor.""" 

1140 tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor( 

1141 tensor, name='tensor') 

1142 if isinstance(tensor, ragged_tensor.RaggedTensor): 

1143 tensor = tensor.values 

1144 

1145 if isinstance(shape, dynamic_ragged_shape.DynamicRaggedShape): 

1146 flat_values = array_ops.reshape(tensor, shape.inner_shape) 

1147 return ragged_tensor.RaggedTensor._from_nested_row_partitions( # pylint: disable=protected-access 

1148 flat_values, 

1149 shape.row_partitions, 

1150 validate=False) 

1151 else: 

1152 shape = ops.convert_to_tensor(shape, name='shape') 

1153 return array_ops.reshape(tensor, shape) 

1154 

1155 

1156@dispatch.dispatch_for_api(array_ops.broadcast_to) 

1157def broadcast_to( 

1158 input: ragged_tensor.RaggedOrDense, # pylint: disable=redefined-builtin 

1159 shape: dynamic_ragged_shape.DynamicRaggedShape 

1160) -> Union[ragged_tensor.RaggedTensor, ops.Tensor]: 

1161 """Broadcasts a potentially ragged tensor to a ragged shape. 

1162 

1163 Tiles `input` as necessary to match the given shape. 

1164 

1165 Behavior is undefined if `input` is not broadcast-compatible with `shape`. 

1166 

1167 Args: 

1168 input: The potentially ragged tensor to broadcast. 

1169 shape: A `DynamicRaggedShape` 

1170 

1171 Returns: 

1172 A potentially ragged tensor whose values are taken from 

1173 `input`, and whose shape matches `shape`. 

1174 """ 

1175 return dynamic_ragged_shape.broadcast_to(input, shape) 

1176 

1177 

1178# Note: default value for out_type needs to be int32, to match the 

1179# default for tf.shape's out_type parameter. 

1180@dispatch.dispatch_for_api(array_ops.shape) 

1181def ragged_shape( 

1182 input: ragged_tensor.Ragged, # pylint: disable=redefined-builtin 

1183 name: Optional[str] = None, 

1184 out_type=dtypes.int32) -> dynamic_ragged_shape.DynamicRaggedShape: 

1185 """Returns the shape of a RaggedTensor. 

1186 

1187 Args: 

1188 input: A `RaggedTensor` 

1189 name: A name for the operation (optional). 

1190 out_type: dtype used to encode the shape. 

1191 

1192 Returns: 

1193 A `tf.experimental.DynamicRaggedShape` 

1194 """ 

1195 with ops.name_scope(name, 'RaggedShape', [input]): 

1196 return dynamic_ragged_shape.DynamicRaggedShape.from_tensor(input, out_type) 

1197 

1198 

1199@dispatch.dispatch_for_api(array_ops.broadcast_dynamic_shape) 

1200def broadcast_dynamic_shape( 

1201 shape_x: dynamic_ragged_shape.DenseOrRaggedShape, 

1202 shape_y: dynamic_ragged_shape.DenseOrRaggedShape 

1203) -> dynamic_ragged_shape.DynamicRaggedShape: 

1204 """Returns the shape formed by broadcasting two shapes to be compatible. 

1205 

1206 1. If shape_x and shape_y both have row_partitions, then fail if their dtypes 

1207 don't match. 

1208 2. If neither has row_partitions and they have different dtypes, 

1209 go with int64. 

1210 3. If one has row_partitions, go with that dtype. 

1211 

1212 Args: 

1213 shape_x: A `DynamicRaggedShape` 

1214 shape_y: A `DynamicRaggedShape` 

1215 

1216 Returns: 

1217 A `DynamicRaggedShape`. 

1218 Raises: 

1219 ValueError: If `shape_x` and `shape_y` are not broadcast-compatible. 

1220 """ 

1221 if not isinstance(shape_x, dynamic_ragged_shape.DynamicRaggedShape): 

1222 shape_x = dynamic_ragged_shape.DynamicRaggedShape([], shape_x) 

1223 if not isinstance(shape_y, dynamic_ragged_shape.DynamicRaggedShape): 

1224 shape_y = dynamic_ragged_shape.DynamicRaggedShape([], shape_y) 

1225 return dynamic_ragged_shape.broadcast_dynamic_shape(shape_x, shape_y) 

1226 

1227 

1228@dispatch.dispatch_for_api(array_ops.ones) 

1229def ones(shape: dynamic_ragged_shape.DynamicRaggedShape, 

1230 dtype=dtypes.float32, 

1231 name=None) -> ragged_tensor.RaggedOrDense: 

1232 """Returns ones shaped like x.""" 

1233 flat_values = array_ops.ones(shape.inner_shape, dtype=dtype, name=name) 

1234 return shape._add_row_partitions(flat_values) # pylint: disable=protected-access 

1235 

1236 

1237@dispatch.dispatch_for_api(array_ops.zeros) 

1238def zeros(shape: dynamic_ragged_shape.DynamicRaggedShape, 

1239 dtype=dtypes.float32, 

1240 name=None) -> ragged_tensor.RaggedOrDense: 

1241 """Returns ones shaped like x.""" 

1242 flat_values = array_ops.zeros(shape.inner_shape, dtype=dtype, name=name) 

1243 return shape._add_row_partitions(flat_values) # pylint: disable=protected-access 

1244 

1245 

1246@dispatch.dispatch_for_api(array_ops.fill) 

1247def fill(dims: dynamic_ragged_shape.DynamicRaggedShape, 

1248 value: core_types.TensorLike, 

1249 name: Optional[str] = None) -> ragged_tensor.RaggedOrDense: 

1250 """Creates a tensor with shape `dims` and fills it with `value`.""" 

1251 flat_values = array_ops.fill(dims.inner_shape, value, name=name) 

1252 return dims._add_row_partitions(flat_values) # pylint: disable=protected-access 

1253 

1254 

1255# =============================================================================== 

1256# bitcast 

1257# =============================================================================== 

1258@dispatch.dispatch_for_api(array_ops.bitcast) 

1259def bitcast( 

1260 input: ragged_tensor.RaggedOrDense, # pylint: disable=redefined-builtin 

1261 type, # pylint: disable=redefined-builtin 

1262 name=None) -> ragged_tensor.RaggedOrDense: 

1263 """RaggedTensor dispatch override for tf.bitcast.""" 

1264 type = dtypes.as_dtype(type) 

1265 with ops.name_scope(name, 'Bitcast', [input]): 

1266 input = ragged_tensor.convert_to_tensor_or_ragged_tensor( 

1267 input, name='input') 

1268 if (input.dtype.size < type.size and input.flat_values.shape.rank < 2): 

1269 raise ValueError('`input.flat_values` is required to have rank >= 2 when ' 

1270 'input.dtype.size < type.size. Actual rank: ' 

1271 f'{input.flat_values.shape.rank}') 

1272 return input.with_flat_values(array_ops.bitcast(input.flat_values, type))