Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ragged/ragged_util.py: 38%
32 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Private convenience functions for RaggedTensors.
17None of these methods are exposed in the main "ragged" package.
18"""
20from tensorflow.python.ops import array_ops
21from tensorflow.python.ops import check_ops
22from tensorflow.python.ops import control_flow_ops
23from tensorflow.python.ops import gen_ragged_math_ops
24from tensorflow.python.ops import math_ops
27def assert_splits_match(nested_splits_lists):
28 """Checks that the given splits lists are identical.
30 Performs static tests to ensure that the given splits lists are identical,
31 and returns a list of control dependency op tensors that check that they are
32 fully identical.
34 Args:
35 nested_splits_lists: A list of nested_splits_lists, where each split_list is
36 a list of `splits` tensors from a `RaggedTensor`, ordered from outermost
37 ragged dimension to innermost ragged dimension.
39 Returns:
40 A list of control dependency op tensors.
41 Raises:
42 ValueError: If the splits are not identical.
43 """
44 error_msg = "Inputs must have identical ragged splits"
45 for splits_list in nested_splits_lists:
46 if len(splits_list) != len(nested_splits_lists[0]):
47 raise ValueError(error_msg)
48 return [
49 check_ops.assert_equal(s1, s2, message=error_msg)
50 for splits_list in nested_splits_lists[1:]
51 for (s1, s2) in zip(nested_splits_lists[0], splits_list)
52 ]
55# Note: imported here to avoid circular dependency of array_ops.
56get_positive_axis = array_ops.get_positive_axis
57convert_to_int_tensor = array_ops.convert_to_int_tensor
58repeat = array_ops.repeat_with_axis
61def lengths_to_splits(lengths):
62 """Returns splits corresponding to the given lengths."""
63 return array_ops.concat([[0], math_ops.cumsum(lengths)], axis=-1)
66def repeat_ranges(params, splits, repeats):
67 """Repeats each range of `params` (as specified by `splits`) `repeats` times.
69 Let the `i`th range of `params` be defined as
70 `params[splits[i]:splits[i + 1]]`. Then this function returns a tensor
71 containing range 0 repeated `repeats[0]` times, followed by range 1 repeated
72 `repeats[1]`, ..., followed by the last range repeated `repeats[-1]` times.
74 Args:
75 params: The `Tensor` whose values should be repeated.
76 splits: A splits tensor indicating the ranges of `params` that should be
77 repeated. Elements should be non-negative integers.
78 repeats: The number of times each range should be repeated. Supports
79 broadcasting from a scalar value. Elements should be non-negative
80 integers.
82 Returns:
83 A `Tensor` with the same rank and type as `params`.
85 #### Example:
87 >>> print(repeat_ranges(
88 ... params=tf.constant(['a', 'b', 'c']),
89 ... splits=tf.constant([0, 2, 3]),
90 ... repeats=tf.constant(3)))
91 tf.Tensor([b'a' b'b' b'a' b'b' b'a' b'b' b'c' b'c' b'c'],
92 shape=(9,), dtype=string)
93 """
94 # Check if the input is valid
95 splits_checks = [
96 check_ops.assert_non_negative(
97 splits, message="Input argument 'splits' must be non-negative"
98 ),
99 check_ops.assert_integer(
100 splits,
101 message=(
102 "Input argument 'splits' must be integer, but got"
103 f" {splits.dtype} instead"
104 ),
105 ),
106 ]
107 repeats_checks = [
108 check_ops.assert_non_negative(
109 repeats, message="Input argument 'repeats' must be non-negative"
110 ),
111 check_ops.assert_integer(
112 repeats,
113 message=(
114 "Input argument 'repeats' must be integer, but got"
115 f" {repeats.dtype} instead"
116 ),
117 ),
118 ]
119 splits = control_flow_ops.with_dependencies(splits_checks, splits)
120 repeats = control_flow_ops.with_dependencies(repeats_checks, repeats)
122 # Divide `splits` into starts and limits, and repeat them `repeats` times.
123 if repeats.shape.ndims != 0:
124 repeated_starts = repeat(splits[:-1], repeats, axis=0)
125 repeated_limits = repeat(splits[1:], repeats, axis=0)
126 else:
127 # Optimization: we can just call repeat once, and then slice the result.
128 repeated_splits = repeat(splits, repeats, axis=0)
129 n_splits = array_ops.shape(repeated_splits, out_type=repeats.dtype)[0]
130 repeated_starts = repeated_splits[:n_splits - repeats]
131 repeated_limits = repeated_splits[repeats:]
133 # Get indices for each range from starts to limits, and use those to gather
134 # the values in the desired repetition pattern.
135 one = array_ops.ones((), repeated_starts.dtype)
136 offsets = gen_ragged_math_ops.ragged_range(
137 repeated_starts, repeated_limits, one)
138 return array_ops.gather(params, offsets.rt_dense_values)