Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/array_ops_stack.py: 34%
32 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# Tests for this file live in python/kernel_tests/array_ops_test.py
16"""Operations to stack and unstack tensors."""
18from tensorflow.python.framework import ops
19from tensorflow.python.ops import gen_array_ops
20from tensorflow.python.util import dispatch
21from tensorflow.python.util.tf_export import tf_export
24@tf_export("stack")
25@dispatch.add_dispatch_support
26def stack(values, axis=0, name="stack"):
27 """Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor.
29 See also `tf.concat`, `tf.tile`, `tf.repeat`.
31 Packs the list of tensors in `values` into a tensor with rank one higher than
32 each tensor in `values`, by packing them along the `axis` dimension.
33 Given a list of length `N` of tensors of shape `(A, B, C)`;
35 if `axis == 0` then the `output` tensor will have the shape `(N, A, B, C)`.
36 if `axis == 1` then the `output` tensor will have the shape `(A, N, B, C)`.
37 Etc.
39 For example:
41 >>> x = tf.constant([1, 4])
42 >>> y = tf.constant([2, 5])
43 >>> z = tf.constant([3, 6])
44 >>> tf.stack([x, y, z])
45 <tf.Tensor: shape=(3, 2), dtype=int32, numpy=
46 array([[1, 4],
47 [2, 5],
48 [3, 6]], dtype=int32)>
49 >>> tf.stack([x, y, z], axis=1)
50 <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
51 array([[1, 2, 3],
52 [4, 5, 6]], dtype=int32)>
54 This is the opposite of unstack. The numpy equivalent is `np.stack`
56 >>> np.array_equal(np.stack([x, y, z]), tf.stack([x, y, z]))
57 True
59 Args:
60 values: A list of `Tensor` objects with the same shape and type.
61 axis: An `int`. The axis to stack along. Defaults to the first dimension.
62 Negative values wrap around, so the valid range is `[-(R+1), R+1)`.
63 name: A name for this operation (optional).
65 Returns:
66 output: A stacked `Tensor` with the same type as `values`.
68 Raises:
69 ValueError: If `axis` is out of the range [-(R+1), R+1).
70 """
71 if axis == 0:
72 try:
73 # If the input is a constant list, it can be converted to a constant op
74 return ops.convert_to_tensor(values, name=name)
75 except (TypeError, ValueError, NotImplementedError):
76 pass # Input list contains non-constant tensors
78 value_shape = ops.convert_to_tensor(values[0], name=name)._shape_tuple() # pylint: disable=protected-access
79 if value_shape is not None:
80 expanded_num_dims = len(value_shape) + 1
81 if axis < -expanded_num_dims or axis >= expanded_num_dims:
82 raise ValueError(f"Argument `axis` = {axis} not in range "
83 f"[{-expanded_num_dims}, {expanded_num_dims})")
85 return gen_array_ops.pack(values, axis=axis, name=name)
88@tf_export("unstack")
89@dispatch.add_dispatch_support
90def unstack(value, num=None, axis=0, name="unstack"):
91 """Unpacks the given dimension of a rank-`R` tensor into rank-`(R-1)` tensors.
93 Unpacks tensors from `value` by chipping it along the `axis` dimension.
95 >>> x = tf.reshape(tf.range(12), (3,4))
96 >>>
97 >>> p, q, r = tf.unstack(x)
98 >>> p.shape.as_list()
99 [4]
101 >>> i, j, k, l = tf.unstack(x, axis=1)
102 >>> i.shape.as_list()
103 [3]
105 This is the opposite of stack.
107 >>> x = tf.stack([i, j, k, l], axis=1)
109 More generally if you have a tensor of shape `(A, B, C, D)`:
111 >>> A, B, C, D = [2, 3, 4, 5]
112 >>> t = tf.random.normal(shape=[A, B, C, D])
114 The number of tensor returned is equal to the length of the target `axis`:
116 >>> axis = 2
117 >>> items = tf.unstack(t, axis=axis)
118 >>> len(items) == t.shape[axis]
119 True
121 The shape of each result tensor is equal to the shape of the input tensor,
122 with the target `axis` removed.
124 >>> items[0].shape.as_list() # [A, B, D]
125 [2, 3, 5]
127 The value of each tensor `items[i]` is equal to the slice of `input` across
128 `axis` at index `i`:
130 >>> for i in range(len(items)):
131 ... slice = t[:,:,i,:]
132 ... assert tf.reduce_all(slice == items[i])
134 #### Python iterable unpacking
136 With eager execution you _can_ unstack the 0th axis of a tensor using python's
137 iterable unpacking:
139 >>> t = tf.constant([1,2,3])
140 >>> a,b,c = t
142 `unstack` is still necessary because Iterable unpacking doesn't work in
143 a `@tf.function`: Symbolic tensors are not iterable.
145 You need to use `tf.unstack` here:
147 >>> @tf.function
148 ... def bad(t):
149 ... a,b,c = t
150 ... return a
151 >>>
152 >>> bad(t)
153 Traceback (most recent call last):
154 ...
155 OperatorNotAllowedInGraphError: ...
157 >>> @tf.function
158 ... def good(t):
159 ... a,b,c = tf.unstack(t)
160 ... return a
161 >>>
162 >>> good(t).numpy()
163 1
165 #### Unknown shapes
167 Eager tensors have concrete values, so their shape is always known.
168 Inside a `tf.function` the symbolic tensors may have unknown shapes.
169 If the length of `axis` is unknown `tf.unstack` will fail because it cannot
170 handle an unknown number of tensors:
172 >>> @tf.function(input_signature=[tf.TensorSpec([None], tf.float32)])
173 ... def bad(t):
174 ... tensors = tf.unstack(t)
175 ... return tensors[0]
176 >>>
177 >>> bad(tf.constant([1.0, 2.0, 3.0]))
178 Traceback (most recent call last):
179 ...
180 ValueError: Cannot infer argument `num` from shape (None,)
182 If you know the `axis` length you can pass it as the `num` argument. But this
183 must be a constant value.
185 If you actually need a variable number of tensors in a single `tf.function`
186 trace, you will need to use exlicit loops and a `tf.TensorArray` instead.
188 Args:
189 value: A rank `R > 0` `Tensor` to be unstacked.
190 num: An `int`. The length of the dimension `axis`. Automatically inferred if
191 `None` (the default).
192 axis: An `int`. The axis to unstack along. Defaults to the first dimension.
193 Negative values wrap around, so the valid range is `[-R, R)`.
194 name: A name for the operation (optional).
196 Returns:
197 The list of `Tensor` objects unstacked from `value`.
199 Raises:
200 ValueError: If `axis` is out of the range `[-R, R)`.
201 ValueError: If `num` is unspecified and cannot be inferred.
202 InvalidArgumentError: If `num` does not match the shape of `value`.
203 """
204 if num is None:
205 value = ops.convert_to_tensor(value)
206 value_shape = value.get_shape()
207 if value_shape.ndims is not None:
208 if axis < -value_shape.ndims or axis >= value_shape.ndims:
209 raise ValueError(f"Argument `axis` = {axis} not in range "
210 f"[{-value_shape.ndims}, {value_shape.ndims})")
211 num = value_shape.dims[axis].value
212 if num is None:
213 raise ValueError(f"Cannot infer argument `num` from shape {value_shape}")
214 return gen_array_ops.unpack(value, num=num, axis=axis, name=name)