Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/array_ops_stack.py: 34%

32 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2023 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15# Tests for this file live in python/kernel_tests/array_ops_test.py 

16"""Operations to stack and unstack tensors.""" 

17 

18from tensorflow.python.framework import ops 

19from tensorflow.python.ops import gen_array_ops 

20from tensorflow.python.util import dispatch 

21from tensorflow.python.util.tf_export import tf_export 

22 

23 

24@tf_export("stack") 

25@dispatch.add_dispatch_support 

26def stack(values, axis=0, name="stack"): 

27 """Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor. 

28 

29 See also `tf.concat`, `tf.tile`, `tf.repeat`. 

30 

31 Packs the list of tensors in `values` into a tensor with rank one higher than 

32 each tensor in `values`, by packing them along the `axis` dimension. 

33 Given a list of length `N` of tensors of shape `(A, B, C)`; 

34 

35 if `axis == 0` then the `output` tensor will have the shape `(N, A, B, C)`. 

36 if `axis == 1` then the `output` tensor will have the shape `(A, N, B, C)`. 

37 Etc. 

38 

39 For example: 

40 

41 >>> x = tf.constant([1, 4]) 

42 >>> y = tf.constant([2, 5]) 

43 >>> z = tf.constant([3, 6]) 

44 >>> tf.stack([x, y, z]) 

45 <tf.Tensor: shape=(3, 2), dtype=int32, numpy= 

46 array([[1, 4], 

47 [2, 5], 

48 [3, 6]], dtype=int32)> 

49 >>> tf.stack([x, y, z], axis=1) 

50 <tf.Tensor: shape=(2, 3), dtype=int32, numpy= 

51 array([[1, 2, 3], 

52 [4, 5, 6]], dtype=int32)> 

53 

54 This is the opposite of unstack. The numpy equivalent is `np.stack` 

55 

56 >>> np.array_equal(np.stack([x, y, z]), tf.stack([x, y, z])) 

57 True 

58 

59 Args: 

60 values: A list of `Tensor` objects with the same shape and type. 

61 axis: An `int`. The axis to stack along. Defaults to the first dimension. 

62 Negative values wrap around, so the valid range is `[-(R+1), R+1)`. 

63 name: A name for this operation (optional). 

64 

65 Returns: 

66 output: A stacked `Tensor` with the same type as `values`. 

67 

68 Raises: 

69 ValueError: If `axis` is out of the range [-(R+1), R+1). 

70 """ 

71 if axis == 0: 

72 try: 

73 # If the input is a constant list, it can be converted to a constant op 

74 return ops.convert_to_tensor(values, name=name) 

75 except (TypeError, ValueError, NotImplementedError): 

76 pass # Input list contains non-constant tensors 

77 

78 value_shape = ops.convert_to_tensor(values[0], name=name)._shape_tuple() # pylint: disable=protected-access 

79 if value_shape is not None: 

80 expanded_num_dims = len(value_shape) + 1 

81 if axis < -expanded_num_dims or axis >= expanded_num_dims: 

82 raise ValueError(f"Argument `axis` = {axis} not in range " 

83 f"[{-expanded_num_dims}, {expanded_num_dims})") 

84 

85 return gen_array_ops.pack(values, axis=axis, name=name) 

86 

87 

88@tf_export("unstack") 

89@dispatch.add_dispatch_support 

90def unstack(value, num=None, axis=0, name="unstack"): 

