Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/shape_util.py: 47%
19 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tensor shape utilities."""
16from tensorflow.python.eager import context
17from tensorflow.python.framework import dtypes
18from tensorflow.python.framework import ops
19from tensorflow.python.framework import tensor_shape
20from tensorflow.python.framework import tensor_util
23def shape_tensor(shape): # pylint: disable=invalid-name
24 """Convert to an int32 or int64 tensor, defaulting to int32 if empty."""
25 dtype = None
26 if isinstance(shape, (tuple, list)):
27 if not shape:
28 dtype = dtypes.int32
29 else:
30 # If there are Dimension objects in the shape, unwrap them. This can be a
31 # problem if v1 and v2 TensorShape objects get mixed up in partial
32 # conversions, leading to shapes such as (1, 2, Dimension(5)), which are
33 # not convertible to Tensors because of mixed content.
34 shape = tuple(map(tensor_shape.dimension_value, shape))
35 return ops.convert_to_tensor(shape, dtype=dtype, name="shape")
38# DO NOT USE: For testing only.
39_ENABLE_MAYBE_SET_STATIC_SHAPE = True
42def maybe_set_static_shape(tensor, shape): # pylint: disable=invalid-name
43 """Sets the shape of `tensor` to the `shape`'s constant value, if inferrable.
45 This is a temporary workaround to fix shape inference across functional op
46 boundaries. E.g.
48 ```python
49 shape = tf.constant([3])
50 @tf.function
51 def f():
52 u = tf.random_uniform(shape)
53 return u
54 ```
56 If we were to rely solely on C++ shape inference, the shape of `u` inside
57 `f` would be unknown because C++ shape inference is not aware of the outer
58 graph and all it sees is a Placeholder node when backtracing the captured
59 tensor for `shape`. `maybe_set_static_shape` computes the static shape value
60 of `shape` by traversing the `FuncGraph` boundaries and sets the correct
61 shape.
63 A longer term solution would be to fix C++ shape inference.
65 Args:
66 tensor: A tensor.
67 shape: A shape tensor.
68 """
69 if (_ENABLE_MAYBE_SET_STATIC_SHAPE and not context.executing_eagerly() and
70 ops.get_default_graph().building_function and
71 not tensor.shape.is_fully_defined() and tensor_util.is_tensor(shape)):
72 shape = shape_tensor(shape)
73 const_shape = tensor_util.constant_value_as_shape(shape)
74 tensor.set_shape(const_shape)