Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ragged/ragged_math_ops.py: 28%

386 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Support for ragged tensors.""" 

16 

17import functools 

18import typing 

19 

20import numpy as np 

21 

22from tensorflow.python.framework import dtypes 

23from tensorflow.python.framework import errors 

24from tensorflow.python.framework import ops 

25from tensorflow.python.framework import tensor_util 

26from tensorflow.python.ops import array_ops 

27from tensorflow.python.ops import check_ops 

28from tensorflow.python.ops import gen_ragged_math_ops 

29from tensorflow.python.ops import map_fn 

30from tensorflow.python.ops import math_ops 

31from tensorflow.python.ops import nn_ops 

32from tensorflow.python.ops.ragged import ragged_functional_ops 

33from tensorflow.python.ops.ragged import ragged_tensor 

34from tensorflow.python.ops.ragged import segment_id_ops 

35from tensorflow.python.util import dispatch 

36from tensorflow.python.util.tf_export import tf_export 

37 

38 

39#=============================================================================== 

40# ragged.range 

41#=============================================================================== 

42# pylint: disable=redefined-builtin 

43@tf_export('ragged.range') 

44@dispatch.add_dispatch_support 

45def range(starts, 

46 limits=None, 

47 deltas=1, 

48 dtype=None, 

49 name=None, 

50 row_splits_dtype=dtypes.int64): 

51 """Returns a `RaggedTensor` containing the specified sequences of numbers. 

52 

53 Each row of the returned `RaggedTensor` contains a single sequence: 

54 

55 ```python 

56 ragged.range(starts, limits, deltas)[i] == 

57 tf.range(starts[i], limits[i], deltas[i]) 

58 ``` 

59 

60 If `start[i] < limits[i] and deltas[i] > 0`, then `output[i]` will be an 

61 empty list. Similarly, if `start[i] > limits[i] and deltas[i] < 0`, then 

62 `output[i]` will be an empty list. This behavior is consistent with the 

63 Python `range` function, but differs from the `tf.range` op, which returns 

64 an error for these cases. 

65 

66 Examples: 

67 

68 >>> tf.ragged.range([3, 5, 2]).to_list() 

69 [[0, 1, 2], [0, 1, 2, 3, 4], [0, 1]] 

70 >>> tf.ragged.range([0, 5, 8], [3, 3, 12]).to_list() 

71 [[0, 1, 2], [], [8, 9, 10, 11]] 

72 >>> tf.ragged.range([0, 5, 8], [3, 3, 12], 2).to_list() 

73 [[0, 2], [], [8, 10]] 

74 

75 The input tensors `starts`, `limits`, and `deltas` may be scalars or vectors. 

76 The vector inputs must all have the same size. Scalar inputs are broadcast 

77 to match the size of the vector inputs. 

78 

79 Args: 

80 starts: Vector or scalar `Tensor`. Specifies the first entry for each range 

81 if `limits` is not `None`; otherwise, specifies the range limits, and the 

82 first entries default to `0`. 

83 limits: Vector or scalar `Tensor`. Specifies the exclusive upper limits for 

84 each range. 

85 deltas: Vector or scalar `Tensor`. Specifies the increment for each range. 

86 Defaults to `1`. 

87 dtype: The type of the elements of the resulting tensor. If not specified, 

88 then a value is chosen based on the other args. 

89 name: A name for the operation. 

90 row_splits_dtype: `dtype` for the returned `RaggedTensor`'s `row_splits` 

91 tensor. One of `tf.int32` or `tf.int64`. 

92 

93 Returns: 

94 A `RaggedTensor` of type `dtype` with `ragged_rank=1`. 

95 """ 

96 row_splits_dtype = dtypes.as_dtype(row_splits_dtype) 

97 if limits is None: 

98 starts, limits = 0, starts 

99 

100 with ops.name_scope(name, 'RaggedRange', [starts, limits, deltas]) as name: 

101 starts = ops.convert_to_tensor(starts, dtype=dtype, name='starts') 

102 limits = ops.convert_to_tensor(limits, dtype=dtype, name='limits') 

103 deltas = ops.convert_to_tensor(deltas, dtype=dtype, name='deltas') 

104 

105 # infer dtype if not explicitly provided 

106 if dtype is None: 

107 starts, limits, deltas = _infer_matching_dtype( 

108 [starts, limits, deltas], 

109 [dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64]) 

110 

111 result = gen_ragged_math_ops.ragged_range( 

112 starts, limits, deltas, Tsplits=row_splits_dtype, name=name) 

113 return ragged_tensor.RaggedTensor.from_row_splits( 

114 result.rt_dense_values, result.rt_nested_splits, validate=False) 

115 

116 

117def _infer_matching_dtype(tensors, dtype_hierarchy): 

118 """Infers a matching dtype for tensors, and casts them to that dtype.""" 

119 assert all(t.dtype in dtype_hierarchy for t in tensors) 

120 inferred_dtype = max([t.dtype for t in tensors], key=dtype_hierarchy.index) 

121 return [math_ops.cast(t, inferred_dtype) for t in tensors] 

122 

123 

124ops.no_gradient('RaggedRange') 

125 

126#=============================================================================== 

127# ragged_segment_<AGGREGATE> 

128#=============================================================================== 

129 

130# Docstring template used for the raggged_segment_<AGGREGATE> ops. 

131_RAGGED_SEGMENT_DOCSTRING = """\ 

132Computes the %(combination)s along segments of a RaggedTensor. 

133 

134 Returns a RaggedTensor `output` with `num_segments` rows, where the row 

135 `output[i]` is formed by taking the %(combination)s of all rows of `data` 

136 whose corresponding `segment_id` is `i`. 

137 

138 The length of the row `output[i]` will be the maximum of the lengths of 

139 all rows of `data` whose corresponding `segment_id` is `i`. If no `data` 

140 rows correspond to a given segment ID, then the output row for that segment 

141 ID will be empty. 

142 

143 Args: 

144 data: A `RaggedTensor` containing the values to combine. 

145 segment_ids: A `Tensor` or `RaggedTensor`. Must have type `int64` or 

146 `int32`. `segment_ids.shape` must be a prefix of `data.shape`. 

147 Must be greater than or equal to zero, and less than `num_segments`. 

148 `segment_ids` is not required to be sorted. 

149 num_segments: An `int32` or `int64` scalar specifying the number of 

150 distinct segment ids. 

151 name: A name prefix for the returned tensor (optional). 

152 Returns: 

153 A `RaggedTensor` containing the %(combined)s values. The returned tensor 

154 has the same dtype as `data`, and its shape is 

155 `[num_segments] + data.shape[segment_ids.rank:]`. 

156 Raises: 

157 ValueError: If `segment_ids.shape` is not a prefix of `data.shape`. 

158""" 

159 

160 

161def _ragged_segment_aggregate(unsorted_segment_op, 

162 data, 

163 segment_ids, 

164 num_segments, 

165 separator=None, 

166 name=None): 

167 """Aggregates along segments of a RaggedTensor using `unsorted_segment_op`. 

168 

169 Returns a RaggedTensor `output` with `num_segments` rows, where the row 

170 `output[i]` is formed by combining all rows of `data` whose corresponding 

171 `segment_id` is `i`. The values in each row are combined using 

172 `unsorted_segment_op`. 

173 

174 The length of the row `output[i]` will be the maximum of the lengths of 

175 all rows of `data` whose corresponding `segment_id` is `i`. If no `data` 

176 rows correspond to a given segment ID, then the output row for that segment 