91 """Unpacks the given dimension of a rank-`R` tensor into rank-`(R-1)` tensors. 

92 

93 Unpacks tensors from `value` by chipping it along the `axis` dimension. 

94 

95 >>> x = tf.reshape(tf.range(12), (3,4)) 

96 >>> 

97 >>> p, q, r = tf.unstack(x) 

98 >>> p.shape.as_list() 

99 [4] 

100 

101 >>> i, j, k, l = tf.unstack(x, axis=1) 

102 >>> i.shape.as_list() 

103 [3] 

104 

105 This is the opposite of stack. 

106 

107 >>> x = tf.stack([i, j, k, l], axis=1) 

108 

109 More generally if you have a tensor of shape `(A, B, C, D)`: 

110 

111 >>> A, B, C, D = [2, 3, 4, 5] 

112 >>> t = tf.random.normal(shape=[A, B, C, D]) 

113 

114 The number of tensor returned is equal to the length of the target `axis`: 

115 

116 >>> axis = 2 

117 >>> items = tf.unstack(t, axis=axis) 

118 >>> len(items) == t.shape[axis] 

119 True 

120 

121 The shape of each result tensor is equal to the shape of the input tensor, 

122 with the target `axis` removed. 

123 

124 >>> items[0].shape.as_list() # [A, B, D] 

125 [2, 3, 5] 

126 

127 The value of each tensor `items[i]` is equal to the slice of `input` across 

128 `axis` at index `i`: 

129 

130 >>> for i in range(len(items)): 

131 ... slice = t[:,:,i,:] 

132 ... assert tf.reduce_all(slice == items[i]) 

133 

134 #### Python iterable unpacking 

135 

136 With eager execution you _can_ unstack the 0th axis of a tensor using python's 

137 iterable unpacking: 

138 

139 >>> t = tf.constant([1,2,3]) 

140 >>> a,b,c = t 

141 

142 `unstack` is still necessary because Iterable unpacking doesn't work in 

143 a `@tf.function`: Symbolic tensors are not iterable. 

144 

145 You need to use `tf.unstack` here: 

146 

147 >>> @tf.function 

148 ... def bad(t): 

149 ... a,b,c = t 

150 ... return a 

151 >>> 

152 >>> bad(t) 

153 Traceback (most recent call last): 

154 ... 

155 OperatorNotAllowedInGraphError: ... 

156 

157 >>> @tf.function 

158 ... def good(t): 

159 ... a,b,c = tf.unstack(t) 

160 ... return a 

161 >>> 

162 >>> good(t).numpy() 

163 1 

164 

165 #### Unknown shapes 

166 

167 Eager tensors have concrete values, so their shape is always known. 

168 Inside a `tf.function` the symbolic tensors may have unknown shapes. 

169 If the length of `axis` is unknown `tf.unstack` will fail because it cannot 

170 handle an unknown number of tensors: 

171 

172 >>> @tf.function(input_signature=[tf.TensorSpec([None], tf.float32)]) 

173 ... def bad(t): 

174 ... tensors = tf.unstack(t) 

175 ... return tensors[0] 

176 >>> 

177 >>> bad(tf.constant([1.0, 2.0, 3.0])) 

178 Traceback (most recent call last): 

179 ... 

180 ValueError: Cannot infer argument `num` from shape (None,) 

181 

182 If you know the `axis` length you can pass it as the `num` argument. But this 

183 must be a constant value. 

184 

185 If you actually need a variable number of tensors in a single `tf.function` 

186 trace, you will need to use exlicit loops and a `tf.TensorArray` instead. 

187 

188 Args: 

189 value: A rank `R > 0` `Tensor` to be unstacked. 

190 num: An `int`. The length of the dimension `axis`. Automatically inferred if 

191 `None` (the default). 

192 axis: An `int`. The axis to unstack along. Defaults to the first dimension. 

193 Negative values wrap around, so the valid range is `[-R, R)`. 

194 name: A name for the operation (optional). 

195 

196 Returns: 

197 The list of `Tensor` objects unstacked from `value`. 

198 

199 Raises: 

200 ValueError: If `axis` is out of the range `[-R, R)`. 

201 ValueError: If `num` is unspecified and cannot be inferred. 

202 InvalidArgumentError: If `num` does not match the shape of `value`. 

203 """ 

204 if num is None: 

205 value = ops.convert_to_tensor(value) 

206 value_shape = value.get_shape() 

207 if value_shape.ndims is not None: 

208 if axis < -value_shape.ndims or axis >= value_shape.ndims: 

209 raise ValueError(f"Argument `axis` = {axis} not in range " 

210 f"[{-value_shape.ndims}, {value_shape.ndims})") 

211 num = value_shape.dims[axis].value 

212 if num is None: 

213 raise ValueError(f"Cannot infer argument `num` from shape {value_shape}") 

214 return gen_array_ops.unpack(value, num=num, axis=axis, name=name)