Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ragged/ragged_tensor_shape.py: 17%

269 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Shapes & broadcasting for RaggedTensors.""" 

16 

17from tensorflow.python.framework import constant_op 

18from tensorflow.python.framework import dtypes 

19from tensorflow.python.framework import ops 

20from tensorflow.python.framework import tensor_shape 

21from tensorflow.python.framework import tensor_util 

22from tensorflow.python.ops import array_ops 

23from tensorflow.python.ops import array_ops_stack 

24from tensorflow.python.ops import control_flow_assert 

25from tensorflow.python.ops import math_ops 

26from tensorflow.python.ops.ragged import ragged_array_ops 

27from tensorflow.python.ops.ragged import ragged_config 

28from tensorflow.python.ops.ragged import ragged_tensor 

29from tensorflow.python.ops.ragged import ragged_util 

30 

31 

32class RaggedTensorDynamicShape: 

33 """A collection of tensors encoding the shape of a potentially ragged tensor. 

34 

35 Each `RaggedTensorDynamicShape` consists of an ordered list of dimension 

36 sizes. There are two dimension types: 

37 

38 * "Uniform dimensions" are dimensions where all slices have the same 

39 length. `RaggedTensorDynamicShape` records the size of each uniform 

40 dimension using a single scalar integer. 

41 

42 * "Ragged dimensions" are dimensions whose slices may have different 

43 lengths. `RaggedTensorDynamicShape` records the size of each ragged 

44 dimension using an integer vector containing the slice lengths for all 

45 the slices across that dimension. 

46 

47 Furthermore, there are two ways a dimension might be encoded: 

48 

49 * "Partitioned dimensions" are dimensions that are encoded using a 

50 `RaggedTensor`'s `nested_row_splits`. The outermostmost partitioned 

51 dimension must be uniform, and the innermost partitioned dimension must 

52 be ragged. 

53 

54 * "Inner dimensions" are dimensions that are encoded using a 

55 `RaggedTensor`'s `flat_values`. Inner dimensions are always uniform. 

56 

57 The sizes of partitioned dimensions are recorded using `partitioned_dim_sizes` 

58 and `inner_dim_sizes`: 

59 

60 * `partitioned_dim_sizes` is a list of tensors (one for each partitioned 

61 dimension). 

62 

63 * For uniform dimensions, the tensor is an integer scalar specifying the 

64 size of all slices across that dimension. 

65 * For ragged dimensions, the tensor is an integer vector specifying the 

66 size of each slice across that dimension. 

67 

68 * `inner_dim_sizes` is a single integer vector, where each element 

69 specifies the size of a single inner dimension. 

70 

71 Examples: 

72 

73 Tensor | Ragged | Partitioned Dim Sizes | Inner Dim 

74 : Rank : : Sizes 

75 ------------------------------ | ------ | ---------------------- | ---------- 

76 `[[1, 2, 3], [4, 5, 6]]` | 0 | | `2, 3` 

77 `[[1, 2], [], [3, 4, 5]]` | 1 | `3, (2, 0, 3)` | 

78 `[[[1, 2], [3, 4]], [[5, 6]]]` | 1 | `2, (2, 1)` | 2 

79 `[[[1, 2], [3]], [[4, 5]]]` | 2 | `2, (2, 1), (2, 1, 2)` | 

80 """ 

81 

82 def __init__(self, partitioned_dim_sizes, inner_dim_sizes, 

83 dim_size_dtype=None): 

84 """Creates a RaggedTensorDynamicShape. 

85 

86 Args: 

87 partitioned_dim_sizes: A `list` of 0-D or 1-D integer `Tensor`, one for 

88 each partitioned dimension. If dimension `d` is uniform, then 

89 `partitioned_dim_sizes[d]` must be an integer scalar, specifying the 

90 size of all slices across dimension `d`. If dimension `d` is ragged, 