177 ID will be empty. 

178 

179 Args: 

180 unsorted_segment_op: The tensorflow `op` that should be used to combine 

181 values in each row. Must have the same signature and basic behavior as 

182 `unsorted_segment_sum`, `unsorted_segment_max`, etc. 

183 data: A `RaggedTensor` containing the values to be combined. 

184 segment_ids: A `Tensor` or `RaggedTensor`. Must have type `int64` or 

185 `int32`. `segment_ids.shape` must be a prefix of `data.shape`. 

186 `segment_ids` is not required to be sorted. 

187 num_segments: An `int32` or `int64` scalar. 

188 separator: An optional string. Defaults to None. The separator to use when 

189 joining. Only used for string types. 

190 name: A name prefix for the returned tensor (optional). 

191 

192 Returns: 

193 A `RaggedTensor` containing the aggregated values. The returned tensor 

194 has the same dtype as `data`, and its shape is 

195 `[num_segments] + data.shape[segment_ids.rank:]`. 

196 Raises: 

197 ValueError: If segment_ids.shape is not a prefix of data.shape. 

198 """ 

199 if not (ragged_tensor.is_ragged(data) or 

200 ragged_tensor.is_ragged(segment_ids)): 

201 if separator is not None: 

202 # It uses unsorted_segment_join. 

203 return unsorted_segment_op(data, segment_ids, num_segments, separator, 

204 name) 

205 else: 

206 return unsorted_segment_op(data, segment_ids, num_segments, name) 

207 

208 with ops.name_scope(name, 'RaggedSegment', 

209 [data, segment_ids, num_segments]) as name: 

210 data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data') 

211 segment_ids = ragged_tensor.convert_to_tensor_or_ragged_tensor( 

212 segment_ids, name='segment_ids') 

213 data, segment_ids = ragged_tensor.match_row_splits_dtypes(data, segment_ids) 

214 if segment_ids.dtype not in (dtypes.int32, dtypes.int64): 

215 raise ValueError('segment_ids must have dtype int32 or int64.') 

216 

217 if ragged_tensor.is_ragged(segment_ids): 

218 if not ragged_tensor.is_ragged(data): 

219 raise ValueError('segment_ids.shape must be a prefix of data.shape, ' 

220 'but segment_ids is ragged and data is not.') 

221 check_splits = check_ops.assert_equal( 

222 segment_ids.row_splits, 

223 data.row_splits, 

224 message='segment_ids.shape must be a prefix of data.shape') 

225 with ops.control_dependencies([check_splits]): 

226 return _ragged_segment_aggregate(unsorted_segment_op, data.values, 

227 segment_ids.values, num_segments, 

228 separator) 

229 

230 # Find the length of each row in data. (shape=[data_nrows]) 

231 data_row_lengths = data.row_splits[1:] - data.row_splits[:-1] 

232 

233 # Find the length that each output row will have. The length of the row 

234 # corresponding to segment `id` is `max(data_row_lengths[i])` where 

235 # `segment_ids[i]=id`. (shape=[output_nrows]) 

236 output_row_lengths = math_ops.maximum( 

237 math_ops.unsorted_segment_max(data_row_lengths, segment_ids, 

238 num_segments), 0) 

239 

240 # Build the splits tensor for the output RaggedTensor. 

241 output_splits = array_ops.concat([ 

242 array_ops.zeros([1], output_row_lengths.dtype), 

243 math_ops.cumsum(output_row_lengths) 

244 ], 

245 axis=0) 

246 

247 # For each row in `data`, find the start & limit position where that row's 

248 # values will be aggregated in output.values. 

249 data_row_to_out_row_start = array_ops.gather(output_splits, segment_ids) 

250 data_row_to_out_row_limit = data_row_to_out_row_start + data_row_lengths 

251 

252 # For each value in `data.values`, find the position where it will 

253 # aggregated in `output.values`. 

254 # Get the target output values index for each data values index. 

255 data_val_to_out_val_index = range(data_row_to_out_row_start, 

256 data_row_to_out_row_limit).values 

257 

258 # Recursively aggregate the values. 

259 output_values = _ragged_segment_aggregate(unsorted_segment_op, data.values, 

260 data_val_to_out_val_index, 

261 output_splits[-1], separator) 

262 return ragged_tensor.RaggedTensor.from_row_splits( 

263 output_values, output_splits, validate=False) 

264 

265 

266@dispatch.dispatch_for_api(math_ops.unsorted_segment_sum) 

267def segment_sum(data: ragged_tensor.RaggedOrDense, 

268 segment_ids: ragged_tensor.RaggedOrDense, 

269 num_segments, 

270 name=None): 

271 # For docs, see: _RAGGED_SEGMENT_DOCSTRING 

272 return _ragged_segment_aggregate( 

273 math_ops.unsorted_segment_sum, 

274 data=data, 

275 segment_ids=segment_ids, 

276 num_segments=num_segments, 

277 name=(name or 'RaggedSegmentSum')) 

278 

279 

280@dispatch.dispatch_for_api(math_ops.unsorted_segment_prod) 

281def segment_prod(data: ragged_tensor.RaggedOrDense, 

282 segment_ids: ragged_tensor.RaggedOrDense, 

283 num_segments, 

284 name=None): 

285 # For docs, see: _RAGGED_SEGMENT_DOCSTRING 

286 return _ragged_segment_aggregate( 

287 math_ops.unsorted_segment_prod, 

288 data=data, 

289 segment_ids=segment_ids, 

290 num_segments=num_segments, 

291 name=(name or 'RaggedSegmentProd')) 

292 

293 

294@dispatch.dispatch_for_api(math_ops.unsorted_segment_min) 

295def segment_min(data: ragged_tensor.RaggedOrDense, 

296 segment_ids: ragged_tensor.RaggedOrDense, 

297 num_segments, 

298 name=None): 

299 # For docs, see: _RAGGED_SEGMENT_DOCSTRING 

300 return _ragged_segment_aggregate( 

301 math_ops.unsorted_segment_min, 

302 data=data, 

303 segment_ids=segment_ids, 

304 num_segments=num_segments, 

305 name=(name or 'RaggedSegmentMin')) 

306 

307 

308@dispatch.dispatch_for_api(math_ops.unsorted_segment_max) 

309def segment_max(data: ragged_tensor.RaggedOrDense, 

310 segment_ids: ragged_tensor.RaggedOrDense, 

311 num_segments, 

312 name=None): 

313 # For docs, see: _RAGGED_SEGMENT_DOCSTRING 

314 return _ragged_segment_aggregate( 

315 math_ops.unsorted_segment_max, 

316 data=data, 

317 segment_ids=segment_ids, 

318 num_segments=num_segments, 

319 name=(name or 'RaggedSegmentMax')) 

320 

321 

322@dispatch.dispatch_for_api(math_ops.unsorted_segment_mean) 

323def segment_mean(data: ragged_tensor.RaggedOrDense, 

324 segment_ids: ragged_tensor.RaggedOrDense, 

325 num_segments, 

326 name=None): 

327 """For docs, see: _RAGGED_SEGMENT_DOCSTRING.""" 

328 with ops.name_scope(name, 'RaggedSegmentMean', 

329 [data, segment_ids, num_segments]): 

330 total = segment_sum(data, segment_ids, num_segments) 

331 ones = ragged_tensor.RaggedTensor.from_nested_row_splits( 

332 array_ops.ones_like(data.flat_values), 

333 data.nested_row_splits, 

334 validate=False) 

335 count = segment_sum(ones, segment_ids, num_segments) 

336 if ragged_tensor.is_ragged(total): 

337 return total.with_flat_values(total.flat_values / count.flat_values) 

338 else: 

339 return total / count 

340 

341 

342@dispatch.dispatch_for_api(math_ops.unsorted_segment_sqrt_n) 

343def segment_sqrt_n(data: ragged_tensor.RaggedOrDense, 

344 segment_ids: ragged_tensor.RaggedOrDense, 

345 num_segments, 

346 name=None): 

347 """For docs, see: _RAGGED_SEGMENT_DOCSTRING.""" 

348 with ops.name_scope(name, 'RaggedSegmentSqrtN', 

349 [data, segment_ids, num_segments]): 

350 total = segment_sum(data, segment_ids, num_segments) 

351 ones = ragged_tensor.RaggedTensor.from_nested_row_splits( 

352 array_ops.ones_like(data.flat_values), 

353 data.nested_row_splits, 

354 validate=False) 

355 count = segment_sum(ones, segment_ids, num_segments) 

356 if ragged_tensor.is_ragged(total): 

357 return total.with_flat_values(total.flat_values / 

358 math_ops.sqrt(count.flat_values)) 

359 else: 

360 return total / math_ops.sqrt(count) 

361 

362 

363def _set_ragged_segment_docstring(func, combination, combined): 

364 func.__doc__ = _RAGGED_SEGMENT_DOCSTRING % dict( 

365 combination=combination, combined=combined) 

366 

367 

368_set_ragged_segment_docstring(segment_sum, 'sum', 'summed') 

369_set_ragged_segment_docstring(segment_prod, 'product', 'multiplied') 

370_set_ragged_segment_docstring(segment_min, 'minimum', 'minimized') 

371_set_ragged_segment_docstring(segment_max, 'maximum', 'maximized') 

372_set_ragged_segment_docstring(segment_mean, 'mean', 'averaged') 

373_set_ragged_segment_docstring(segment_sqrt_n, 'sum divided by sqrt(N)', 

374 'summed') 

375 

376#=============================================================================== 

377# ragged_reduce_<AGGREGATE> 

378#=============================================================================== 

379 

380# Docstring template used for ragged_reduce_<AGGREGATE> ops. 

381_RAGGED_REDUCE_DOCSTRING = """\ 

