Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ragged/ragged_dispatch.py: 62%

74 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Operator dispatch for RaggedTensors.""" 

16 

17from tensorflow.python.ops import logging_ops 

18from tensorflow.python.ops import math_ops 

19from tensorflow.python.ops import string_ops 

20from tensorflow.python.ops.ragged import ragged_tensor 

21from tensorflow.python.ops.ragged import ragged_tensor_shape 

22from tensorflow.python.util import dispatch 

23from tensorflow.python.util import tf_decorator 

24from tensorflow.python.util import tf_export 

25from tensorflow.python.util import tf_inspect 

26 

27 

28@dispatch.dispatch_for_unary_elementwise_apis(ragged_tensor.Ragged) 

29def ragged_unary_elementwise_op(op, x): 

30 """Unary elementwise api handler for RaggedTensors.""" 

31 x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x) 

32 return x.with_values(op(x.values)) 

33 

34 

35# TODO(martinz): This is deprecated. Delete. 

36def ragged_binary_elementwise_op(op, x, y): 

37 """Binary elementwise api handler for RaggedTensors.""" 

38 x_is_ragged = ragged_tensor.is_ragged(x) 

39 y_is_ragged = ragged_tensor.is_ragged(y) 

40 

41 # Convert args to tensors. 

42 x = ragged_tensor.convert_to_tensor_or_ragged_tensor( 

43 x, preferred_dtype=(y.dtype if y_is_ragged else None)) 

44 y = ragged_tensor.convert_to_tensor_or_ragged_tensor( 

45 y, preferred_dtype=x.dtype) 

46 

47 if x_is_ragged and y_is_ragged: 

48 x, y = ragged_tensor.match_row_splits_dtypes(x, y) 

49 

50 # Perform broadcasting, when appropraite 

51 if ((x_is_ragged and y_is_ragged) or 

52 (x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims) or 

53 (y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)): 

54 # If both x and y are ragged, they must have the same row_splits_dtype now. 

55 if x_is_ragged: 

56 dim_size_dtype = x.row_splits.dtype 

57 else: 

58 dim_size_dtype = y.row_splits.dtype 

59 

60 shape_x = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor( 

61 x, dim_size_dtype=dim_size_dtype) 

62 shape_y = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor( 

63 y, dim_size_dtype=dim_size_dtype) 

64 bcast_shape = ragged_tensor_shape.broadcast_dynamic_shape(shape_x, shape_y) 

65 x = ragged_tensor_shape.broadcast_to( 

66 x, bcast_shape, broadcast_inner_dimensions=False) 

67 y = ragged_tensor_shape.broadcast_to( 

68 y, bcast_shape, broadcast_inner_dimensions=False) 

69 

70 x_values = x.flat_values if ragged_tensor.is_ragged(x) else x 

71 y_values = y.flat_values if ragged_tensor.is_ragged(y) else y 

72 mapped_values = op(x_values, y_values) 

73 if isinstance(mapped_values, bool): 

74 return mapped_values # Special case for tensor_equals. 

75 if ragged_tensor.is_ragged(x): 

76 return x.with_flat_values(mapped_values) 

77 else: 

78 return y.with_flat_values(mapped_values) 

79 

80 

81# TODO(edloper): Update the documentation generation tools to automatically 

82# build lists of which types are supported by which ops (and then delete all 

83# the following code). 

84 

85 

86# We don't need to register a separate delegation handler for these v1 ops, 

87# since they delegate to the v2 ops (which already have a handler). But we 

88# still want to include them in the ragged_op_list() output. 

89_V2_OPS_THAT_ARE_DELEGATED_TO_FROM_V1_OPS = [ 

90 math_ops.reduce_sum, 

91 math_ops.reduce_prod, 

92 math_ops.reduce_min, 

93 math_ops.reduce_max, 

94 math_ops.reduce_mean, 

95 math_ops.reduce_variance, 

96 math_ops.reduce_std, 

97 math_ops.reduce_any, 

98 math_ops.reduce_all, 

99 string_ops.string_to_number, 

100 string_ops.string_to_hash_bucket, 

101 string_ops.reduce_join_v2, 

102] 

103 

104 

105def _ragged_op_signature(op, ragged_args, ragged_varargs=False): 

106 """Returns a signature for the given op, marking ragged args in bold.""" 

107 op_name = tf_export.get_canonical_name_for_symbol(op) 

108 argspec = tf_inspect.getfullargspec(op) 

109 arg_names = argspec.args 

110 

111 # Mark ragged arguments in bold. 

112 for pos in ragged_args: 

113 arg_names[pos] = '**' + arg_names[pos] + '**' 

114 

115 # Add argument defaults. 

116 if argspec.defaults is not None: 

117 for pos in range(-1, -len(argspec.defaults) - 1, -1): 

118 arg_names[pos] += '=`{!r}`'.format(argspec.defaults[pos]) 

119 

120 # Add varargs and keyword args 

121 if argspec.varargs: 

122 if ragged_varargs: 

123 arg_names.append('***' + argspec.varargs + '**') 

124 else: 

125 arg_names.append('*' + argspec.varargs) 

126 if argspec.varkw: 

127 arg_names.append('**' + argspec.varkw) 

128 

129 return '* `tf.{}`({})'.format(op_name, ', '.join(arg_names)) 

130 

131 

132def _op_is_in_tf_version(op, version): 

133 if version == 1: 

134 return (tf_export.get_v1_names(tf_decorator.unwrap(op)[1]) or 

135 op in _V2_OPS_THAT_ARE_DELEGATED_TO_FROM_V1_OPS) 

136 elif version == 2: 

137 return tf_export.get_v2_names(tf_decorator.unwrap(op)[1]) 

138 else: 

139 raise ValueError('Expected version 1 or 2.') 

140 

141 

142def ragged_op_list(tf_version=2): 

143 """Returns a string listing operations that have dispathers registered.""" 

144 lines = [] 

145 api_signatures = dispatch.type_based_dispatch_signatures_for( 

146 ragged_tensor.RaggedTensor) 

147 for api, signatures in api_signatures.items(): 

148 arg_names = tf_inspect.getargspec(api).args 

149 ragged_args = set() 

150 for signature in signatures: 

151 for arg in signature: 

152 ragged_args.add(arg if isinstance(arg, int) else arg_names.index(arg)) 

153 if _op_is_in_tf_version(api, tf_version): 

154 lines.append(_ragged_op_signature(api, ragged_args)) 

155 

156 lines.append( 

157 _ragged_op_signature(logging_ops.print_v2, [], ragged_varargs=True)) 

158 return ('\n\n### Additional ops that support `RaggedTensor`\n\n' 

159 'Arguments that accept `RaggedTensor`s are marked in **bold**.\n\n' + 

160 '\n'.join(sorted(lines)) + 'n')