91 then `partitioned_dim_sizes[d]` must be an integer vector, specifying 

92 the size of each slice across dimension `d`. 

93 inner_dim_sizes: A 1-D integer `Tensor`, whose length is equal to the 

94 number of inner dimensions. `inner_dim_sizes[n]` is the size of all 

95 slices across the `n`th inner dimension (which is the 

96 `(len(partitioned_dim_sizes)+n)`th dimension in the overall tensor. 

97 dim_size_dtype: dtype for dimension sizes. If not specified, then it 

98 is chosen based on the dtypes of `partitioned_dim_sizes` and 

99 `inner_dim_sizes`. 

100 """ 

101 assert isinstance(partitioned_dim_sizes, (list, tuple)) 

102 

103 with ops.name_scope(None, 'RaggedTensorDynamicShape', 

104 (partitioned_dim_sizes, inner_dim_sizes)): 

105 partitioned_dim_sizes = tuple( 

106 ops.convert_to_tensor(size, name='partitioned_dimension_size_%d' % i) 

107 for (i, size) in enumerate(partitioned_dim_sizes)) 

108 inner_dim_sizes = ops.convert_to_tensor( 

109 inner_dim_sizes, name='inner_dim_sizes') 

110 

111 # Validate shapes. 

112 if partitioned_dim_sizes: 

113 for axis, dimension_size in enumerate(partitioned_dim_sizes): 

114 if dimension_size.shape.ndims is None: 

115 raise ValueError( 

116 'rank of partitioned_dim_sizes[%d] is unknown' % axis) 

117 dimension_size.shape.with_rank_at_most(1) 

118 if partitioned_dim_sizes[0].shape.ndims == 1: 

119 raise ValueError('outermost partitioned dimension must be uniform') 

120 if partitioned_dim_sizes[-1].shape.ndims == 0: 

121 raise ValueError('innermost partitioned dimension must be ragged') 

122 inner_dim_sizes.shape.assert_has_rank(1) 

123 

124 # Convert dimension size tensors to a single dtype. 

125 if dim_size_dtype is None: 

126 dim_size_dtypes = set( 

127 p.dtype for p in partitioned_dim_sizes if p.shape.ndims == 1) 

128 if not dim_size_dtypes: 

129 dim_size_dtype = dtypes.int64 

130 elif len(dim_size_dtypes) == 1: 

131 dim_size_dtype = dim_size_dtypes.pop() 

132 else: 

133 if not ragged_config.auto_cast_partition_dtype(): 

134 raise ValueError('partitioned_dim_sizes must have matching dtypes') 

135 dim_size_dtype = dtypes.int64 

136 partitioned_dim_sizes = tuple(math_ops.cast(p, dim_size_dtype) 

137 for p in partitioned_dim_sizes) 

138 inner_dim_sizes = math_ops.cast(inner_dim_sizes, dim_size_dtype) 

139 

140 self._partitioned_dim_sizes = partitioned_dim_sizes 

141 self._inner_dim_sizes = inner_dim_sizes 

142 

143 def __repr__(self): 

144 return ('RaggedTensorDynamicShape' 

145 '(partitioned_dim_sizes=%r, inner_dim_sizes=%r)' % 

146 (self._partitioned_dim_sizes, self._inner_dim_sizes)) 

147 

148 @staticmethod 

149 def from_dim_sizes(dim_sizes): 

150 """Constructs a ragged shape from a list of dimension sizes. 

151 

152 This list contains a single tensor for each dimension, where the tensor 

153 is a scalar if the dimension is uniform, or a vector if the dimension is 

154 ragged. 

155 

156 Args: 

157 dim_sizes: List of int32 or int64 scalars or vectors. 

158 

159 Returns: 

160 A RaggedTensorDynamicShape. 

161 """ 

162 with ops.name_scope(None, 'RaggedTensorDynamicShapeFromDimensionSizes', 

163 [dim_sizes]): 

164 dim_sizes = tuple( 

165 ops.convert_to_tensor(size, preferred_dtype=dtypes.int64, 

166 name='dim_sizes') for size in dim_sizes) 

167 # Split the dimensions into partitioned & inner dimensions. 

168 inner_split = 0 

169 for dim, dim_size in enumerate(dim_sizes): 

170 if dim_size.shape.ndims == 1: 

171 inner_split = dim + 1 

172 elif dim_size.shape.ndims != 0: 

173 raise ValueError('Each dim_size must be a scalar or a vector') 

174 return RaggedTensorDynamicShape(dim_sizes[:inner_split], 

175 dim_sizes[inner_split:]) 

176 

177 @classmethod 

178 def from_tensor(cls, rt_input, dim_size_dtype=None): 

179 """Constructs a ragged shape for a potentially ragged tensor.""" 

180 with ops.name_scope(None, 'RaggedTensorDynamicShapeFromTensor', [rt_input]): 

181 rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt_input) 

182 if not ragged_tensor.is_ragged(rt_input): 

183 return cls([], array_ops.shape(rt_input), dim_size_dtype=dim_size_dtype) 

184 else: 

185 partitioned_dim_sizes = ( 

186 (rt_input.nrows(),) + rt_input.nested_row_lengths()) 

187 return RaggedTensorDynamicShape( 

188 partitioned_dim_sizes, 

189 array_ops.shape(rt_input.flat_values)[1:], 

190 dim_size_dtype=dim_size_dtype) 

191 

192 def dimension_size(self, axis): 

193 """Returns the size of slices across the specified dimension.""" 

194 if not isinstance(axis, int): 

195 raise TypeError('axis must be an integer') 

196 partitioned_ndims = len(self._partitioned_dim_sizes) 

197 if axis < partitioned_ndims: 

198 return self._partitioned_dim_sizes[axis] 

199 else: 

200 return self._inner_dim_sizes[axis - partitioned_ndims] 

201 

202 def is_ragged(self, axis): 

203 """Returns true if the indicated dimension is ragged.""" 

204 if not isinstance(axis, int): 

205 raise TypeError('axis must be an integer') 

206 rank = self.rank 

207 if axis < 0: 

208 raise ValueError('Negative axis values are not supported') 

209 elif rank is not None and axis >= rank: 

210 raise ValueError('Expected axis=%s < rank=%s' % (axis, rank)) 

211 else: 

212 return (axis > 0 and axis < len(self._partitioned_dim_sizes) and 

213 self._partitioned_dim_sizes[axis].shape.ndims == 1) 

214 

215 @property 

216 def rank(self): 

217 """The number of dimensions in this shape, or None if unknown.""" 

218 inner_ndims = tensor_shape.dimension_value(self._inner_dim_sizes.shape[0]) 

219 if inner_ndims is None: 

220 return None 

221 else: 

222 return len(self._partitioned_dim_sizes) + inner_ndims 

223 

224 @property 

225 def partitioned_dim_sizes(self): 

226 """The partitioned dimension sizes for this shape. 

227 

228 Returns: 

229 A `list` of 0-D or 1-D integer `Tensor`. 

230 """ 

231 return self._partitioned_dim_sizes 

232 

233 @property 

234 def inner_dim_sizes(self): 

235 """The inner dimension sizes for this shape. 

236 

237 Returns: 

238 A 1-D integer `Tensor`. 

239 """ 

240 return self._inner_dim_sizes 

241 

242 @property 

243 def num_partitioned_dimensions(self): 

244 """The number of partitioned dimensions in this shape.""" 

245 return len(self._partitioned_dim_sizes) 

246 

247 @property 

248 def num_inner_dimensions(self): 

249 """The number of inner dimensions, or `None` if not statically known.""" 

250 return tensor_shape.dimension_value(self._inner_dim_sizes.shape[0]) 

251 

252 @property 

253 def dim_size_dtype(self): 

254 """DType used by this shape for dimension sizes.""" 

255 return self._inner_dim_sizes.dtype 

256 

257 def broadcast_to_rank(self, rank): 

258 """Adds leading size-1 dimensions to broadcast `self` to the given rank. 

259 

260 E.g., if `shape1` is `[3, (D2), 4]`, then `shape1.broadcast_to_rank(5)` 

261 is `[1, 1, 3, (D2), 4]`. 

262 

263 Args: 

264 rank: The rank for the returned shape. 

265 

266 Returns: 

267 A RaggedTensorDynamicShape with `rank` dimensions, whose inner dimensions 

268 have the same size as `self` and whose outer dimensions have size `1`. 

269 

270 Raises: 

271 ValueError: If `self.rank` is unknown or greater than `rank`. 

272 """ 

273 if self.rank is None: 

274 raise ValueError('Unable to broadcast: self.rank is unknown') 

275 dims_to_add = rank - self.rank 

276 if dims_to_add < 0: 

277 raise ValueError('Unable to broadcast: rank=%d must be greater than ' 

278 'self.rank=%d.' % (rank, self.rank)) 

279 elif dims_to_add == 0: 

280 return self 

281 elif self._partitioned_dim_sizes: 

282 partitioned_dims = (1,) * dims_to_add + self._partitioned_dim_sizes 

283 return RaggedTensorDynamicShape(partitioned_dims, self.inner_dim_sizes, 

284 self.dim_size_dtype) 

285 else: 

286 inner_dims = array_ops.concat( 

287 [array_ops.ones([dims_to_add], self.dim_size_dtype), 

288 self.inner_dim_sizes], 

289 axis=0) 

290 return RaggedTensorDynamicShape([], inner_dims, self.dim_size_dtype) 

291 

292 def broadcast_dimension(self, axis, lengths): 

293 """Returns a shape that is broadcast-compatible with self & lengths. 

294 

295 * If dimension[axis] is uniform and lengths is a scalar, the check 

296 that either lengths==1 or axis==1 or lengths==axis, and tile 

297 dimension[axis] with tf.where(lengths==axis, 1, axis) repeats. 

298 

299 * If dimension[axis] is uniform and lengths is a vector, then check 

300 that dimension[axis]==1, and raggedly tile dimension[axis] with 

301 lengths repeats. (we can skip tiling if we statically know that 

302 slice_lengths == 1??) 

303 

304 * If dimension[axis] is ragged and lengths is a scalar, then check 

305 that lengths==1. 

306 

307 * If dimension[axis] is ragged and lengths is a vector, then check 

308 that self.dimension_size(axis) == lengths. 

309 

310 Args: 

311 axis: `int`. The dimension to broadcast. 

312 lengths: 0-D or 1-D integer `Tensor`. 

313 

314 Returns: 

315 A `RaggedTensorDynamicShape`. 

316 """ 

317 lengths = ragged_util.convert_to_int_tensor( 

318 lengths, name='lengths', dtype=self.dim_size_dtype) 

319 # Check whether lengths is a scalar (for uniform dimensions) or 

320 # vector (for ragged dimensions). 

321 if lengths.shape.ndims is None: 

322 raise ValueError('lengths must have a known rank.') 

323 elif lengths.shape.ndims > 1: 

324 raise ValueError('lengths must be a scalar or vector') 

325 else: 

326 lengths_is_scalar = (lengths.shape.ndims == 0) 

327 

328 # Verify that the shapes are compatible. 

329 if self.is_ragged(axis): 

330 if lengths_is_scalar: 

331 condition = math_ops.equal(lengths, 1) 

332 else: 

333 condition = math_ops.reduce_all( 

334 math_ops.equal(lengths, self.dimension_size(axis))) 

335 else: 

336 axis_dim_size = self.dimension_size(axis) 

337 if lengths_is_scalar: 

338 condition = ( 

339 math_ops.equal(lengths, 1) | math_ops.equal(axis_dim_size, 1) 

340 | math_ops.equal(axis_dim_size, lengths)) 

341 else: 

342 condition = math_ops.equal(axis_dim_size, 1) 

343 broadcast_err = [ 

344 'Unable to broadcast: dimension size mismatch in dimension', axis, 

345 'lengths=', lengths, 'dim_size=', 

346 self.dimension_size(axis) 

347 ] 

348 broadcast_check = control_flow_assert.Assert( 

349 condition, data=broadcast_err, summarize=10) 

350 

351 with ops.control_dependencies([broadcast_check]): 

352 # Partitioned dimensions: 

353 if axis < self.num_partitioned_dimensions: 

354 if self.is_ragged(axis): 

355 # Use an identity op to make sure the check actually gets run. 

356 return RaggedTensorDynamicShape( 

357 self._partitioned_dim_sizes, 

358 array_ops.identity(self.inner_dim_sizes), self.dim_size_dtype) 

359 else: 

360 return self._broadcast_uniform_partitioned_dimension(axis, lengths) 

361 

362 # Inner dimensions: 

363 else: 

364 if lengths_is_scalar: 

365 return self._broadcast_inner_dimension_to_uniform(axis, lengths) 

366 else: 

367 if axis == 0: 

368 raise ValueError('Unable to broadcast: ' 

369 'outermost dimension must be uniform.') 

370 return self._broadcast_inner_dimension_to_ragged(axis, lengths) 

371 

372 def num_slices_in_dimension(self, axis): 

373 """Returns the total number of slices across the indicated dimension.""" 

374 if axis < 0: 

375 return constant_op.constant(1, dtype=self.dim_size_dtype) 

376 elif self.is_ragged(axis): 

377 return math_ops.reduce_sum(self._partitioned_dim_sizes[axis]) 

378 else: 

379 return self.dimension_size(axis) * self.num_slices_in_dimension(axis - 1) 

380 

381 def _broadcast_uniform_partitioned_dimension(self, axis, lengths): 

382 """Broadcasts the partitioned dimension `axis` to match `lengths`.""" 

383 axis_dim_size = self.dimension_size(axis) 

384 partitioned_sizes = list(self._partitioned_dim_sizes[:axis]) 

385 

386 if lengths.shape.ndims == 0: 

387 lengths = array_ops.where( 

388 math_ops.equal(axis_dim_size, 1), lengths, axis_dim_size) 

389 repeats = array_ops.where(math_ops.equal(axis_dim_size, 1), lengths, 1) 

390 splits = array_ops_stack.stack([0, self.num_slices_in_dimension(axis)]) 

391 else: 

392 splits = math_ops.range( 

393 array_ops.size(lengths, out_type=self.dim_size_dtype) + 1) 

394 repeats = lengths 

395 

396 partitioned_sizes.append(lengths) 

397 

398 for dim_size in self._partitioned_dim_sizes[axis + 1:]: 

399 if dim_size.shape.ndims == 0: 

400 partitioned_sizes.append(dim_size) 

401 splits *= dim_size 

402 else: 

403 partitioned_sizes.append( 

404 ragged_util.repeat_ranges(dim_size, splits, repeats)) 

405 splits = array_ops.gather( 

406 ragged_util.lengths_to_splits(dim_size), splits) 

407 inner_sizes = self._inner_dim_sizes 

408 return RaggedTensorDynamicShape(partitioned_sizes, inner_sizes, 

409 self.dim_size_dtype) 

410 

411 def _broadcast_inner_dimension_to_uniform(self, axis, length): 

412 """Broadcasts the inner dimension `axis` to match `lengths`.""" 

413 dim_size = self.dimension_size(axis) 

414 axis_in_inner_dims = axis - self.num_partitioned_dimensions 

415 partitioned_sizes = self._partitioned_dim_sizes 

416 inner_sizes = array_ops.concat([ 

417 self._inner_dim_sizes[:axis_in_inner_dims], 

418 [array_ops.where(math_ops.equal(dim_size, 1), length, dim_size)], 

419 self._inner_dim_sizes[axis_in_inner_dims + 1:] 

420 ], 

421 axis=0) 

422 return RaggedTensorDynamicShape(partitioned_sizes, inner_sizes, 

423 self.dim_size_dtype) 

424 

425 def _broadcast_inner_dimension_to_ragged(self, axis, lengths): 

426 axis_in_inner_dims = axis - self.num_partitioned_dimensions 

427 partitioned_sizes = ( 

428 self._partitioned_dim_sizes + tuple([ 

429 self._inner_dim_sizes[i] for i in range(axis_in_inner_dims) 

430 ]) + (lengths,)) 

431 inner_sizes = self._inner_dim_sizes[axis_in_inner_dims + 1:] 

432 return RaggedTensorDynamicShape(partitioned_sizes, inner_sizes) 

433 

434 def with_dim_size_dtype(self, dtype): 

435 if dtype not in (dtypes.int32, dtypes.int64): 

436 raise ValueError('dtype must be int32 or int64') 

437 if self.dim_size_dtype == dtype: 

438 return self 

439 return RaggedTensorDynamicShape( 

440 [math_ops.cast(p, dtype) for p in self._partitioned_dim_sizes], 

441 math_ops.cast(self._inner_dim_sizes, dtype)) 

442 

443 

444def broadcast_dynamic_shape(shape_x, shape_y): 

445 """Returns the shape formed by broadcasting two shapes to be compatible. 

446 

447 Args: 

448 shape_x: A `RaggedTensorDynamicShape` 

449 shape_y: A `RaggedTensorDynamicShape` 

450 

451 Returns: 

452 A `RaggedTensorDynamicShape`. 

453 Raises: 

454 ValueError: If `shape_x` and `shape_y` are not broadcast-compatible. 

455 """ 

456 if not isinstance(shape_x, RaggedTensorDynamicShape): 

457 raise TypeError('shape_x must be a RaggedTensorDynamicShape') 

458 if not isinstance(shape_y, RaggedTensorDynamicShape): 

459 raise TypeError('shape_y must be a RaggedTensorDynamicShape') 

460 

461 # Broadcast both shapes to have the same rank. 

462 if shape_x.rank is None or shape_y.rank is None: 

463 raise ValueError('Unable to broadcast: unknown rank') 

464 broadcast_rank = max(shape_x.rank, shape_y.rank) 

465 shape_x = shape_x.broadcast_to_rank(broadcast_rank) 

466 shape_y = shape_y.broadcast_to_rank(broadcast_rank) 

467 

468 # Broadcast dimensions one at a time, starting from the outermost dimension. 

469 for axis in range(broadcast_rank): 

470 shape_x = shape_x.broadcast_dimension(axis, shape_y.dimension_size(axis)) 

471 shape_y = shape_y.broadcast_dimension(axis, shape_x.dimension_size(axis)) 

472 

473 return shape_x 

474 

475 

476def broadcast_to(rt_input, shape, broadcast_inner_dimensions=True): 

477 """Broadcasts a potentially ragged tensor to a ragged shape. 

478 

479 Tiles `rt_input` as necessary to match the given shape. 

480 

481 Behavior is undefined if `rt_input` is not broadcast-compatible with `shape`. 

482 

483 Args: 

484 rt_input: The potentially ragged tensor to broadcast. 

485 shape: A `RaggedTensorDynamicShape` 

486 broadcast_inner_dimensions: If false, then inner dimensions will not be 

487 tiled. 

488 

489 Returns: 

490 A potentially ragged tensor whose values are taken from 

491 `rt_input`, and whose shape matches `shape`. 

492 """ 

493 if not isinstance(shape, RaggedTensorDynamicShape): 

494 raise TypeError('shape must be a RaggedTensorDynamicShape') 

495 rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt_input) 

496 

497 # Broadcasting to a uniform shape. 

498 if shape.num_partitioned_dimensions == 0: 

499 return _broadcast_to_uniform_shape(rt_input, shape, 

500 broadcast_inner_dimensions) 

501 else: 

502 return _broadcast_to_ragged_shape(rt_input, shape, 

503 broadcast_inner_dimensions) 

504 

505 

506def _broadcast_to_uniform_shape(rt_input, shape, broadcast_inner_dimensions): 

507 """Broadcasts rt_input to the uniform shape `shape`.""" 

508 if isinstance(rt_input, ragged_tensor.RaggedTensor): 

509 raise ValueError('Incompatible with shape: ragged rank mismatch') 

510 if broadcast_inner_dimensions: 

511 return array_ops.broadcast_to(rt_input, shape.inner_dim_sizes) 

512 else: 

513 return rt_input 

514 

515 

516def _broadcast_to_ragged_shape(rt_input, dst_shape, broadcast_inner_dimensions): 

517 """Broadcasts rt_input to the ragged shape `dst_shape`.""" 

518 # Check that rt_input and dst_shape have the same row_splits dtype. 

519 if (isinstance(rt_input, ragged_tensor.RaggedTensor) and 

520 rt_input.row_splits.dtype != dst_shape.dim_size_dtype): 

521 if not ragged_config.auto_cast_partition_dtype(): 

522 raise ValueError('rt_input and dst_shape have different row_split ' 

523 'dtypes; use RaggedTensor.with_row_splits_dtype() or ' 

524 'RaggedTensorDynamicShape.with_dim_size_dtype() to ' 

525 'convert to a compatible dtype.') 

526 rt_input = rt_input.with_row_splits_dtype(dtypes.int64) 

527 dst_shape = dst_shape.with_dim_size_dtype(dtypes.int64) 

528 

529 # dst_shape's rank and ragged_rank must be greater than or equal to rt_input's 

530 if rt_input.shape.ndims is None or dst_shape.rank is None: 

531 raise ValueError('Unable to broadcast: unknown rank') 

532 if rt_input.shape.ndims > dst_shape.rank: 

533 raise ValueError('Incompatible with shape: rank mismatch') 

534 if (isinstance(rt_input, ragged_tensor.RaggedTensor) and 

535 rt_input.ragged_rank >= dst_shape.num_partitioned_dimensions): 

536 raise ValueError('Incompatible with shape: ragged rank mismatch') 

537 

538 src_shape = RaggedTensorDynamicShape.from_tensor(rt_input) 

539 src_shape = src_shape.broadcast_to_rank(dst_shape.rank) 

540 

541 # Add dimensions to rt_input so its rank and ragged_rank matches dst_shape. 

542 if dst_shape.rank > rt_input.shape.ndims: 

543 if rt_input.shape.ndims < dst_shape.num_inner_dimensions + 1: 

544 rt_input = array_ops.reshape( 

545 rt_input, array_ops.concat([[-1], dst_shape.inner_dim_sizes], axis=0)) 

546 for _ in range(dst_shape.rank - rt_input.shape.ndims): 

547 if ragged_tensor.is_ragged(rt_input): 

548 nrows = rt_input.nrows() 

549 else: 

550 nrows = array_ops.shape(rt_input, 

551 out_type=dst_shape.dim_size_dtype)[0] 

552 rt_input = ragged_tensor.RaggedTensor.from_row_lengths(rt_input, [nrows], 

553 validate=False) 

554 

555 # Add ragged dimensions to match dst_shape. 

556 if ragged_tensor.is_ragged(rt_input): 

557 inner_rank_diff = ( 

558 rt_input.flat_values.shape.ndims - 1 - dst_shape.num_inner_dimensions) 

559 if inner_rank_diff > 0: 

560 rt_input = rt_input.with_flat_values( 

561 ragged_tensor.RaggedTensor.from_tensor( 

562 rt_input.flat_values, ragged_rank=inner_rank_diff, 

563 row_splits_dtype=dst_shape.dim_size_dtype)) 

564 else: 

565 rt_input = ragged_tensor.RaggedTensor.from_tensor( 

566 rt_input, ragged_rank=dst_shape.num_partitioned_dimensions - 1, 

567 row_splits_dtype=dst_shape.dim_size_dtype) 

568 

569 # Do broadcasting for any dimensions that will remain uniform. We can do 

570 # these all at once, since they're independent of one another. 

571 multiples = [1] * dst_shape.rank 

572 for axis in range(dst_shape.num_partitioned_dimensions): 

573 if not src_shape.is_ragged(axis) and not dst_shape.is_ragged(axis): 

574 src_size = src_shape.dimension_size(axis) 

575 dst_size = dst_shape.dimension_size(axis) 

576 if ((tensor_util.constant_value(src_size) in (1, None)) and 

577 (tensor_util.constant_value(dst_size) != 1)): 

578 multiples[axis] = array_ops.where( 

579 math_ops.equal(src_size, 1), dst_size, 1) 

580 if not all(isinstance(v, int) and v == 1 for v in multiples): 

581 multiples = array_ops_stack.stack(multiples, axis=0) 

582 rt_input = ragged_array_ops.tile(rt_input, multiples) 

583 

584 if broadcast_inner_dimensions: 

585 new_shape = array_ops.broadcast_dynamic_shape( 

586 array_ops.shape( 

587 rt_input.flat_values, out_type=dst_shape.dim_size_dtype), 

588 array_ops.concat([[1], dst_shape.inner_dim_sizes], axis=0)) 

589 rt_input = rt_input.with_flat_values( 

590 array_ops.broadcast_to(rt_input.flat_values, new_shape)) 

591 

592 # Do broadcasting for dimensions that become ragged. We must do these from 

593 # outermost to innermost. 

594 for axis in range(dst_shape.num_partitioned_dimensions): 

595 if not src_shape.is_ragged(axis) and dst_shape.is_ragged(axis): 

596 dst_size = dst_shape.dimension_size(axis) 

597 rt_input = _ragged_tile_axis(rt_input, axis, dst_size, 

598 dst_shape.dim_size_dtype) 

599 

600 return rt_input 

601 

602 

603def _ragged_tile_axis(rt_input, axis, repeats, row_splits_dtype): 

604 """Tile a dimension of a RaggedTensor to match a ragged shape.""" 

605 assert axis > 0 # Outermost dimension may not be ragged. 

606 

607 if not ragged_tensor.is_ragged(rt_input): 

608 rt_input = ragged_tensor.RaggedTensor.from_tensor( 

609 rt_input, ragged_rank=1, row_splits_dtype=row_splits_dtype) 

610 

611 if axis > 1: 

612 return rt_input.with_values( 

613 _ragged_tile_axis(rt_input.values, axis - 1, repeats, 

614 row_splits_dtype)) 

615 else: 

616 src_row_splits = rt_input.nested_row_splits 

617 src_row_lengths = rt_input.nested_row_lengths() 

618 splits = src_row_splits[0] 

619 

620 dst_row_lengths = [repeats] 

621 for i in range(1, len(src_row_lengths)): 

622 dst_row_lengths.append( 

623 ragged_util.repeat_ranges(src_row_lengths[i], splits, repeats)) 

624 splits = array_ops.gather(src_row_splits[i], splits) 

625 dst_values = ragged_util.repeat_ranges(rt_input.flat_values, splits, 

626 repeats) 

627 return ragged_tensor.RaggedTensor.from_nested_row_lengths( 

628 dst_values, dst_row_lengths, validate=False)