382Computes the %(combination)s of elements across dimensions of a `RaggedTensor`. 

383 

384 Reduces `input_tensor` along the dimensions given in `axis` by taking the 

385 %(combination)s of values. If a reduced dimension has no elements for 

386 some index, then the value for that index will be %(default)s. 

387 

388 The rank of the tensor is reduced by `1` for each entry in `axis`. If 

389 `axis` is not specified, then all dimensions are reduced, and a scalar 

390 value is returned. 

391 Args: 

392 input_tensor: A `RaggedTensor` containing the values to be %(combined)s. 

393 axis: The dimensions to reduce. May be `None` (to reduce all axes), an 

394 `int` (to reduce a single axis), a `list` or `tuple` of `int` (to reduce 

395 a given set of axes), or a `Tensor` with a constant value. Must be in 

396 the range `[0, input_tensor.rank]`. 

397 name: A name prefix for the returned tensor (optional). 

398 Returns: 

399 A `RaggedTensor` containing the %(combined)s values. The returned tensor 

400 has the same dtype as `data`, and its shape is given by removing the 

401 dimensions specified in `axis` from `input_tensor.shape`. The `ragged_rank` 

402 of the returned tensor is given by substracting any ragged dimensions 

403 specified in `axis` from `input_tensor.ragged_rank`. 

404 Raises: 

405 ValueError: If `axis` contains a `Tensor` whose value is not constant. 

406 ####Example: 

407 %(example)s 

408""" 

409_RAGGED_REDUCE_SUM_EXAMPLE = """ 

410 >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]]) 

411 >>> tf.reduce_sum(rt, axis=0).numpy() # = [3+1+9+2, 1+5+6, 4] 

412 array([15, 12, 4], dtype=int32) 

413 >>> tf.reduce_sum(rt, axis=1).numpy() # = [3+1+4, 1+5, 9, 2+6] 

414 array([8, 6, 9, 8], dtype=int32) 

415""" 

416_RAGGED_REDUCE_PROD_EXAMPLE = """ 

417 >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]]) 

418 >>> tf.reduce_prod(rt, axis=0).numpy() # = [3*1*9*2, 1*5*6, 4] 

419 array([54, 30, 4], dtype=int32) 

420 >>> tf.reduce_prod(rt, axis=1).numpy() # = [3*1*4, 1*5, 9, 2*6] 

421 array([12, 5, 9, 12], dtype=int32) 

422""" 

423_RAGGED_REDUCE_MIN_EXAMPLE = """ 

424 >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]]) 

425 >>> tf.reduce_min(rt, axis=0).numpy() 

426 array([1, 1, 4], dtype=int32) 

427 >>> tf.reduce_min(rt, axis=1).numpy() 

428 array([1, 1, 9, 2], dtype=int32) 

429""" 

430_RAGGED_REDUCE_MAX_EXAMPLE = """ 

431 >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]]) 

432 >>> tf.reduce_max(rt, axis=0).numpy() 

433 array([9, 6, 4], dtype=int32) 

434 >>> tf.reduce_max(rt, axis=1).numpy() 

435 array([4, 5, 9, 6], dtype=int32) 

436""" 

437_RAGGED_REDUCE_MEAN_EXAMPLE = """ 

438 >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]]) 

439 >>> tf.reduce_mean(rt, axis=0).numpy() 

440 array([3.75, 4. , 4. ]) 

441 >>> tf.reduce_mean(rt, axis=1).numpy() 

442 array([2.66666667, 3. , 9. , 4. ]) 

443""" 

444_RAGGED_REDUCE_VARIANCE_EXAMPLE = """ 

445 >>> rt = tf.ragged.constant([[1, 1, 4], [2, 1], [3], [4, 1]], 

446 ... dtype=tf.float64) 

447 >>> tf.math.reduce_variance(rt, axis=0).numpy() 

448 array([1.25, 0., 0.]) 

449 >>> tf.math.reduce_variance(rt, axis=1).numpy() 

450 array([2., 0.25, 0., 2.25]) 

451""" 

452_RAGGED_REDUCE_STD_EXAMPLE = """ 

453 >>> rt = tf.ragged.constant([[1, 0], [2, 1], [3], [4, 1]], 

454 ... dtype=tf.float64) 

455 >>> tf.math.reduce_std(rt, axis=0).numpy() 

456 array([1.11803399, 0.47140452]) 

457 >>> tf.math.reduce_std(rt, axis=1).numpy() 

458 array([0.5, 0.5, 0., 1.5]) 

459""" 

460_RAGGED_REDUCE_ALL_EXAMPLE = """ 

461 >>> rt = tf.ragged.constant([[True, True], [True, True, False, True], [False, True]]) 

462 >>> tf.reduce_all(rt, axis=0).numpy() 

463 array([False, True, False, True]) 

464 >>> tf.reduce_all(rt, axis=1).numpy() 

465 array([ True, False, False]) 

466""" 

467_RAGGED_REDUCE_ANY_EXAMPLE = """ 

468 >>> rt = tf.ragged.constant([[True, True], [True, True, False, True], [False, True]]) 

469 >>> tf.reduce_any(rt, axis=0).numpy() 

