Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ragged/ragged_dispatch.py: 62%
74 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Operator dispatch for RaggedTensors."""
17from tensorflow.python.ops import logging_ops
18from tensorflow.python.ops import math_ops
19from tensorflow.python.ops import string_ops
20from tensorflow.python.ops.ragged import ragged_tensor
21from tensorflow.python.ops.ragged import ragged_tensor_shape
22from tensorflow.python.util import dispatch
23from tensorflow.python.util import tf_decorator
24from tensorflow.python.util import tf_export
25from tensorflow.python.util import tf_inspect
28@dispatch.dispatch_for_unary_elementwise_apis(ragged_tensor.Ragged)
29def ragged_unary_elementwise_op(op, x):
30 """Unary elementwise api handler for RaggedTensors."""
31 x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x)
32 return x.with_values(op(x.values))
35# TODO(martinz): This is deprecated. Delete.
36def ragged_binary_elementwise_op(op, x, y):
37 """Binary elementwise api handler for RaggedTensors."""
38 x_is_ragged = ragged_tensor.is_ragged(x)
39 y_is_ragged = ragged_tensor.is_ragged(y)
41 # Convert args to tensors.
42 x = ragged_tensor.convert_to_tensor_or_ragged_tensor(
43 x, preferred_dtype=(y.dtype if y_is_ragged else None))
44 y = ragged_tensor.convert_to_tensor_or_ragged_tensor(
45 y, preferred_dtype=x.dtype)
47 if x_is_ragged and y_is_ragged:
48 x, y = ragged_tensor.match_row_splits_dtypes(x, y)
50 # Perform broadcasting, when appropraite
51 if ((x_is_ragged and y_is_ragged) or
52 (x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims) or
53 (y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)):
54 # If both x and y are ragged, they must have the same row_splits_dtype now.
55 if x_is_ragged:
56 dim_size_dtype = x.row_splits.dtype
57 else:
58 dim_size_dtype = y.row_splits.dtype
60 shape_x = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(
61 x, dim_size_dtype=dim_size_dtype)
62 shape_y = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(
63 y, dim_size_dtype=dim_size_dtype)
64 bcast_shape = ragged_tensor_shape.broadcast_dynamic_shape(shape_x, shape_y)
65 x = ragged_tensor_shape.broadcast_to(
66 x, bcast_shape, broadcast_inner_dimensions=False)
67 y = ragged_tensor_shape.broadcast_to(
68 y, bcast_shape, broadcast_inner_dimensions=False)
70 x_values = x.flat_values if ragged_tensor.is_ragged(x) else x
71 y_values = y.flat_values if ragged_tensor.is_ragged(y) else y
72 mapped_values = op(x_values, y_values)
73 if isinstance(mapped_values, bool):
74 return mapped_values # Special case for tensor_equals.
75 if ragged_tensor.is_ragged(x):
76 return x.with_flat_values(mapped_values)
77 else:
78 return y.with_flat_values(mapped_values)
81# TODO(edloper): Update the documentation generation tools to automatically
82# build lists of which types are supported by which ops (and then delete all
83# the following code).
86# We don't need to register a separate delegation handler for these v1 ops,
87# since they delegate to the v2 ops (which already have a handler). But we
88# still want to include them in the ragged_op_list() output.
89_V2_OPS_THAT_ARE_DELEGATED_TO_FROM_V1_OPS = [
90 math_ops.reduce_sum,
91 math_ops.reduce_prod,
92 math_ops.reduce_min,
93 math_ops.reduce_max,
94 math_ops.reduce_mean,
95 math_ops.reduce_variance,
96 math_ops.reduce_std,
97 math_ops.reduce_any,
98 math_ops.reduce_all,
99 string_ops.string_to_number,
100 string_ops.string_to_hash_bucket,
101 string_ops.reduce_join_v2,
102]
105def _ragged_op_signature(op, ragged_args, ragged_varargs=False):
106 """Returns a signature for the given op, marking ragged args in bold."""
107 op_name = tf_export.get_canonical_name_for_symbol(op)
108 argspec = tf_inspect.getfullargspec(op)
109 arg_names = argspec.args
111 # Mark ragged arguments in bold.
112 for pos in ragged_args:
113 arg_names[pos] = '**' + arg_names[pos] + '**'
115 # Add argument defaults.
116 if argspec.defaults is not None:
117 for pos in range(-1, -len(argspec.defaults) - 1, -1):
118 arg_names[pos] += '=`{!r}`'.format(argspec.defaults[pos])
120 # Add varargs and keyword args
121 if argspec.varargs:
122 if ragged_varargs:
123 arg_names.append('***' + argspec.varargs + '**')
124 else:
125 arg_names.append('*' + argspec.varargs)
126 if argspec.varkw:
127 arg_names.append('**' + argspec.varkw)
129 return '* `tf.{}`({})'.format(op_name, ', '.join(arg_names))
132def _op_is_in_tf_version(op, version):
133 if version == 1:
134 return (tf_export.get_v1_names(tf_decorator.unwrap(op)[1]) or
135 op in _V2_OPS_THAT_ARE_DELEGATED_TO_FROM_V1_OPS)
136 elif version == 2:
137 return tf_export.get_v2_names(tf_decorator.unwrap(op)[1])
138 else:
139 raise ValueError('Expected version 1 or 2.')
142def ragged_op_list(tf_version=2):
143 """Returns a string listing operations that have dispathers registered."""
144 lines = []
145 api_signatures = dispatch.type_based_dispatch_signatures_for(
146 ragged_tensor.RaggedTensor)
147 for api, signatures in api_signatures.items():
148 arg_names = tf_inspect.getargspec(api).args
149 ragged_args = set()
150 for signature in signatures:
151 for arg in signature:
152 ragged_args.add(arg if isinstance(arg, int) else arg_names.index(arg))
153 if _op_is_in_tf_version(api, tf_version):
154 lines.append(_ragged_op_signature(api, ragged_args))
156 lines.append(
157 _ragged_op_signature(logging_ops.print_v2, [], ragged_varargs=True))
158 return ('\n\n### Additional ops that support `RaggedTensor`\n\n'
159 'Arguments that accept `RaggedTensor`s are marked in **bold**.\n\n' +
160 '\n'.join(sorted(lines)) + 'n')