Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/autograph_ops.py: 25%

59 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2023 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Autograph specific overrides for objects covered by tensor_util.is_tf_type.""" 

16 

17from tensorflow.python.autograph.operators import py_builtins 

18from tensorflow.python.framework import dtypes 

19from tensorflow.python.framework import ops 

20from tensorflow.python.framework import tensor_util 

21from tensorflow.python.ops import array_ops 

22from tensorflow.python.ops import check_ops 

23from tensorflow.python.ops import script_ops 

24from tensorflow.python.ops import sort_ops 

25from tensorflow.python.ops.parallel_for import control_flow_ops as parallel_ops 

26 

27 

28def wrap_py_func(f, args, kwargs=None): 

29 """Helper that wraps a callable to py_func. 

30 

31 The helper passes tensor arguments through the py_func interface. Non-tensor 

32 arguments are allowed, and will be passed to f directly. Note that non-tensor 

33 arguments are captured by f will not update every time the wrapper is 

34 called (this is consistent with its argument list, which only includes 

35 the tensor arguments). In general, it's safest not to reuse this wrapper. 

36 

37 Args: 

38 f: Callable 

39 args: Positional arguments for f, as list or tuple. 

40 kwargs: Keyword arguments for f, as dict with string keys. May be None. 

41 

42 Returns: 

43 The return values of f converted to tensor. 

44 Raises: 

45 ValueError: if any of the arguments are incorrect. 

46 """ 

47 tensor_args = [] 

48 tensor_args_idx = {} 

49 

50 # Of the positional arguments, only grab the tensor ones to be passed through 

51 # the py_func. 

52 n_args = len(args) 

53 arg_is_tensor = tuple(map(tensor_util.is_tf_type, args)) 

54 for i in range(n_args): 

55 if arg_is_tensor[i]: 

56 tensor_args_idx[i] = len(tensor_args) 

57 tensor_args.append(args[i]) 

58 

59 # We essentially take the tensor kwargs, if any, and add them to the list of 

60 # positional arguments. The kwargs are then reconstructed inside the py_func. 

61 # 

62 # For example, if 

63 # 

64 # args = [Tensor(1), 'foo'] 

65 # kwargs = {'a': Tensor(2), 'b': 'bar'} 

66 # 

67 # Then 

68 # 

69 # tensor_args = (Tensor(1), Tensor(2)) 

70 # kwarg_keys = ('a', 'b') 

71 if kwargs: 

72 kwarg_keys = tuple(kwargs.keys()) 

73 kwarg_is_tensor = {k: tensor_util.is_tf_type(kwargs[k]) for k in kwarg_keys} 

74 for k in kwarg_keys: 

75 if kwarg_is_tensor[k]: 

76 tensor_args_idx[k] = len(tensor_args) 

77 tensor_args.append(kwargs[k]) 

78 else: 

79 kwarg_keys = () 

80 

81 def f_wrapper(*tensor_args): 

82 f_args = tuple( 

83 tensor_args[tensor_args_idx[i]] if arg_is_tensor[i] else a 

84 for i, a in enumerate(args) 

85 ) 

86 f_kwargs = { 

87 k: tensor_args[tensor_args_idx[k]] if kwarg_is_tensor[k] else kwargs[k] 

88 for i, k in enumerate(kwarg_keys) 

89 } 

90 f(*f_args, **f_kwargs) 

91 return 1 

92 

93 return script_ops.eager_py_func(f_wrapper, tensor_args, dtypes.int32) 

94 

95 

96def _tf_py_func_print(*objects, **kwargs): 

97 """Overload of print_ as a py_func implementation.""" 

98 override_kwargs = { 

99 k: v for k, v in kwargs.items() if v is not py_builtins.UNSPECIFIED 

100 } 

101 if 'flush' not in override_kwargs: 

102 # Defaulting to flushing the console in graph mode, which helps reduce 

103 # garbled output in IPython. 

104 override_kwargs['flush'] = True 

105 

106 def print_wrapper(*vals, **kwargs): 

107 vals = tuple(v.numpy() if tensor_util.is_tf_type(v) else v for v in vals) 

108 # TensorFlow doesn't seem to generate Unicode when passing strings to 

109 # py_func. This causes the print to add a "b'" wrapper to the output, 

110 # which is probably never what you want. 

111 vals = tuple(v.decode('utf-8') if isinstance(v, bytes) else v for v in vals) 

112 print(*vals, **kwargs) 

113 

114 return wrap_py_func(print_wrapper, objects, override_kwargs) 

115 

116 

117def _tf_sorted(iterable, key, reverse): 

118 """Overload of sorted_ for Tensor iterable.""" 

119 if reverse is py_builtins.UNSPECIFIED: 

120 direction = 'ASCENDING' 

121 else: 

122 direction = 'DESCENDING' 

123 if key is not py_builtins.UNSPECIFIED: 

124 mapped = parallel_ops.vectorized_map(key, iterable) 

125 if mapped.shape.rank is not None and mapped.shape.rank != 1: 

126 raise ValueError('sort only supports only 1D tensors') 

127 with ops.control_dependencies([ 

128 check_ops.assert_rank_v2(mapped, 1, 

129 'sort only supports only 1D tensors') 

130 ]): 

131 order = sort_ops.argsort(mapped, direction=direction) 

132 return array_ops.gather_v2(iterable, order) 

133 if iterable.shape.rank is not None and iterable.shape.rank != 1: 

134 raise ValueError('sort only supports only 1D tensors') 

135 with ops.control_dependencies([ 

136 check_ops.assert_rank_v2(iterable, 1, 

137 'sort only supports only 1D tensors') 

138 ]): 

139 return sort_ops.sort(iterable, direction=direction) 

140 

141py_builtins.print_registry.register( 

142 tensor_util.tf_type_classes, _tf_py_func_print 

143) 

144py_builtins.sorted_registry.register( 

145 tensor_util.tf_type_classes, _tf_sorted 

146)