470 array([ True, True, False, True]) 

471 >>> tf.reduce_any(rt, axis=1).numpy() 

472 array([ True, True, True]) 

473""" 

474 

475 

476def ragged_reduce_aggregate(reduce_op, 

477 unsorted_segment_op, 

478 rt_input, 

479 axis, 

480 keepdims, 

481 separator=None, 

482 name=None): 

483 """Aggregates across axes of a RaggedTensor using the given `Tensor` ops. 

484 

485 Reduces `rt_input` along the dimensions given in `axis`. The rank of the 

486 tensor is reduced by 1 for each entry in `axis`. If `axis` is not specified, 

487 then all dimensions are reduced, and a scalar value is returned. 

488 

489 This op assumes that `reduce_op` and `unsorted_segment_op` are associative; 

490 if not, then reducing multiple axes will return incorrect results. (In 

491 particular, reducing multiple axes is currently implemented by reducing the 

492 axes one at a time.) 

493 

494 Args: 

495 reduce_op: The tensorflow `op` that should be used to reduce values in 

496 uniform dimensions. Must have the same signature and basic behavior as 

497 `reduce_sum`, `reduce_max`, etc. 

498 unsorted_segment_op: The tensorflow `op` that should be used to combine 

499 values in ragged dimensions. Must have the same signature and basic 

500 behavior as `unsorted_segment_sum`, `unsorted_segment_max`, etc. 

501 rt_input: A `Tensor` or `RaggedTensor` containing the values to be reduced. 

502 axis: The axis or axes to reduce. May be `None` (to reduce all axes), an 

503 `int` (to reduce a single axis), a `list` or `tuple` of `int` (to reduce a 

504 given set of axes), or a `Tensor` with a constant value. Must be in the 

505 range `[0, rt_input.rank)`. 

506 keepdims: If true, retains reduced dimensions with length 1. 

507 separator: An optional string. Defaults to None. The separator to use when 

508 joining. The separator must not be set for non-string data types. (i.e. if 

509 separator is not None then it uses string ops) 

510 name: A name prefix for the returned tensor (optional). 

511 

512 Returns: 

513 A `RaggedTensor` containing the reduced values. The returned tensor 

514 has the same dtype as `data`, and its shape is given by removing the 

515 dimensions specified in `axis` from `rt_input.shape`. The `ragged_rank` 

516 of the returned tensor is given by substracting any ragged dimensions 

517 specified in `axis` from `rt_input.ragged_rank`. 

518 Raises: 

519 ValueError: If `axis` contains a `Tensor` whose value is not constant. 

