Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ragged/ragged_math_ops.py: 28%
386 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Support for ragged tensors."""
17import functools
18import typing
20import numpy as np
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import errors
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import tensor_util
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import check_ops
28from tensorflow.python.ops import gen_ragged_math_ops
29from tensorflow.python.ops import map_fn
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops import nn_ops
32from tensorflow.python.ops.ragged import ragged_functional_ops
33from tensorflow.python.ops.ragged import ragged_tensor
34from tensorflow.python.ops.ragged import segment_id_ops
35from tensorflow.python.util import dispatch
36from tensorflow.python.util.tf_export import tf_export
39#===============================================================================
40# ragged.range
41#===============================================================================
42# pylint: disable=redefined-builtin
43@tf_export('ragged.range')
44@dispatch.add_dispatch_support
45def range(starts,
46 limits=None,
47 deltas=1,
48 dtype=None,
49 name=None,
50 row_splits_dtype=dtypes.int64):
51 """Returns a `RaggedTensor` containing the specified sequences of numbers.
53 Each row of the returned `RaggedTensor` contains a single sequence:
55 ```python
56 ragged.range(starts, limits, deltas)[i] ==
57 tf.range(starts[i], limits[i], deltas[i])
58 ```
60 If `start[i] < limits[i] and deltas[i] > 0`, then `output[i]` will be an
61 empty list. Similarly, if `start[i] > limits[i] and deltas[i] < 0`, then
62 `output[i]` will be an empty list. This behavior is consistent with the
63 Python `range` function, but differs from the `tf.range` op, which returns
64 an error for these cases.
66 Examples:
68 >>> tf.ragged.range([3, 5, 2]).to_list()
69 [[0, 1, 2], [0, 1, 2, 3, 4], [0, 1]]
70 >>> tf.ragged.range([0, 5, 8], [3, 3, 12]).to_list()
71 [[0, 1, 2], [], [8, 9, 10, 11]]
72 >>> tf.ragged.range([0, 5, 8], [3, 3, 12], 2).to_list()
73 [[0, 2], [], [8, 10]]
75 The input tensors `starts`, `limits`, and `deltas` may be scalars or vectors.
76 The vector inputs must all have the same size. Scalar inputs are broadcast
77 to match the size of the vector inputs.
79 Args:
80 starts: Vector or scalar `Tensor`. Specifies the first entry for each range
81 if `limits` is not `None`; otherwise, specifies the range limits, and the
82 first entries default to `0`.
83 limits: Vector or scalar `Tensor`. Specifies the exclusive upper limits for
84 each range.
85 deltas: Vector or scalar `Tensor`. Specifies the increment for each range.
86 Defaults to `1`.
87 dtype: The type of the elements of the resulting tensor. If not specified,
88 then a value is chosen based on the other args.
89 name: A name for the operation.
90 row_splits_dtype: `dtype` for the returned `RaggedTensor`'s `row_splits`
91 tensor. One of `tf.int32` or `tf.int64`.
93 Returns:
94 A `RaggedTensor` of type `dtype` with `ragged_rank=1`.
95 """
96 row_splits_dtype = dtypes.as_dtype(row_splits_dtype)
97 if limits is None:
98 starts, limits = 0, starts
100 with ops.name_scope(name, 'RaggedRange', [starts, limits, deltas]) as name:
101 starts = ops.convert_to_tensor(starts, dtype=dtype, name='starts')
102 limits = ops.convert_to_tensor(limits, dtype=dtype, name='limits')
103 deltas = ops.convert_to_tensor(deltas, dtype=dtype, name='deltas')
105 # infer dtype if not explicitly provided
106 if dtype is None:
107 starts, limits, deltas = _infer_matching_dtype(
108 [starts, limits, deltas],
109 [dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64])
111 result = gen_ragged_math_ops.ragged_range(
112 starts, limits, deltas, Tsplits=row_splits_dtype, name=name)
113 return ragged_tensor.RaggedTensor.from_row_splits(
114 result.rt_dense_values, result.rt_nested_splits, validate=False)
117def _infer_matching_dtype(tensors, dtype_hierarchy):
118 """Infers a matching dtype for tensors, and casts them to that dtype."""
119 assert all(t.dtype in dtype_hierarchy for t in tensors)
120 inferred_dtype = max([t.dtype for t in tensors], key=dtype_hierarchy.index)
121 return [math_ops.cast(t, inferred_dtype) for t in tensors]
124ops.no_gradient('RaggedRange')
126#===============================================================================
127# ragged_segment_<AGGREGATE>
128#===============================================================================
130# Docstring template used for the raggged_segment_<AGGREGATE> ops.
131_RAGGED_SEGMENT_DOCSTRING = """\
132Computes the %(combination)s along segments of a RaggedTensor.
134 Returns a RaggedTensor `output` with `num_segments` rows, where the row
135 `output[i]` is formed by taking the %(combination)s of all rows of `data`
136 whose corresponding `segment_id` is `i`.
138 The length of the row `output[i]` will be the maximum of the lengths of
139 all rows of `data` whose corresponding `segment_id` is `i`. If no `data`
140 rows correspond to a given segment ID, then the output row for that segment
141 ID will be empty.
143 Args:
144 data: A `RaggedTensor` containing the values to combine.
145 segment_ids: A `Tensor` or `RaggedTensor`. Must have type `int64` or
146 `int32`. `segment_ids.shape` must be a prefix of `data.shape`.
147 Must be greater than or equal to zero, and less than `num_segments`.
148 `segment_ids` is not required to be sorted.
149 num_segments: An `int32` or `int64` scalar specifying the number of
150 distinct segment ids.
151 name: A name prefix for the returned tensor (optional).
152 Returns:
153 A `RaggedTensor` containing the %(combined)s values. The returned tensor
154 has the same dtype as `data`, and its shape is
155 `[num_segments] + data.shape[segment_ids.rank:]`.
156 Raises:
157 ValueError: If `segment_ids.shape` is not a prefix of `data.shape`.
158"""
161def _ragged_segment_aggregate(unsorted_segment_op,
162 data,
163 segment_ids,
164 num_segments,
165 separator=None,
166 name=None):
167 """Aggregates along segments of a RaggedTensor using `unsorted_segment_op`.
169 Returns a RaggedTensor `output` with `num_segments` rows, where the row
170 `output[i]` is formed by combining all rows of `data` whose corresponding
171 `segment_id` is `i`. The values in each row are combined using
172 `unsorted_segment_op`.
174 The length of the row `output[i]` will be the maximum of the lengths of
175 all rows of `data` whose corresponding `segment_id` is `i`. If no `data`
176 rows correspond to a given segment ID, then the output row for that segment
177 ID will be empty.
179 Args:
180 unsorted_segment_op: The tensorflow `op` that should be used to combine
181 values in each row. Must have the same signature and basic behavior as
182 `unsorted_segment_sum`, `unsorted_segment_max`, etc.
183 data: A `RaggedTensor` containing the values to be combined.
184 segment_ids: A `Tensor` or `RaggedTensor`. Must have type `int64` or
185 `int32`. `segment_ids.shape` must be a prefix of `data.shape`.
186 `segment_ids` is not required to be sorted.
187 num_segments: An `int32` or `int64` scalar.
188 separator: An optional string. Defaults to None. The separator to use when
189 joining. Only used for string types.
190 name: A name prefix for the returned tensor (optional).
192 Returns:
193 A `RaggedTensor` containing the aggregated values. The returned tensor
194 has the same dtype as `data`, and its shape is
195 `[num_segments] + data.shape[segment_ids.rank:]`.
196 Raises:
197 ValueError: If segment_ids.shape is not a prefix of data.shape.
198 """
199 if not (ragged_tensor.is_ragged(data) or
200 ragged_tensor.is_ragged(segment_ids)):
201 if separator is not None:
202 # It uses unsorted_segment_join.
203 return unsorted_segment_op(data, segment_ids, num_segments, separator,
204 name)
205 else:
206 return unsorted_segment_op(data, segment_ids, num_segments, name)
208 with ops.name_scope(name, 'RaggedSegment',
209 [data, segment_ids, num_segments]) as name:
210 data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data')
211 segment_ids = ragged_tensor.convert_to_tensor_or_ragged_tensor(
212 segment_ids, name='segment_ids')
213 data, segment_ids = ragged_tensor.match_row_splits_dtypes(data, segment_ids)
214 if segment_ids.dtype not in (dtypes.int32, dtypes.int64):
215 raise ValueError('segment_ids must have dtype int32 or int64.')
217 if ragged_tensor.is_ragged(segment_ids):
218 if not ragged_tensor.is_ragged(data):
219 raise ValueError('segment_ids.shape must be a prefix of data.shape, '
220 'but segment_ids is ragged and data is not.')
221 check_splits = check_ops.assert_equal(
222 segment_ids.row_splits,
223 data.row_splits,
224 message='segment_ids.shape must be a prefix of data.shape')
225 with ops.control_dependencies([check_splits]):
226 return _ragged_segment_aggregate(unsorted_segment_op, data.values,
227 segment_ids.values, num_segments,
228 separator)
230 # Find the length of each row in data. (shape=[data_nrows])
231 data_row_lengths = data.row_splits[1:] - data.row_splits[:-1]
233 # Find the length that each output row will have. The length of the row
234 # corresponding to segment `id` is `max(data_row_lengths[i])` where
235 # `segment_ids[i]=id`. (shape=[output_nrows])
236 output_row_lengths = math_ops.maximum(
237 math_ops.unsorted_segment_max(data_row_lengths, segment_ids,
238 num_segments), 0)
240 # Build the splits tensor for the output RaggedTensor.
241 output_splits = array_ops.concat([
242 array_ops.zeros([1], output_row_lengths.dtype),
243 math_ops.cumsum(output_row_lengths)
244 ],
245 axis=0)
247 # For each row in `data`, find the start & limit position where that row's
248 # values will be aggregated in output.values.
249 data_row_to_out_row_start = array_ops.gather(output_splits, segment_ids)
250 data_row_to_out_row_limit = data_row_to_out_row_start + data_row_lengths
252 # For each value in `data.values`, find the position where it will
253 # aggregated in `output.values`.
254 # Get the target output values index for each data values index.
255 data_val_to_out_val_index = range(data_row_to_out_row_start,
256 data_row_to_out_row_limit).values
258 # Recursively aggregate the values.
259 output_values = _ragged_segment_aggregate(unsorted_segment_op, data.values,
260 data_val_to_out_val_index,
261 output_splits[-1], separator)
262 return ragged_tensor.RaggedTensor.from_row_splits(
263 output_values, output_splits, validate=False)
266@dispatch.dispatch_for_api(math_ops.unsorted_segment_sum)
267def segment_sum(data: ragged_tensor.RaggedOrDense,
268 segment_ids: ragged_tensor.RaggedOrDense,
269 num_segments,
270 name=None):
271 # For docs, see: _RAGGED_SEGMENT_DOCSTRING
272 return _ragged_segment_aggregate(
273 math_ops.unsorted_segment_sum,
274 data=data,
275 segment_ids=segment_ids,
276 num_segments=num_segments,
277 name=(name or 'RaggedSegmentSum'))
280@dispatch.dispatch_for_api(math_ops.unsorted_segment_prod)
281def segment_prod(data: ragged_tensor.RaggedOrDense,
282 segment_ids: ragged_tensor.RaggedOrDense,
283 num_segments,
284 name=None):
285 # For docs, see: _RAGGED_SEGMENT_DOCSTRING
286 return _ragged_segment_aggregate(
287 math_ops.unsorted_segment_prod,
288 data=data,
289 segment_ids=segment_ids,
290 num_segments=num_segments,
291 name=(name or 'RaggedSegmentProd'))
294@dispatch.dispatch_for_api(math_ops.unsorted_segment_min)
295def segment_min(data: ragged_tensor.RaggedOrDense,
296 segment_ids: ragged_tensor.RaggedOrDense,
297 num_segments,
298 name=None):
299 # For docs, see: _RAGGED_SEGMENT_DOCSTRING
300 return _ragged_segment_aggregate(
301 math_ops.unsorted_segment_min,
302 data=data,
303 segment_ids=segment_ids,
304 num_segments=num_segments,
305 name=(name or 'RaggedSegmentMin'))
308@dispatch.dispatch_for_api(math_ops.unsorted_segment_max)
309def segment_max(data: ragged_tensor.RaggedOrDense,
310 segment_ids: ragged_tensor.RaggedOrDense,
311 num_segments,
312 name=None):
313 # For docs, see: _RAGGED_SEGMENT_DOCSTRING
314 return _ragged_segment_aggregate(
315 math_ops.unsorted_segment_max,
316 data=data,
317 segment_ids=segment_ids,
318 num_segments=num_segments,
319 name=(name or 'RaggedSegmentMax'))
322@dispatch.dispatch_for_api(math_ops.unsorted_segment_mean)
323def segment_mean(data: ragged_tensor.RaggedOrDense,
324 segment_ids: ragged_tensor.RaggedOrDense,
325 num_segments,
326 name=None):
327 """For docs, see: _RAGGED_SEGMENT_DOCSTRING."""
328 with ops.name_scope(name, 'RaggedSegmentMean',
329 [data, segment_ids, num_segments]):
330 total = segment_sum(data, segment_ids, num_segments)
331 ones = ragged_tensor.RaggedTensor.from_nested_row_splits(
332 array_ops.ones_like(data.flat_values),
333 data.nested_row_splits,
334 validate=False)
335 count = segment_sum(ones, segment_ids, num_segments)
336 if ragged_tensor.is_ragged(total):
337 return total.with_flat_values(total.flat_values / count.flat_values)
338 else:
339 return total / count
342@dispatch.dispatch_for_api(math_ops.unsorted_segment_sqrt_n)
343def segment_sqrt_n(data: ragged_tensor.RaggedOrDense,
344 segment_ids: ragged_tensor.RaggedOrDense,
345 num_segments,
346 name=None):
347 """For docs, see: _RAGGED_SEGMENT_DOCSTRING."""
348 with ops.name_scope(name, 'RaggedSegmentSqrtN',
349 [data, segment_ids, num_segments]):
350 total = segment_sum(data, segment_ids, num_segments)
351 ones = ragged_tensor.RaggedTensor.from_nested_row_splits(
352 array_ops.ones_like(data.flat_values),
353 data.nested_row_splits,
354 validate=False)
355 count = segment_sum(ones, segment_ids, num_segments)
356 if ragged_tensor.is_ragged(total):
357 return total.with_flat_values(total.flat_values /
358 math_ops.sqrt(count.flat_values))
359 else:
360 return total / math_ops.sqrt(count)
363def _set_ragged_segment_docstring(func, combination, combined):
364 func.__doc__ = _RAGGED_SEGMENT_DOCSTRING % dict(
365 combination=combination, combined=combined)
368_set_ragged_segment_docstring(segment_sum, 'sum', 'summed')
369_set_ragged_segment_docstring(segment_prod, 'product', 'multiplied')
370_set_ragged_segment_docstring(segment_min, 'minimum', 'minimized')
371_set_ragged_segment_docstring(segment_max, 'maximum', 'maximized')
372_set_ragged_segment_docstring(segment_mean, 'mean', 'averaged')
373_set_ragged_segment_docstring(segment_sqrt_n, 'sum divided by sqrt(N)',
374 'summed')
376#===============================================================================
377# ragged_reduce_<AGGREGATE>
378#===============================================================================
380# Docstring template used for ragged_reduce_<AGGREGATE> ops.
381_RAGGED_REDUCE_DOCSTRING = """\
382Computes the %(combination)s of elements across dimensions of a `RaggedTensor`.
384 Reduces `input_tensor` along the dimensions given in `axis` by taking the
385 %(combination)s of values. If a reduced dimension has no elements for
386 some index, then the value for that index will be %(default)s.
388 The rank of the tensor is reduced by `1` for each entry in `axis`. If
389 `axis` is not specified, then all dimensions are reduced, and a scalar
390 value is returned.
391 Args:
392 input_tensor: A `RaggedTensor` containing the values to be %(combined)s.
393 axis: The dimensions to reduce. May be `None` (to reduce all axes), an
394 `int` (to reduce a single axis), a `list` or `tuple` of `int` (to reduce
395 a given set of axes), or a `Tensor` with a constant value. Must be in
396 the range `[0, input_tensor.rank]`.
397 name: A name prefix for the returned tensor (optional).
398 Returns:
399 A `RaggedTensor` containing the %(combined)s values. The returned tensor
400 has the same dtype as `data`, and its shape is given by removing the
401 dimensions specified in `axis` from `input_tensor.shape`. The `ragged_rank`
402 of the returned tensor is given by substracting any ragged dimensions
403 specified in `axis` from `input_tensor.ragged_rank`.
404 Raises:
405 ValueError: If `axis` contains a `Tensor` whose value is not constant.
406 ####Example:
407 %(example)s
408"""
409_RAGGED_REDUCE_SUM_EXAMPLE = """
410 >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]])
411 >>> tf.reduce_sum(rt, axis=0).numpy() # = [3+1+9+2, 1+5+6, 4]
412 array([15, 12, 4], dtype=int32)
413 >>> tf.reduce_sum(rt, axis=1).numpy() # = [3+1+4, 1+5, 9, 2+6]
414 array([8, 6, 9, 8], dtype=int32)
415"""
416_RAGGED_REDUCE_PROD_EXAMPLE = """
417 >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]])
418 >>> tf.reduce_prod(rt, axis=0).numpy() # = [3*1*9*2, 1*5*6, 4]
419 array([54, 30, 4], dtype=int32)
420 >>> tf.reduce_prod(rt, axis=1).numpy() # = [3*1*4, 1*5, 9, 2*6]
421 array([12, 5, 9, 12], dtype=int32)
422"""
423_RAGGED_REDUCE_MIN_EXAMPLE = """
424 >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]])
425 >>> tf.reduce_min(rt, axis=0).numpy()
426 array([1, 1, 4], dtype=int32)
427 >>> tf.reduce_min(rt, axis=1).numpy()
428 array([1, 1, 9, 2], dtype=int32)
429"""
430_RAGGED_REDUCE_MAX_EXAMPLE = """
431 >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]])
432 >>> tf.reduce_max(rt, axis=0).numpy()
433 array([9, 6, 4], dtype=int32)
434 >>> tf.reduce_max(rt, axis=1).numpy()
435 array([4, 5, 9, 6], dtype=int32)
436"""
437_RAGGED_REDUCE_MEAN_EXAMPLE = """
438 >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]])
439 >>> tf.reduce_mean(rt, axis=0).numpy()
440 array([3.75, 4. , 4. ])
441 >>> tf.reduce_mean(rt, axis=1).numpy()
442 array([2.66666667, 3. , 9. , 4. ])
443"""
444_RAGGED_REDUCE_VARIANCE_EXAMPLE = """
445 >>> rt = tf.ragged.constant([[1, 1, 4], [2, 1], [3], [4, 1]],
446 ... dtype=tf.float64)
447 >>> tf.math.reduce_variance(rt, axis=0).numpy()
448 array([1.25, 0., 0.])
449 >>> tf.math.reduce_variance(rt, axis=1).numpy()
450 array([2., 0.25, 0., 2.25])
451"""
452_RAGGED_REDUCE_STD_EXAMPLE = """
453 >>> rt = tf.ragged.constant([[1, 0], [2, 1], [3], [4, 1]],
454 ... dtype=tf.float64)
455 >>> tf.math.reduce_std(rt, axis=0).numpy()
456 array([1.11803399, 0.47140452])
457 >>> tf.math.reduce_std(rt, axis=1).numpy()
458 array([0.5, 0.5, 0., 1.5])
459"""
460_RAGGED_REDUCE_ALL_EXAMPLE = """
461 >>> rt = tf.ragged.constant([[True, True], [True, True, False, True], [False, True]])
462 >>> tf.reduce_all(rt, axis=0).numpy()
463 array([False, True, False, True])
464 >>> tf.reduce_all(rt, axis=1).numpy()
465 array([ True, False, False])
466"""
467_RAGGED_REDUCE_ANY_EXAMPLE = """
468 >>> rt = tf.ragged.constant([[True, True], [True, True, False, True], [False, True]])
469 >>> tf.reduce_any(rt, axis=0).numpy()
470 array([ True, True, False, True])
471 >>> tf.reduce_any(rt, axis=1).numpy()
472 array([ True, True, True])
473"""
476def ragged_reduce_aggregate(reduce_op,
477 unsorted_segment_op,
478 rt_input,
479 axis,
480 keepdims,
481 separator=None,
482 name=None):
483 """Aggregates across axes of a RaggedTensor using the given `Tensor` ops.
485 Reduces `rt_input` along the dimensions given in `axis`. The rank of the
486 tensor is reduced by 1 for each entry in `axis`. If `axis` is not specified,
487 then all dimensions are reduced, and a scalar value is returned.
489 This op assumes that `reduce_op` and `unsorted_segment_op` are associative;
490 if not, then reducing multiple axes will return incorrect results. (In
491 particular, reducing multiple axes is currently implemented by reducing the
492 axes one at a time.)
494 Args:
495 reduce_op: The tensorflow `op` that should be used to reduce values in
496 uniform dimensions. Must have the same signature and basic behavior as
497 `reduce_sum`, `reduce_max`, etc.
498 unsorted_segment_op: The tensorflow `op` that should be used to combine
499 values in ragged dimensions. Must have the same signature and basic
500 behavior as `unsorted_segment_sum`, `unsorted_segment_max`, etc.
501 rt_input: A `Tensor` or `RaggedTensor` containing the values to be reduced.
502 axis: The axis or axes to reduce. May be `None` (to reduce all axes), an
503 `int` (to reduce a single axis), a `list` or `tuple` of `int` (to reduce a
504 given set of axes), or a `Tensor` with a constant value. Must be in the
505 range `[0, rt_input.rank)`.
506 keepdims: If true, retains reduced dimensions with length 1.
507 separator: An optional string. Defaults to None. The separator to use when
508 joining. The separator must not be set for non-string data types. (i.e. if
509 separator is not None then it uses string ops)
510 name: A name prefix for the returned tensor (optional).
512 Returns:
513 A `RaggedTensor` containing the reduced values. The returned tensor
514 has the same dtype as `data`, and its shape is given by removing the
515 dimensions specified in `axis` from `rt_input.shape`. The `ragged_rank`
516 of the returned tensor is given by substracting any ragged dimensions
517 specified in `axis` from `rt_input.ragged_rank`.
518 Raises:
519 ValueError: If `axis` contains a `Tensor` whose value is not constant.
520 """
521 # When separator is not None, We infer that dtype is string and
522 # reduce_join will be called.
523 if separator is None:
524 maybe_separator = {}
525 else:
526 maybe_separator = {'separator': separator}
528 if not ragged_tensor.is_ragged(rt_input):
529 return reduce_op(
530 rt_input, axis, keepdims=keepdims, name=name, **maybe_separator)
532 if isinstance(axis, ops.Tensor):
533 axis = tensor_util.constant_value(axis)
534 if axis is None:
535 raise ValueError('axis must be known at graph construction time.')
536 if isinstance(axis, np.ndarray):
537 axis = axis.tolist()
539 # When reducing all axes, just ignore splits & reduce the inner values.
540 if axis is None:
541 result = reduce_op(rt_input.flat_values, None, keepdims=keepdims,
542 name=name, **maybe_separator)
543 if keepdims:
544 # Expand the result to the input number of dimensions.
545 for _ in rt_input.shape[1:]:
546 result = array_ops.expand_dims(result, axis=0)
547 return result
549 with ops.name_scope(name, 'RaggedReduce', [rt_input, axis]):
550 if isinstance(axis, (tuple, list)):
551 if not axis:
552 return rt_input
553 elif len(axis) == 1:
554 axis = axis[0]
555 else:
556 # When reducing multiple axes, as we reduce one at a time (see below),
557 # the negative axis has to be converted to positive at the first run
558 # as the sort with negative axis will have different orders.
559 # See GitHub issue 27497.
560 axis = [
561 array_ops.get_positive_axis(a, rt_input.shape.ndims, 'axis[%s]' % i,
562 'rank(input_tensor)')
563 for i, a in enumerate(axis)
564 ]
565 # When reducing multiple axes, just reduce one at a time. This is less
566 # efficient, and only works for associative ops. (In particular, it
567 # does not work for reduce_mean.) However, reducing multiple axes at
568 # once will probably require a nontrivial c++ op.
569 axis = sorted(axis)
570 inner_reduced = ragged_reduce_aggregate(reduce_op, unsorted_segment_op,
571 rt_input, axis[-1], keepdims,
572 separator)
573 return ragged_reduce_aggregate(reduce_op, unsorted_segment_op,
574 inner_reduced, axis[:-1], keepdims,
575 separator)
577 rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(
578 rt_input, name='rt_input')
580 axis = array_ops.get_positive_axis(
581 axis, rt_input.shape.ndims, ndims_name='rank(input_tensor)')
583 if axis == 0:
584 # out[i_1, i_2, ..., i_N] = sum_{j} rt_input[j, i_1, i_2, ..., i_N]
585 row_lengths = rt_input.row_splits[1:] - rt_input.row_splits[:-1]
586 num_segments = math_ops.maximum(math_ops.reduce_max(row_lengths), 0)
587 segment_ids = range(row_lengths).values
588 result = _ragged_segment_aggregate(unsorted_segment_op, rt_input.values,
589 segment_ids, num_segments, separator)
590 if keepdims:
591 result = array_ops.expand_dims(result, axis=0)
592 return result
593 elif axis == 1:
594 # out[i_0, i_1, i_2, ..., i_N] = sum_{j} rt_input[i_0, j, i_2, ..., i_N]
595 num_segments = array_ops.shape(rt_input.row_splits)[0] - 1
596 segment_ids = segment_id_ops.row_splits_to_segment_ids(
597 rt_input.row_splits)
598 result = _ragged_segment_aggregate(unsorted_segment_op, rt_input.values,
599 segment_ids, num_segments, separator)
600 if keepdims:
601 result = array_ops.expand_dims(result, axis=1)
602 return result
603 else:
604 # out[i_0, ..., i_[axis-1], i_axis+1], ..., i_N] =
605 # sum_{j} rt_input [i_0, ..., i_[axis-1], j, i_axis+1], ..., i_N]
606 return rt_input.with_values(
607 ragged_reduce_aggregate(reduce_op, unsorted_segment_op,
608 rt_input.values, axis - 1, keepdims,
609 separator))
612@dispatch.dispatch_for_api(math_ops.reduce_sum)
613def reduce_sum(input_tensor: ragged_tensor.Ragged,
614 axis=None,
615 keepdims=None,
616 name=None):
617 """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
619 return ragged_reduce_aggregate(
620 reduce_op=math_ops.reduce_sum,
621 unsorted_segment_op=math_ops.unsorted_segment_sum,
622 rt_input=input_tensor,
623 axis=axis,
624 keepdims=keepdims,
625 name=(name or 'RaggedReduceSum'))
628@dispatch.dispatch_for_api(math_ops.reduce_prod)
629def reduce_prod(input_tensor: ragged_tensor.Ragged,
630 axis=None,
631 keepdims=None,
632 name=None):
633 """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
634 return ragged_reduce_aggregate(
635 reduce_op=math_ops.reduce_prod,
636 unsorted_segment_op=math_ops.unsorted_segment_prod,
637 rt_input=input_tensor,
638 axis=axis,
639 keepdims=keepdims,
640 name=(name or 'RaggedReduceProd'))
643@dispatch.dispatch_for_api(math_ops.reduce_min)
644def reduce_min(input_tensor: ragged_tensor.Ragged,
645 axis=None,
646 keepdims=None,
647 name=None):
648 """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
649 return ragged_reduce_aggregate(
650 reduce_op=math_ops.reduce_min,
651 unsorted_segment_op=math_ops.unsorted_segment_min,
652 rt_input=input_tensor,
653 axis=axis,
654 keepdims=keepdims,
655 name=(name or 'RaggedReduceMin'))
658@dispatch.dispatch_for_api(math_ops.reduce_max)
659def reduce_max(input_tensor: ragged_tensor.Ragged,
660 axis=None,
661 keepdims=None,
662 name=None):
663 """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
664 return ragged_reduce_aggregate(
665 reduce_op=math_ops.reduce_max,
666 unsorted_segment_op=math_ops.unsorted_segment_max,
667 rt_input=input_tensor,
668 axis=axis,
669 keepdims=keepdims,
670 name=(name or 'RaggedReduceMax'))
673@dispatch.dispatch_for_api(math_ops.reduce_mean)
674def reduce_mean(input_tensor: ragged_tensor.Ragged,
675 axis=None,
676 keepdims=None,
677 name=None):
678 """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
679 with ops.name_scope(name, 'RaggedReduceMean', [input_tensor, axis]):
680 total = reduce_sum(input_tensor, axis, keepdims)
681 if ragged_tensor.is_ragged(input_tensor):
682 ones = ragged_tensor.RaggedTensor.from_nested_row_splits(
683 array_ops.ones_like(input_tensor.flat_values),
684 input_tensor.nested_row_splits,
685 validate=False)
686 else:
687 ones = array_ops.ones_like(input_tensor)
688 count = reduce_sum(ones, axis, keepdims)
689 if ragged_tensor.is_ragged(total):
690 return ragged_tensor.RaggedTensor.from_nested_row_splits(
691 total.flat_values / count.flat_values,
692 total.nested_row_splits,
693 validate=False)
694 else:
695 return total / count
698@dispatch.dispatch_for_api(math_ops.reduce_variance)
699def reduce_variance(input_tensor: ragged_tensor.Ragged,
700 axis=None,
701 keepdims=False,
702 name=None):
703 """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
704 with ops.name_scope(name, 'RaggedReduceVariance', [input_tensor, axis]):
705 input_tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor(
706 input_tensor, name='input_tensor')
707 if input_tensor.dtype.is_complex:
708 raise ValueError(
709 'reduce_variance is not supported for RaggedTensors with complex dtypes.'
710 )
711 square_of_input = math_ops.square(input_tensor)
712 mean_of_square = reduce_mean(square_of_input, axis=axis, keepdims=keepdims)
713 mean = reduce_mean(input_tensor, axis=axis, keepdims=keepdims)
714 square_of_mean = math_ops.square(mean)
715 # Note: the above method of computing variance is not numerically stable,
716 # and can result in negative variances. Here we clip to >= 0.
717 return math_ops.maximum(mean_of_square - square_of_mean, 0)
720@dispatch.dispatch_for_api(math_ops.reduce_std)
721def reduce_std(input_tensor: ragged_tensor.Ragged,
722 axis=None,
723 keepdims=False,
724 name=None):
725 """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
726 with ops.name_scope(name, 'RaggedReduceStd', [input_tensor, axis]):
727 variance = reduce_variance(input_tensor, axis=axis, keepdims=keepdims)
728 return math_ops.sqrt(variance)
731def _cast(input_tensor, dtype):
732 return ragged_functional_ops.map_flat_values(math_ops.cast, input_tensor,
733 dtype)
736@dispatch.dispatch_for_api(math_ops.reduce_all)
737def reduce_all(input_tensor: ragged_tensor.Ragged,
738 axis=None,
739 keepdims=None,
740 name=None):
741 """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
742 with ops.name_scope(name, 'RaggedReduceAll', [input_tensor, axis]):
743 return _cast(
744 reduce_prod(_cast(input_tensor, dtypes.int32), axis, keepdims),
745 dtypes.bool)
748@dispatch.dispatch_for_api(math_ops.reduce_any)
749def reduce_any(input_tensor: ragged_tensor.Ragged,
750 axis=None,
751 keepdims=None,
752 name=None):
753 """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
754 with ops.name_scope(name, 'RaggedReduceAny', [input_tensor, axis]):
755 return _cast(
756 reduce_sum(_cast(input_tensor, dtypes.int32), axis, keepdims),
757 dtypes.bool)
760def _set_ragged_reduce_docstring(func, combination, combined, default, example):
761 func.__doc__ = _RAGGED_REDUCE_DOCSTRING % dict(
762 combination=combination,
763 combined=combined,
764 default=default,
765 example=example)
768_set_ragged_reduce_docstring(reduce_sum, 'sum', 'summed', '0',
769 _RAGGED_REDUCE_SUM_EXAMPLE)
770_set_ragged_reduce_docstring(reduce_prod, 'product', 'multiplied', '1',
771 _RAGGED_REDUCE_PROD_EXAMPLE)
772_set_ragged_reduce_docstring(reduce_min, 'minimum', 'minimized',
773 '`input_tensor.dtype.min`',
774 _RAGGED_REDUCE_MIN_EXAMPLE)
775_set_ragged_reduce_docstring(reduce_max, 'maximum', 'maximized',
776 '`input_tensor.dtype.max`',
777 _RAGGED_REDUCE_MAX_EXAMPLE)
778_set_ragged_reduce_docstring(reduce_mean, 'mean', 'averaged', 'NaN',
779 _RAGGED_REDUCE_MEAN_EXAMPLE)
780_set_ragged_reduce_docstring(reduce_variance, 'variance', 'averaged', 'NaN',
781 _RAGGED_REDUCE_VARIANCE_EXAMPLE)
782_set_ragged_reduce_docstring(reduce_std, 'std', 'averaged', 'NaN',
783 _RAGGED_REDUCE_STD_EXAMPLE)
784_set_ragged_reduce_docstring(reduce_all, 'logical and', 'and-ed', 'True',
785 _RAGGED_REDUCE_ALL_EXAMPLE)
786_set_ragged_reduce_docstring(reduce_any, 'logical or', 'or-ed', 'False',
787 _RAGGED_REDUCE_ANY_EXAMPLE)
790#===============================================================================
791# ragged.matmul
792#===============================================================================
793@dispatch.dispatch_for_api(math_ops.matmul)
794def matmul(a: ragged_tensor.RaggedOrDense,
795 b: ragged_tensor.RaggedOrDense,
796 transpose_a=False,
797 transpose_b=False,
798 adjoint_a=False,
799 adjoint_b=False,
800 a_is_sparse=False,
801 b_is_sparse=False,
802 output_type=None,
803 name=None):
804 """Multiplies matrix `a` by matrix `b`.
806 If all transpose or adjoint attributes are `False` then:
808 ```
809 output[..., i, j] = sum_k (a[..., i, k] * b[..., k, j]), for all indices i, j.
810 ```
812 The inputs `a` and `b` must have `rank >= 2`, where the outermost `rank - 2`
813 dimensions are batch dimensions. The inputs must have the same dtype. See
814 `tf.matmul` for more information.
816 Args:
817 a: `tf.Tensor` or `RaggedTensor` with `rank > 1`.
818 b: `tf.Tensor` or `RaggedTensor` with same type and rank as `a`.
819 transpose_a: If `True`, `a` is transposed before multiplication.
820 transpose_b: If `True`, `b` is transposed before multiplication.
821 adjoint_a: If `True`, `a` is conjugated & transposed before multiplication.
822 adjoint_b: If `True`, `b` is conjugated & transposed before multiplication.
823 a_is_sparse: If `True`, optimize assuming `a` is mostly zero.
824 b_is_sparse: If `True`, optimize assuming `b` is mostly zero.
825 output_type: The output datatype (optional).
826 name: Name for the operation (optional).
828 Returns:
829 A `Tensor` or `RaggedTensor` with the same rank and shape as `a`, where
830 each inner-most matrix is the product of the corresponding matrices in `a`
831 and `b`.
832 """
833 if transpose_a and adjoint_a:
834 raise ValueError('Only one of transpose_a and adjoint_a can be True.')
835 if transpose_b and adjoint_b:
836 raise ValueError('Only one of transpose_b and adjoint_b can be True.')
838 kwargs = dict(
839 transpose_a=transpose_a,
840 transpose_b=transpose_b,
841 adjoint_a=adjoint_a,
842 adjoint_b=adjoint_b,
843 a_is_sparse=a_is_sparse,
844 b_is_sparse=b_is_sparse,
845 output_type=output_type)
847 with ops.name_scope(name, 'RaggedMatMul', [a, b]) as name:
848 a = ragged_tensor.convert_to_tensor_or_ragged_tensor(a, name='a')
849 b = ragged_tensor.convert_to_tensor_or_ragged_tensor(b, name='b')
851 a_is_ragged = isinstance(a, ragged_tensor.RaggedTensor)
852 b_is_ragged = isinstance(b, ragged_tensor.RaggedTensor)
853 if not (a_is_ragged or b_is_ragged):
854 return math_ops.matmul(a, b, **kwargs)
856 if a.dtype != b.dtype:
857 raise ValueError('`a` and `b` must have the same dtype.')
859 # TODO(edloper): Support broadcasting inputs. (Broadcast support is not
860 # documented by https://www.tensorflow.org/api_docs/python/tf/linalg/matmul,
861 # but it is supported by the op.)
863 # Find the rank of the input tensors.
864 if a.shape.rank is None:
865 if b.shape.rank is None:
866 raise ValueError('matmul requires at least one input to have known '
867 'rank if either input is ragged.')
868 rank = b.shape.rank
869 else:
870 if b.shape.rank is not None and a.shape.rank != b.shape.rank:
871 raise ValueError('`a` and `b` must have the same rank.')
872 rank = a.shape.rank
874 # At least one of `a` and `b` is ragged; and ragged tensors always have
875 # rank>=2.
876 if rank < 2:
877 # This can happen if e.g. `a` is a 1D dense tensor and `b` is a
878 # ragged tensor with unknown rank. Since ragged tensors always have
879 # `rank>=2`, this implies that `a` and `b` have different ranks.
880 raise ValueError('`a` and `b` must have the same rank.')
882 # Rank>3: We have multiple batch dimensions. Merge them into a single
883 # batch dimension, recursively call `matmul`, and then restore the original
884 # batch dimension (using a.row_splits).
885 if rank > 3:
886 shape_err = 'Batch dimensions of `a` and `b` do not have the same size.'
887 if not a_is_ragged:
888 a = ragged_tensor.RaggedTensor.from_tensor(a, ragged_rank=1)
889 if not b_is_ragged:
890 b = ragged_tensor.RaggedTensor.from_tensor(b, ragged_rank=1)
891 with ops.control_dependencies([
892 check_ops.assert_equal(a.row_splits, b.row_splits, message=shape_err)
893 ]):
894 flat_result = matmul(a.values, b.values, **kwargs)
895 return a.with_values(flat_result)
897 if rank == 2:
898 return _matmul_2d(a, b, **kwargs)
900 assert rank == 3 # I.e., we have a single batch dimension.
902 a_ragged_rank = a.ragged_rank if a_is_ragged else 0
903 if a_ragged_rank == 1 and not (b_is_ragged or transpose_a or adjoint_a):
904 # If `a.shape=[B, (I), J]` and `b.shape=[B, J, K], then we can compute
905 # the result with a single dense `matmul`.
906 return _matmul_3d_with_batch_dim_folding(a, b, **kwargs)
907 else:
908 # Otherwie, fall back on using `map_fn`.
909 return _matmul_3d_with_map_fn(a, b, **kwargs)
912def _matmul_2d(a, b, **kwargs):
913 """Multiplies potentially ragged 2D tensors.
915 Args:
916 a: A 2D Tensor or RaggedTensor with `shape=[I, J]`
917 b: A 2D Tensor or RaggedTensor with `shape=[J, K]`
918 **kwargs: Additional arguments for `tf.matmul` (e.g. transpose_a).
920 Returns:
921 A 2D Tensor with `shape=[I, K]`.
922 """
923 # multiplying `a` and `b` is only well-defined if `a` and `b` are
924 # actually uniform (and just happened to be stored as ragged tensors).
925 # Check that they're uniform, convert them to tf.Tensor.
926 ragged_err = ('The matrices in `a` and `b` may not be '
927 'ragged in their innermost dimension.')
928 checks = []
929 if isinstance(a, ragged_tensor.RaggedTensor):
930 original_size = array_ops.size(a.flat_values)
931 a = a.to_tensor()
932 checks.append(
933 check_ops.assert_equal(
934 original_size, array_ops.size(a), message=ragged_err))
935 if isinstance(b, ragged_tensor.RaggedTensor):
936 original_size = array_ops.size(b.flat_values)
937 b = b.to_tensor()
938 checks.append(
939 check_ops.assert_equal(
940 original_size, array_ops.size(b), message=ragged_err))
941 with ops.control_dependencies(checks):
942 return math_ops.matmul(a, b, **kwargs)
945def _matmul_3d_with_map_fn(a, b, **kwargs):
946 """Multiplies batches of 2D matrices using map_fn.
948 `output[n, i, k]` = sum_j (a[n, i, j] * b[n, j, k])` (for all `n`, `i`, `k`).
950 Requires that `a[n, i].nrows()` == `b[n].nrows()` (for all `n` and `i`).
952 Args:
953 a: A 3D Tensor or RaggedTensor with `shape=[B, I, J]`, where dimensions `I`
954 and `J` may be ragged.
955 b: A 3D Tensor or RaggedTensor with `shape=[B, J, K]`, where dimensions `J`
956 and `K` may be ragged.
957 **kwargs: Additional arguments for `tf.matmul` (e.g. transpose_a).
959 Returns:
960 A 3D RaggedTensor with `shape=[B, (I), (K)]`.
961 """
962 # Determine the ragged rank of the result. In the normal case, we have:
963 # [B, I, J] * [B, J, K] -> [B, I, K]
964 # Or if we're using transpose_b, then we have:
965 # [B, I, J] * [B, K, J] -> [B, I, K]
966 # In either case, output_ragged_rank=2 iff the K dimension is ragged.
967 if (isinstance(b, ragged_tensor.RaggedTensor) and
968 (b.ragged_rank == 2 or kwargs.get('transpose_b') or
969 kwargs.get('adjoint_b'))):
970 output_ragged_rank = 2
971 else:
972 output_ragged_rank = 1
974 def single_batch_matmul(x):
975 out = _matmul_2d(x[0], x[1], **kwargs)
976 if output_ragged_rank == 2:
977 out = ragged_tensor.RaggedTensor.from_tensor(out)
978 return out
980 fn_out_shape = None # Figure out proper shape.
981 row_splits_dtype = (
982 a.row_splits.dtype
983 if isinstance(a, ragged_tensor.RaggedTensor) else b.row_splits.dtype)
984 output_type = kwargs['output_type']
985 if output_type is None:
986 output_type = a.dtype
987 spec = ragged_tensor.RaggedTensorSpec(
988 shape=fn_out_shape,
989 dtype=output_type,
990 ragged_rank=output_ragged_rank - 1,
991 row_splits_dtype=row_splits_dtype)
992 result = map_fn.map_fn(
993 single_batch_matmul, elems=(a, b), fn_output_signature=spec)
995 # map_fn loses shape information; restore it, where possible.
996 # pylint: disable=protected-access
997 if kwargs.get('transpose_a') or kwargs.get('adjoint_a'):
998 result._set_shape(a.shape[:-2] + a.shape[-1:] + [None])
999 else:
1000 result._set_shape(a.shape[:-2] + a.shape[-2:-1] + [None])
1001 if kwargs.get('transpose_b') or kwargs.get('adjoint_b'):
1002 result._set_shape(b.shape[:-2] + [None] + b.shape[-2:-1])
1003 else:
1004 result._set_shape(b.shape[:-2] + [None] + b.shape[-1:])
1006 return result
1009def _matmul_3d_with_batch_dim_folding(a, b, **kwargs):
1010 """Multiply batches of 2D matrices where only `a.shape[1]` is ragged.
1012 Args:
1013 a: A RaggedTensor with `shape=[B, (I), J]`. (ragged_rank must be 1.)
1014 b: A Tensor with `shape=[B, J, K]`
1015 **kwargs: Additional arguments for `tf.matmul` (e.g. transpose_a).
1016 transpose_a and adjoint_a must not be true.
1018 Returns:
1019 A RaggedTensor with `shape=[B, (I), K].
1020 """
1021 # reshaped_a.shape = [sum(i_1, i_2, ..., i_B), 1, J]
1022 reshaped_a = array_ops.expand_dims(a.values, 1)
1023 # reshaped_b.shape = [sum(i_1, i_2, ..., i_B), J, K]
1024 reshaped_b = array_ops.repeat(b, a.row_lengths(), axis=0)
1025 # flat_result.shape = [sum(i_1, i_2, ..., i_B), 1, K]
1026 flat_result = math_ops.matmul(reshaped_a, reshaped_b, **kwargs)
1027 # result.shape = [B, (I), K]
1028 return a.with_values(array_ops.squeeze(flat_result, axis=1))
1031#===============================================================================
1032# ragged.softmax
1033#===============================================================================
1034@dispatch.dispatch_for_api(nn_ops.softmax_v2)
1035def softmax(logits: ragged_tensor.Ragged, axis=None, name=None):
1036 """Computes softmax activations.
1038 Used for multi-class predictions. The sum of all outputs generated by softmax
1039 is 1.
1041 This function performs the equivalent of
1043 softmax = tf.exp(logits) / tf.reduce_sum(tf.exp(logits), axis)
1045 Example usage:
1047 >>> softmax = tf.nn.softmax([-1, 0., 1.])
1048 >>> softmax
1049 <tf.Tensor: shape=(3,), dtype=float32,
1050 numpy=array([0.09003057, 0.24472848, 0.66524094], dtype=float32)>
1051 >>> sum(softmax)
1052 <tf.Tensor: shape=(), dtype=float32, numpy=1.0>
1054 Args:
1055 logits: A non-empty `Tensor`. Must be one of the following types: `half`,
1056 `float32`, `float64`.
1057 axis: The dimension softmax would be performed on. The default is -1 which
1058 indicates the last dimension.
1059 name: A name for the operation (optional).
1061 Returns:
1062 A `Tensor`. Has the same type and shape as `logits`.
1064 Raises:
1065 InvalidArgumentError: if `logits` is empty or `axis` is beyond the last
1066 dimension of `logits`.
1067 """
1068 if axis is None:
1069 axis = -1
1071 with ops.name_scope(name, 'RaggedSoftmax', [logits]) as name:
1072 max_input = reduce_max(logits, axis=axis, keepdims=True)
1073 logits_exp = math_ops.exp(math_ops.subtract(logits, max_input))
1074 denominator = reduce_sum(logits_exp, axis=axis, keepdims=True)
1075 return math_ops.divide(logits_exp, denominator)
1078#===============================================================================
1079# ragged.add_n
1080#===============================================================================
1081@dispatch.dispatch_for_api(math_ops.add_n)
1082def add_n(inputs: typing.List[ragged_tensor.RaggedOrDense], name=None):
1083 """RaggedTensor implementation for tf.math.add_n."""
1084 if len(inputs) < 0:
1085 raise ValueError('tf.add_n: expected at least one input.')
1086 with ops.name_scope(name, 'RaggedAddN', inputs):
1087 return ragged_functional_ops.map_flat_values(math_ops.add_n, inputs)
1090#===============================================================================
1091# Ragged version of nn_ops.dropout
1092#===============================================================================
1093@dispatch.dispatch_for_api(nn_ops.dropout)
1094def dropout_v1(x: ragged_tensor.Ragged,
1095 keep_prob=None,
1096 noise_shape=None,
1097 seed=None,
1098 name=None,
1099 rate=None):
1100 """Ragged dispatch target for tf.nn.dropout."""
1101 if noise_shape is not None:
1102 raise ValueError('noise_shape is not supported yet for RaggedTensor x')
1103 with ops.name_scope(name, 'RaggedNNDropout', [x, rate]):
1104 x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name='x')
1105 return x.with_flat_values(
1106 nn_ops.dropout(
1107 x.flat_values, keep_prob=keep_prob, seed=seed, rate=rate))
1110@dispatch.dispatch_for_api(nn_ops.dropout_v2)
1111def dropout_v2(x: ragged_tensor.Ragged,
1112 rate,
1113 noise_shape=None,
1114 seed=None,
1115 name=None):
1116 """Ragged dispatch target for tf.nn.dropout."""
1117 if noise_shape is not None:
1118 raise ValueError('noise_shape is not supported yet for RaggedTensor x')
1119 with ops.name_scope(name, 'RaggedNNDropout', [x, rate]):
1120 x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name='x')
1121 return x.with_flat_values(
1122 nn_ops.dropout_v2(x.flat_values, rate=rate, seed=seed))
1125@dispatch.dispatch_for_api(nn_ops.stateless_dropout)
1126def stateless_dropout(x: ragged_tensor.Ragged,
1127 rate,
1128 seed,
1129 rng_alg=None,
1130 noise_shape=None,
1131 name=None):
1132 """Ragged dispatch target for tf.nn.experimental.stateless_dropout."""
1133 if noise_shape is not None:
1134 raise ValueError('noise_shape is not supported yet for RaggedTensor x')
1135 with ops.name_scope(name, 'RaggedNNStatelessDropout', [x, rate]):
1136 x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name='x')
1137 return x.with_flat_values(
1138 nn_ops.stateless_dropout(
1139 x.flat_values, rate=rate, seed=seed, rng_alg=rng_alg))
1142#===============================================================================
1143# Ragged version of Tensor.__eq__ and Tensor.__ne__
1144#===============================================================================
1145@dispatch.dispatch_for_api(math_ops.tensor_equals)
1146def tensor_equals(self: ragged_tensor.RaggedOrDense,
1147 other: ragged_tensor.RaggedOrDense):
1148 """Ragged version of the operation invoked by `Tensor.__eq__`."""
1149 if other is None:
1150 return False
1151 elif _use_legacy_mode_for_tensor_equality(self):
1152 return self is other
1153 else:
1154 try:
1155 return math_ops.equal(self, other)
1156 except (errors.InvalidArgumentError, ValueError):
1157 return False # values are not broadcast-compatbile.
1160@dispatch.dispatch_for_api(math_ops.tensor_not_equals)
1161def tensor_not_equals(self: ragged_tensor.RaggedOrDense,
1162 other: ragged_tensor.RaggedOrDense):
1163 """Ragged version of the operation invoked by `Tensor.__ne__`."""
1164 if other is None:
1165 return False
1166 elif _use_legacy_mode_for_tensor_equality(self):
1167 return self is not other
1168 else:
1169 try:
1170 return math_ops.not_equal(self, other)
1171 except (errors.InvalidArgumentError, ValueError):
1172 return True # values are not broadcast-compatbile.
1175def _use_legacy_mode_for_tensor_equality(self):
1176 g = getattr(self, 'graph', None)
1177 return not (ops.Tensor._USE_EQUALITY and # pylint: disable=protected-access
1178 ops.executing_eagerly_outside_functions() and
1179 (g is None or g.building_function))
1182def _cumsum_flat_values_at_ragged_rank(last_rp, flat_values, exclusive=False,
1183 reverse=False):
1184 """Calculate flat_values for math_ops.cumsum when axis==ragged_rank."""
1185 if not exclusive:
1186 partial = _cumsum_flat_values_at_ragged_rank(
1187 last_rp, flat_values, exclusive=True, reverse=reverse)
1188 return partial + flat_values
1190 if reverse:
1191 youngest_sibling = array_ops.gather(
1192 params=last_rp.row_splits(), indices=last_rp.value_rowids() + 1) - 1
1193 new_flat_values = math_ops.cumsum(flat_values, exclusive=True, reverse=True)
1194 initial_values = array_ops.gather(params=new_flat_values,
1195 indices=youngest_sibling)
1197 return new_flat_values - initial_values
1198 else:
1199 eldest_sibling = array_ops.gather(
1200 params=last_rp.row_splits(), indices=last_rp.value_rowids())
1201 new_flat_values = math_ops.cumsum(flat_values, exclusive=True)
1202 initial_values = array_ops.gather(params=new_flat_values,
1203 indices=eldest_sibling)
1204 return new_flat_values - initial_values
1207@dispatch.dispatch_for_api(math_ops.cumsum)
1208def ragged_cumsum(x: ragged_tensor.Ragged,
1209 axis: int = 0,
1210 exclusive: bool = False,
1211 reverse: bool = False,
1212 name: typing.Optional[str] = None):
1213 """Calculate math_ops.cumsum for a RaggedTensor.
1215 Given a ragged tensor `x`, the `result` is a ragged tensor with the same
1216 shape. One can calculate the value of `result[i_1...i_k]` as follows:
1217 ```
1218 dense_result=tf.math.cumsum(rt.to_tensor(), axis=axis, exclusive=exclusive,
1219 reverse=reverse)
1220 result[i_1...i_k]=dense_result[i_1...i_k]
1221 ```
1223 Args:
1224 x: the original ragged tensor to sum.
1225 axis: the axis along which to sum, can range -rank<=axis<rank.
1226 exclusive: is the sum exclusive or inclusive? If True, then result[0]=0.
1227 If False, then result[0]=x[0].
1228 reverse: If True, sum from back to front.
1229 name: the name of the op.
1230 Returns:
1231 the cumulative sum.
1232 """
1233 with ops.name_scope(name, 'RaggedCumSum', [x, axis, exclusive, reverse]):
1234 axis = array_ops.get_positive_axis(axis, x.shape.rank, ndims_name='rank')
1235 if axis == x.ragged_rank:
1236 last_rp = x._nested_row_partitions[-1] # pylint: disable=protected-access
1237 return x.with_flat_values(
1238 _cumsum_flat_values_at_ragged_rank(last_rp, x.flat_values,
1239 exclusive=exclusive,
1240 reverse=reverse))
1241 elif axis > x.ragged_rank:
1242 new_axis = axis - x.ragged_rank
1243 cumsum_bound = functools.partial(
1244 math_ops.cumsum, axis=new_axis, exclusive=exclusive, reverse=reverse)
1245 return ragged_functional_ops.map_flat_values(cumsum_bound, x)
1246 else:
1247 dense_version = x.to_tensor()
1248 result = math_ops.cumsum(
1249 dense_version, axis, exclusive=exclusive, reverse=reverse, name=name)
1250 return ragged_tensor.RaggedTensor.from_tensor(
1251 result, lengths=x.nested_row_lengths())