Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/shape_util.py: 47%

19 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2023 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Tensor shape utilities.""" 

16from tensorflow.python.eager import context 

17from tensorflow.python.framework import dtypes 

18from tensorflow.python.framework import ops 

19from tensorflow.python.framework import tensor_shape 

20from tensorflow.python.framework import tensor_util 

21 

22 

23def shape_tensor(shape): # pylint: disable=invalid-name 

24 """Convert to an int32 or int64 tensor, defaulting to int32 if empty.""" 

25 dtype = None 

26 if isinstance(shape, (tuple, list)): 

27 if not shape: 

28 dtype = dtypes.int32 

29 else: 

30 # If there are Dimension objects in the shape, unwrap them. This can be a 

31 # problem if v1 and v2 TensorShape objects get mixed up in partial 

32 # conversions, leading to shapes such as (1, 2, Dimension(5)), which are 

33 # not convertible to Tensors because of mixed content. 

34 shape = tuple(map(tensor_shape.dimension_value, shape)) 

35 return ops.convert_to_tensor(shape, dtype=dtype, name="shape") 

36 

37 

38# DO NOT USE: For testing only. 

39_ENABLE_MAYBE_SET_STATIC_SHAPE = True 

40 

41 

42def maybe_set_static_shape(tensor, shape): # pylint: disable=invalid-name 

43 """Sets the shape of `tensor` to the `shape`'s constant value, if inferrable. 

44 

45 This is a temporary workaround to fix shape inference across functional op 

46 boundaries. E.g. 

47 

48 ```python 

49 shape = tf.constant([3]) 

50 @tf.function 

51 def f(): 

52 u = tf.random_uniform(shape) 

53 return u 

54 ``` 

55 

56 If we were to rely solely on C++ shape inference, the shape of `u` inside 

57 `f` would be unknown because C++ shape inference is not aware of the outer 

58 graph and all it sees is a Placeholder node when backtracing the captured 

59 tensor for `shape`. `maybe_set_static_shape` computes the static shape value 

60 of `shape` by traversing the `FuncGraph` boundaries and sets the correct 

61 shape. 

62 

63 A longer term solution would be to fix C++ shape inference. 

64 

65 Args: 

66 tensor: A tensor. 

67 shape: A shape tensor. 

68 """ 

69 if (_ENABLE_MAYBE_SET_STATIC_SHAPE and not context.executing_eagerly() and 

70 ops.get_default_graph().building_function and 

71 not tensor.shape.is_fully_defined() and tensor_util.is_tensor(shape)): 

72 shape = shape_tensor(shape) 

73 const_shape = tensor_util.constant_value_as_shape(shape) 

74 tensor.set_shape(const_shape)