520 """ 

521 # When separator is not None, We infer that dtype is string and 

522 # reduce_join will be called. 

523 if separator is None: 

524 maybe_separator = {} 

525 else: 

526 maybe_separator = {'separator': separator} 

527 

528 if not ragged_tensor.is_ragged(rt_input): 

529 return reduce_op( 

530 rt_input, axis, keepdims=keepdims, name=name, **maybe_separator) 

531 

532 if isinstance(axis, ops.Tensor): 

533 axis = tensor_util.constant_value(axis) 

534 if axis is None: 

535 raise ValueError('axis must be known at graph construction time.') 

536 if isinstance(axis, np.ndarray): 

537 axis = axis.tolist() 

538 

539 # When reducing all axes, just ignore splits & reduce the inner values. 

540 if axis is None: 

541 result = reduce_op(rt_input.flat_values, None, keepdims=keepdims, 

542 name=name, **maybe_separator) 

543 if keepdims: 

544 # Expand the result to the input number of dimensions. 

545 for _ in rt_input.shape[1:]: 

546 result = array_ops.expand_dims(result, axis=0) 

547 return result 

548 

549 with ops.name_scope(name, 'RaggedReduce', [rt_input, axis]): 

550 if isinstance(axis, (tuple, list)): 

551 if not axis: 

552 return rt_input 

553 elif len(axis) == 1: 

554 axis = axis[0] 

555 else: 

556 # When reducing multiple axes, as we reduce one at a time (see below), 

557 # the negative axis has to be converted to positive at the first run 

558 # as the sort with negative axis will have different orders. 

559 # See GitHub issue 27497. 

560 axis = [ 

561 array_ops.get_positive_axis(a, rt_input.shape.ndims, 'axis[%s]' % i, 

562 'rank(input_tensor)') 

563 for i, a in enumerate(axis) 

564 ] 

565 # When reducing multiple axes, just reduce one at a time. This is less 

566 # efficient, and only works for associative ops. (In particular, it 

567 # does not work for reduce_mean.) However, reducing multiple axes at 

568 # once will probably require a nontrivial c++ op. 

569 axis = sorted(axis) 

570 inner_reduced = ragged_reduce_aggregate(reduce_op, unsorted_segment_op, 

571 rt_input, axis[-1], keepdims, 

572 separator) 

573 return ragged_reduce_aggregate(reduce_op, unsorted_segment_op, 

574 inner_reduced, axis[:-1], keepdims, 

575 separator) 

576 

577 rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor( 

578 rt_input, name='rt_input') 

579 

580 axis = array_ops.get_positive_axis( 

581 axis, rt_input.shape.ndims, ndims_name='rank(input_tensor)') 

582 

583 if axis == 0: 

584 # out[i_1, i_2, ..., i_N] = sum_{j} rt_input[j, i_1, i_2, ..., i_N] 

585 row_lengths = rt_input.row_splits[1:] - rt_input.row_splits[:-1] 

586 num_segments = math_ops.maximum(math_ops.reduce_max(row_lengths), 0) 

587 segment_ids = range(row_lengths).values 

588 result = _ragged_segment_aggregate(unsorted_segment_op, rt_input.values, 

589 segment_ids, num_segments, separator) 

590 if keepdims: 

591 result = array_ops.expand_dims(result, axis=0) 

592 return result 

593 elif axis == 1: 

594 # out[i_0, i_1, i_2, ..., i_N] = sum_{j} rt_input[i_0, j, i_2, ..., i_N] 

595 num_segments = array_ops.shape(rt_input.row_splits)[0] - 1 

596 segment_ids = segment_id_ops.row_splits_to_segment_ids( 

597 rt_input.row_splits) 

598 result = _ragged_segment_aggregate(unsorted_segment_op, rt_input.values, 

599 segment_ids, num_segments, separator) 

600 if keepdims: 

601 result = array_ops.expand_dims(result, axis=1) 

602 return result 

603 else: 

604 # out[i_0, ..., i_[axis-1], i_axis+1], ..., i_N] = 

605 # sum_{j} rt_input [i_0, ..., i_[axis-1], j, i_axis+1], ..., i_N] 

606 return rt_input.with_values( 

607 ragged_reduce_aggregate(reduce_op, unsorted_segment_op, 

608 rt_input.values, axis - 1, keepdims, 

609 separator)) 

610 

611 

612@dispatch.dispatch_for_api(math_ops.reduce_sum) 

613def reduce_sum(input_tensor: ragged_tensor.Ragged, 

614 axis=None, 

615 keepdims=None, 

616 name=None): 

617 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 

618 

619 return ragged_reduce_aggregate( 

620 reduce_op=math_ops.reduce_sum, 

621 unsorted_segment_op=math_ops.unsorted_segment_sum, 

622 rt_input=input_tensor, 

623 axis=axis, 

624 keepdims=keepdims, 

625 name=(name or 'RaggedReduceSum')) 

626 

627 

628@dispatch.dispatch_for_api(math_ops.reduce_prod) 

629def reduce_prod(input_tensor: ragged_tensor.Ragged, 

630 axis=None, 

631 keepdims=None, 

632 name=None): 

633 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 

634 return ragged_reduce_aggregate( 

635 reduce_op=math_ops.reduce_prod, 

636 unsorted_segment_op=math_ops.unsorted_segment_prod, 

637 rt_input=input_tensor, 

638 axis=axis, 

639 keepdims=keepdims, 

640 name=(name or 'RaggedReduceProd')) 

641 

642 

643@dispatch.dispatch_for_api(math_ops.reduce_min) 

644def reduce_min(input_tensor: ragged_tensor.Ragged, 

645 axis=None, 

646 keepdims=None, 

647 name=None): 

648 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 

649 return ragged_reduce_aggregate( 

650 reduce_op=math_ops.reduce_min, 

651 unsorted_segment_op=math_ops.unsorted_segment_min, 

652 rt_input=input_tensor, 

653 axis=axis, 

654 keepdims=keepdims, 

655 name=(name or 'RaggedReduceMin')) 

656 

657 

658@dispatch.dispatch_for_api(math_ops.reduce_max) 

659def reduce_max(input_tensor: ragged_tensor.Ragged, 

660 axis=None, 

661 keepdims=None, 

662 name=None): 

663 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 

664 return ragged_reduce_aggregate( 

665 reduce_op=math_ops.reduce_max, 

666 unsorted_segment_op=math_ops.unsorted_segment_max, 

667 rt_input=input_tensor, 

668 axis=axis, 

669 keepdims=keepdims, 

670 name=(name or 'RaggedReduceMax')) 

671 

672 

673@dispatch.dispatch_for_api(math_ops.reduce_mean) 

674def reduce_mean(input_tensor: ragged_tensor.Ragged, 

675 axis=None, 

676 keepdims=None, 

677 name=None): 

678 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 

679 with ops.name_scope(name, 'RaggedReduceMean', [input_tensor, axis]): 

680 total = reduce_sum(input_tensor, axis, keepdims) 

681 if ragged_tensor.is_ragged(input_tensor): 

682 ones = ragged_tensor.RaggedTensor.from_nested_row_splits( 

683 array_ops.ones_like(input_tensor.flat_values), 

684 input_tensor.nested_row_splits, 

685 validate=False) 

686 else: 

687 ones = array_ops.ones_like(input_tensor) 

688 count = reduce_sum(ones, axis, keepdims) 

689 if ragged_tensor.is_ragged(total): 

690 return ragged_tensor.RaggedTensor.from_nested_row_splits( 

691 total.flat_values / count.flat_values, 

692 total.nested_row_splits, 

693 validate=False) 

694 else: 

695 return total / count 

696 

697 

698@dispatch.dispatch_for_api(math_ops.reduce_variance) 

699def reduce_variance(input_tensor: ragged_tensor.Ragged, 

700 axis=None, 

701 keepdims=False, 

702 name=None): 

703 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 

704 with ops.name_scope(name, 'RaggedReduceVariance', [input_tensor, axis]): 

705 input_tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor( 

706 input_tensor, name='input_tensor') 

707 if input_tensor.dtype.is_complex: 

708 raise ValueError( 

709 'reduce_variance is not supported for RaggedTensors with complex dtypes.' 

710 ) 

711 square_of_input = math_ops.square(input_tensor) 

712 mean_of_square = reduce_mean(square_of_input, axis=axis, keepdims=keepdims) 

713 mean = reduce_mean(input_tensor, axis=axis, keepdims=keepdims) 

714 square_of_mean = math_ops.square(mean) 

715 # Note: the above method of computing variance is not numerically stable, 

716 # and can result in negative variances. Here we clip to >= 0. 

717 return math_ops.maximum(mean_of_square - square_of_mean, 0) 

718 

719 

720@dispatch.dispatch_for_api(math_ops.reduce_std) 

721def reduce_std(input_tensor: ragged_tensor.Ragged, 

722 axis=None, 

723 keepdims=False, 

724 name=None): 

725 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 

726 with ops.name_scope(name, 'RaggedReduceStd', [input_tensor, axis]): 

727 variance = reduce_variance(input_tensor, axis=axis, keepdims=keepdims) 

728 return math_ops.sqrt(variance) 

729 

730 

731def _cast(input_tensor, dtype): 

732 return ragged_functional_ops.map_flat_values(math_ops.cast, input_tensor, 

733 dtype) 

734 

735 

736@dispatch.dispatch_for_api(math_ops.reduce_all) 

737def reduce_all(input_tensor: ragged_tensor.Ragged, 

738 axis=None, 

739 keepdims=None, 

740 name=None): 

741 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 

742 with ops.name_scope(name, 'RaggedReduceAll', [input_tensor, axis]): 

743 return _cast( 

744 reduce_prod(_cast(input_tensor, dtypes.int32), axis, keepdims), 

745 dtypes.bool) 

746 

747 

748@dispatch.dispatch_for_api(math_ops.reduce_any) 

749def reduce_any(input_tensor: ragged_tensor.Ragged, 

750 axis=None, 

751 keepdims=None, 

752 name=None): 

753 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 

754 with ops.name_scope(name, 'RaggedReduceAny', [input_tensor, axis]): 

755 return _cast( 

756 reduce_sum(_cast(input_tensor, dtypes.int32), axis, keepdims), 

757 dtypes.bool) 

758 

759 

760def _set_ragged_reduce_docstring(func, combination, combined, default, example): 

761 func.__doc__ = _RAGGED_REDUCE_DOCSTRING % dict( 

762 combination=combination, 

763 combined=combined, 

764 default=default, 

765 example=example) 

766 

767 

768_set_ragged_reduce_docstring(reduce_sum, 'sum', 'summed', '0', 

769 _RAGGED_REDUCE_SUM_EXAMPLE) 

770_set_ragged_reduce_docstring(reduce_prod, 'product', 'multiplied', '1', 

771 _RAGGED_REDUCE_PROD_EXAMPLE) 

772_set_ragged_reduce_docstring(reduce_min, 'minimum', 'minimized', 

773 '`input_tensor.dtype.min`', 

774 _RAGGED_REDUCE_MIN_EXAMPLE) 

775_set_ragged_reduce_docstring(reduce_max, 'maximum', 'maximized', 

776 '`input_tensor.dtype.max`', 

777 _RAGGED_REDUCE_MAX_EXAMPLE) 

778_set_ragged_reduce_docstring(reduce_mean, 'mean', 'averaged', 'NaN', 

779 _RAGGED_REDUCE_MEAN_EXAMPLE) 

780_set_ragged_reduce_docstring(reduce_variance, 'variance', 'averaged', 'NaN', 

781 _RAGGED_REDUCE_VARIANCE_EXAMPLE) 

782_set_ragged_reduce_docstring(reduce_std, 'std', 'averaged', 'NaN', 

783 _RAGGED_REDUCE_STD_EXAMPLE) 

784_set_ragged_reduce_docstring(reduce_all, 'logical and', 'and-ed', 'True', 

785 _RAGGED_REDUCE_ALL_EXAMPLE) 

786_set_ragged_reduce_docstring(reduce_any, 'logical or', 'or-ed', 'False', 

787 _RAGGED_REDUCE_ANY_EXAMPLE) 

788 

789 

790#=============================================================================== 

791# ragged.matmul 

792#=============================================================================== 

793@dispatch.dispatch_for_api(math_ops.matmul) 

794def matmul(a: ragged_tensor.RaggedOrDense, 

795 b: ragged_tensor.RaggedOrDense, 

796 transpose_a=False, 

797 transpose_b=False, 

798 adjoint_a=False, 

799 adjoint_b=False, 

800 a_is_sparse=False, 

801 b_is_sparse=False, 

802 output_type=None, 

803 name=None): 

804 """Multiplies matrix `a` by matrix `b`. 

