Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/operators/slices.py: 40%

52 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Operators specific to slicing operations.""" 

16 

17import collections 

18 

19from tensorflow.python.framework import dtypes 

20from tensorflow.python.framework import tensor_util 

21from tensorflow.python.ops import gen_array_ops 

22from tensorflow.python.ops import gen_string_ops 

23from tensorflow.python.ops import list_ops 

24from tensorflow.python.ops import tensor_array_ops 

25 

26 

27# TODO(mdan): Support extended slices. 

28 

29 

30class GetItemOpts(collections.namedtuple('GetItemOpts', ('element_dtype',))): 

31 pass 

32 

33 

34def get_item(target, i, opts): 

35 """The slice read operator (i.e. __getitem__). 

36 

37 Note: it is unspecified whether target will be mutated or not. In general, 

38 if target is mutable (like Python lists), it will be mutated. 

39 

40 Args: 

41 target: An entity that supports getitem semantics. 

42 i: Index to read from. 

43 opts: A GetItemOpts object. 

44 

45 Returns: 

46 The read element. 

47 

48 Raises: 

49 ValueError: if target is not of a supported type. 

50 """ 

51 assert isinstance(opts, GetItemOpts) 

52 

53 if isinstance(target, tensor_array_ops.TensorArray): 

54 return _tf_tensorarray_get_item(target, i) 

55 elif tensor_util.is_tf_type(target): 

56 if target.dtype == dtypes.variant: 

57 return _tf_tensor_list_get_item(target, i, opts) 

58 elif target.dtype == dtypes.string and target.shape.ndims == 0: 

59 return _tf_tensor_string_get_item(target, i) 

60 else: 

61 return _tf_tensor_get_item(target, i) 

62 else: 

63 return _py_get_item(target, i) 

64 

65 

66def _tf_tensorarray_get_item(target, i): 

67 """Overload of get_item that stages a TensorArray read.""" 

68 return target.read(i) 

69 

70 

71def _tf_tensor_list_get_item(target, i, opts): 

72 """Overload of get_item that stages a Tensor list read.""" 

73 if opts.element_dtype is None: 

74 raise ValueError('cannot retrieve from a list without knowing its ' 

75 'element type; use set_element_type to annotate it') 

76 x = list_ops.tensor_list_get_item(target, i, element_dtype=opts.element_dtype) 

77 return x 

78 

79 

80def _tf_tensor_get_item(target, i): 

81 """Overload of get_item that stages a Tensor (not Tensor list) read.""" 

82 return target[i] 

83 

84 

85def _tf_tensor_string_get_item(target, i): 

86 """Overload of get_item that stages a Tensor string read.""" 

87 x = gen_string_ops.substr(target, i, 1) 

88 return x 

89 

90 

91def _py_get_item(target, i): 

92 """Overload of get_item that executes a Python list modification.""" 

93 return target[i] 

94 

95 

96def set_item(target, i, x): 

97 """The slice write operator (i.e. __setitem__). 

98 

99 Note: it is unspecified whether target will be mutated or not. In general, 

100 if target is mutable (like Python lists), it will be mutated. 

101 

102 Args: 

103 target: An entity that supports setitem semantics. 

104 i: Index to modify. 

105 x: The new element value. 

106 

107 Returns: 

108 Same as target, after the update was performed. 

109 

110 Raises: 

111 ValueError: if target is not of a supported type. 

112 """ 

113 if isinstance(target, tensor_array_ops.TensorArray): 

114 return _tf_tensorarray_set_item(target, i, x) 

115 elif tensor_util.is_tf_type(target): 

116 if target.dtype == dtypes.variant: 

117 return _tf_tensor_list_set_item(target, i, x) 

118 else: 

119 return _tf_tensor_set_item(target, i, x) 

120 else: 

121 return _py_set_item(target, i, x) 

122 

123 

124def _tf_tensorarray_set_item(target, i, x): 

125 """Overload of set_item that stages a TensorArray write.""" 

126 return target.write(i, x) 

127 

128 

129def _tf_tensor_list_set_item(target, i, x): 

130 """Overload of set_item that stages a Tensor list update.""" 

131 return list_ops.tensor_list_set_item(target, i, x) 

132 

133 

134def _tf_tensor_set_item(target, i, x): 

135 """Overload of set_item that stages a Tensor scatter update.""" 

136 return gen_array_ops.tensor_scatter_update(target, ((i,),), (x,)) 

137 

138 

139def _py_set_item(target, i, x): 

140 """Overload of set_item that executes a Python list modification.""" 

141 target[i] = x 

142 return target