805 

806 If all transpose or adjoint attributes are `False` then: 

807 

808 ``` 

809 output[..., i, j] = sum_k (a[..., i, k] * b[..., k, j]), for all indices i, j. 

810 ``` 

811 

812 The inputs `a` and `b` must have `rank >= 2`, where the outermost `rank - 2` 

813 dimensions are batch dimensions. The inputs must have the same dtype. See 

814 `tf.matmul` for more information. 

815 

816 Args: 

817 a: `tf.Tensor` or `RaggedTensor` with `rank > 1`. 

818 b: `tf.Tensor` or `RaggedTensor` with same type and rank as `a`. 

819 transpose_a: If `True`, `a` is transposed before multiplication. 

820 transpose_b: If `True`, `b` is transposed before multiplication. 

821 adjoint_a: If `True`, `a` is conjugated & transposed before multiplication. 

822 adjoint_b: If `True`, `b` is conjugated & transposed before multiplication. 

823 a_is_sparse: If `True`, optimize assuming `a` is mostly zero. 

824 b_is_sparse: If `True`, optimize assuming `b` is mostly zero. 

825 output_type: The output datatype (optional). 

826 name: Name for the operation (optional). 

827 

828 Returns: 

829 A `Tensor` or `RaggedTensor` with the same rank and shape as `a`, where 

830 each inner-most matrix is the product of the corresponding matrices in `a` 

831 and `b`. 

832 """ 

833 if transpose_a and adjoint_a: 

834 raise ValueError('Only one of transpose_a and adjoint_a can be True.') 

835 if transpose_b and adjoint_b: 

836 raise ValueError('Only one of transpose_b and adjoint_b can be True.') 

837 

838 kwargs = dict( 

839 transpose_a=transpose_a, 

840 transpose_b=transpose_b, 

841 adjoint_a=adjoint_a, 

842 adjoint_b=adjoint_b, 

843 a_is_sparse=a_is_sparse, 

844 b_is_sparse=b_is_sparse, 

845 output_type=output_type) 

846 

847 with ops.name_scope(name, 'RaggedMatMul', [a, b]) as name: 

848 a = ragged_tensor.convert_to_tensor_or_ragged_tensor(a, name='a') 

849 b = ragged_tensor.convert_to_tensor_or_ragged_tensor(b, name='b') 

850 

851 a_is_ragged = isinstance(a, ragged_tensor.RaggedTensor) 

852 b_is_ragged = isinstance(b, ragged_tensor.RaggedTensor) 

853 if not (a_is_ragged or b_is_ragged): 

854 return math_ops.matmul(a, b, **kwargs) 

855 

856 if a.dtype != b.dtype: 

857 raise ValueError('`a` and `b` must have the same dtype.') 

858 

859 # TODO(edloper): Support broadcasting inputs. (Broadcast support is not 

860 # documented by https://www.tensorflow.org/api_docs/python/tf/linalg/matmul, 

861 # but it is supported by the op.) 

862 

863 # Find the rank of the input tensors. 

864 if a.shape.rank is None: 

865 if b.shape.rank is None: 

866 raise ValueError('matmul requires at least one input to have known ' 

867 'rank if either input is ragged.') 

868 rank = b.shape.rank 

869 else: 

870 if b.shape.rank is not None and a.shape.rank != b.shape.rank: 

871 raise ValueError('`a` and `b` must have the same rank.') 

872 rank = a.shape.rank 

873 

874 # At least one of `a` and `b` is ragged; and ragged tensors always have 

875 # rank>=2. 

876 if rank < 2: 

877 # This can happen if e.g. `a` is a 1D dense tensor and `b` is a 

878 # ragged tensor with unknown rank. Since ragged tensors always have 

879 # `rank>=2`, this implies that `a` and `b` have different ranks. 

880 raise ValueError('`a` and `b` must have the same rank.') 

881 

882 # Rank>3: We have multiple batch dimensions. Merge them into a single 

883 # batch dimension, recursively call `matmul`, and then restore the original 

884 # batch dimension (using a.row_splits). 

885 if rank > 3: 

886 shape_err = 'Batch dimensions of `a` and `b` do not have the same size.' 

887 if not a_is_ragged: 

888 a = ragged_tensor.RaggedTensor.from_tensor(a, ragged_rank=1) 

889 if not b_is_ragged: 

890 b = ragged_tensor.RaggedTensor.from_tensor(b, ragged_rank=1) 

891 with ops.control_dependencies([ 

892 check_ops.assert_equal(a.row_splits, b.row_splits, message=shape_err) 

893 ]): 

894 flat_result = matmul(a.values, b.values, **kwargs) 

895 return a.with_values(flat_result) 

896 

897 if rank == 2: 

898 return _matmul_2d(a, b, **kwargs) 

899 

900 assert rank == 3 # I.e., we have a single batch dimension. 

901 

902 a_ragged_rank = a.ragged_rank if a_is_ragged else 0 

903 if a_ragged_rank == 1 and not (b_is_ragged or transpose_a or adjoint_a): 

904 # If `a.shape=[B, (I), J]` and `b.shape=[B, J, K], then we can compute 

905 # the result with a single dense `matmul`. 

906 return _matmul_3d_with_batch_dim_folding(a, b, **kwargs) 

907 else: 

908 # Otherwie, fall back on using `map_fn`. 

909 return _matmul_3d_with_map_fn(a, b, **kwargs) 

910 

911 

912def _matmul_2d(a, b, **kwargs): 

913 """Multiplies potentially ragged 2D tensors. 

914 

915 Args: 

916 a: A 2D Tensor or RaggedTensor with `shape=[I, J]` 

917 b: A 2D Tensor or RaggedTensor with `shape=[J, K]` 

918 **kwargs: Additional arguments for `tf.matmul` (e.g. transpose_a). 

919 

920 Returns: 

921 A 2D Tensor with `shape=[I, K]`. 

922 """ 

923 # multiplying `a` and `b` is only well-defined if `a` and `b` are 

924 # actually uniform (and just happened to be stored as ragged tensors). 

925 # Check that they're uniform, convert them to tf.Tensor. 

926 ragged_err = ('The matrices in `a` and `b` may not be ' 

927 'ragged in their innermost dimension.') 

928 checks = [] 

929 if isinstance(a, ragged_tensor.RaggedTensor): 

930 original_size = array_ops.size(a.flat_values) 

931 a = a.to_tensor() 

932 checks.append( 

933 check_ops.assert_equal( 

934 original_size, array_ops.size(a), message=ragged_err)) 

935 if isinstance(b, ragged_tensor.RaggedTensor): 

936 original_size = array_ops.size(b.flat_values) 

937 b = b.to_tensor() 

938 checks.append( 

939 check_ops.assert_equal( 

940 original_size, array_ops.size(b), message=ragged_err)) 

941 with ops.control_dependencies(checks): 

942 return math_ops.matmul(a, b, **kwargs) 

943 

944 

945def _matmul_3d_with_map_fn(a, b, **kwargs): 

946 """Multiplies batches of 2D matrices using map_fn. 

947 

948 `output[n, i, k]` = sum_j (a[n, i, j] * b[n, j, k])` (for all `n`, `i`, `k`). 

949 

950 Requires that `a[n, i].nrows()` == `b[n].nrows()` (for all `n` and `i`). 

951 

952 Args: 

953 a: A 3D Tensor or RaggedTensor with `shape=[B, I, J]`, where dimensions `I` 

954 and `J` may be ragged. 

955 b: A 3D Tensor or RaggedTensor with `shape=[B, J, K]`, where dimensions `J` 

956 and `K` may be ragged. 

957 **kwargs: Additional arguments for `tf.matmul` (e.g. transpose_a). 

958 

959 Returns: 

960 A 3D RaggedTensor with `shape=[B, (I), (K)]`. 

961 """ 

962 # Determine the ragged rank of the result. In the normal case, we have: 

963 # [B, I, J] * [B, J, K] -> [B, I, K] 

964 # Or if we're using transpose_b, then we have: 

965 # [B, I, J] * [B, K, J] -> [B, I, K] 

966 # In either case, output_ragged_rank=2 iff the K dimension is ragged. 

967 if (isinstance(b, ragged_tensor.RaggedTensor) and 

968 (b.ragged_rank == 2 or kwargs.get('transpose_b') or 

969 kwargs.get('adjoint_b'))): 

970 output_ragged_rank = 2 

971 else: 

972 output_ragged_rank = 1 

973 

974 def single_batch_matmul(x): 

975 out = _matmul_2d(x[0], x[1], **kwargs) 

976 if output_ragged_rank == 2: 

977 out = ragged_tensor.RaggedTensor.from_tensor(out) 

978 return out 

979 

980 fn_out_shape = None # Figure out proper shape. 

981 row_splits_dtype = ( 

982 a.row_splits.dtype 

983 if isinstance(a, ragged_tensor.RaggedTensor) else b.row_splits.dtype) 

984 output_type = kwargs['output_type'] 

985 if output_type is None: 

986 output_type = a.dtype 

987 spec = ragged_tensor.RaggedTensorSpec( 

988 shape=fn_out_shape, 

989 dtype=output_type, 

990 ragged_rank=output_ragged_rank - 1, 

991 row_splits_dtype=row_splits_dtype) 

992 result = map_fn.map_fn( 

993 single_batch_matmul, elems=(a, b), fn_output_signature=spec) 

994 

995 # map_fn loses shape information; restore it, where possible. 

996 # pylint: disable=protected-access 

997 if kwargs.get('transpose_a') or kwargs.get('adjoint_a'): 

998 result._set_shape(a.shape[:-2] + a.shape[-1:] + [None]) 

999 else: 

1000 result._set_shape(a.shape[:-2] + a.shape[-2:-1] + [None]) 

1001 if kwargs.get('transpose_b') or kwargs.get('adjoint_b'): 

1002 result._set_shape(b.shape[:-2] + [None] + b.shape[-2:-1]) 

1003 else: 

1004 result._set_shape(b.shape[:-2] + [None] + b.shape[-1:]) 

1005 

1006 return result 

1007 

1008 

1009def _matmul_3d_with_batch_dim_folding(a, b, **kwargs): 

1010 """Multiply batches of 2D matrices where only `a.shape[1]` is ragged. 

1011 

1012 Args: 

1013 a: A RaggedTensor with `shape=[B, (I), J]`. (ragged_rank must be 1.) 

1014 b: A Tensor with `shape=[B, J, K]` 

1015 **kwargs: Additional arguments for `tf.matmul` (e.g. transpose_a). 

1016 transpose_a and adjoint_a must not be true. 

1017 

1018 Returns: 

1019 A RaggedTensor with `shape=[B, (I), K]. 

1020 """ 

1021 # reshaped_a.shape = [sum(i_1, i_2, ..., i_B), 1, J] 

1022 reshaped_a = array_ops.expand_dims(a.values, 1) 

1023 # reshaped_b.shape = [sum(i_1, i_2, ..., i_B), J, K] 

1024 reshaped_b = array_ops.repeat(b, a.row_lengths(), axis=0) 

1025 # flat_result.shape = [sum(i_1, i_2, ..., i_B), 1, K] 

1026 flat_result = math_ops.matmul(reshaped_a, reshaped_b, **kwargs) 

1027 # result.shape = [B, (I), K] 

1028 return a.with_values(array_ops.squeeze(flat_result, axis=1)) 

1029 

1030 

1031#=============================================================================== 

1032# ragged.softmax 

1033#=============================================================================== 

1034@dispatch.dispatch_for_api(nn_ops.softmax_v2) 

1035def softmax(logits: ragged_tensor.Ragged, axis=None, name=None): 

1036 """Computes softmax activations. 

1037 

1038 Used for multi-class predictions. The sum of all outputs generated by softmax 

1039 is 1. 

1040 

1041 This function performs the equivalent of 

1042 

1043 softmax = tf.exp(logits) / tf.reduce_sum(tf.exp(logits), axis) 

1044 

1045 Example usage: 

1046 

1047 >>> softmax = tf.nn.softmax([-1, 0., 1.]) 

1048 >>> softmax 

1049 <tf.Tensor: shape=(3,), dtype=float32, 

1050 numpy=array([0.09003057, 0.24472848, 0.66524094], dtype=float32)> 

1051 >>> sum(softmax) 

1052 <tf.Tensor: shape=(), dtype=float32, numpy=1.0> 

1053 

1054 Args: 

1055 logits: A non-empty `Tensor`. Must be one of the following types: `half`, 

1056 `float32`, `float64`. 

1057 axis: The dimension softmax would be performed on. The default is -1 which 

1058 indicates the last dimension. 

1059 name: A name for the operation (optional). 

1060 

1061 Returns: 

1062 A `Tensor`. Has the same type and shape as `logits`. 

1063 

1064 Raises: 

1065 InvalidArgumentError: if `logits` is empty or `axis` is beyond the last 

1066 dimension of `logits`. 

1067 """ 

1068 if axis is None: 

1069 axis = -1 

1070 

1071 with ops.name_scope(name, 'RaggedSoftmax', [logits]) as name: 

1072 max_input = reduce_max(logits, axis=axis, keepdims=True) 

1073 logits_exp = math_ops.exp(math_ops.subtract(logits, max_input)) 

1074 denominator = reduce_sum(logits_exp, axis=axis, keepdims=True) 

1075 return math_ops.divide(logits_exp, denominator) 

1076 

1077 

1078#=============================================================================== 

1079# ragged.add_n 

1080#=============================================================================== 

1081@dispatch.dispatch_for_api(math_ops.add_n) 

1082def add_n(inputs: typing.List[ragged_tensor.RaggedOrDense], name=None): 

1083 """RaggedTensor implementation for tf.math.add_n.""" 

1084 if len(inputs) < 0: 

1085 raise ValueError('tf.add_n: expected at least one input.') 

1086 with ops.name_scope(name, 'RaggedAddN', inputs): 

1087 return ragged_functional_ops.map_flat_values(math_ops.add_n, inputs) 

1088 

1089 

1090#=============================================================================== 

1091# Ragged version of nn_ops.dropout 

1092#=============================================================================== 

1093@dispatch.dispatch_for_api(nn_ops.dropout) 

1094def dropout_v1(x: ragged_tensor.Ragged, 

1095 keep_prob=None, 

1096 noise_shape=None, 

1097 seed=None, 

1098 name=None, 

1099 rate=None): 

1100 """Ragged dispatch target for tf.nn.dropout.""" 

1101 if noise_shape is not None: 

1102 raise ValueError('noise_shape is not supported yet for RaggedTensor x') 

1103 with ops.name_scope(name, 'RaggedNNDropout', [x, rate]): 

1104 x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name='x') 

1105 return x.with_flat_values( 

1106 nn_ops.dropout( 

1107 x.flat_values, keep_prob=keep_prob, seed=seed, rate=rate)) 

1108 

1109 

1110@dispatch.dispatch_for_api(nn_ops.dropout_v2) 

1111def dropout_v2(x: ragged_tensor.Ragged, 

1112 rate, 

1113 noise_shape=None, 

1114 seed=None, 

1115 name=None): 

1116 """Ragged dispatch target for tf.nn.dropout.""" 

1117 if noise_shape is not None: 

1118 raise ValueError('noise_shape is not supported yet for RaggedTensor x') 

1119 with ops.name_scope(name, 'RaggedNNDropout', [x, rate]): 

1120 x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name='x') 

1121 return x.with_flat_values( 

1122 nn_ops.dropout_v2(x.flat_values, rate=rate, seed=seed)) 

1123 

1124 

1125@dispatch.dispatch_for_api(nn_ops.stateless_dropout) 

1126def stateless_dropout(x: ragged_tensor.Ragged, 

1127 rate, 

1128 seed, 

1129 rng_alg=None, 

1130 noise_shape=None, 

1131 name=None): 

1132 """Ragged dispatch target for tf.nn.experimental.stateless_dropout.""" 

1133 if noise_shape is not None: 

1134 raise ValueError('noise_shape is not supported yet for RaggedTensor x') 

1135 with ops.name_scope(name, 'RaggedNNStatelessDropout', [x, rate]): 

1136 x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name='x') 

1137 return x.with_flat_values( 

1138 nn_ops.stateless_dropout( 

1139 x.flat_values, rate=rate, seed=seed, rng_alg=rng_alg)) 

1140 

1141 

1142#=============================================================================== 

1143# Ragged version of Tensor.__eq__ and Tensor.__ne__ 

1144#=============================================================================== 

1145@dispatch.dispatch_for_api(math_ops.tensor_equals) 

1146def tensor_equals(self: ragged_tensor.RaggedOrDense, 

1147 other: ragged_tensor.RaggedOrDense): 

1148 """Ragged version of the operation invoked by `Tensor.__eq__`.""" 

1149 if other is None: 

1150 return False 

1151 elif _use_legacy_mode_for_tensor_equality(self): 

1152 return self is other 

1153 else: 

1154 try: 

1155 return math_ops.equal(self, other) 

1156 except (errors.InvalidArgumentError, ValueError): 

1157 return False # values are not broadcast-compatbile. 

1158 

1159 

1160@dispatch.dispatch_for_api(math_ops.tensor_not_equals) 

1161def tensor_not_equals(self: ragged_tensor.RaggedOrDense, 

1162 other: ragged_tensor.RaggedOrDense): 

1163 """Ragged version of the operation invoked by `Tensor.__ne__`.""" 

1164 if other is None: 

1165 return False 

1166 elif _use_legacy_mode_for_tensor_equality(self): 

1167 return self is not other 

1168 else: 

1169 try: 

1170 return math_ops.not_equal(self, other) 

1171 except (errors.InvalidArgumentError, ValueError): 

1172 return True # values are not broadcast-compatbile. 

1173 

1174 

1175def _use_legacy_mode_for_tensor_equality(self): 

1176 g = getattr(self, 'graph', None) 

1177 return not (ops.Tensor._USE_EQUALITY and # pylint: disable=protected-access 

1178 ops.executing_eagerly_outside_functions() and 

1179 (g is None or g.building_function)) 

1180 

1181 

1182def _cumsum_flat_values_at_ragged_rank(last_rp, flat_values, exclusive=False, 

1183 reverse=False): 

1184 """Calculate flat_values for math_ops.cumsum when axis==ragged_rank.""" 

1185 if not exclusive: 

1186 partial = _cumsum_flat_values_at_ragged_rank( 

1187 last_rp, flat_values, exclusive=True, reverse=reverse) 

1188 return partial + flat_values 

1189 

1190 if reverse: 

1191 youngest_sibling = array_ops.gather( 

1192 params=last_rp.row_splits(), indices=last_rp.value_rowids() + 1) - 1 

1193 new_flat_values = math_ops.cumsum(flat_values, exclusive=True, reverse=True) 

1194 initial_values = array_ops.gather(params=new_flat_values, 

1195 indices=youngest_sibling) 

1196 

1197 return new_flat_values - initial_values 

1198 else: 

1199 eldest_sibling = array_ops.gather( 

1200 params=last_rp.row_splits(), indices=last_rp.value_rowids()) 

1201 new_flat_values = math_ops.cumsum(flat_values, exclusive=True) 

1202 initial_values = array_ops.gather(params=new_flat_values, 

1203 indices=eldest_sibling) 

1204 return new_flat_values - initial_values 

1205 

1206 

1207@dispatch.dispatch_for_api(math_ops.cumsum) 

1208def ragged_cumsum(x: ragged_tensor.Ragged, 

1209 axis: int = 0, 

1210 exclusive: bool = False, 

1211 reverse: bool = False, 

1212 name: typing.Optional[str] = None): 

1213 """Calculate math_ops.cumsum for a RaggedTensor. 

1214 

1215 Given a ragged tensor `x`, the `result` is a ragged tensor with the same 

1216 shape. One can calculate the value of `result[i_1...i_k]` as follows: 

1217 ``` 

1218 dense_result=tf.math.cumsum(rt.to_tensor(), axis=axis, exclusive=exclusive, 

1219 reverse=reverse) 

1220 result[i_1...i_k]=dense_result[i_1...i_k] 

1221 ``` 

1222 

1223 Args: 

1224 x: the original ragged tensor to sum. 

1225 axis: the axis along which to sum, can range -rank<=axis<rank. 

1226 exclusive: is the sum exclusive or inclusive? If True, then result[0]=0. 

1227 If False, then result[0]=x[0]. 

1228 reverse: If True, sum from back to front. 

1229 name: the name of the op. 

1230 Returns: 

1231 the cumulative sum. 

1232 """ 

1233 with ops.name_scope(name, 'RaggedCumSum', [x, axis, exclusive, reverse]): 

1234 axis = array_ops.get_positive_axis(axis, x.shape.rank, ndims_name='rank') 

1235 if axis == x.ragged_rank: 

1236 last_rp = x._nested_row_partitions[-1] # pylint: disable=protected-access 

1237 return x.with_flat_values( 

1238 _cumsum_flat_values_at_ragged_rank(last_rp, x.flat_values, 

1239 exclusive=exclusive, 

1240 reverse=reverse)) 

1241 elif axis > x.ragged_rank: 

1242 new_axis = axis - x.ragged_rank 

1243 cumsum_bound = functools.partial( 

1244 math_ops.cumsum, axis=new_axis, exclusive=exclusive, reverse=reverse) 

1245 return ragged_functional_ops.map_flat_values(cumsum_bound, x) 

1246 else: 

1247 dense_version = x.to_tensor() 

1248 result = math_ops.cumsum( 

1249 dense_version, axis, exclusive=exclusive, reverse=reverse, name=name) 

1250 return ragged_tensor.RaggedTensor.from_tensor( 

1251 result, lengths=x.nested_row_lengths())